Commit 7ea81099 authored by chenych's avatar chenych
Browse files

update llama4

parent 84987715
# Copyright 2024 THUDM and the LlamaFactory team. # Copyright 2025 THUDM and the LlamaFactory team.
# #
# This code is inspired by the THUDM's ChatGLM implementation. # This code is inspired by the THUDM's ChatGLM implementation.
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py # https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
...@@ -17,13 +17,15 @@ ...@@ -17,13 +17,15 @@
import asyncio import asyncio
import os import os
from collections.abc import AsyncGenerator, Generator
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence from typing import TYPE_CHECKING, Any, Optional
from ..extras.constants import EngineName from ..extras.constants import EngineName
from ..extras.misc import torch_gc from ..extras.misc import torch_gc
from ..hparams import get_infer_args from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine from .hf_engine import HuggingfaceEngine
from .sglang_engine import SGLangEngine
from .vllm_engine import VllmEngine from .vllm_engine import VllmEngine
...@@ -38,20 +40,21 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None: ...@@ -38,20 +40,21 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
class ChatModel: class ChatModel:
r""" r"""General class for chat models. Backed by huggingface or vllm engines.
General class for chat models. Backed by huggingface or vllm engines.
Supports both sync and async methods. Supports both sync and async methods.
Sync methods: chat(), stream_chat() and get_scores(). Sync methods: chat(), stream_chat() and get_scores().
Async methods: achat(), astream_chat() and aget_scores(). Async methods: achat(), astream_chat() and aget_scores().
""" """
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args) model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
if model_args.infer_backend == EngineName.HF: if model_args.infer_backend == EngineName.HF:
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == EngineName.VLLM: elif model_args.infer_backend == EngineName.VLLM:
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args) self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == EngineName.SGLANG:
self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
else: else:
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}") raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
...@@ -61,17 +64,15 @@ class ChatModel: ...@@ -61,17 +64,15 @@ class ChatModel:
def chat( def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> list["Response"]:
r""" r"""Get a list of responses of the chat model."""
Gets a list of responses of the chat model.
"""
task = asyncio.run_coroutine_threadsafe( task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
) )
...@@ -79,32 +80,28 @@ class ChatModel: ...@@ -79,32 +80,28 @@ class ChatModel:
async def achat( async def achat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> list["Response"]:
r""" r"""Asynchronously get a list of responses of the chat model."""
Asynchronously gets a list of responses of the chat model.
"""
return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs) return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
def stream_chat( def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
r""" r"""Get the response token-by-token of the chat model."""
Gets the response token-by-token of the chat model.
"""
generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs) generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
while True: while True:
try: try:
...@@ -115,17 +112,15 @@ class ChatModel: ...@@ -115,17 +112,15 @@ class ChatModel:
async def astream_chat( async def astream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r""" r"""Asynchronously get the response token-by-token of the chat model."""
Asynchronously gets the response token-by-token of the chat model.
"""
async for new_token in self.engine.stream_chat( async for new_token in self.engine.stream_chat(
messages, system, tools, images, videos, audios, **input_kwargs messages, system, tools, images, videos, audios, **input_kwargs
): ):
...@@ -133,23 +128,19 @@ class ChatModel: ...@@ -133,23 +128,19 @@ class ChatModel:
def get_scores( def get_scores(
self, self,
batch_input: List[str], batch_input: list[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> list[float]:
r""" r"""Get a list of scores of the reward model."""
Gets a list of scores of the reward model.
"""
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop) task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
return task.result() return task.result()
async def aget_scores( async def aget_scores(
self, self,
batch_input: List[str], batch_input: list[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> list[float]:
r""" r"""Asynchronously get a list of scores of the reward model."""
Asynchronously gets a list of scores of the reward model.
"""
return await self.engine.get_scores(batch_input, **input_kwargs) return await self.engine.get_scores(batch_input, **input_kwargs)
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
import asyncio import asyncio
import concurrent.futures
import os import os
from collections.abc import AsyncGenerator
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
...@@ -76,15 +76,15 @@ class HuggingfaceEngine(BaseEngine): ...@@ -76,15 +76,15 @@ class HuggingfaceEngine(BaseEngine):
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
template: "Template", template: "Template",
generating_args: Dict[str, Any], generating_args: dict[str, Any],
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]: ) -> tuple[dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]} mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
if images is not None: if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]}) mm_input_dict.update({"images": images, "imglens": [len(images)]})
...@@ -130,7 +130,7 @@ class HuggingfaceEngine(BaseEngine): ...@@ -130,7 +130,7 @@ class HuggingfaceEngine(BaseEngine):
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None) skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
if stop is not None: if stop is not None:
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.") logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
...@@ -217,15 +217,15 @@ class HuggingfaceEngine(BaseEngine): ...@@ -217,15 +217,15 @@ class HuggingfaceEngine(BaseEngine):
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
template: "Template", template: "Template",
generating_args: Dict[str, Any], generating_args: dict[str, Any],
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[dict[str, Any]] = {},
) -> List["Response"]: ) -> list["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args( gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, model,
tokenizer, tokenizer,
...@@ -272,14 +272,14 @@ class HuggingfaceEngine(BaseEngine): ...@@ -272,14 +272,14 @@ class HuggingfaceEngine(BaseEngine):
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
template: "Template", template: "Template",
generating_args: Dict[str, Any], generating_args: dict[str, Any],
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[dict[str, Any]] = {},
) -> Callable[[], str]: ) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args( gen_kwargs, _ = HuggingfaceEngine._process_args(
model, model,
...@@ -317,12 +317,12 @@ class HuggingfaceEngine(BaseEngine): ...@@ -317,12 +317,12 @@ class HuggingfaceEngine(BaseEngine):
def _get_scores( def _get_scores(
model: "PreTrainedModelWrapper", model: "PreTrainedModelWrapper",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
batch_input: List[str], batch_input: list[str],
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[dict[str, Any]] = {},
) -> List[float]: ) -> list[float]:
max_length: Optional[int] = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
device = getattr(model.pretrained_model, "device", "cuda") device = getattr(model.pretrained_model, "device", "cuda")
inputs: Dict[str, "torch.Tensor"] = tokenizer( inputs: dict[str, torch.Tensor] = tokenizer(
batch_input, batch_input,
padding=True, padding=True,
truncation=True, truncation=True,
...@@ -330,25 +330,24 @@ class HuggingfaceEngine(BaseEngine): ...@@ -330,25 +330,24 @@ class HuggingfaceEngine(BaseEngine):
return_tensors="pt", return_tensors="pt",
add_special_tokens=False, add_special_tokens=False,
).to(device) ).to(device)
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1] values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1)) scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return scores return scores
@override @override
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> list["Response"]:
if not self.can_generate: if not self.can_generate:
raise ValueError("The current model does not support `chat`.") raise ValueError("The current model does not support `chat`.")
loop = asyncio.get_running_loop()
input_args = ( input_args = (
self.model, self.model,
self.tokenizer, self.tokenizer,
...@@ -364,24 +363,22 @@ class HuggingfaceEngine(BaseEngine): ...@@ -364,24 +363,22 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool: return await asyncio.to_thread(self._chat, *input_args)
return await loop.run_in_executor(pool, self._chat, *input_args)
@override @override
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
if not self.can_generate: if not self.can_generate:
raise ValueError("The current model does not support `stream_chat`.") raise ValueError("The current model does not support `stream_chat`.")
loop = asyncio.get_running_loop()
input_args = ( input_args = (
self.model, self.model,
self.tokenizer, self.tokenizer,
...@@ -397,25 +394,22 @@ class HuggingfaceEngine(BaseEngine): ...@@ -397,25 +394,22 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
stream = self._stream_chat(*input_args) stream = self._stream_chat(*input_args)
while True: while True:
try: try:
yield await loop.run_in_executor(pool, stream) yield await asyncio.to_thread(stream)
except StopAsyncIteration: except StopAsyncIteration:
break break
@override @override
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: list[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> list[float]:
if self.can_generate: if self.can_generate:
raise ValueError("Cannot get scores using an auto-regressive model.") raise ValueError("Cannot get scores using an auto-regressive model.")
loop = asyncio.get_running_loop()
input_args = (self.model, self.tokenizer, batch_input, input_kwargs) input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
async with self.semaphore: async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool: return await asyncio.to_thread(self._get_scores, *input_args)
return await loop.run_in_executor(pool, self._get_scores, *input_args)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import atexit
import json
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
import requests
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
from ..extras.misc import get_device_count, torch_gc
from ..extras.packages import is_sglang_available
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod
from .base_engine import BaseEngine, Response
if is_sglang_available():
from sglang.utils import launch_server_cmd, terminate_process, wait_for_server # type: ignore
if TYPE_CHECKING:
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
logger = logging.get_logger(__name__)
class SGLangEngine(BaseEngine):
"""Inference engine for SGLang models.
This class wraps the SGLang engine to provide a consistent interface for text generation
that matches LLaMA Factory's requirements. It uses the SGLang HTTP server approach for
better interaction and performance. The engine launches a server process and communicates
with it via HTTP requests.
For more details on the SGLang HTTP server approach, see:
https://docs.sglang.ai/backend/send_request.html
"""
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.name = EngineName.SGLANG
self.model_args = model_args
config = load_config(model_args) # may download model from ms hub
if getattr(config, "quantization_config", None): # gptq models should use float16
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
model_args.infer_dtype = "float16"
self.can_generate = finetuning_args.stage == "sft"
tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
self.generating_args = generating_args.to_dict()
launch_cmd = [
"python3 -m sglang.launch_server",
f"--model-path {model_args.model_name_or_path}",
f"--dtype {model_args.infer_dtype}",
f"--context-length {model_args.sglang_maxlen}",
f"--mem-fraction-static {model_args.sglang_mem_fraction}",
f"--tp-size {model_args.sglang_tp_size if model_args.sglang_tp_size != -1 else get_device_count() or 1}",
f"--download-dir {model_args.cache_dir}",
"--log-level error",
]
launch_cmd = " ".join(launch_cmd)
logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
try:
torch_gc()
self.server_process, port = launch_server_cmd(launch_cmd)
self.base_url = f"http://localhost:{port}"
atexit.register(self._cleanup_server)
logger.info_rank0(f"Waiting for SGLang server to be ready at {self.base_url}")
wait_for_server(self.base_url, timeout=300)
logger.info_rank0(f"SGLang server initialized successfully at {self.base_url}")
try:
response = requests.get(f"{self.base_url}/get_model_info", timeout=5)
if response.status_code == 200:
model_info = response.json()
logger.info(f"SGLang server model info: {model_info}")
except Exception as e:
logger.debug(f"Note: could not get model info: {str(e)}")
except Exception as e:
logger.error(f"Failed to start SGLang server: {str(e)}")
self._cleanup_server() # make sure to clean up any started process
raise RuntimeError(f"SGLang server initialization failed: {str(e)}.")
def _cleanup_server(self):
r"""Clean up the server process when the engine is destroyed."""
if hasattr(self, "server_process") and self.server_process:
try:
logger.info("Terminating SGLang server process")
terminate_process(self.server_process)
logger.info("SGLang server process terminated")
except Exception as e:
logger.warning(f"Error terminating SGLang server: {str(e)}")
async def _generate(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncIterator[dict[str, Any]]:
if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
messages = self.template.mm_plugin.process_messages(
messages, images or [], videos or [], audios or [], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
if num_return_sequences != 1:
raise NotImplementedError("SGLang only supports n=1.")
if "max_new_tokens" in self.generating_args:
max_tokens = self.generating_args["max_new_tokens"]
elif "max_length" in self.generating_args:
if self.generating_args["max_length"] > prompt_length:
max_tokens = self.generating_args["max_length"] - prompt_length
else:
max_tokens = 1
if max_length:
max_tokens = max_length - prompt_length if max_length > prompt_length else 1
if max_new_tokens:
max_tokens = max_new_tokens
sampling_params = {
"temperature": temperature if temperature is not None else self.generating_args["temperature"],
"top_p": (top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
"top_k": (top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0
"stop": stop,
"stop_token_ids": self.template.get_stop_token_ids(self.tokenizer),
"max_new_tokens": max_tokens,
"repetition_penalty": (
repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
)
or 1.0, # repetition_penalty must > 0
"skip_special_tokens": skip_special_tokens
if skip_special_tokens is not None
else self.generating_args["skip_special_tokens"],
}
def stream_request():
json_data = {
"input_ids": prompt_ids,
"sampling_params": sampling_params,
"stream": True,
}
response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
if response.status_code != 200:
raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")
for chunk in response.iter_lines(decode_unicode=False):
chunk = str(chunk.decode("utf-8"))
if chunk == "data: [DONE]":
break
if chunk and chunk.startswith("data:"):
yield json.loads(chunk[5:].strip("\n"))
return await asyncio.to_thread(stream_request)
@override
async def chat(
self,
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> list["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
for request_output in generator:
final_output = request_output
results = [
Response(
response_text=final_output["text"],
response_length=final_output["meta_info"]["completion_tokens"],
prompt_length=final_output["meta_info"]["prompt_tokens"],
finish_reason="stop" if final_output["meta_info"]["finish_reason"] == "stop" else "length",
)
]
return results
@override
async def stream_chat(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
for result in generator:
delta_text = result["text"][len(generated_text) :]
generated_text = result["text"]
yield delta_text
@override
async def get_scores(
self,
batch_input: list[str],
**input_kwargs,
) -> list[float]:
raise NotImplementedError("SGLang engine does not support `get_scores`.")
def __del__(self):
r"""Ensure server is cleaned up when object is deleted."""
self._cleanup_server()
try:
atexit.unregister(self._cleanup_server)
except Exception:
pass
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
import uuid import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union from collections.abc import AsyncGenerator, AsyncIterator
from typing import TYPE_CHECKING, Any, Optional, Union
from typing_extensions import override from typing_extensions import override
...@@ -53,7 +54,7 @@ class VllmEngine(BaseEngine): ...@@ -53,7 +54,7 @@ class VllmEngine(BaseEngine):
self.model_args = model_args self.model_args = model_args
config = load_config(model_args) # may download model from ms hub config = load_config(model_args) # may download model from ms hub
if getattr(config, "quantization_config", None): # gptq models should use float16 if getattr(config, "quantization_config", None): # gptq models should use float16
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "") quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto": if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
model_args.infer_dtype = "float16" model_args.infer_dtype = "float16"
...@@ -82,7 +83,7 @@ class VllmEngine(BaseEngine): ...@@ -82,7 +83,7 @@ class VllmEngine(BaseEngine):
"max_lora_rank": model_args.vllm_max_lora_rank, "max_lora_rank": model_args.vllm_max_lora_rank,
} }
if self.template.mm_plugin.__class__.__name__ != "BasePlugin": if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2} engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
if isinstance(model_args.vllm_config, dict): if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config) engine_args.update(model_args.vllm_config)
...@@ -101,33 +102,26 @@ class VllmEngine(BaseEngine): ...@@ -101,33 +102,26 @@ class VllmEngine(BaseEngine):
async def _generate( async def _generate(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = f"chatcmpl-{uuid.uuid4().hex}" request_id = f"chatcmpl-{uuid.uuid4().hex}"
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]} if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]})
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if videos is not None: if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"] messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
if audios is not None: if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"] messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
messages = self.template.mm_plugin.process_messages( messages = self.template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], self.processor messages, images or [], videos or [], audios or [], self.processor
) )
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"] system = system or self.generating_args["default_system"]
...@@ -143,7 +137,7 @@ class VllmEngine(BaseEngine): ...@@ -143,7 +137,7 @@ class VllmEngine(BaseEngine):
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None) skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
if length_penalty is not None: if length_penalty is not None:
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.") logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
...@@ -185,8 +179,24 @@ class VllmEngine(BaseEngine): ...@@ -185,8 +179,24 @@ class VllmEngine(BaseEngine):
images, images,
image_max_pixels=self.model_args.image_max_pixels, image_max_pixels=self.model_args.image_max_pixels,
image_min_pixels=self.model_args.image_min_pixels, image_min_pixels=self.model_args.image_min_pixels,
) )["images"]
}
elif videos is not None:
multi_modal_data = {
"video": self.template.mm_plugin._regularize_videos(
videos,
image_max_pixels=self.model_args.video_max_pixels,
image_min_pixels=self.model_args.video_min_pixels,
video_fps=self.model_args.video_fps,
video_maxlen=self.model_args.video_maxlen,
)["videos"]
} }
elif audios is not None:
audio_data = self.template.mm_plugin._regularize_audios(
audios,
sampling_rate=self.model_args.audio_sampling_rate,
)
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
else: else:
multi_modal_data = None multi_modal_data = None
...@@ -201,14 +211,14 @@ class VllmEngine(BaseEngine): ...@@ -201,14 +211,14 @@ class VllmEngine(BaseEngine):
@override @override
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> list["Response"]:
final_output = None final_output = None
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs) generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
async for request_output in generator: async for request_output in generator:
...@@ -230,12 +240,12 @@ class VllmEngine(BaseEngine): ...@@ -230,12 +240,12 @@ class VllmEngine(BaseEngine):
@override @override
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
generated_text = "" generated_text = ""
...@@ -248,7 +258,7 @@ class VllmEngine(BaseEngine): ...@@ -248,7 +258,7 @@ class VllmEngine(BaseEngine):
@override @override
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: list[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> list[float]:
raise NotImplementedError("vLLM engine does not support get_scores.") raise NotImplementedError("vLLM engine does not support `get_scores`.")
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import os import os
import random
import subprocess import subprocess
import sys import sys
from enum import Enum, unique from enum import Enum, unique
...@@ -24,7 +23,7 @@ from .chat.chat_model import run_chat ...@@ -24,7 +23,7 @@ from .chat.chat_model import run_chat
from .eval.evaluator import run_eval from .eval.evaluator import run_eval
from .extras import logging from .extras import logging
from .extras.env import VERSION, print_env from .extras.env import VERSION, print_env
from .extras.misc import get_device_count, is_env_enabled, use_ray from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
from .train.tuner import export_model, run_exp from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui from .webui.interface import run_web_demo, run_web_ui
...@@ -92,7 +91,7 @@ def main(): ...@@ -92,7 +91,7 @@ def main():
node_rank = os.getenv("NODE_RANK", "0") node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count())) nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999))) master_port = os.getenv("MASTER_PORT", str(find_available_port()))
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}") logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
if int(nnodes) > 1: if int(nnodes) > 1:
print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}") print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
......
...@@ -24,14 +24,14 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer ...@@ -24,14 +24,14 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [ __all__ = [
"TEMPLATES",
"KTODataCollatorWithPadding", "KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq", "MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding", "PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask",
"Role", "Role",
"split_dataset", "SFTDataCollatorWith4DAttentionMask",
"get_dataset",
"TEMPLATES",
"Template", "Template",
"get_dataset",
"get_template_and_fix_tokenizer", "get_template_and_fix_tokenizer",
"split_dataset",
] ]
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team. # Copyright 2025 OpenAccess AI Collective and the LlamaFactory team.
# #
# This code is inspired by the OpenAccess AI Collective's axolotl library. # This code is inspired by the OpenAccess AI Collective's axolotl library.
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py # https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence from typing import TYPE_CHECKING, Any, Literal, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -24,6 +24,7 @@ import torch.nn.functional as F ...@@ -24,6 +24,7 @@ import torch.nn.functional as F
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
from ..extras.misc import get_current_device
from ..extras.packages import is_pillow_available from ..extras.packages import is_pillow_available
...@@ -38,9 +39,10 @@ if TYPE_CHECKING: ...@@ -38,9 +39,10 @@ if TYPE_CHECKING:
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r""" r"""Expand 2d attention mask to 4d attention mask.
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking. Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g. e.g.
```python ```python
...@@ -62,24 +64,37 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype ...@@ -62,24 +64,37 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
``` ```
where `o` equals to `0.0`, `x` equals to `min_dtype`. where `o` equals to `0.0`, `x` equals to `min_dtype`.
""" """
bsz, seq_len = attention_mask_with_indices.size() _, seq_len = attention_mask_with_indices.size()
# Move to compute device if the source is CPU.
source_device = attention_mask_with_indices.device
compute_device = get_current_device() if source_device.type == "cpu" else source_device
if compute_device != source_device:
attention_mask_with_indices = attention_mask_with_indices.to(compute_device)
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len) zero_tensor = torch.tensor(0, dtype=dtype, device=compute_device)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
padding_mask = torch.where(expanded_mask != 0, 1, 0) # Create a non-padding mask.
# Create a block-diagonal mask. non_padding = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask # Create indices for comparison.
# Use the lower triangular mask to zero out the upper triangular part indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, seq_len]
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long)) indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3) # [bsz, 1, seq_len, 1]
# Create a lower triangular mask.
tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=compute_device))
attention_mask_4d = (indices == indices_t) & non_padding & tril_mask
# Invert the attention mask. # Invert the attention mask.
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype) attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype)
# Move back to original device if needed.
if compute_device != source_device:
attention_mask_4d = attention_mask_4d.to(source_device)
return attention_mask_4d return attention_mask_4d
@dataclass @dataclass
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r""" r"""Data collator that supports VLMs.
Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios. Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
""" """
...@@ -91,7 +106,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -91,7 +106,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if self.template is None: if self.template is None:
raise ValueError("Template is required for MultiModalDataCollator.") raise ValueError("Template is required for MultiModalDataCollator.")
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_audios = [], [], [] batch_images, batch_videos, batch_audios = [], [], []
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], [] batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
for feature in features: for feature in features:
...@@ -166,7 +181,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -166,7 +181,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
for i, feature in enumerate(features): for i, feature in enumerate(features):
feature["token_type_ids"] = token_type_ids[i] feature["token_type_ids"] = token_type_ids[i]
features: Dict[str, "torch.Tensor"] = super().__call__(features) features: dict[str, torch.Tensor] = super().__call__(features)
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
rope_index_kwargs = { rope_index_kwargs = {
...@@ -175,9 +190,27 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -175,9 +190,27 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"video_grid_thw": mm_inputs.get("video_grid_thw"), "video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": features["attention_mask"], "attention_mask": features["attention_mask"],
} }
if "second_per_grid_ts" in mm_inputs: if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts") rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
if "video_second_per_grid" in mm_inputs: # for qwen2omni
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(
feature_attention_mask, dim=1
) # FIXME need to get video image lengths
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
# avoid conflict
new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
features["position_ids"], features["rope_deltas"] = (
new_position_ids.clone(),
rope_deltas - delta0,
) # avoid inplace operation FIXME
else: # for qwen2vl
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs) features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
...@@ -198,15 +231,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -198,15 +231,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
@dataclass @dataclass
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
r""" r"""Data collator for 4d attention mask."""
Data collator for 4d attention mask.
"""
block_diag_attn: bool = False block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager" attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32 compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
features = super().__call__(features) features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2": if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
...@@ -220,13 +251,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): ...@@ -220,13 +251,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
@dataclass @dataclass
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r""" r"""Data collator for pairwise data."""
Data collator for pairwise data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
r""" r"""Pad batched data to the longest sequence in the batch.
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples. the last n examples represent rejected examples.
...@@ -249,11 +277,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): ...@@ -249,11 +277,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
@dataclass @dataclass
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r""" r"""Data collator for KTO data."""
Data collator for KTO data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
target_features = [] target_features = []
kl_features = [] kl_features = []
kto_tags = [] kto_tags = []
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import os import os
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union from typing import TYPE_CHECKING, Any, Optional, Union
from ..extras import logging from ..extras import logging
from .data_utils import Role from .data_utils import Role
...@@ -26,8 +26,12 @@ if TYPE_CHECKING: ...@@ -26,8 +26,12 @@ if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments from ..hparams import DataArguments
from .mm_plugin import AudioInput, ImageInput, VideoInput
from .parser import DatasetAttr from .parser import DatasetAttr
MediaType = Union[ImageInput, VideoInput, AudioInput]
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -36,12 +40,12 @@ class DatasetConverter: ...@@ -36,12 +40,12 @@ class DatasetConverter:
dataset_attr: "DatasetAttr" dataset_attr: "DatasetAttr"
data_args: "DataArguments" data_args: "DataArguments"
def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[List[Any]]: def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]:
r""" r"""Optionally concatenate media path to media dir when loading from local disk."""
Optionally concatenates media path to media dir when loading from local disk. if medias is None:
""" return None
if not isinstance(medias, list): elif not isinstance(medias, list):
medias = [medias] if medias is not None else [] medias = [medias]
elif len(medias) == 0: elif len(medias) == 0:
return None return None
else: else:
...@@ -57,16 +61,14 @@ class DatasetConverter: ...@@ -57,16 +61,14 @@ class DatasetConverter:
return medias return medias
@abstractmethod @abstractmethod
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]: def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
r""" r"""Convert a single example in the dataset to the standard format."""
Converts a single example in the dataset to the standard format.
"""
... ...
@dataclass @dataclass
class AlpacaDatasetConverter(DatasetConverter): class AlpacaDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]: def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
prompt = [] prompt = []
if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list): if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list):
for old_prompt, old_response in example[self.dataset_attr.history]: for old_prompt, old_response in example[self.dataset_attr.history]:
...@@ -116,7 +118,7 @@ class AlpacaDatasetConverter(DatasetConverter): ...@@ -116,7 +118,7 @@ class AlpacaDatasetConverter(DatasetConverter):
@dataclass @dataclass
class SharegptDatasetConverter(DatasetConverter): class SharegptDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]: def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
tag_mapping = { tag_mapping = {
self.dataset_attr.user_tag: Role.USER.value, self.dataset_attr.user_tag: Role.USER.value,
self.dataset_attr.assistant_tag: Role.ASSISTANT.value, self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
...@@ -216,10 +218,8 @@ DATASET_CONVERTERS = { ...@@ -216,10 +218,8 @@ DATASET_CONVERTERS = {
} }
def register_dataset_converter(name: str, dataset_converter: Type["DatasetConverter"]) -> None: def register_dataset_converter(name: str, dataset_converter: type["DatasetConverter"]) -> None:
r""" r"""Register a new dataset converter."""
Register a new dataset converter.
"""
if name in DATASET_CONVERTERS: if name in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} already exists.") raise ValueError(f"Dataset converter {name} already exists.")
...@@ -227,9 +227,7 @@ def register_dataset_converter(name: str, dataset_converter: Type["DatasetConver ...@@ -227,9 +227,7 @@ def register_dataset_converter(name: str, dataset_converter: Type["DatasetConver
def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter": def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter":
r""" r"""Get a dataset converter."""
Gets a dataset converter.
"""
if name not in DATASET_CONVERTERS: if name not in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} not found.") raise ValueError(f"Dataset converter {name} not found.")
...@@ -242,17 +240,17 @@ def align_dataset( ...@@ -242,17 +240,17 @@ def align_dataset(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""Align the dataset to a specific format.
Aligned dataset: Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1) _prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) _response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..." _system: "..."
_tools: "...", _tools: "..."
_images: [], _images: []
_videos: [], _videos: []
_audios: [], _audios: []
""" """
column_names = list(next(iter(dataset)).keys()) column_names = list(next(iter(dataset)).keys())
kwargs = {} kwargs = {}
if not data_args.streaming: if not data_args.streaming:
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from enum import Enum, unique from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union from typing import TYPE_CHECKING, Optional, TypedDict, Union
from datasets import DatasetDict, concatenate_datasets, interleave_datasets from datasets import DatasetDict, concatenate_datasets, interleave_datasets
...@@ -29,7 +29,7 @@ if TYPE_CHECKING: ...@@ -29,7 +29,7 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] SLOTS = list[Union[str, set[str], dict[str, str]]]
@unique @unique
...@@ -43,15 +43,13 @@ class Role(str, Enum): ...@@ -43,15 +43,13 @@ class Role(str, Enum):
class DatasetModule(TypedDict): class DatasetModule(TypedDict):
train_dataset: Optional[Union["Dataset", "IterableDataset"]] train_dataset: Optional[Union["Dataset", "IterableDataset"]]
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]] eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]
def merge_dataset( def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int all_datasets: list[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""Merge multiple datasets to a unified dataset."""
Merges multiple datasets to a unified dataset.
"""
if len(all_datasets) == 1: if len(all_datasets) == 1:
return all_datasets[0] return all_datasets[0]
...@@ -78,14 +76,13 @@ def merge_dataset( ...@@ -78,14 +76,13 @@ def merge_dataset(
def split_dataset( def split_dataset(
dataset: Optional[Union["Dataset", "IterableDataset"]], dataset: Optional[Union["Dataset", "IterableDataset"]],
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]], eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
data_args: "DataArguments", data_args: "DataArguments",
seed: int, seed: int,
) -> "DatasetDict": ) -> "DatasetDict":
r""" r"""Split the dataset and returns a dataset dict containing train set and validation set.
Splits the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset. Support both map dataset and iterable dataset.
""" """
if eval_dataset is not None and data_args.val_size > 1e-6: if eval_dataset is not None and data_args.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.") raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
...@@ -120,10 +117,8 @@ def split_dataset( ...@@ -120,10 +117,8 @@ def split_dataset(
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule": def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
r""" r"""Convert dataset or dataset dict to dataset module."""
Converts dataset or dataset dict to dataset module. dataset_module: DatasetModule = {}
"""
dataset_module: "DatasetModule" = {}
if isinstance(dataset, DatasetDict): # dataset dict if isinstance(dataset, DatasetDict): # dataset dict
if "train" in dataset: if "train" in dataset:
dataset_module["train_dataset"] = dataset["train"] dataset_module["train_dataset"] = dataset["train"]
......
...@@ -16,7 +16,7 @@ import json ...@@ -16,7 +16,7 @@ import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional, Union from typing import Optional, Union
from typing_extensions import override from typing_extensions import override
...@@ -31,14 +31,11 @@ class Formatter(ABC): ...@@ -31,14 +31,11 @@ class Formatter(ABC):
@abstractmethod @abstractmethod
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
r""" r"""Forms a list of slots according to the inputs to encode."""
Forms a list of slots according to the inputs to encode.
"""
... ...
def extract(self, content: str) -> Union[str, List["FunctionCall"]]: def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
r""" r"""Extract a list of tuples from the response message if using tools.
Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments. Each tuple consists of function name and function arguments.
""" """
...@@ -105,7 +102,7 @@ class FunctionFormatter(StringFormatter): ...@@ -105,7 +102,7 @@ class FunctionFormatter(StringFormatter):
if thought: if thought:
content = content.replace(thought.group(0), "") content = content.replace(thought.group(0), "")
functions: List["FunctionCall"] = [] functions: list[FunctionCall] = []
try: try:
tool_calls = json.loads(content) tool_calls = json.loads(content)
if not isinstance(tool_calls, list): # parallel function call if not isinstance(tool_calls, list): # parallel function call
...@@ -141,5 +138,5 @@ class ToolFormatter(Formatter): ...@@ -141,5 +138,5 @@ class ToolFormatter(Formatter):
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
@override @override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]: def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
return self.tool_utils.tool_extractor(content) return self.tool_utils.tool_extractor(content)
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np import numpy as np
from datasets import load_dataset, load_from_disk from datasets import load_dataset, load_from_disk
...@@ -54,9 +54,7 @@ def _load_single_dataset( ...@@ -54,9 +54,7 @@ def _load_single_dataset(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""Load a single dataset and aligns it to the standard format."""
Loads a single dataset and aligns it to the standard format.
"""
logger.info_rank0(f"Loading dataset {dataset_attr}...") logger.info_rank0(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]: if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
...@@ -133,10 +131,12 @@ def _load_single_dataset( ...@@ -133,10 +131,12 @@ def _load_single_dataset(
split=dataset_attr.split, split=dataset_attr.split,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
streaming=data_args.streaming,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
trust_remote_code=model_args.trust_remote_code, trust_remote_code=model_args.trust_remote_code,
streaming=data_args.streaming and dataset_attr.load_from != "file",
) )
if data_args.streaming and dataset_attr.load_from == "file":
dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
if dataset_attr.num_samples is not None and not data_args.streaming: if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples target_num = dataset_attr.num_samples
...@@ -158,16 +158,14 @@ def _load_single_dataset( ...@@ -158,16 +158,14 @@ def _load_single_dataset(
def _get_merged_dataset( def _get_merged_dataset(
dataset_names: Optional[Sequence[str]], dataset_names: Optional[list[str]],
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
merge: bool = True, merge: bool = True,
) -> Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]: ) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
r""" r"""Return the merged datasets in the standard format."""
Returns the merged datasets in the standard format.
"""
if dataset_names is None: if dataset_names is None:
return None return None
...@@ -192,9 +190,7 @@ def _get_dataset_processor( ...@@ -192,9 +190,7 @@ def _get_dataset_processor(
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
do_generate: bool = False, do_generate: bool = False,
) -> "DatasetProcessor": ) -> "DatasetProcessor":
r""" r"""Return the corresponding dataset processor."""
Returns the corresponding dataset processor.
"""
if stage == "pt": if stage == "pt":
dataset_processor_class = PretrainDatasetProcessor dataset_processor_class = PretrainDatasetProcessor
elif stage == "sft" and not do_generate: elif stage == "sft" and not do_generate:
...@@ -236,9 +232,7 @@ def _get_preprocessed_dataset( ...@@ -236,9 +232,7 @@ def _get_preprocessed_dataset(
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False, is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]: ) -> Optional[Union["Dataset", "IterableDataset"]]:
r""" r"""Preprocesses the dataset, including format checking and tokenization."""
Preprocesses the dataset, including format checking and tokenization.
"""
if dataset is None: if dataset is None:
return None return None
...@@ -284,9 +278,7 @@ def get_dataset( ...@@ -284,9 +278,7 @@ def get_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
) -> "DatasetModule": ) -> "DatasetModule":
r""" r"""Get the train dataset and optionally gets the evaluation dataset."""
Gets the train dataset and optionally gets the evaluation dataset.
"""
# Load tokenized dataset if path exists # Load tokenized dataset if path exists
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path): if has_tokenized_data(data_args.tokenized_path):
......
This diff is collapsed.
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Sequence from typing import Any, Literal, Optional
from transformers.utils import cached_file from huggingface_hub import hf_hub_download
from ..extras.constants import DATA_CONFIG from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope, use_openmind from ..extras.misc import use_modelscope, use_openmind
...@@ -25,9 +25,7 @@ from ..extras.misc import use_modelscope, use_openmind ...@@ -25,9 +25,7 @@ from ..extras.misc import use_modelscope, use_openmind
@dataclass @dataclass
class DatasetAttr: class DatasetAttr:
r""" r"""Dataset attributes."""
Dataset attributes.
"""
# basic configs # basic configs
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"] load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
...@@ -68,10 +66,10 @@ class DatasetAttr: ...@@ -68,10 +66,10 @@ class DatasetAttr:
def __repr__(self) -> str: def __repr__(self) -> str:
return self.dataset_name return self.dataset_name
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None: def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
setattr(self, key, obj.get(key, default)) setattr(self, key, obj.get(key, default))
def join(self, attr: Dict[str, Any]) -> None: def join(self, attr: dict[str, Any]) -> None:
self.set_attr("formatting", attr, default="alpaca") self.set_attr("formatting", attr, default="alpaca")
self.set_attr("ranking", attr, default=False) self.set_attr("ranking", attr, default=False)
self.set_attr("subset", attr) self.set_attr("subset", attr)
...@@ -92,10 +90,8 @@ class DatasetAttr: ...@@ -92,10 +90,8 @@ class DatasetAttr:
self.set_attr(tag, attr["tags"]) self.set_attr(tag, attr["tags"])
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]: def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> list["DatasetAttr"]:
r""" r"""Get the attributes of the datasets."""
Gets the attributes of the datasets.
"""
if dataset_names is None: if dataset_names is None:
dataset_names = [] dataset_names = []
...@@ -103,7 +99,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) - ...@@ -103,7 +99,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_info = None dataset_info = None
else: else:
if dataset_dir.startswith("REMOTE:"): if dataset_dir.startswith("REMOTE:"):
config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset") config_path = hf_hub_download(repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
else: else:
config_path = os.path.join(dataset_dir, DATA_CONFIG) config_path = os.path.join(dataset_dir, DATA_CONFIG)
...@@ -116,7 +112,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) - ...@@ -116,7 +112,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_info = None dataset_info = None
dataset_list: List["DatasetAttr"] = [] dataset_list: list[DatasetAttr] = []
for name in dataset_names: for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE if dataset_info is None: # dataset_dir is ONLINE
if use_modelscope(): if use_modelscope():
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .feedback import FeedbackDatasetProcessor from .feedback import FeedbackDatasetProcessor
from .pairwise import PairwiseDatasetProcessor from .pairwise import PairwiseDatasetProcessor
from .pretrain import PretrainDatasetProcessor from .pretrain import PretrainDatasetProcessor
...@@ -9,9 +23,9 @@ from .unsupervised import UnsupervisedDatasetProcessor ...@@ -9,9 +23,9 @@ from .unsupervised import UnsupervisedDatasetProcessor
__all__ = [ __all__ = [
"DatasetProcessor", "DatasetProcessor",
"FeedbackDatasetProcessor", "FeedbackDatasetProcessor",
"PackedSupervisedDatasetProcessor",
"PairwiseDatasetProcessor", "PairwiseDatasetProcessor",
"PretrainDatasetProcessor", "PretrainDatasetProcessor",
"PackedSupervisedDatasetProcessor",
"SupervisedDatasetProcessor", "SupervisedDatasetProcessor",
"UnsupervisedDatasetProcessor", "UnsupervisedDatasetProcessor",
] ]
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
...@@ -30,15 +30,15 @@ logger = logging.get_logger(__name__) ...@@ -30,15 +30,15 @@ logger = logging.get_logger(__name__)
class FeedbackDatasetProcessor(DatasetProcessor): class FeedbackDatasetProcessor(DatasetProcessor):
def _encode_data_example( def _encode_data_example(
self, self,
prompt: Sequence[Dict[str, str]], prompt: list[dict[str, str]],
response: Sequence[Dict[str, str]], response: list[dict[str, str]],
kl_response: Sequence[Dict[str, str]], kl_response: list[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int], bool]: ) -> tuple[list[int], list[int], list[int], list[int], bool]:
if response[0]["content"]: # desired example if response[0]["content"]: # desired example
kto_tag = True kto_tag = True
messages = prompt + [response[0]] messages = prompt + [response[0]]
...@@ -82,9 +82,9 @@ class FeedbackDatasetProcessor(DatasetProcessor): ...@@ -82,9 +82,9 @@ class FeedbackDatasetProcessor(DatasetProcessor):
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag return input_ids, labels, kl_input_ids, kl_labels, kto_tag
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs # Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions.
kl_response = examples["_response"][::-1] kl_response = [examples["_response"][-1]] + examples["_response"][:-1]
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
...@@ -121,7 +121,7 @@ class FeedbackDatasetProcessor(DatasetProcessor): ...@@ -121,7 +121,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
return model_inputs return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None: def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
...@@ -30,14 +30,14 @@ logger = logging.get_logger(__name__) ...@@ -30,14 +30,14 @@ logger = logging.get_logger(__name__)
class PairwiseDatasetProcessor(DatasetProcessor): class PairwiseDatasetProcessor(DatasetProcessor):
def _encode_data_example( def _encode_data_example(
self, self,
prompt: Sequence[Dict[str, str]], prompt: list[dict[str, str]],
response: Sequence[Dict[str, str]], response: list[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int]]: ) -> tuple[list[int], list[int], list[int], list[int]]:
chosen_messages = self.template.mm_plugin.process_messages( chosen_messages = self.template.mm_plugin.process_messages(
prompt + [response[0]], images, videos, audios, self.processor prompt + [response[0]], images, videos, audios, self.processor
) )
...@@ -68,7 +68,7 @@ class PairwiseDatasetProcessor(DatasetProcessor): ...@@ -68,7 +68,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>` # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
...@@ -99,7 +99,7 @@ class PairwiseDatasetProcessor(DatasetProcessor): ...@@ -99,7 +99,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
return model_inputs return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None: def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"])) valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"])) valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"])) print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
...@@ -17,14 +17,14 @@ ...@@ -17,14 +17,14 @@
from dataclasses import dataclass from dataclasses import dataclass
from itertools import chain from itertools import chain
from typing import Any, Dict, List from typing import Any
from .processor_utils import DatasetProcessor from .processor_utils import DatasetProcessor
@dataclass @dataclass
class PretrainDatasetProcessor(DatasetProcessor): class PretrainDatasetProcessor(DatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]] text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
...@@ -52,6 +52,6 @@ class PretrainDatasetProcessor(DatasetProcessor): ...@@ -52,6 +52,6 @@ class PretrainDatasetProcessor(DatasetProcessor):
return result return result
def print_data_example(self, example: Dict[str, List[int]]) -> None: def print_data_example(self, example: dict[str, list[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import bisect import bisect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -27,9 +27,7 @@ if TYPE_CHECKING: ...@@ -27,9 +27,7 @@ if TYPE_CHECKING:
@dataclass @dataclass
class DatasetProcessor(ABC): class DatasetProcessor(ABC):
r""" r"""A class for data processors."""
A class for data processors.
"""
template: "Template" template: "Template"
tokenizer: "PreTrainedTokenizer" tokenizer: "PreTrainedTokenizer"
...@@ -37,32 +35,24 @@ class DatasetProcessor(ABC): ...@@ -37,32 +35,24 @@ class DatasetProcessor(ABC):
data_args: "DataArguments" data_args: "DataArguments"
@abstractmethod @abstractmethod
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
r""" r"""Build model inputs from the examples."""
Builds model inputs from the examples.
"""
... ...
@abstractmethod @abstractmethod
def print_data_example(self, example: Dict[str, List[int]]) -> None: def print_data_example(self, example: dict[str, list[int]]) -> None:
r""" r"""Print a data example to stdout."""
Print a data example to stdout.
"""
... ...
def search_for_fit(numbers: Sequence[int], capacity: int) -> int: def search_for_fit(numbers: list[int], capacity: int) -> int:
r""" r"""Find the index of largest number that fits into the knapsack with the given capacity."""
Finds the index of largest number that fits into the knapsack with the given capacity.
"""
index = bisect.bisect(numbers, capacity) index = bisect.bisect(numbers, capacity)
return -1 if index == 0 else (index - 1) return -1 if index == 0 else (index - 1)
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: def greedy_knapsack(numbers: list[int], capacity: int) -> list[list[int]]:
r""" r"""Implement efficient greedy algorithm with binary search for the knapsack problem."""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
numbers.sort() # sort numbers in ascending order for binary search numbers.sort() # sort numbers in ascending order for binary search
knapsacks = [] knapsacks = []
...@@ -83,10 +73,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: ...@@ -83,10 +73,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
return knapsacks return knapsacks
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> tuple[int, int]:
r""" r"""Compute the real sequence length after truncation by the cutoff_len."""
Computes the real sequence length after truncation by the cutoff_len.
"""
if target_len * 2 < cutoff_len: # truncate source if target_len * 2 < cutoff_len: # truncate source
max_target_len = cutoff_len max_target_len = cutoff_len
elif source_len * 2 < cutoff_len: # truncate target elif source_len * 2 < cutoff_len: # truncate target
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
...@@ -32,14 +32,14 @@ logger = logging.get_logger(__name__) ...@@ -32,14 +32,14 @@ logger = logging.get_logger(__name__)
class SupervisedDatasetProcessor(DatasetProcessor): class SupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example( def _encode_data_example(
self, self,
prompt: Sequence[Dict[str, str]], prompt: list[dict[str, str]],
response: Sequence[Dict[str, str]], response: list[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
) -> Tuple[List[int], List[int]]: ) -> tuple[list[int], list[int]]:
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor) messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
input_ids, labels = self.template.mm_plugin.process_token_ids( input_ids, labels = self.template.mm_plugin.process_token_ids(
[], [], images, videos, audios, self.tokenizer, self.processor [], [], images, videos, audios, self.tokenizer, self.processor
...@@ -85,7 +85,7 @@ class SupervisedDatasetProcessor(DatasetProcessor): ...@@ -85,7 +85,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
return input_ids, labels return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>` # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair. # for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
...@@ -114,7 +114,7 @@ class SupervisedDatasetProcessor(DatasetProcessor): ...@@ -114,7 +114,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
return model_inputs return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None: def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
...@@ -124,7 +124,7 @@ class SupervisedDatasetProcessor(DatasetProcessor): ...@@ -124,7 +124,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
@dataclass @dataclass
class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor): class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# TODO: use `position_ids` to achieve packing # TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>` # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>` # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
...@@ -165,7 +165,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor): ...@@ -165,7 +165,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len) knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
for knapsack in knapsacks: for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], [] packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos, packed_audios = [], [], [] packed_images, packed_videos, packed_audios, packed_position_ids = [], [], [], []
for i, length in enumerate(knapsack): for i, length in enumerate(knapsack):
index = length2indexes[length].pop() index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index] packed_input_ids += batch_input_ids[index]
...@@ -175,6 +175,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor): ...@@ -175,6 +175,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
packed_audios += batch_audios[index] packed_audios += batch_audios[index]
if self.data_args.neat_packing: if self.data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
packed_position_ids += list(range(len(batch_input_ids[index])))
else: else:
packed_attention_masks += [1] * len(batch_input_ids[index]) packed_attention_masks += [1] * len(batch_input_ids[index])
...@@ -184,6 +185,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor): ...@@ -184,6 +185,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
packed_labels += [IGNORE_INDEX] * pad_length packed_labels += [IGNORE_INDEX] * pad_length
if self.data_args.neat_packing: if self.data_args.neat_packing:
packed_attention_masks += [0] * pad_length packed_attention_masks += [0] * pad_length
packed_position_ids += [0] * pad_length
else: else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn packed_attention_masks += [1] * pad_length # more efficient flash_attn
...@@ -196,5 +198,6 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor): ...@@ -196,5 +198,6 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
model_inputs["images"].append(packed_images or None) model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None) model_inputs["videos"].append(packed_videos or None)
model_inputs["audios"].append(packed_audios or None) model_inputs["audios"].append(packed_audios or None)
model_inputs["position_ids"].append(packed_position_ids or None)
return model_inputs return model_inputs
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
from ..data_utils import Role from ..data_utils import Role
...@@ -30,14 +30,14 @@ logger = logging.get_logger(__name__) ...@@ -30,14 +30,14 @@ logger = logging.get_logger(__name__)
class UnsupervisedDatasetProcessor(DatasetProcessor): class UnsupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example( def _encode_data_example(
self, self,
prompt: Sequence[Dict[str, str]], prompt: list[dict[str, str]],
response: Sequence[Dict[str, str]], response: list[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
) -> Tuple[List[int], List[int]]: ) -> tuple[list[int], list[int]]:
if len(response) == 1: if len(response) == 1:
messages = prompt + response messages = prompt + response
else: else:
...@@ -56,7 +56,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor): ...@@ -56,7 +56,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
labels = labels[:target_len] labels = labels[:target_len]
return input_ids, labels return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>` # build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
...@@ -84,7 +84,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor): ...@@ -84,7 +84,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
return model_inputs return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None: def print_data_example(self, example: dict[str, list[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"])) print("label_ids:\n{}".format(example["labels"]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment