Commit 317a82e2 authored by chenych's avatar chenych
Browse files

Add QWQ-32B

parent 37b0ad9f
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from typing import Sequence
from openai import OpenAI
from transformers.utils.versions import require_version
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
grade_to_score = {"A": 4, "B": 3, "C": 2}
total_score, total_hour = 0, 0
for grade, hour in zip(grades, hours):
total_score += grade_to_score[grade] * hour
total_hour += hour
return round(total_score / total_hour, 2)
def main():
client = OpenAI(
api_key="{}".format(os.environ.get("API_KEY", "0")),
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
)
tools = [
{
"type": "function",
"function": {
"name": "calculate_gpa",
"description": "Calculate the Grade Point Average (GPA) based on grades and credit hours",
"parameters": {
"type": "object",
"properties": {
"grades": {"type": "array", "items": {"type": "string"}, "description": "The grades"},
"hours": {"type": "array", "items": {"type": "integer"}, "description": "The credit hours"},
},
"required": ["grades", "hours"],
},
},
}
]
tool_map = {"calculate_gpa": calculate_gpa}
messages = []
messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."})
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
if result.choices[0].message.tool_calls is None:
raise ValueError("Cannot retrieve function call from the response.")
messages.append(result.choices[0].message)
tool_call = result.choices[0].message.tool_calls[0].function
print(tool_call)
# Function(arguments='{"grades": ["A", "A", "B", "C"], "hours": [3, 4, 3, 2]}', name='calculate_gpa')
name, arguments = tool_call.name, json.loads(tool_call.arguments)
tool_result = tool_map[name](**arguments)
messages.append({"role": "tool", "content": json.dumps({"gpa": tool_result}, ensure_ascii=False)})
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
print(result.choices[0].message.content)
# Based on the grades and credit hours you provided, your Grade Point Average (GPA) is 3.42.
if __name__ == "__main__":
main()
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import json import json
from typing import Optional
import fire import fire
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
...@@ -45,14 +46,16 @@ def vllm_infer( ...@@ -45,14 +46,16 @@ def vllm_infer(
top_k: int = 50, top_k: int = 50,
max_new_tokens: int = 1024, max_new_tokens: int = 1024,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
seed: Optional[int] = None,
pipeline_parallel_size: int = 1, pipeline_parallel_size: int = 1,
image_resolution: int = 512 * 512, image_max_pixels: int = 768 * 768,
image_min_pixels: int = 32 * 32,
): ):
r""" r"""
Performs batch generation using vLLM engine, which supports tensor parallelism. Performs batch generation using vLLM engine, which supports tensor parallelism.
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
""" """
check_version("vllm>=0.4.3,<=0.6.5") check_version("vllm>=0.4.3,<=0.7.3")
if pipeline_parallel_size > get_device_count(): if pipeline_parallel_size > get_device_count():
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.") raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
...@@ -86,7 +89,9 @@ def vllm_infer( ...@@ -86,7 +89,9 @@ def vllm_infer(
for sample in dataset_module["train_dataset"]: for sample in dataset_module["train_dataset"]:
if sample["images"]: if sample["images"]:
multi_modal_data = { multi_modal_data = {
"image": template_obj.mm_plugin._regularize_images(sample["images"], image_resolution=image_resolution) "image": template_obj.mm_plugin._regularize_images(
sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
)
} }
else: else:
multi_modal_data = None multi_modal_data = None
...@@ -105,6 +110,7 @@ def vllm_infer( ...@@ -105,6 +110,7 @@ def vllm_infer(
stop_token_ids=template_obj.get_stop_token_ids(tokenizer), stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
max_tokens=generating_args.max_new_tokens, max_tokens=generating_args.max_new_tokens,
skip_special_tokens=False, skip_special_tokens=False,
seed=seed,
) )
if model_args.adapter_name_or_path is not None: if model_args.adapter_name_or_path is not None:
lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0]) lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -36,7 +36,7 @@ def get_requires() -> List[str]: ...@@ -36,7 +36,7 @@ def get_requires() -> List[str]:
def get_console_scripts() -> List[str]: def get_console_scripts() -> List[str]:
console_scripts = ["llamafactory-cli = llamafactory.cli:main"] console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
if os.environ.get("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "1"]: if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]:
console_scripts.append("lmf = llamafactory.cli:main") console_scripts.append("lmf = llamafactory.cli:main")
return console_scripts return console_scripts
...@@ -44,9 +44,9 @@ def get_console_scripts() -> List[str]: ...@@ -44,9 +44,9 @@ def get_console_scripts() -> List[str]:
extra_require = { extra_require = {
"torch": ["torch>=1.13.1"], "torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"], "torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"], "metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.14.4"], "deepspeed": ["deepspeed>=0.10.0,<=0.16.2"],
"liger-kernel": ["liger-kernel"], "liger-kernel": ["liger-kernel"],
"bitsandbytes": ["bitsandbytes>=0.39.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"], "hqq": ["hqq"],
...@@ -54,7 +54,7 @@ extra_require = { ...@@ -54,7 +54,7 @@ extra_require = {
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"], "awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"], "aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<=0.6.5"], "vllm": ["vllm>=0.4.3,<=0.7.3"],
"galore": ["galore-torch"], "galore": ["galore-torch"],
"apollo": ["apollo-torch"], "apollo": ["apollo-torch"],
"badam": ["badam>=1.2.1"], "badam": ["badam>=1.2.1"],
...@@ -69,7 +69,6 @@ extra_require = { ...@@ -69,7 +69,6 @@ extra_require = {
"msgpack", "msgpack",
"referencing", "referencing",
"jsonschema_specifications", "jsonschema_specifications",
"librosa",
], ],
"modelscope": ["modelscope"], "modelscope": ["modelscope"],
"openmind": ["openmind"], "openmind": ["openmind"],
...@@ -92,7 +91,7 @@ def main(): ...@@ -92,7 +91,7 @@ def main():
url="https://github.com/hiyouga/LLaMA-Factory", url="https://github.com/hiyouga/LLaMA-Factory",
package_dir={"": "src"}, package_dir={"": "src"},
packages=find_packages("src"), packages=find_packages("src"),
python_requires=">=3.8.0", python_requires=">=3.9.0",
install_requires=get_requires(), install_requires=get_requires(),
extras_require=extra_require, extras_require=extra_require,
entry_points={"console_scripts": get_console_scripts()}, entry_points={"console_scripts": get_console_scripts()},
...@@ -104,10 +103,10 @@ def main(): ...@@ -104,10 +103,10 @@ def main():
"License :: OSI Approved :: Apache Software License", "License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
], ],
) )
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -20,17 +20,17 @@ Level: ...@@ -20,17 +20,17 @@ Level:
Dependency graph: Dependency graph:
main: main:
transformers>=4.41.2,<=4.46.1 transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.1.0 datasets>=2.16.0,<=3.2.0
accelerate>=0.34.0,<=1.0.1 accelerate>=0.34.0,<=1.2.1
peft>=0.11.1,<=0.12.0 peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6 trl>=0.8.6,<=0.9.6
attention: attention:
transformers>=4.42.4 (gemma+fa2) transformers>=4.42.4 (gemma+fa2)
longlora: longlora:
transformers>=4.41.2,<=4.46.1 transformers>=4.41.2,<4.48.0
packing: packing:
transformers>=4.43.0,<=4.46.1 transformers>=4.43.0
Disable version checking: DISABLE_VERSION_CHECK=1 Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1 Enable VRAM recording: RECORD_VRAM=1
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple ...@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole from ..data import Role as DataRole
from ..extras import logging from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.misc import is_env_enabled
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify from .common import dictify, jsonify
from .protocol import ( from .protocol import (
...@@ -70,7 +72,8 @@ ROLE_MAPPING = { ...@@ -70,7 +72,8 @@ ROLE_MAPPING = {
def _process_request( def _process_request(
request: "ChatCompletionRequest", request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]: ) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}") if is_env_enabled("API_VERBOSE", "1"):
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
...@@ -99,10 +102,12 @@ def _process_request( ...@@ -99,10 +102,12 @@ def _process_request(
content = json.dumps(tool_calls, ensure_ascii=False) content = json.dumps(tool_calls, ensure_ascii=False)
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content}) input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
elif isinstance(message.content, list): elif isinstance(message.content, list):
text_content = ""
for input_item in message.content: for input_item in message.content:
if input_item.type == "text": if input_item.type == "text":
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text}) text_content += input_item.text
else: else:
text_content += IMAGE_PLACEHOLDER
image_url = input_item.image_url.url image_url = input_item.image_url.url
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1])) image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
...@@ -112,6 +117,8 @@ def _process_request( ...@@ -112,6 +117,8 @@ def _process_request(
image_stream = requests.get(image_url, stream=True).raw image_stream = requests.get(image_url, stream=True).raw
images.append(Image.open(image_stream).convert("RGB")) images.append(Image.open(image_stream).convert("RGB"))
input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content})
else: else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -22,7 +22,7 @@ if TYPE_CHECKING: ...@@ -22,7 +22,7 @@ if TYPE_CHECKING:
from vllm import AsyncLLMEngine from vllm import AsyncLLMEngine
from ..data import Template from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
...@@ -68,6 +68,7 @@ class BaseEngine(ABC): ...@@ -68,6 +68,7 @@ class BaseEngine(ABC):
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
r""" r"""
...@@ -83,6 +84,7 @@ class BaseEngine(ABC): ...@@ -83,6 +84,7 @@ class BaseEngine(ABC):
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r""" r"""
......
...@@ -27,7 +27,7 @@ from .vllm_engine import VllmEngine ...@@ -27,7 +27,7 @@ from .vllm_engine import VllmEngine
if TYPE_CHECKING: if TYPE_CHECKING:
from ..data.mm_plugin import ImageInput, VideoInput from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
...@@ -66,13 +66,14 @@ class ChatModel: ...@@ -66,13 +66,14 @@ class ChatModel:
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
r""" r"""
Gets a list of responses of the chat model. Gets a list of responses of the chat model.
""" """
task = asyncio.run_coroutine_threadsafe( task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
) )
return task.result() return task.result()
...@@ -83,12 +84,13 @@ class ChatModel: ...@@ -83,12 +84,13 @@ class ChatModel:
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
r""" r"""
Asynchronously gets a list of responses of the chat model. Asynchronously gets a list of responses of the chat model.
""" """
return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs) return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
def stream_chat( def stream_chat(
self, self,
...@@ -97,12 +99,13 @@ class ChatModel: ...@@ -97,12 +99,13 @@ class ChatModel:
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
r""" r"""
Gets the response token-by-token of the chat model. Gets the response token-by-token of the chat model.
""" """
generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs) generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
while True: while True:
try: try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
...@@ -117,12 +120,15 @@ class ChatModel: ...@@ -117,12 +120,15 @@ class ChatModel:
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r""" r"""
Asynchronously gets the response token-by-token of the chat model. Asynchronously gets the response token-by-token of the chat model.
""" """
async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs): async for new_token in self.engine.stream_chat(
messages, system, tools, images, videos, audios, **input_kwargs
):
yield new_token yield new_token
def get_scores( def get_scores(
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -24,7 +24,7 @@ from typing_extensions import override ...@@ -24,7 +24,7 @@ from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
...@@ -35,7 +35,7 @@ if TYPE_CHECKING: ...@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from trl import PreTrainedModelWrapper from trl import PreTrainedModelWrapper
from ..data import Template from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
...@@ -81,9 +81,10 @@ class HuggingfaceEngine(BaseEngine): ...@@ -81,9 +81,10 @@ class HuggingfaceEngine(BaseEngine):
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]} mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
if images is not None: if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]}) mm_input_dict.update({"images": images, "imglens": [len(images)]})
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
...@@ -94,14 +95,25 @@ class HuggingfaceEngine(BaseEngine): ...@@ -94,14 +95,25 @@ class HuggingfaceEngine(BaseEngine):
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages): if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"] messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
if audios is not None:
mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
messages = template.mm_plugin.process_messages( messages = template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], processor messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
) )
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"] system = system or generating_args["default_system"]
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools) prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
prompt_ids, _ = template.mm_plugin.process_token_ids( prompt_ids, _ = template.mm_plugin.process_token_ids(
prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor prompt_ids,
None,
mm_input_dict["images"],
mm_input_dict["videos"],
mm_input_dict["audios"],
tokenizer,
processor,
) )
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device) inputs = torch.tensor([prompt_ids], device=model.device)
...@@ -114,6 +126,7 @@ class HuggingfaceEngine(BaseEngine): ...@@ -114,6 +126,7 @@ class HuggingfaceEngine(BaseEngine):
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1) num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None) repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None) length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
...@@ -133,6 +146,9 @@ class HuggingfaceEngine(BaseEngine): ...@@ -133,6 +146,9 @@ class HuggingfaceEngine(BaseEngine):
if repetition_penalty is not None if repetition_penalty is not None
else generating_args["repetition_penalty"], else generating_args["repetition_penalty"],
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"], length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
skip_special_tokens=skip_special_tokens
if skip_special_tokens is not None
else generating_args["skip_special_tokens"],
eos_token_id=template.get_stop_token_ids(tokenizer), eos_token_id=template.get_stop_token_ids(tokenizer),
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,
) )
...@@ -166,9 +182,11 @@ class HuggingfaceEngine(BaseEngine): ...@@ -166,9 +182,11 @@ class HuggingfaceEngine(BaseEngine):
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor) mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
for key, value in mm_inputs.items(): for key, value in mm_inputs.items():
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs if isinstance(value, list) and isinstance(value[0], torch.Tensor): # for pixtral inputs
value = torch.stack(value) # assume they have same sizes value = torch.stack(value) # assume they have same sizes
elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs elif (
isinstance(value, list) and isinstance(value[0], list) and isinstance(value[0][0], torch.Tensor)
): # for minicpmv inputs
value = torch.stack([torch.stack(v) for v in value]) value = torch.stack([torch.stack(v) for v in value])
elif not isinstance(value, torch.Tensor): elif not isinstance(value, torch.Tensor):
value = torch.tensor(value) value = torch.tensor(value)
...@@ -176,12 +194,18 @@ class HuggingfaceEngine(BaseEngine): ...@@ -176,12 +194,18 @@ class HuggingfaceEngine(BaseEngine):
if torch.is_floating_point(value): # cast data dtype for paligemma if torch.is_floating_point(value): # cast data dtype for paligemma
value = value.to(model.dtype) value = value.to(model.dtype)
gen_kwargs[key] = value.to(model.device) if key == "second_per_grid_ts": # qwen2.5vl special case
gen_kwargs[key] = value.tolist()
else:
gen_kwargs[key] = value.to(model.device)
if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]: if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
gen_kwargs["input_ids"] = inputs gen_kwargs["input_ids"] = inputs
del gen_kwargs["image_sizes"]
gen_kwargs["tokenizer"] = tokenizer gen_kwargs["tokenizer"] = tokenizer
if "audio_feature_lens" in mm_inputs:
gen_kwargs["audio_feature_lens"] = mm_inputs["audio_feature_lens"]
gen_kwargs.pop("image_sizes", None)
return gen_kwargs, prompt_length return gen_kwargs, prompt_length
...@@ -198,6 +222,7 @@ class HuggingfaceEngine(BaseEngine): ...@@ -198,6 +222,7 @@ class HuggingfaceEngine(BaseEngine):
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]: ) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args( gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
...@@ -211,6 +236,7 @@ class HuggingfaceEngine(BaseEngine): ...@@ -211,6 +236,7 @@ class HuggingfaceEngine(BaseEngine):
tools, tools,
images, images,
videos, videos,
audios,
input_kwargs, input_kwargs,
) )
generate_output = model.generate(**gen_kwargs) generate_output = model.generate(**gen_kwargs)
...@@ -219,7 +245,9 @@ class HuggingfaceEngine(BaseEngine): ...@@ -219,7 +245,9 @@ class HuggingfaceEngine(BaseEngine):
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
response = tokenizer.batch_decode( response = tokenizer.batch_decode(
response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True response_ids,
skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
clean_up_tokenization_spaces=True,
) )
results = [] results = []
for i in range(len(response)): for i in range(len(response)):
...@@ -249,6 +277,7 @@ class HuggingfaceEngine(BaseEngine): ...@@ -249,6 +277,7 @@ class HuggingfaceEngine(BaseEngine):
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]: ) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args( gen_kwargs, _ = HuggingfaceEngine._process_args(
...@@ -262,10 +291,13 @@ class HuggingfaceEngine(BaseEngine): ...@@ -262,10 +291,13 @@ class HuggingfaceEngine(BaseEngine):
tools, tools,
images, images,
videos, videos,
audios,
input_kwargs, input_kwargs,
) )
streamer = TextIteratorStreamer( streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=generating_args["skip_special_tokens"] tokenizer,
skip_prompt=True,
skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
) )
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True) thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
...@@ -309,6 +341,7 @@ class HuggingfaceEngine(BaseEngine): ...@@ -309,6 +341,7 @@ class HuggingfaceEngine(BaseEngine):
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
if not self.can_generate: if not self.can_generate:
...@@ -326,6 +359,7 @@ class HuggingfaceEngine(BaseEngine): ...@@ -326,6 +359,7 @@ class HuggingfaceEngine(BaseEngine):
tools, tools,
images, images,
videos, videos,
audios,
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:
...@@ -340,6 +374,7 @@ class HuggingfaceEngine(BaseEngine): ...@@ -340,6 +374,7 @@ class HuggingfaceEngine(BaseEngine):
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
if not self.can_generate: if not self.can_generate:
...@@ -357,6 +392,7 @@ class HuggingfaceEngine(BaseEngine): ...@@ -357,6 +392,7 @@ class HuggingfaceEngine(BaseEngine):
tools, tools,
images, images,
videos, videos,
audios,
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,27 +19,22 @@ from typing_extensions import override ...@@ -19,27 +19,22 @@ from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.misc import get_device_count from ..extras.misc import get_device_count
from ..extras.packages import is_pillow_available, is_vllm_available from ..extras.packages import is_vllm_available
from ..model import load_config, load_tokenizer from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod from ..model.model_utils.quantization import QuantizationMethod
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
if is_pillow_available():
from PIL import Image
from PIL.Image import Image as ImageObject
if is_vllm_available(): if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
if TYPE_CHECKING: if TYPE_CHECKING:
from ..data.mm_plugin import ImageInput, VideoInput from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
...@@ -54,6 +49,7 @@ class VllmEngine(BaseEngine): ...@@ -54,6 +49,7 @@ class VllmEngine(BaseEngine):
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ) -> None:
self.model_args = model_args
config = load_config(model_args) # may download model from ms hub config = load_config(model_args) # may download model from ms hub
if getattr(config, "quantization_config", None): # gptq models should use float16 if getattr(config, "quantization_config", None): # gptq models should use float16
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
...@@ -109,10 +105,11 @@ class VllmEngine(BaseEngine): ...@@ -109,10 +105,11 @@ class VllmEngine(BaseEngine):
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = f"chatcmpl-{uuid.uuid4().hex}" request_id = f"chatcmpl-{uuid.uuid4().hex}"
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]} mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
if images is not None: if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]}) mm_input_dict.update({"images": images, "imglens": [len(images)]})
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
...@@ -123,8 +120,13 @@ class VllmEngine(BaseEngine): ...@@ -123,8 +120,13 @@ class VllmEngine(BaseEngine):
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages): if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"] messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
if audios is not None:
mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
messages = self.template.mm_plugin.process_messages( messages = self.template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], self.processor messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], self.processor
) )
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"] system = system or self.generating_args["default_system"]
...@@ -137,6 +139,7 @@ class VllmEngine(BaseEngine): ...@@ -137,6 +139,7 @@ class VllmEngine(BaseEngine):
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1) num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None) repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None) length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
...@@ -170,19 +173,19 @@ class VllmEngine(BaseEngine): ...@@ -170,19 +173,19 @@ class VllmEngine(BaseEngine):
stop=stop, stop=stop,
stop_token_ids=self.template.get_stop_token_ids(self.tokenizer), stop_token_ids=self.template.get_stop_token_ids(self.tokenizer),
max_tokens=max_tokens, max_tokens=max_tokens,
skip_special_tokens=self.generating_args["skip_special_tokens"], skip_special_tokens=skip_special_tokens
if skip_special_tokens is not None
else self.generating_args["skip_special_tokens"],
) )
if images is not None: # add image features if images is not None: # add image features
multi_modal_data = {"image": []} multi_modal_data = {
for image in images: "image": self.template.mm_plugin._regularize_images(
if not isinstance(image, (str, ImageObject)): images,
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.") image_max_pixels=self.model_args.image_max_pixels,
image_min_pixels=self.model_args.image_min_pixels,
if isinstance(image, str): )
image = Image.open(image).convert("RGB") }
multi_modal_data["image"].append(image)
else: else:
multi_modal_data = None multi_modal_data = None
...@@ -202,10 +205,11 @@ class VllmEngine(BaseEngine): ...@@ -202,10 +205,11 @@ class VllmEngine(BaseEngine):
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
final_output = None final_output = None
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs) generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
async for request_output in generator: async for request_output in generator:
final_output = request_output final_output = request_output
...@@ -230,10 +234,11 @@ class VllmEngine(BaseEngine): ...@@ -230,10 +234,11 @@ class VllmEngine(BaseEngine):
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
generated_text = "" generated_text = ""
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs) generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
async for result in generator: async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :] delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text generated_text = result.outputs[0].text
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -24,7 +24,7 @@ from .chat.chat_model import run_chat ...@@ -24,7 +24,7 @@ from .chat.chat_model import run_chat
from .eval.evaluator import run_eval from .eval.evaluator import run_eval
from .extras import logging from .extras import logging
from .extras.env import VERSION, print_env from .extras.env import VERSION, print_env
from .extras.misc import get_device_count, use_ray from .extras.misc import get_device_count, is_env_enabled, use_ray
from .train.tuner import export_model, run_exp from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui from .webui.interface import run_web_demo, run_web_ui
...@@ -86,7 +86,7 @@ def main(): ...@@ -86,7 +86,7 @@ def main():
elif command == Command.EXPORT: elif command == Command.EXPORT:
export_model() export_model()
elif command == Command.TRAIN: elif command == Command.TRAIN:
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"] force_torchrun = is_env_enabled("FORCE_TORCHRUN")
if force_torchrun or (get_device_count() > 1 and not use_ray()): if force_torchrun or (get_device_count() > 1 and not use_ray()):
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999))) master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from ..extras import logging
from .data_utils import Role
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments
from .mm_plugin import ImageInput, VideoInput
from .parser import DatasetAttr
logger = logging.get_logger(__name__)
def _convert_images(
images: Union["ImageInput", Sequence["ImageInput"]],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["ImageInput"]]:
r"""
Optionally concatenates image path to dataset dir when loading from local disk.
"""
if not isinstance(images, list):
images = [images]
elif len(images) == 0:
return None
else:
images = images[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(images)):
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.image_dir, images[i])):
images[i] = os.path.join(data_args.image_dir, images[i])
return images
def _convert_videos(
videos: Union["VideoInput", Sequence["VideoInput"]],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["VideoInput"]]:
r"""
Optionally concatenates video path to dataset dir when loading from local disk.
"""
if not isinstance(videos, list):
videos = [videos]
elif len(videos) == 0:
return None
else:
videos = videos[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(videos)):
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.image_dir, videos[i])):
videos[i] = os.path.join(data_args.image_dir, videos[i])
return videos
def convert_alpaca(
example: Dict[str, Any],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:
r"""
Converts alpaca format dataset to the standard format.
"""
prompt = []
if dataset_attr.history and isinstance(example[dataset_attr.history], list):
for old_prompt, old_response in example[dataset_attr.history]:
prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
query = []
if dataset_attr.prompt and example[dataset_attr.prompt]:
query.append(example[dataset_attr.prompt])
if dataset_attr.query and example[dataset_attr.query]:
query.append(example[dataset_attr.query])
prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
if example[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(example[dataset_attr.chosen], str)
and isinstance(example[dataset_attr.rejected], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.chosen]},
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]},
]
elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
else: # unsupervised
response = []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
output = {
"_prompt": prompt,
"_response": response,
"_system": example[dataset_attr.system] if dataset_attr.system else "",
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
}
return output
def convert_sharegpt(
example: Dict[str, Any],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:
r"""
Converts sharegpt format dataset to the standard format.
"""
tag_mapping = {
dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value,
dataset_attr.observation_tag: Role.OBSERVATION.value,
dataset_attr.function_tag: Role.FUNCTION.value,
dataset_attr.system_tag: Role.SYSTEM.value,
}
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
messages = example[dataset_attr.messages]
if (
dataset_attr.system_tag
and len(messages) != 0
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
):
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = example[dataset_attr.system] if dataset_attr.system else ""
aligned_messages = []
broken_data = False
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning_rank0(f"Invalid role tag in {messages}.")
broken_data = True
aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning_rank0(f"Invalid message count in {messages}.")
broken_data = True
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if example[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(example[dataset_attr.chosen], dict)
and isinstance(example[dataset_attr.rejected], dict)
): # pairwise example
chosen = example[dataset_attr.chosen]
rejected = example[dataset_attr.rejected]
if (
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
broken_data = True
prompt = aligned_messages
response = [
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if broken_data:
logger.warning_rank0("Skipping this abnormal example.")
prompt, response = [], []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
output = {
"_prompt": prompt,
"_response": response,
"_system": system,
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
}
return output
def align_dataset(
dataset: Union["Dataset", "IterableDataset"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_tools: "...",
_images: [],
_videos: [],
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Converting format of dataset",
)
return dataset.map(
convert_func,
batched=False,
remove_columns=column_names,
**kwargs,
)
...@@ -18,11 +18,12 @@ ...@@ -18,11 +18,12 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
from ..extras.packages import is_pillow_available from ..extras.packages import is_pillow_available
...@@ -80,7 +81,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -80,7 +81,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r""" r"""
Data collator that supports VLMs. Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels, and optionally contain images and videos. Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
""" """
template: Optional["Template"] = None template: Optional["Template"] = None
...@@ -91,26 +92,54 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -91,26 +92,54 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
raise ValueError("Template is required for MultiModalDataCollator.") raise ValueError("Template is required for MultiModalDataCollator.")
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], [] batch_images, batch_videos, batch_audios = [], [], []
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
for feature in features: for feature in features:
images = feature.pop("images", None) or [] images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or [] videos = feature.pop("videos", None) or []
audios = feature.pop("audios", None) or []
batch_images.extend(images) batch_images.extend(images)
batch_videos.extend(videos) batch_videos.extend(videos)
batch_audios.extend(audios)
batch_imglens.append(len(images)) batch_imglens.append(len(images))
batch_vidlens.append(len(videos)) batch_vidlens.append(len(videos))
batch_audlens.append(len(audios))
batch_input_ids.append(feature["input_ids"]) batch_input_ids.append(feature["input_ids"])
fake_input_ids = []
if ( if (
self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0 self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
): # avoid process hanging in zero3/fsdp case ): # avoid process hanging in zero3/fsdp case
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}] fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))] fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor) fake_messages = self.template.mm_plugin.process_messages(
fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False) fake_messages, fake_images, [], [], self.processor
fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
fake_input_ids, None, fake_images, [], self.tokenizer, self.processor
) )
_fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
_fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
_fake_input_ids, None, fake_images, [], [], self.tokenizer, self.processor
)
fake_input_ids.extend(_fake_input_ids)
batch_images = fake_images
batch_imglens[0] = 1
if (
self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
): # avoid process hanging in zero3/fsdp case
fake_messages = [{"role": "user", "content": AUDIO_PLACEHOLDER}]
fake_audios = [np.zeros(1600)]
fake_messages = self.template.mm_plugin.process_messages(
fake_messages, [], [], fake_audios, self.processor
)
_fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
_fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
_fake_input_ids, None, [], [], fake_audios, self.tokenizer, self.processor
)
fake_input_ids.extend(_fake_input_ids)
batch_audios = fake_audios
batch_audlens[0] = 1
if len(fake_input_ids) != 0:
if self.tokenizer.padding_side == "right": if self.tokenizer.padding_side == "right":
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids) features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
...@@ -120,12 +149,17 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -120,12 +149,17 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"] features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"] features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]
batch_images = fake_images
batch_imglens[0] = 1
batch_input_ids[0] = features[0]["input_ids"] batch_input_ids[0] = features[0]["input_ids"]
mm_inputs = self.template.mm_plugin.get_mm_inputs( mm_inputs = self.template.mm_plugin.get_mm_inputs(
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor batch_images,
batch_videos,
batch_audios,
batch_imglens,
batch_vidlens,
batch_audlens,
batch_input_ids,
self.processor,
) )
if "token_type_ids" in mm_inputs: if "token_type_ids" in mm_inputs:
token_type_ids = mm_inputs.pop("token_type_ids") token_type_ids = mm_inputs.pop("token_type_ids")
...@@ -135,12 +169,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -135,12 +169,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features: Dict[str, "torch.Tensor"] = super().__call__(features) features: Dict[str, "torch.Tensor"] = super().__call__(features)
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index( rope_index_kwargs = {
input_ids=features["input_ids"], "input_ids": features["input_ids"],
image_grid_thw=mm_inputs.get("image_grid_thw", None), "image_grid_thw": mm_inputs.get("image_grid_thw"),
video_grid_thw=mm_inputs.get("video_grid_thw", None), "video_grid_thw": mm_inputs.get("video_grid_thw"),
attention_mask=features["attention_mask"], "attention_mask": features["attention_mask"],
) }
if "second_per_grid_ts" in mm_inputs:
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask = mm_inputs.pop("cross_attention_mask") cross_attention_mask = mm_inputs.pop("cross_attention_mask")
...@@ -149,8 +187,6 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -149,8 +187,6 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len)) mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len))
features.update(mm_inputs) features.update(mm_inputs)
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to()
if "image_bound" in features: # for minicpmv inputs if "image_bound" in features: # for minicpmv inputs
bsz, seq_length = features["input_ids"].shape bsz, seq_length = features["input_ids"].shape
...@@ -204,6 +240,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): ...@@ -204,6 +240,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"labels": feature[f"{key}_labels"], "labels": feature[f"{key}_labels"],
"images": feature["images"], "images": feature["images"],
"videos": feature["videos"], "videos": feature["videos"],
"audios": feature["audios"],
} }
concatenated_features.append(target_feature) concatenated_features.append(target_feature)
...@@ -227,6 +264,7 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): ...@@ -227,6 +264,7 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"labels": feature["labels"], "labels": feature["labels"],
"images": feature["images"], "images": feature["images"],
"videos": feature["videos"], "videos": feature["videos"],
"audios": feature["audios"],
} }
kl_feature = { kl_feature = {
"input_ids": feature["kl_input_ids"], "input_ids": feature["kl_input_ids"],
...@@ -234,6 +272,7 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): ...@@ -234,6 +272,7 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"labels": feature["kl_labels"], "labels": feature["kl_labels"],
"images": feature["images"], "images": feature["images"],
"videos": feature["videos"], "videos": feature["videos"],
"audios": feature["audios"],
} }
target_features.append(target_feature) target_features.append(target_feature)
kl_features.append(kl_feature) kl_features.append(kl_feature)
...@@ -244,6 +283,8 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): ...@@ -244,6 +283,8 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
batch["kl_input_ids"] = kl_batch["input_ids"] batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"] batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"] batch["kl_labels"] = kl_batch["labels"]
if "cross_attention_mask" in kl_batch: # for mllama inputs.
batch["kl_cross_attention_mask"] = kl_batch["cross_attention_mask"]
if "token_type_ids" in kl_batch: if "token_type_ids" in kl_batch:
batch["kl_token_type_ids"] = kl_batch["token_type_ids"] batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union
from ..extras import logging
from .data_utils import Role
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments
from .parser import DatasetAttr
logger = logging.get_logger(__name__)
@dataclass
class DatasetConverter:
dataset_attr: "DatasetAttr"
data_args: "DataArguments"
def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[List[Any]]:
r"""
Optionally concatenates media path to media dir when loading from local disk.
"""
if not isinstance(medias, list):
medias = [medias] if medias is not None else []
elif len(medias) == 0:
return None
else:
medias = medias[:]
if self.dataset_attr.load_from in ["script", "file"] and isinstance(medias[0], str):
for i in range(len(medias)):
if os.path.isfile(os.path.join(self.data_args.media_dir, medias[i])):
medias[i] = os.path.join(self.data_args.media_dir, medias[i])
else:
logger.warning_rank0_once(f"Media {medias[i]} does not exist in `media_dir`. Use original path.")
return medias
@abstractmethod
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
r"""
Converts a single example in the dataset to the standard format.
"""
...
@dataclass
class AlpacaDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
prompt = []
if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list):
for old_prompt, old_response in example[self.dataset_attr.history]:
prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
query = []
if self.dataset_attr.prompt and example[self.dataset_attr.prompt]:
query.append(example[self.dataset_attr.prompt])
if self.dataset_attr.query and example[self.dataset_attr.query]:
query.append(example[self.dataset_attr.query])
prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
if self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}]
if example[self.dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
self.dataset_attr.ranking
and isinstance(example[self.dataset_attr.chosen], str)
and isinstance(example[self.dataset_attr.rejected], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.chosen]},
{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.rejected]},
]
elif self.dataset_attr.response and isinstance(example[self.dataset_attr.response], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}]
else: # unsupervised
response = []
output = {
"_prompt": prompt,
"_response": response,
"_system": example[self.dataset_attr.system] if self.dataset_attr.system else "",
"_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "",
"_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
"_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
"_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
}
return output
@dataclass
class SharegptDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
tag_mapping = {
self.dataset_attr.user_tag: Role.USER.value,
self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
self.dataset_attr.observation_tag: Role.OBSERVATION.value,
self.dataset_attr.function_tag: Role.FUNCTION.value,
self.dataset_attr.system_tag: Role.SYSTEM.value,
}
odd_tags = (self.dataset_attr.user_tag, self.dataset_attr.observation_tag)
even_tags = (self.dataset_attr.assistant_tag, self.dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
messages = example[self.dataset_attr.messages]
if (
self.dataset_attr.system_tag
and len(messages) != 0
and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag
):
system = messages[0][self.dataset_attr.content_tag]
messages = messages[1:]
else:
system = example[self.dataset_attr.system] if self.dataset_attr.system else ""
aligned_messages = []
broken_data = False
for turn_idx, message in enumerate(messages):
if message[self.dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning_rank0(f"Invalid role tag in {messages}.")
broken_data = True
break
aligned_messages.append(
{
"role": tag_mapping[message[self.dataset_attr.role_tag]],
"content": message[self.dataset_attr.content_tag],
}
)
if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
self.dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning_rank0(f"Invalid message count in {messages}.")
broken_data = True
if broken_data:
logger.warning_rank0("Skipping this abnormal example.")
prompt, response = [], []
elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if example[self.dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
self.dataset_attr.ranking
and isinstance(example[self.dataset_attr.chosen], dict)
and isinstance(example[self.dataset_attr.rejected], dict)
): # pairwise example
chosen = example[self.dataset_attr.chosen]
rejected = example[self.dataset_attr.rejected]
if (
chosen[self.dataset_attr.role_tag] not in accept_tags[-1]
or rejected[self.dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
broken_data = True
prompt = aligned_messages
response = [
{
"role": tag_mapping[chosen[self.dataset_attr.role_tag]],
"content": chosen[self.dataset_attr.content_tag],
},
{
"role": tag_mapping[rejected[self.dataset_attr.role_tag]],
"content": rejected[self.dataset_attr.content_tag],
},
]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
output = {
"_prompt": prompt,
"_response": response,
"_system": system,
"_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "",
"_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
"_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
"_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
}
return output
DATASET_CONVERTERS = {
"alpaca": AlpacaDatasetConverter,
"sharegpt": SharegptDatasetConverter,
}
def register_dataset_converter(name: str, dataset_converter: Type["DatasetConverter"]) -> None:
r"""
Register a new dataset converter.
"""
if name in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} already exists.")
DATASET_CONVERTERS[name] = dataset_converter
def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter":
r"""
Gets a dataset converter.
"""
if name not in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} not found.")
return DATASET_CONVERTERS[name](dataset_attr, data_args)
def align_dataset(
dataset: Union["Dataset", "IterableDataset"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_tools: "...",
_images: [],
_videos: [],
_audios: [],
"""
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Converting format of dataset",
)
dataset_converter = get_dataset_converter(dataset_attr.formatting, dataset_attr, data_args)
return dataset.map(
dataset_converter,
batched=False,
remove_columns=column_names,
**kwargs,
)
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment