Commit 2778a3d0 authored by luopl's avatar luopl
Browse files

updata to v0.9.1_stable

parent e92143e3
...@@ -20,7 +20,7 @@ from setuptools import find_packages, setup ...@@ -20,7 +20,7 @@ from setuptools import find_packages, setup
def get_version() -> str: def get_version() -> str:
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f: with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION") pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
(version,) = re.findall(pattern, file_content) (version,) = re.findall(pattern, file_content)
...@@ -28,7 +28,7 @@ def get_version() -> str: ...@@ -28,7 +28,7 @@ def get_version() -> str:
def get_requires() -> List[str]: def get_requires() -> List[str]:
with open("requirements.txt", "r", encoding="utf-8") as f: with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines return lines
...@@ -54,13 +54,14 @@ extra_require = { ...@@ -54,13 +54,14 @@ extra_require = {
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"], "awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"], "aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<=0.6.2"], "vllm": ["vllm>=0.4.3,<0.6.4"],
"galore": ["galore-torch"], "galore": ["galore-torch"],
"badam": ["badam>=1.2.1"], "badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"], "adam-mini": ["adam-mini"],
"qwen": ["transformers_stream_generator"], "qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"], "modelscope": ["modelscope"],
"dev": ["ruff", "pytest"], "openmind": ["openmind"],
"dev": ["pre-commit", "ruff", "pytest"],
} }
...@@ -71,7 +72,7 @@ def main(): ...@@ -71,7 +72,7 @@ def main():
author="hiyouga", author="hiyouga",
author_email="hiyouga" "@" "buaa.edu.cn", author_email="hiyouga" "@" "buaa.edu.cn",
description="Easy-to-use LLM fine-tuning framework", description="Easy-to-use LLM fine-tuning framework",
long_description=open("README.md", "r", encoding="utf-8").read(), long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"], keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
license="Apache 2.0 License", license="Apache 2.0 License",
......
...@@ -23,9 +23,9 @@ from llamafactory.chat import ChatModel ...@@ -23,9 +23,9 @@ from llamafactory.chat import ChatModel
def main(): def main():
chat_model = ChatModel() chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0") api_host = os.getenv("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000")) api_port = int(os.getenv("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port)) print(f"Visit http://localhost:{api_port}/docs for API document.")
uvicorn.run(app, host=api_host, port=api_port) uvicorn.run(app, host=api_host, port=api_port)
......
...@@ -20,17 +20,17 @@ Level: ...@@ -20,17 +20,17 @@ Level:
Dependency graph: Dependency graph:
main: main:
transformers>=4.41.2,<=4.45.2 transformers>=4.41.2,<=4.46.1
datasets>=2.16.0,<=2.21.0 datasets>=2.16.0,<=3.1.0
accelerate>=0.30.1,<=0.34.2 accelerate>=0.34.0,<=1.0.1
peft>=0.11.1,<=0.12.0 peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6 trl>=0.8.6,<=0.9.6
attention: attention:
transformers>=4.42.4 (gemma+fa2) transformers>=4.42.4 (gemma+fa2)
longlora: longlora:
transformers>=4.41.2,<=4.45.2 transformers>=4.41.2,<=4.46.1
packing: packing:
transformers>=4.41.2,<=4.45.2 transformers>=4.41.2,<=4.46.1
Disable version checking: DISABLE_VERSION_CHECK=1 Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1 Enable VRAM recording: RECORD_VRAM=1
...@@ -38,6 +38,7 @@ Force check imports: FORCE_CHECK_IMPORTS=1 ...@@ -38,6 +38,7 @@ Force check imports: FORCE_CHECK_IMPORTS=1
Force using torchrun: FORCE_TORCHRUN=1 Force using torchrun: FORCE_TORCHRUN=1
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Use modelscope: USE_MODELSCOPE_HUB=1 Use modelscope: USE_MODELSCOPE_HUB=1
Use openmind: USE_OPENMIND_HUB=1
""" """
from .extras.env import VERSION from .extras.env import VERSION
......
...@@ -68,7 +68,7 @@ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU mem ...@@ -68,7 +68,7 @@ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU mem
def create_app(chat_model: "ChatModel") -> "FastAPI": def create_app(chat_model: "ChatModel") -> "FastAPI":
root_path = os.environ.get("FASTAPI_ROOT_PATH", "") root_path = os.getenv("FASTAPI_ROOT_PATH", "")
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path) app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
...@@ -77,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": ...@@ -77,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
api_key = os.environ.get("API_KEY", None) api_key = os.getenv("API_KEY")
security = HTTPBearer(auto_error=False) security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
...@@ -91,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": ...@@ -91,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
dependencies=[Depends(verify_api_key)], dependencies=[Depends(verify_api_key)],
) )
async def list_models(): async def list_models():
model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo")) model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo"))
return ModelList(data=[model_card]) return ModelList(data=[model_card])
@app.post( @app.post(
...@@ -128,7 +128,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": ...@@ -128,7 +128,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
def run_api() -> None: def run_api() -> None:
chat_model = ChatModel() chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0") api_host = os.getenv("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000")) api_port = int(os.getenv("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port)) print(f"Visit http://localhost:{api_port}/docs for API document.")
uvicorn.run(app, host=api_host, port=api_port) uvicorn.run(app, host=api_host, port=api_port)
...@@ -21,7 +21,7 @@ import uuid ...@@ -21,7 +21,7 @@ import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole from ..data import Role as DataRole
from ..extras.logging import get_logger from ..extras import logging
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify from .common import dictify, jsonify
from .protocol import ( from .protocol import (
...@@ -57,7 +57,7 @@ if TYPE_CHECKING: ...@@ -57,7 +57,7 @@ if TYPE_CHECKING:
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
logger = get_logger(__name__) logger = logging.get_logger(__name__)
ROLE_MAPPING = { ROLE_MAPPING = {
Role.USER: DataRole.USER.value, Role.USER: DataRole.USER.value,
Role.ASSISTANT: DataRole.ASSISTANT.value, Role.ASSISTANT: DataRole.ASSISTANT.value,
...@@ -69,8 +69,8 @@ ROLE_MAPPING = { ...@@ -69,8 +69,8 @@ ROLE_MAPPING = {
def _process_request( def _process_request(
request: "ChatCompletionRequest", request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]: ) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False))) logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
...@@ -84,7 +84,7 @@ def _process_request( ...@@ -84,7 +84,7 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = [] input_messages = []
image = None images = []
for i, message in enumerate(request.messages): for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
...@@ -111,7 +111,7 @@ def _process_request( ...@@ -111,7 +111,7 @@ def _process_request(
else: # web uri else: # web uri
image_stream = requests.get(image_url, stream=True).raw image_stream = requests.get(image_url, stream=True).raw
image = Image.open(image_stream).convert("RGB") images.append(Image.open(image_stream).convert("RGB"))
else: else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
...@@ -124,7 +124,7 @@ def _process_request( ...@@ -124,7 +124,7 @@ def _process_request(
else: else:
tools = None tools = None
return input_messages, system, tools, image return input_messages, system, tools, images or None
def _create_stream_chat_completion_chunk( def _create_stream_chat_completion_chunk(
...@@ -142,13 +142,13 @@ def _create_stream_chat_completion_chunk( ...@@ -142,13 +142,13 @@ def _create_stream_chat_completion_chunk(
async def create_chat_completion_response( async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse": ) -> "ChatCompletionResponse":
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, image = _process_request(request) input_messages, system, tools, images = _process_request(request)
responses = await chat_model.achat( responses = await chat_model.achat(
input_messages, input_messages,
system, system,
tools, tools,
image, images,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
...@@ -169,7 +169,7 @@ async def create_chat_completion_response( ...@@ -169,7 +169,7 @@ async def create_chat_completion_response(
tool_calls = [] tool_calls = []
for tool in result: for tool in result:
function = Function(name=tool[0], arguments=tool[1]) function = Function(name=tool[0], arguments=tool[1])
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)) tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls) response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
finish_reason = Finish.TOOL finish_reason = Finish.TOOL
...@@ -193,8 +193,8 @@ async def create_chat_completion_response( ...@@ -193,8 +193,8 @@ async def create_chat_completion_response(
async def create_stream_chat_completion_response( async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, image = _process_request(request) input_messages, system, tools, images = _process_request(request)
if tools: if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
...@@ -208,7 +208,7 @@ async def create_stream_chat_completion_response( ...@@ -208,7 +208,7 @@ async def create_stream_chat_completion_response(
input_messages, input_messages,
system, system,
tools, tools,
image, images,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
...@@ -229,7 +229,7 @@ async def create_stream_chat_completion_response( ...@@ -229,7 +229,7 @@ async def create_stream_chat_completion_response(
async def create_score_evaluation_response( async def create_score_evaluation_response(
request: "ScoreEvaluationRequest", chat_model: "ChatModel" request: "ScoreEvaluationRequest", chat_model: "ChatModel"
) -> "ScoreEvaluationResponse": ) -> "ScoreEvaluationResponse":
score_id = "scoreval-{}".format(uuid.uuid4().hex) score_id = f"scoreval-{uuid.uuid4().hex}"
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
......
...@@ -66,8 +66,8 @@ class BaseEngine(ABC): ...@@ -66,8 +66,8 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
r""" r"""
...@@ -81,8 +81,8 @@ class BaseEngine(ABC): ...@@ -81,8 +81,8 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r""" r"""
......
...@@ -53,7 +53,7 @@ class ChatModel: ...@@ -53,7 +53,7 @@ class ChatModel:
elif model_args.infer_backend == "vllm": elif model_args.infer_backend == "vllm":
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args) self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
else: else:
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend)) raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
self._loop = asyncio.new_event_loop() self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
...@@ -64,15 +64,15 @@ class ChatModel: ...@@ -64,15 +64,15 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
r""" r"""
Gets a list of responses of the chat model. Gets a list of responses of the chat model.
""" """
task = asyncio.run_coroutine_threadsafe( task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop
) )
return task.result() return task.result()
...@@ -81,28 +81,28 @@ class ChatModel: ...@@ -81,28 +81,28 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
r""" r"""
Asynchronously gets a list of responses of the chat model. Asynchronously gets a list of responses of the chat model.
""" """
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs) return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs)
def stream_chat( def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
r""" r"""
Gets the response token-by-token of the chat model. Gets the response token-by-token of the chat model.
""" """
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs) generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs)
while True: while True:
try: try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
...@@ -115,14 +115,14 @@ class ChatModel: ...@@ -115,14 +115,14 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r""" r"""
Asynchronously gets the response token-by-token of the chat model. Asynchronously gets the response token-by-token of the chat model.
""" """
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs): async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs):
yield new_token yield new_token
def get_scores( def get_scores(
......
...@@ -23,8 +23,8 @@ from transformers import GenerationConfig, TextIteratorStreamer ...@@ -23,8 +23,8 @@ from transformers import GenerationConfig, TextIteratorStreamer
from typing_extensions import override from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.logging import get_logger
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
...@@ -39,7 +39,7 @@ if TYPE_CHECKING: ...@@ -39,7 +39,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class HuggingfaceEngine(BaseEngine): class HuggingfaceEngine(BaseEngine):
...@@ -63,11 +63,11 @@ class HuggingfaceEngine(BaseEngine): ...@@ -63,11 +63,11 @@ class HuggingfaceEngine(BaseEngine):
try: try:
asyncio.get_event_loop() asyncio.get_event_loop()
except RuntimeError: except RuntimeError:
logger.warning("There is no current event loop, creating a new one.") logger.warning_once("There is no current event loop, creating a new one.")
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1"))) self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@staticmethod @staticmethod
def _process_args( def _process_args(
...@@ -79,20 +79,20 @@ class HuggingfaceEngine(BaseEngine): ...@@ -79,20 +79,20 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]} mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
if image is not None: if images is not None:
mm_input_dict.update({"images": [image], "imglens": [1]}) mm_input_dict.update({"images": images, "imglens": [len(images)]})
if IMAGE_PLACEHOLDER not in messages[0]["content"]: if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"] messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if video is not None: if videos is not None:
mm_input_dict.update({"videos": [video], "vidlens": [1]}) mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
if VIDEO_PLACEHOLDER not in messages[0]["content"]: if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"] messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
messages = template.mm_plugin.process_messages( messages = template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], processor messages, mm_input_dict["images"], mm_input_dict["videos"], processor
...@@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine): ...@@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine):
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if stop is not None: if stop is not None:
logger.warning("Stop parameter is not supported by the huggingface engine yet.") logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
generating_args = generating_args.copy() generating_args = generating_args.copy()
generating_args.update( generating_args.update(
...@@ -164,9 +164,13 @@ class HuggingfaceEngine(BaseEngine): ...@@ -164,9 +164,13 @@ class HuggingfaceEngine(BaseEngine):
logits_processor=get_logits_processor(), logits_processor=get_logits_processor(),
) )
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor) mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
for key, value in mm_inputs.items(): for key, value in mm_inputs.items():
value = value if isinstance(value, torch.Tensor) else torch.tensor(value) if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
value = torch.stack(value) # assume they have same sizes
elif not isinstance(value, torch.Tensor):
value = torch.tensor(value)
gen_kwargs[key] = value.to(model.device) gen_kwargs[key] = value.to(model.device)
return gen_kwargs, prompt_length return gen_kwargs, prompt_length
...@@ -182,12 +186,22 @@ class HuggingfaceEngine(BaseEngine): ...@@ -182,12 +186,22 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]: ) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args( gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs model,
tokenizer,
processor,
template,
generating_args,
messages,
system,
tools,
images,
videos,
input_kwargs,
) )
generate_output = model.generate(**gen_kwargs) generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
...@@ -218,12 +232,22 @@ class HuggingfaceEngine(BaseEngine): ...@@ -218,12 +232,22 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]: ) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args( gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs model,
tokenizer,
processor,
template,
generating_args,
messages,
system,
tools,
images,
videos,
input_kwargs,
) )
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer
...@@ -266,8 +290,8 @@ class HuggingfaceEngine(BaseEngine): ...@@ -266,8 +290,8 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
if not self.can_generate: if not self.can_generate:
...@@ -283,8 +307,8 @@ class HuggingfaceEngine(BaseEngine): ...@@ -283,8 +307,8 @@ class HuggingfaceEngine(BaseEngine):
messages, messages,
system, system,
tools, tools,
image, images,
video, videos,
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:
...@@ -297,8 +321,8 @@ class HuggingfaceEngine(BaseEngine): ...@@ -297,8 +321,8 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
if not self.can_generate: if not self.can_generate:
...@@ -314,8 +338,8 @@ class HuggingfaceEngine(BaseEngine): ...@@ -314,8 +338,8 @@ class HuggingfaceEngine(BaseEngine):
messages, messages,
system, system,
tools, tools,
image, images,
video, videos,
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:
......
...@@ -18,8 +18,8 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List ...@@ -18,8 +18,8 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List
from typing_extensions import override from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger
from ..extras.misc import get_device_count from ..extras.misc import get_device_count
from ..extras.packages import is_pillow_available, is_vllm_available from ..extras.packages import is_pillow_available, is_vllm_available
from ..model import load_config, load_tokenizer from ..model import load_config, load_tokenizer
...@@ -43,7 +43,7 @@ if TYPE_CHECKING: ...@@ -43,7 +43,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class VllmEngine(BaseEngine): class VllmEngine(BaseEngine):
...@@ -83,11 +83,13 @@ class VllmEngine(BaseEngine): ...@@ -83,11 +83,13 @@ class VllmEngine(BaseEngine):
"enable_lora": model_args.adapter_name_or_path is not None, "enable_lora": model_args.adapter_name_or_path is not None,
"max_lora_rank": model_args.vllm_max_lora_rank, "max_lora_rank": model_args.vllm_max_lora_rank,
} }
if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config)
if getattr(config, "is_yi_vl_derived_model", None): if getattr(config, "is_yi_vl_derived_model", None):
import vllm.model_executor.models.llava import vllm.model_executor.models.llava
logger.info("Detected Yi-VL model, applying projector patch.") logger.info_rank0("Detected Yi-VL model, applying projector patch.")
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args)) self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
...@@ -101,21 +103,28 @@ class VllmEngine(BaseEngine): ...@@ -101,21 +103,28 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex) request_id = f"chatcmpl-{uuid.uuid4().hex}"
if image is not None: if images is not None:
if IMAGE_PLACEHOLDER not in messages[0]["content"]: if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"] messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}] if self.template.mm_plugin.__class__.__name__ == "Qwen2vlPlugin": # temporary solution
image_str = f"<|vision_start|>{self.template.mm_plugin.image_token}<|vision_end|>"
else:
image_str = self.template.mm_plugin.image_token or ""
paired_messages = [
{"role": message["role"], "content": message["content"].replace(IMAGE_PLACEHOLDER, image_str)}
for message in messages
] + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"] system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
use_beam_search: bool = self.generating_args["num_beams"] > 1
temperature: Optional[float] = input_kwargs.pop("temperature", None) temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None) top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None) top_k: Optional[float] = input_kwargs.pop("top_k", None)
...@@ -126,6 +135,9 @@ class VllmEngine(BaseEngine): ...@@ -126,6 +135,9 @@ class VllmEngine(BaseEngine):
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if length_penalty is not None:
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
if "max_new_tokens" in self.generating_args: if "max_new_tokens" in self.generating_args:
max_tokens = self.generating_args["max_new_tokens"] max_tokens = self.generating_args["max_new_tokens"]
elif "max_length" in self.generating_args: elif "max_length" in self.generating_args:
...@@ -149,27 +161,29 @@ class VllmEngine(BaseEngine): ...@@ -149,27 +161,29 @@ class VllmEngine(BaseEngine):
temperature=temperature if temperature is not None else self.generating_args["temperature"], temperature=temperature if temperature is not None else self.generating_args["temperature"],
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0 top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
top_k=top_k if top_k is not None else self.generating_args["top_k"], top_k=top_k if top_k is not None else self.generating_args["top_k"],
use_beam_search=use_beam_search,
length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
stop=stop, stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=max_tokens, max_tokens=max_tokens,
skip_special_tokens=True, skip_special_tokens=True,
) )
if image is not None: # add image features if images is not None: # add image features
if not isinstance(image, (str, ImageObject)): image_data = []
raise ValueError("Expected image input is a path or PIL.Image, but got {}.".format(type(image))) for image in images:
if not isinstance(image, (str, ImageObject)):
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
if isinstance(image, str):
image = Image.open(image).convert("RGB")
if isinstance(image, str): image_data.append(image)
image = Image.open(image).convert("RGB")
multi_modal_data = {"image": image} multi_modal_data = {"image": image_data}
else: else:
multi_modal_data = None multi_modal_data = None
result_generator = self.model.generate( result_generator = self.model.generate(
inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data}, {"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
lora_request=self.lora_request, lora_request=self.lora_request,
...@@ -182,12 +196,12 @@ class VllmEngine(BaseEngine): ...@@ -182,12 +196,12 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
final_output = None final_output = None
generator = await self._generate(messages, system, tools, image, video, **input_kwargs) generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
async for request_output in generator: async for request_output in generator:
final_output = request_output final_output = request_output
...@@ -210,12 +224,12 @@ class VllmEngine(BaseEngine): ...@@ -210,12 +224,12 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
generated_text = "" generated_text = ""
generator = await self._generate(messages, system, tools, image, video, **input_kwargs) generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
async for result in generator: async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :] delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text generated_text = result.outputs[0].text
......
...@@ -22,8 +22,8 @@ from . import launcher ...@@ -22,8 +22,8 @@ from . import launcher
from .api.app import run_api from .api.app import run_api
from .chat.chat_model import run_chat from .chat.chat_model import run_chat
from .eval.evaluator import run_eval from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env from .extras.env import VERSION, print_env
from .extras.logging import get_logger
from .extras.misc import get_device_count from .extras.misc import get_device_count
from .train.tuner import export_model, run_exp from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui from .webui.interface import run_web_demo, run_web_ui
...@@ -47,7 +47,7 @@ USAGE = ( ...@@ -47,7 +47,7 @@ USAGE = (
WELCOME = ( WELCOME = (
"-" * 58 "-" * 58
+ "\n" + "\n"
+ "| Welcome to LLaMA Factory, version {}".format(VERSION) + f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION)) + " " * (21 - len(VERSION))
+ "|\n|" + "|\n|"
+ " " * 56 + " " * 56
...@@ -56,7 +56,7 @@ WELCOME = ( ...@@ -56,7 +56,7 @@ WELCOME = (
+ "-" * 58 + "-" * 58
) )
logger = get_logger(__name__) logger = logging.get_logger(__name__)
@unique @unique
...@@ -86,25 +86,26 @@ def main(): ...@@ -86,25 +86,26 @@ def main():
elif command == Command.EXPORT: elif command == Command.EXPORT:
export_model() export_model()
elif command == Command.TRAIN: elif command == Command.TRAIN:
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"] force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1: if force_torchrun or get_device_count() > 1:
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port)) logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
process = subprocess.run( process = subprocess.run(
( (
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}" "--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
).format( )
nnodes=os.environ.get("NNODES", "1"), .format(
node_rank=os.environ.get("RANK", "0"), nnodes=os.getenv("NNODES", "1"),
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())), node_rank=os.getenv("NODE_RANK", "0"),
nproc_per_node=os.getenv("NPROC_PER_NODE", str(get_device_count())),
master_addr=master_addr, master_addr=master_addr,
master_port=master_port, master_port=master_port,
file_name=launcher.__file__, file_name=launcher.__file__,
args=" ".join(sys.argv[1:]), args=" ".join(sys.argv[1:]),
), )
shell=True, .split()
) )
sys.exit(process.returncode) sys.exit(process.returncode)
else: else:
...@@ -118,4 +119,4 @@ def main(): ...@@ -118,4 +119,4 @@ def main():
elif command == Command.HELP: elif command == Command.HELP:
print(USAGE) print(USAGE)
else: else:
raise NotImplementedError("Unknown command: {}.".format(command)) raise NotImplementedError(f"Unknown command: {command}.")
...@@ -16,7 +16,7 @@ import os ...@@ -16,7 +16,7 @@ import os
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from ..extras.logging import get_logger from ..extras import logging
from .data_utils import Role from .data_utils import Role
...@@ -29,45 +29,51 @@ if TYPE_CHECKING: ...@@ -29,45 +29,51 @@ if TYPE_CHECKING:
from .parser import DatasetAttr from .parser import DatasetAttr
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _convert_images( def _convert_images(
images: Sequence["ImageInput"], images: Union["ImageInput", Sequence["ImageInput"]],
dataset_attr: "DatasetAttr", dataset_attr: "DatasetAttr",
data_args: "DataArguments", data_args: "DataArguments",
) -> Optional[List["ImageInput"]]: ) -> Optional[List["ImageInput"]]:
r""" r"""
Optionally concatenates image path to dataset dir when loading from local disk. Optionally concatenates image path to dataset dir when loading from local disk.
""" """
if len(images) == 0: if not isinstance(images, list):
images = [images]
elif len(images) == 0:
return None return None
else:
images = images[:]
images = images[:]
if dataset_attr.load_from in ["script", "file"]: if dataset_attr.load_from in ["script", "file"]:
for i in range(len(images)): for i in range(len(images)):
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])): if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.image_dir, images[i])):
images[i] = os.path.join(data_args.dataset_dir, images[i]) images[i] = os.path.join(data_args.image_dir, images[i])
return images return images
def _convert_videos( def _convert_videos(
videos: Sequence["VideoInput"], videos: Union["VideoInput", Sequence["VideoInput"]],
dataset_attr: "DatasetAttr", dataset_attr: "DatasetAttr",
data_args: "DataArguments", data_args: "DataArguments",
) -> Optional[List["VideoInput"]]: ) -> Optional[List["VideoInput"]]:
r""" r"""
Optionally concatenates video path to dataset dir when loading from local disk. Optionally concatenates video path to dataset dir when loading from local disk.
""" """
if len(videos) == 0: if not isinstance(videos, list):
videos = [videos]
elif len(videos) == 0:
return None return None
else:
videos = videos[:]
videos = videos[:]
if dataset_attr.load_from in ["script", "file"]: if dataset_attr.load_from in ["script", "file"]:
for i in range(len(videos)): for i in range(len(videos)):
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])): if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.image_dir, videos[i])):
videos[i] = os.path.join(data_args.dataset_dir, videos[i]) videos[i] = os.path.join(data_args.image_dir, videos[i])
return videos return videos
...@@ -161,7 +167,7 @@ def convert_sharegpt( ...@@ -161,7 +167,7 @@ def convert_sharegpt(
broken_data = False broken_data = False
for turn_idx, message in enumerate(messages): for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning("Invalid role tag in {}.".format(messages)) logger.warning_rank0(f"Invalid role tag in {messages}.")
broken_data = True broken_data = True
aligned_messages.append( aligned_messages.append(
...@@ -171,7 +177,7 @@ def convert_sharegpt( ...@@ -171,7 +177,7 @@ def convert_sharegpt(
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0 dataset_attr.ranking and len(aligned_messages) % 2 == 0
): ):
logger.warning("Invalid message count in {}.".format(messages)) logger.warning_rank0(f"Invalid message count in {messages}.")
broken_data = True broken_data = True
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
...@@ -192,7 +198,7 @@ def convert_sharegpt( ...@@ -192,7 +198,7 @@ def convert_sharegpt(
chosen[dataset_attr.role_tag] not in accept_tags[-1] chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1] or rejected[dataset_attr.role_tag] not in accept_tags[-1]
): ):
logger.warning("Invalid role tag in {}.".format([chosen, rejected])) logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
broken_data = True broken_data = True
prompt = aligned_messages prompt = aligned_messages
...@@ -205,7 +211,7 @@ def convert_sharegpt( ...@@ -205,7 +211,7 @@ def convert_sharegpt(
response = aligned_messages[-1:] response = aligned_messages[-1:]
if broken_data: if broken_data:
logger.warning("Skipping this abnormal example.") logger.warning_rank0("Skipping this abnormal example.")
prompt, response = [], [] prompt, response = [], []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
......
...@@ -79,7 +79,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -79,7 +79,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
processor: Optional["ProcessorMixin"] = None processor: Optional["ProcessorMixin"] = None
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], [] batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []
for feature in features: for feature in features:
images = feature.pop("images", None) or [] images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or [] videos = feature.pop("videos", None) or []
...@@ -87,10 +87,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -87,10 +87,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_videos.extend(videos) batch_videos.extend(videos)
batch_imglens.append(len(images)) batch_imglens.append(len(images))
batch_vidlens.append(len(videos)) batch_vidlens.append(len(videos))
batch_seqlens.append(len(feature["input_ids"])) batch_input_ids.append(feature["input_ids"])
mm_inputs = self.template.mm_plugin.get_mm_inputs( mm_inputs = self.template.mm_plugin.get_mm_inputs(
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
) )
if "token_type_ids" in mm_inputs: if "token_type_ids" in mm_inputs:
token_type_ids = mm_inputs.pop("token_type_ids") token_type_ids = mm_inputs.pop("token_type_ids")
...@@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features: Dict[str, "torch.Tensor"] = super().__call__(features) features: Dict[str, "torch.Tensor"] = super().__call__(features)
features.update(mm_inputs) features.update(mm_inputs)
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to()
return features return features
...@@ -137,9 +140,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): ...@@ -137,9 +140,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
for key in ("chosen", "rejected"): for key in ("chosen", "rejected"):
for feature in features: for feature in features:
target_feature = { target_feature = {
"input_ids": feature["{}_input_ids".format(key)], "input_ids": feature[f"{key}_input_ids"],
"attention_mask": feature["{}_attention_mask".format(key)], "attention_mask": feature[f"{key}_attention_mask"],
"labels": feature["{}_labels".format(key)], "labels": feature[f"{key}_labels"],
"images": feature["images"], "images": feature["images"],
"videos": feature["videos"], "videos": feature["videos"],
} }
......
...@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict ...@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict
from datasets import DatasetDict, concatenate_datasets, interleave_datasets from datasets import DatasetDict, concatenate_datasets, interleave_datasets
from ..extras.logging import get_logger from ..extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -26,7 +26,7 @@ if TYPE_CHECKING: ...@@ -26,7 +26,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments from ..hparams import DataArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
...@@ -56,12 +56,12 @@ def merge_dataset( ...@@ -56,12 +56,12 @@ def merge_dataset(
return all_datasets[0] return all_datasets[0]
elif data_args.mix_strategy == "concat": elif data_args.mix_strategy == "concat":
if data_args.streaming: if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.") logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets) return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"): elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming: if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets( return interleave_datasets(
datasets=all_datasets, datasets=all_datasets,
...@@ -70,7 +70,7 @@ def merge_dataset( ...@@ -70,7 +70,7 @@ def merge_dataset(
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
) )
else: else:
raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy)) raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
def split_dataset( def split_dataset(
......
...@@ -83,14 +83,14 @@ class StringFormatter(Formatter): ...@@ -83,14 +83,14 @@ class StringFormatter(Formatter):
if isinstance(slot, str): if isinstance(slot, str):
for name, value in kwargs.items(): for name, value in kwargs.items():
if not isinstance(value, str): if not isinstance(value, str):
raise RuntimeError("Expected a string, got {}".format(value)) raise RuntimeError(f"Expected a string, got {value}")
slot = slot.replace("{{" + name + "}}", value, 1) slot = slot.replace("{{" + name + "}}", value, 1)
elements.append(slot) elements.append(slot)
elif isinstance(slot, (dict, set)): elif isinstance(slot, (dict, set)):
elements.append(slot) elements.append(slot)
else: else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
return elements return elements
...@@ -113,7 +113,7 @@ class FunctionFormatter(Formatter): ...@@ -113,7 +113,7 @@ class FunctionFormatter(Formatter):
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
except json.JSONDecodeError: except json.JSONDecodeError:
raise RuntimeError("Invalid JSON format in function message: {}".format(str([content]))) # flat string raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
elements = [] elements = []
for name, arguments in functions: for name, arguments in functions:
...@@ -124,7 +124,7 @@ class FunctionFormatter(Formatter): ...@@ -124,7 +124,7 @@ class FunctionFormatter(Formatter):
elif isinstance(slot, (dict, set)): elif isinstance(slot, (dict, set)):
elements.append(slot) elements.append(slot)
else: else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
return elements return elements
...@@ -141,7 +141,7 @@ class ToolFormatter(Formatter): ...@@ -141,7 +141,7 @@ class ToolFormatter(Formatter):
tools = json.loads(content) tools = json.loads(content)
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""] return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError: except json.JSONDecodeError:
raise RuntimeError("Invalid JSON format in tool description: {}".format(str([content]))) # flat string raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
@override @override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]: def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
......
...@@ -20,8 +20,8 @@ import numpy as np ...@@ -20,8 +20,8 @@ import numpy as np
from datasets import DatasetDict, load_dataset, load_from_disk from datasets import DatasetDict, load_dataset, load_from_disk
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import FILEEXT2TYPE from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from ..extras.misc import has_tokenized_data from ..extras.misc import has_tokenized_data
from .aligner import align_dataset from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset from .data_utils import merge_dataset, split_dataset
...@@ -39,7 +39,7 @@ if TYPE_CHECKING: ...@@ -39,7 +39,7 @@ if TYPE_CHECKING:
from .template import Template from .template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _load_single_dataset( def _load_single_dataset(
...@@ -51,9 +51,9 @@ def _load_single_dataset( ...@@ -51,9 +51,9 @@ def _load_single_dataset(
r""" r"""
Loads a single dataset and aligns it to the standard format. Loads a single dataset and aligns it to the standard format.
""" """
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info_rank0(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]: if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
data_path = dataset_attr.dataset_name data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset data_name = dataset_attr.subset
data_dir = dataset_attr.folder data_dir = dataset_attr.folder
...@@ -69,25 +69,24 @@ def _load_single_dataset( ...@@ -69,25 +69,24 @@ def _load_single_dataset(
if os.path.isdir(local_path): # is directory if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path): for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name)) data_files.append(os.path.join(local_path, file_name))
if data_path is None:
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
raise ValueError("File types should be identical.")
elif os.path.isfile(local_path): # is file elif os.path.isfile(local_path): # is file
data_files.append(local_path) data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else: else:
raise ValueError("File {} not found.".format(local_path)) raise ValueError(f"File {local_path} not found.")
data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
if data_path is None: if data_path is None:
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys()))) raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
raise ValueError("File types should be identical.")
else: else:
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from)) raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
if dataset_attr.load_from == "ms_hub": if dataset_attr.load_from == "ms_hub":
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
from modelscope import MsDataset from modelscope import MsDataset # type: ignore
from modelscope.utils.config_ds import MS_DATASETS_CACHE from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load( dataset = MsDataset.load(
...@@ -98,10 +97,27 @@ def _load_single_dataset( ...@@ -98,10 +97,27 @@ def _load_single_dataset(
split=dataset_attr.split, split=dataset_attr.split,
cache_dir=cache_dir, cache_dir=cache_dir,
token=model_args.ms_hub_token, token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), use_streaming=data_args.streaming,
) )
if isinstance(dataset, MsDataset): if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset() dataset = dataset.to_hf_dataset()
elif dataset_attr.load_from == "om_hub":
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
from openmind import OmDataset # type: ignore
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
dataset = OmDataset.load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=dataset_attr.split,
cache_dir=cache_dir,
token=model_args.om_hub_token,
streaming=data_args.streaming,
)
else: else:
dataset = load_dataset( dataset = load_dataset(
path=data_path, path=data_path,
...@@ -111,13 +127,10 @@ def _load_single_dataset( ...@@ -111,13 +127,10 @@ def _load_single_dataset(
split=dataset_attr.split, split=dataset_attr.split,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")), streaming=data_args.streaming,
trust_remote_code=True, trust_remote_code=True,
) )
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if dataset_attr.num_samples is not None and not data_args.streaming: if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
...@@ -128,7 +141,7 @@ def _load_single_dataset( ...@@ -128,7 +141,7 @@ def _load_single_dataset(
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched." assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
dataset = dataset.select(indexes) dataset = dataset.select(indexes)
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr)) logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
if data_args.max_samples is not None: # truncate dataset if data_args.max_samples is not None: # truncate dataset
max_samples = min(data_args.max_samples, len(dataset)) max_samples = min(data_args.max_samples, len(dataset))
...@@ -224,9 +237,9 @@ def get_dataset( ...@@ -224,9 +237,9 @@ def get_dataset(
# Load tokenized dataset # Load tokenized dataset
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path): if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.") logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path) dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path)) logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
dataset_module: Dict[str, "Dataset"] = {} dataset_module: Dict[str, "Dataset"] = {}
if "train" in dataset_dict: if "train" in dataset_dict:
...@@ -277,8 +290,8 @@ def get_dataset( ...@@ -277,8 +290,8 @@ def get_dataset(
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if training_args.should_save: if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path) dataset_dict.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.")
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path)) logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
sys.exit(0) sys.exit(0)
......
This diff is collapsed.
...@@ -20,7 +20,7 @@ from typing import Any, Dict, List, Literal, Optional, Sequence ...@@ -20,7 +20,7 @@ from typing import Any, Dict, List, Literal, Optional, Sequence
from transformers.utils import cached_file from transformers.utils import cached_file
from ..extras.constants import DATA_CONFIG from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope from ..extras.misc import use_modelscope, use_openmind
@dataclass @dataclass
...@@ -30,7 +30,7 @@ class DatasetAttr: ...@@ -30,7 +30,7 @@ class DatasetAttr:
""" """
# basic configs # basic configs
load_from: Literal["hf_hub", "ms_hub", "script", "file"] load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
dataset_name: str dataset_name: str
formatting: Literal["alpaca", "sharegpt"] = "alpaca" formatting: Literal["alpaca", "sharegpt"] = "alpaca"
ranking: bool = False ranking: bool = False
...@@ -87,31 +87,39 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) - ...@@ -87,31 +87,39 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
config_path = os.path.join(dataset_dir, DATA_CONFIG) config_path = os.path.join(dataset_dir, DATA_CONFIG)
try: try:
with open(config_path, "r") as f: with open(config_path) as f:
dataset_info = json.load(f) dataset_info = json.load(f)
except Exception as err: except Exception as err:
if len(dataset_names) != 0: if len(dataset_names) != 0:
raise ValueError("Cannot open {} due to {}.".format(config_path, str(err))) raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
dataset_info = None dataset_info = None
dataset_list: List["DatasetAttr"] = [] dataset_list: List["DatasetAttr"] = []
for name in dataset_names: for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE if dataset_info is None: # dataset_dir is ONLINE
load_from = "ms_hub" if use_modelscope() else "hf_hub" if use_modelscope():
load_from = "ms_hub"
elif use_openmind():
load_from = "om_hub"
else:
load_from = "hf_hub"
dataset_attr = DatasetAttr(load_from, dataset_name=name) dataset_attr = DatasetAttr(load_from, dataset_name=name)
dataset_list.append(dataset_attr) dataset_list.append(dataset_attr)
continue continue
if name not in dataset_info: if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
has_hf_url = "hf_hub_url" in dataset_info[name] has_hf_url = "hf_hub_url" in dataset_info[name]
has_ms_url = "ms_hub_url" in dataset_info[name] has_ms_url = "ms_hub_url" in dataset_info[name]
has_om_url = "om_hub_url" in dataset_info[name]
if has_hf_url or has_ms_url: if has_hf_url or has_ms_url or has_om_url:
if (use_modelscope() and has_ms_url) or (not has_hf_url): if has_ms_url and (use_modelscope() or not has_hf_url):
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]) dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
elif has_om_url and (use_openmind() or not has_hf_url):
dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"])
else: else:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]: elif "script_url" in dataset_info[name]:
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import infer_seqlen from .processor_utils import infer_seqlen
...@@ -28,7 +28,7 @@ if TYPE_CHECKING: ...@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template from ..template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _encode_feedback_example( def _encode_feedback_example(
...@@ -94,7 +94,9 @@ def preprocess_feedback_dataset( ...@@ -94,7 +94,9 @@ def preprocess_feedback_dataset(
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example( input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
...@@ -123,6 +125,6 @@ def preprocess_feedback_dataset( ...@@ -123,6 +125,6 @@ def preprocess_feedback_dataset(
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0: if desirable_num == 0 or undesirable_num == 0:
logger.warning("Your dataset only has one preference type.") logger.warning_rank0("Your dataset only has one preference type.")
return model_inputs return model_inputs
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import infer_seqlen from .processor_utils import infer_seqlen
...@@ -28,7 +28,7 @@ if TYPE_CHECKING: ...@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template from ..template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _encode_pairwise_example( def _encode_pairwise_example(
...@@ -77,7 +77,9 @@ def preprocess_pairwise_dataset( ...@@ -77,7 +77,9 @@ def preprocess_pairwise_dataset(
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example( chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
...@@ -110,8 +112,8 @@ def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "Pr ...@@ -110,8 +112,8 @@ def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "Pr
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"])) print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False))) print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
print("chosen_label_ids:\n{}".format(example["chosen_labels"])) print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False))) print(f"chosen_labels:\n{tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"])) print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False))) print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
print("rejected_label_ids:\n{}".format(example["rejected_labels"])) print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False))) print(f"rejected_labels:\n{tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import greedy_knapsack, infer_seqlen from .processor_utils import greedy_knapsack, infer_seqlen
...@@ -28,7 +28,7 @@ if TYPE_CHECKING: ...@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template from ..template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _encode_supervised_example( def _encode_supervised_example(
...@@ -99,7 +99,9 @@ def preprocess_supervised_dataset( ...@@ -99,7 +99,9 @@ def preprocess_supervised_dataset(
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue continue
input_ids, labels = _encode_supervised_example( input_ids, labels = _encode_supervised_example(
...@@ -141,7 +143,9 @@ def preprocess_packed_supervised_dataset( ...@@ -141,7 +143,9 @@ def preprocess_packed_supervised_dataset(
length2indexes = defaultdict(list) length2indexes = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue continue
input_ids, labels = _encode_supervised_example( input_ids, labels = _encode_supervised_example(
...@@ -160,7 +164,7 @@ def preprocess_packed_supervised_dataset( ...@@ -160,7 +164,7 @@ def preprocess_packed_supervised_dataset(
) )
length = len(input_ids) length = len(input_ids)
if length > data_args.cutoff_len: if length > data_args.cutoff_len:
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len)) logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
else: else:
lengths.append(length) lengths.append(length)
length2indexes[length].append(valid_num) length2indexes[length].append(valid_num)
...@@ -212,4 +216,4 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: " ...@@ -212,4 +216,4 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"])) print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False))) print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment