Commit 7ea81099 authored by chenych's avatar chenych
Browse files

update llama4

parent 84987715
# Copyright 2024 THUDM and the LlamaFactory team.
# Copyright 2025 THUDM and the LlamaFactory team.
#
# This code is inspired by the THUDM's ChatGLM implementation.
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
......@@ -17,13 +17,15 @@
import asyncio
import os
from collections.abc import AsyncGenerator, Generator
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
from typing import TYPE_CHECKING, Any, Optional
from ..extras.constants import EngineName
from ..extras.misc import torch_gc
from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine
from .sglang_engine import SGLangEngine
from .vllm_engine import VllmEngine
......@@ -38,20 +40,21 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
class ChatModel:
r"""
General class for chat models. Backed by huggingface or vllm engines.
r"""General class for chat models. Backed by huggingface or vllm engines.
Supports both sync and async methods.
Sync methods: chat(), stream_chat() and get_scores().
Async methods: achat(), astream_chat() and aget_scores().
"""
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
if model_args.infer_backend == EngineName.HF:
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == EngineName.VLLM:
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == EngineName.SGLANG:
self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
else:
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
......@@ -61,17 +64,15 @@ class ChatModel:
def chat(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
) -> list["Response"]:
r"""Get a list of responses of the chat model."""
task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
)
......@@ -79,32 +80,28 @@ class ChatModel:
async def achat(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Asynchronously gets a list of responses of the chat model.
"""
) -> list["Response"]:
r"""Asynchronously get a list of responses of the chat model."""
return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
def stream_chat(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> Generator[str, None, None]:
r"""
Gets the response token-by-token of the chat model.
"""
r"""Get the response token-by-token of the chat model."""
generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
while True:
try:
......@@ -115,17 +112,15 @@ class ChatModel:
async def astream_chat(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
r"""
Asynchronously gets the response token-by-token of the chat model.
"""
r"""Asynchronously get the response token-by-token of the chat model."""
async for new_token in self.engine.stream_chat(
messages, system, tools, images, videos, audios, **input_kwargs
):
......@@ -133,23 +128,19 @@ class ChatModel:
def get_scores(
self,
batch_input: List[str],
batch_input: list[str],
**input_kwargs,
) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
) -> list[float]:
r"""Get a list of scores of the reward model."""
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
return task.result()
async def aget_scores(
self,
batch_input: List[str],
batch_input: list[str],
**input_kwargs,
) -> List[float]:
r"""
Asynchronously gets a list of scores of the reward model.
"""
) -> list[float]:
r"""Asynchronously get a list of scores of the reward model."""
return await self.engine.get_scores(batch_input, **input_kwargs)
......
......@@ -13,10 +13,10 @@
# limitations under the License.
import asyncio
import concurrent.futures
import os
from collections.abc import AsyncGenerator
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
from transformers import GenerationConfig, TextIteratorStreamer
......@@ -76,15 +76,15 @@ class HuggingfaceEngine(BaseEngine):
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
generating_args: dict[str, Any],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
input_kwargs: Optional[dict[str, Any]] = {},
) -> tuple[dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]})
......@@ -130,7 +130,7 @@ class HuggingfaceEngine(BaseEngine):
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
if stop is not None:
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
......@@ -217,15 +217,15 @@ class HuggingfaceEngine(BaseEngine):
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
generating_args: dict[str, Any],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
input_kwargs: Optional[dict[str, Any]] = {},
) -> list["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model,
tokenizer,
......@@ -272,14 +272,14 @@ class HuggingfaceEngine(BaseEngine):
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
generating_args: dict[str, Any],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
input_kwargs: Optional[dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
model,
......@@ -317,12 +317,12 @@ class HuggingfaceEngine(BaseEngine):
def _get_scores(
model: "PreTrainedModelWrapper",
tokenizer: "PreTrainedTokenizer",
batch_input: List[str],
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List[float]:
batch_input: list[str],
input_kwargs: Optional[dict[str, Any]] = {},
) -> list[float]:
max_length: Optional[int] = input_kwargs.pop("max_length", None)
device = getattr(model.pretrained_model, "device", "cuda")
inputs: Dict[str, "torch.Tensor"] = tokenizer(
inputs: dict[str, torch.Tensor] = tokenizer(
batch_input,
padding=True,
truncation=True,
......@@ -330,25 +330,24 @@ class HuggingfaceEngine(BaseEngine):
return_tensors="pt",
add_special_tokens=False,
).to(device)
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return scores
@override
async def chat(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
) -> list["Response"]:
if not self.can_generate:
raise ValueError("The current model does not support `chat`.")
loop = asyncio.get_running_loop()
input_args = (
self.model,
self.tokenizer,
......@@ -364,24 +363,22 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs,
)
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._chat, *input_args)
return await asyncio.to_thread(self._chat, *input_args)
@override
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
raise ValueError("The current model does not support `stream_chat`.")
loop = asyncio.get_running_loop()
input_args = (
self.model,
self.tokenizer,
......@@ -397,25 +394,22 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs,
)
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
stream = self._stream_chat(*input_args)
while True:
try:
yield await loop.run_in_executor(pool, stream)
yield await asyncio.to_thread(stream)
except StopAsyncIteration:
break
@override
async def get_scores(
self,
batch_input: List[str],
batch_input: list[str],
**input_kwargs,
) -> List[float]:
) -> list[float]:
if self.can_generate:
raise ValueError("Cannot get scores using an auto-regressive model.")
loop = asyncio.get_running_loop()
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._get_scores, *input_args)
return await asyncio.to_thread(self._get_scores, *input_args)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import atexit
import json
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
import requests
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
from ..extras.misc import get_device_count, torch_gc
from ..extras.packages import is_sglang_available
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod
from .base_engine import BaseEngine, Response
if is_sglang_available():
from sglang.utils import launch_server_cmd, terminate_process, wait_for_server # type: ignore
if TYPE_CHECKING:
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
logger = logging.get_logger(__name__)
class SGLangEngine(BaseEngine):
"""Inference engine for SGLang models.
This class wraps the SGLang engine to provide a consistent interface for text generation
that matches LLaMA Factory's requirements. It uses the SGLang HTTP server approach for
better interaction and performance. The engine launches a server process and communicates
with it via HTTP requests.
For more details on the SGLang HTTP server approach, see:
https://docs.sglang.ai/backend/send_request.html
"""
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.name = EngineName.SGLANG
self.model_args = model_args
config = load_config(model_args) # may download model from ms hub
if getattr(config, "quantization_config", None): # gptq models should use float16
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
model_args.infer_dtype = "float16"
self.can_generate = finetuning_args.stage == "sft"
tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
self.generating_args = generating_args.to_dict()
launch_cmd = [
"python3 -m sglang.launch_server",
f"--model-path {model_args.model_name_or_path}",
f"--dtype {model_args.infer_dtype}",
f"--context-length {model_args.sglang_maxlen}",
f"--mem-fraction-static {model_args.sglang_mem_fraction}",
f"--tp-size {model_args.sglang_tp_size if model_args.sglang_tp_size != -1 else get_device_count() or 1}",
f"--download-dir {model_args.cache_dir}",
"--log-level error",
]
launch_cmd = " ".join(launch_cmd)
logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
try:
torch_gc()
self.server_process, port = launch_server_cmd(launch_cmd)
self.base_url = f"http://localhost:{port}"
atexit.register(self._cleanup_server)
logger.info_rank0(f"Waiting for SGLang server to be ready at {self.base_url}")
wait_for_server(self.base_url, timeout=300)
logger.info_rank0(f"SGLang server initialized successfully at {self.base_url}")
try:
response = requests.get(f"{self.base_url}/get_model_info", timeout=5)
if response.status_code == 200:
model_info = response.json()
logger.info(f"SGLang server model info: {model_info}")
except Exception as e:
logger.debug(f"Note: could not get model info: {str(e)}")
except Exception as e:
logger.error(f"Failed to start SGLang server: {str(e)}")
self._cleanup_server() # make sure to clean up any started process
raise RuntimeError(f"SGLang server initialization failed: {str(e)}.")
def _cleanup_server(self):
r"""Clean up the server process when the engine is destroyed."""
if hasattr(self, "server_process") and self.server_process:
try:
logger.info("Terminating SGLang server process")
terminate_process(self.server_process)
logger.info("SGLang server process terminated")
except Exception as e:
logger.warning(f"Error terminating SGLang server: {str(e)}")
async def _generate(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncIterator[dict[str, Any]]:
if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
messages = self.template.mm_plugin.process_messages(
messages, images or [], videos or [], audios or [], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
if num_return_sequences != 1:
raise NotImplementedError("SGLang only supports n=1.")
if "max_new_tokens" in self.generating_args:
max_tokens = self.generating_args["max_new_tokens"]
elif "max_length" in self.generating_args:
if self.generating_args["max_length"] > prompt_length:
max_tokens = self.generating_args["max_length"] - prompt_length
else:
max_tokens = 1
if max_length:
max_tokens = max_length - prompt_length if max_length > prompt_length else 1
if max_new_tokens:
max_tokens = max_new_tokens
sampling_params = {
"temperature": temperature if temperature is not None else self.generating_args["temperature"],
"top_p": (top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
"top_k": (top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0
"stop": stop,
"stop_token_ids": self.template.get_stop_token_ids(self.tokenizer),
"max_new_tokens": max_tokens,
"repetition_penalty": (
repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
)
or 1.0, # repetition_penalty must > 0
"skip_special_tokens": skip_special_tokens
if skip_special_tokens is not None
else self.generating_args["skip_special_tokens"],
}
def stream_request():
json_data = {
"input_ids": prompt_ids,
"sampling_params": sampling_params,
"stream": True,
}
response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
if response.status_code != 200:
raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")
for chunk in response.iter_lines(decode_unicode=False):
chunk = str(chunk.decode("utf-8"))
if chunk == "data: [DONE]":
break
if chunk and chunk.startswith("data:"):
yield json.loads(chunk[5:].strip("\n"))
return await asyncio.to_thread(stream_request)
@override
async def chat(
self,
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> list["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
for request_output in generator:
final_output = request_output
results = [
Response(
response_text=final_output["text"],
response_length=final_output["meta_info"]["completion_tokens"],
prompt_length=final_output["meta_info"]["prompt_tokens"],
finish_reason="stop" if final_output["meta_info"]["finish_reason"] == "stop" else "length",
)
]
return results
@override
async def stream_chat(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
for result in generator:
delta_text = result["text"][len(generated_text) :]
generated_text = result["text"]
yield delta_text
@override
async def get_scores(
self,
batch_input: list[str],
**input_kwargs,
) -> list[float]:
raise NotImplementedError("SGLang engine does not support `get_scores`.")
def __del__(self):
r"""Ensure server is cleaned up when object is deleted."""
self._cleanup_server()
try:
atexit.unregister(self._cleanup_server)
except Exception:
pass
......@@ -13,7 +13,8 @@
# limitations under the License.
import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from collections.abc import AsyncGenerator, AsyncIterator
from typing import TYPE_CHECKING, Any, Optional, Union
from typing_extensions import override
......@@ -53,7 +54,7 @@ class VllmEngine(BaseEngine):
self.model_args = model_args
config = load_config(model_args) # may download model from ms hub
if getattr(config, "quantization_config", None): # gptq models should use float16
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
model_args.infer_dtype = "float16"
......@@ -82,7 +83,7 @@ class VllmEngine(BaseEngine):
"max_lora_rank": model_args.vllm_max_lora_rank,
}
if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config)
......@@ -101,33 +102,26 @@ class VllmEngine(BaseEngine):
async def _generate(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = f"chatcmpl-{uuid.uuid4().hex}"
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]})
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if videos is not None:
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
if audios is not None:
mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
messages = self.template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], self.processor
messages, images or [], videos or [], audios or [], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
......@@ -143,7 +137,7 @@ class VllmEngine(BaseEngine):
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
if length_penalty is not None:
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
......@@ -185,8 +179,24 @@ class VllmEngine(BaseEngine):
images,
image_max_pixels=self.model_args.image_max_pixels,
image_min_pixels=self.model_args.image_min_pixels,
)
)["images"]
}
elif videos is not None:
multi_modal_data = {
"video": self.template.mm_plugin._regularize_videos(
videos,
image_max_pixels=self.model_args.video_max_pixels,
image_min_pixels=self.model_args.video_min_pixels,
video_fps=self.model_args.video_fps,
video_maxlen=self.model_args.video_maxlen,
)["videos"]
}
elif audios is not None:
audio_data = self.template.mm_plugin._regularize_audios(
audios,
sampling_rate=self.model_args.audio_sampling_rate,
)
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
else:
multi_modal_data = None
......@@ -201,14 +211,14 @@ class VllmEngine(BaseEngine):
@override
async def chat(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
) -> list["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
async for request_output in generator:
......@@ -230,12 +240,12 @@ class VllmEngine(BaseEngine):
@override
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
......@@ -248,7 +258,7 @@ class VllmEngine(BaseEngine):
@override
async def get_scores(
self,
batch_input: List[str],
batch_input: list[str],
**input_kwargs,
) -> List[float]:
raise NotImplementedError("vLLM engine does not support get_scores.")
) -> list[float]:
raise NotImplementedError("vLLM engine does not support `get_scores`.")
......@@ -13,7 +13,6 @@
# limitations under the License.
import os
import random
import subprocess
import sys
from enum import Enum, unique
......@@ -24,7 +23,7 @@ from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import get_device_count, is_env_enabled, use_ray
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
......@@ -92,7 +91,7 @@ def main():
node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
if int(nnodes) > 1:
print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
......
......@@ -24,14 +24,14 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [
"TEMPLATES",
"KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask",
"Role",
"split_dataset",
"get_dataset",
"TEMPLATES",
"SFTDataCollatorWith4DAttentionMask",
"Template",
"get_dataset",
"get_template_and_fix_tokenizer",
"split_dataset",
]
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
# Copyright 2025 OpenAccess AI Collective and the LlamaFactory team.
#
# This code is inspired by the OpenAccess AI Collective's axolotl library.
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
......@@ -16,7 +16,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
from typing import TYPE_CHECKING, Any, Literal, Optional
import numpy as np
import torch
......@@ -24,6 +24,7 @@ import torch.nn.functional as F
from transformers import DataCollatorForSeq2Seq
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
from ..extras.misc import get_current_device
from ..extras.packages import is_pillow_available
......@@ -38,9 +39,10 @@ if TYPE_CHECKING:
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
r"""Expand 2d attention mask to 4d attention mask.
Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g.
```python
......@@ -62,24 +64,37 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
```
where `o` equals to `0.0`, `x` equals to `min_dtype`.
"""
bsz, seq_len = attention_mask_with_indices.size()
_, seq_len = attention_mask_with_indices.size()
# Move to compute device if the source is CPU.
source_device = attention_mask_with_indices.device
compute_device = get_current_device() if source_device.type == "cpu" else source_device
if compute_device != source_device:
attention_mask_with_indices = attention_mask_with_indices.to(compute_device)
min_dtype = torch.finfo(dtype).min
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
padding_mask = torch.where(expanded_mask != 0, 1, 0)
# Create a block-diagonal mask.
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
# Use the lower triangular mask to zero out the upper triangular part
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
zero_tensor = torch.tensor(0, dtype=dtype, device=compute_device)
# Create a non-padding mask.
non_padding = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
# Create indices for comparison.
indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, seq_len]
indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3) # [bsz, 1, seq_len, 1]
# Create a lower triangular mask.
tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=compute_device))
attention_mask_4d = (indices == indices_t) & non_padding & tril_mask
# Invert the attention mask.
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype)
# Move back to original device if needed.
if compute_device != source_device:
attention_mask_4d = attention_mask_4d.to(source_device)
return attention_mask_4d
@dataclass
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r"""
Data collator that supports VLMs.
r"""Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
"""
......@@ -91,7 +106,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if self.template is None:
raise ValueError("Template is required for MultiModalDataCollator.")
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_audios = [], [], []
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
for feature in features:
......@@ -166,7 +181,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
for i, feature in enumerate(features):
feature["token_type_ids"] = token_type_ids[i]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
features: dict[str, torch.Tensor] = super().__call__(features)
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
rope_index_kwargs = {
......@@ -175,9 +190,27 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": features["attention_mask"],
}
if "second_per_grid_ts" in mm_inputs:
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
if "video_second_per_grid" in mm_inputs: # for qwen2omni
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(
feature_attention_mask, dim=1
) # FIXME need to get video image lengths
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
# avoid conflict
new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
features["position_ids"], features["rope_deltas"] = (
new_position_ids.clone(),
rope_deltas - delta0,
) # avoid inplace operation FIXME
else: # for qwen2vl
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
......@@ -198,15 +231,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
@dataclass
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for 4d attention mask.
"""
r"""Data collator for 4d attention mask."""
block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
......@@ -220,13 +251,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
@dataclass
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
r"""Data collator for pairwise data."""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
r"""
Pads batched data to the longest sequence in the batch.
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
r"""Pad batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
......@@ -249,11 +277,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
@dataclass
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""
r"""Data collator for KTO data."""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
target_features = []
kl_features = []
kto_tags = []
......
......@@ -15,7 +15,7 @@
import os
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from ..extras import logging
from .data_utils import Role
......@@ -26,8 +26,12 @@ if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments
from .mm_plugin import AudioInput, ImageInput, VideoInput
from .parser import DatasetAttr
MediaType = Union[ImageInput, VideoInput, AudioInput]
logger = logging.get_logger(__name__)
......@@ -36,12 +40,12 @@ class DatasetConverter:
dataset_attr: "DatasetAttr"
data_args: "DataArguments"
def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[List[Any]]:
r"""
Optionally concatenates media path to media dir when loading from local disk.
"""
if not isinstance(medias, list):
medias = [medias] if medias is not None else []
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]:
r"""Optionally concatenate media path to media dir when loading from local disk."""
if medias is None:
return None
elif not isinstance(medias, list):
medias = [medias]
elif len(medias) == 0:
return None
else:
......@@ -57,16 +61,14 @@ class DatasetConverter:
return medias
@abstractmethod
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
r"""
Converts a single example in the dataset to the standard format.
"""
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
r"""Convert a single example in the dataset to the standard format."""
...
@dataclass
class AlpacaDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
prompt = []
if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list):
for old_prompt, old_response in example[self.dataset_attr.history]:
......@@ -116,7 +118,7 @@ class AlpacaDatasetConverter(DatasetConverter):
@dataclass
class SharegptDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
tag_mapping = {
self.dataset_attr.user_tag: Role.USER.value,
self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
......@@ -216,10 +218,8 @@ DATASET_CONVERTERS = {
}
def register_dataset_converter(name: str, dataset_converter: Type["DatasetConverter"]) -> None:
r"""
Register a new dataset converter.
"""
def register_dataset_converter(name: str, dataset_converter: type["DatasetConverter"]) -> None:
r"""Register a new dataset converter."""
if name in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} already exists.")
......@@ -227,9 +227,7 @@ def register_dataset_converter(name: str, dataset_converter: Type["DatasetConver
def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter":
r"""
Gets a dataset converter.
"""
r"""Get a dataset converter."""
if name not in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} not found.")
......@@ -242,17 +240,17 @@ def align_dataset(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
r"""Align the dataset to a specific format.
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_tools: "...",
_images: [],
_videos: [],
_audios: [],
_tools: "..."
_images: []
_videos: []
_audios: []
"""
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union
from typing import TYPE_CHECKING, Optional, TypedDict, Union
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
......@@ -29,7 +29,7 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
SLOTS = list[Union[str, set[str], dict[str, str]]]
@unique
......@@ -43,15 +43,13 @@ class Role(str, Enum):
class DatasetModule(TypedDict):
train_dataset: Optional[Union["Dataset", "IterableDataset"]]
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
all_datasets: list[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
) -> Union["Dataset", "IterableDataset"]:
r"""
Merges multiple datasets to a unified dataset.
"""
r"""Merge multiple datasets to a unified dataset."""
if len(all_datasets) == 1:
return all_datasets[0]
......@@ -78,14 +76,13 @@ def merge_dataset(
def split_dataset(
dataset: Optional[Union["Dataset", "IterableDataset"]],
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]],
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
data_args: "DataArguments",
seed: int,
) -> "DatasetDict":
r"""
Splits the dataset and returns a dataset dict containing train set and validation set.
r"""Split the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset.
Support both map dataset and iterable dataset.
"""
if eval_dataset is not None and data_args.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
......@@ -120,10 +117,8 @@ def split_dataset(
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
r"""
Converts dataset or dataset dict to dataset module.
"""
dataset_module: "DatasetModule" = {}
r"""Convert dataset or dataset dict to dataset module."""
dataset_module: DatasetModule = {}
if isinstance(dataset, DatasetDict): # dataset dict
if "train" in dataset:
dataset_module["train_dataset"] = dataset["train"]
......
......@@ -16,7 +16,7 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Optional, Union
from typing import Optional, Union
from typing_extensions import override
......@@ -31,14 +31,11 @@ class Formatter(ABC):
@abstractmethod
def apply(self, **kwargs) -> SLOTS:
r"""
Forms a list of slots according to the inputs to encode.
"""
r"""Forms a list of slots according to the inputs to encode."""
...
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extract a list of tuples from the response message if using tools.
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
r"""Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments.
"""
......@@ -105,7 +102,7 @@ class FunctionFormatter(StringFormatter):
if thought:
content = content.replace(thought.group(0), "")
functions: List["FunctionCall"] = []
functions: list[FunctionCall] = []
try:
tool_calls = json.loads(content)
if not isinstance(tool_calls, list): # parallel function call
......@@ -141,5 +138,5 @@ class ToolFormatter(Formatter):
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
@override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
return self.tool_utils.tool_extractor(content)
......@@ -13,7 +13,7 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np
from datasets import load_dataset, load_from_disk
......@@ -54,9 +54,7 @@ def _load_single_dataset(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
Loads a single dataset and aligns it to the standard format.
"""
r"""Load a single dataset and aligns it to the standard format."""
logger.info_rank0(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
......@@ -133,10 +131,12 @@ def _load_single_dataset(
split=dataset_attr.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=data_args.streaming,
num_proc=data_args.preprocessing_num_workers,
trust_remote_code=model_args.trust_remote_code,
streaming=data_args.streaming and dataset_attr.load_from != "file",
)
if data_args.streaming and dataset_attr.load_from == "file":
dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples
......@@ -158,16 +158,14 @@ def _load_single_dataset(
def _get_merged_dataset(
dataset_names: Optional[Sequence[str]],
dataset_names: Optional[list[str]],
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
merge: bool = True,
) -> Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]:
r"""
Returns the merged datasets in the standard format.
"""
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
r"""Return the merged datasets in the standard format."""
if dataset_names is None:
return None
......@@ -192,9 +190,7 @@ def _get_dataset_processor(
processor: Optional["ProcessorMixin"],
do_generate: bool = False,
) -> "DatasetProcessor":
r"""
Returns the corresponding dataset processor.
"""
r"""Return the corresponding dataset processor."""
if stage == "pt":
dataset_processor_class = PretrainDatasetProcessor
elif stage == "sft" and not do_generate:
......@@ -236,9 +232,7 @@ def _get_preprocessed_dataset(
processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Preprocesses the dataset, including format checking and tokenization.
"""
r"""Preprocesses the dataset, including format checking and tokenization."""
if dataset is None:
return None
......@@ -284,9 +278,7 @@ def get_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> "DatasetModule":
r"""
Gets the train dataset and optionally gets the evaluation dataset.
"""
r"""Get the train dataset and optionally gets the evaluation dataset."""
# Load tokenized dataset if path exists
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
......
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/processing_llava.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
import re
from copy import deepcopy
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, TypedDict, Union
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
import numpy as np
import torch
......@@ -51,28 +68,65 @@ if TYPE_CHECKING:
path: Optional[str]
bytes: Optional[bytes]
ImageInput = Union[str, bytes, EncodedImage, ImageObject]
VideoInput = str
AudioInput = Union[str, NDArray]
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
VideoInput = Union[str, BinaryIO]
AudioInput = Union[str, BinaryIO, NDArray]
class MMProcessor(ProcessorMixin):
patch_size: int
image_seq_length: int
num_additional_image_tokens: int
vision_feature_select_strategy: Literal["default", "full"]
def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
pass
def _get_paligemma_token_type_ids(
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
) -> List[List[int]]:
r"""
Gets paligemma token type ids for computing loss.
def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
r"""Get paligemma token type ids for computing loss.
It is slightly different with the original token type ids where the prompt part is 0.
Returns:
batch_token_type_ids: shape (batch_size, sequence_length)
batch_token_type_ids: shape (batch_size, seq_length)
"""
batch_token_type_ids = []
for imglen, seqlen in zip(imglens, seqlens):
image_seqlen = imglen * getattr(processor, "image_seqlen")
image_seqlen = imglen * processor.image_seq_length
batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
return batch_token_type_ids
def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcessor"):
r"""Get gemma3 token type ids for computing loss.
Returns:
batch_token_type_ids: shape (batch_size, seq_length)
"""
image_token_id: int = getattr(processor, "image_token_id")
batch_token_type_ids = []
for token_ids in batch_ids:
token_ids = np.array(token_ids)
token_type_ids = np.zeros_like(token_ids)
token_type_ids[token_ids == image_token_id] = 1
batch_token_type_ids.append(token_type_ids.tolist())
return batch_token_type_ids
def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]:
r"""Make nested list of images."""
batch_images = []
for imglen in imglens:
batch_images.append(images[:imglen])
images = images[imglen:]
return batch_images
@dataclass
class MMPluginMixin:
image_token: Optional[str]
......@@ -82,16 +136,17 @@ class MMPluginMixin:
def _validate_input(
self,
processor: Optional["ProcessorMixin"],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["MMProcessor"],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
) -> None:
r"""
Validates if this model accepts the input modalities.
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
r"""Validate if this model accepts the input modalities."""
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseImageProcessor = getattr(
processor, "video_processor", getattr(processor, "image_processor", None)
)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
if len(images) != 0 and self.image_token is None:
raise ValueError(
"This model does not support image input. Please check whether the correct `template` is used."
......@@ -113,15 +168,16 @@ class MMPluginMixin:
if self.image_token is not None and image_processor is None:
raise ValueError("Image processor was not found, please check and update your processor config.")
if self.video_token is not None and video_processor is None:
raise ValueError("Video processor was not found, please check and update your processor config.")
if self.audio_token is not None and feature_extractor is None:
raise ValueError("Audio feature extractor was not found, please check and update your processor config.")
def _preprocess_image(
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
) -> "ImageObject":
r"""
Pre-processes a single image.
"""
r"""Pre-process a single image."""
if (image.width * image.height) > image_max_pixels:
resize_factor = math.sqrt(image_max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
......@@ -139,10 +195,8 @@ class MMPluginMixin:
def _get_video_sample_indices(
self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs
) -> List[int]:
r"""
Computes video sample indices according to fps.
"""
) -> list[int]:
r"""Compute video sample indices according to fps."""
total_frames = video_stream.frames
if total_frames == 0: # infinite video
return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
......@@ -151,13 +205,11 @@ class MMPluginMixin:
sample_frames = min(total_frames, video_maxlen, sample_frames)
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
r"""
Regularizes images to avoid error. Including reading and pre-processing.
"""
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str, list["ImageObject"]]:
r"""Regularize images to avoid error. Including reading and pre-processing."""
results = []
for image in images:
if isinstance(image, str):
if isinstance(image, (str, BinaryIO)):
image = Image.open(image)
elif isinstance(image, bytes):
image = Image.open(BytesIO(image))
......@@ -172,53 +224,52 @@ class MMPluginMixin:
results.append(self._preprocess_image(image, **kwargs))
return results
return {"images": results}
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
r"""
Regularizes videos to avoid error. Including reading, resizing and converting.
"""
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]:
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
results = []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
frames: List["ImageObject"] = []
frames: list[ImageObject] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
frames = self._regularize_images(frames, **kwargs)
frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames)
return results
return {"videos": results}
def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> List["NDArray"]:
r"""
Regularizes audios to avoid error. Including reading and resampling.
"""
results = []
def _regularize_audios(
self, audios: list["AudioInput"], sampling_rate: float, **kwargs
) -> dict[str, Union[list["NDArray"], list[float]]]:
r"""Regularizes audios to avoid error. Including reading and resampling."""
results, sampling_rates = [], []
for audio in audios:
if isinstance(audio, str):
audio = librosa.load(audio, sr=sampling_rate)[0]
if isinstance(audio, (str, BinaryIO)):
audio, sampling_rate = librosa.load(audio, sr=sampling_rate)
if not isinstance(audio, np.ndarray):
raise ValueError(f"Expect input is a list of audios, but got {type(audio)}.")
results.append(audio)
sampling_rates.append(sampling_rate)
return results
return {"audios": results, "sampling_rates": sampling_rates}
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs.
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
imglens: Optional[list[int]] = None,
) -> dict[str, "torch.Tensor"]:
r"""Process visual inputs.
Returns: (llava and paligemma)
pixel_values: tensor with shape (B, C, H, W)
......@@ -226,44 +277,67 @@ class MMPluginMixin:
Returns: (qwen2-vl)
pixel_values: tensor with shape (num_patches, patch_dim)
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
where num_patches == torch.prod(image_grid_thw)
Returns: (mllama)
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
It holds num_patches == torch.prod(image_grid_thw)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
mm_inputs = {}
if len(images) != 0:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)["images"]
if imglens is not None: # if imglens are provided, make batched images
images = _make_batched_images(images, imglens)
image_processor_kwargs = {}
if getattr(processor, "image_do_pan_and_scan", False): # gemma3 image processor
image_processor_kwargs.update(
{
"do_pan_and_scan": True,
"pan_and_scan_min_crop_size": 256,
"pan_and_scan_max_num_crops": 4,
"pan_and_scan_min_ratio_to_activate": 1.2,
}
)
mm_inputs.update(image_processor(images, return_tensors="pt"))
mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs))
if len(videos) != 0:
video_processor: BaseImageProcessor = getattr(
processor, "video_processor", getattr(processor, "image_processor", None)
)
videos = self._regularize_videos(
videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
)["videos"]
if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
else: # for llava_next_video
mm_inputs.update(video_processor(videos, return_tensors="pt"))
if len(audios) != 0:
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
audios = self._regularize_audios(
audios,
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
)
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
)["audios"]
mm_inputs.update(
feature_extractor(
audios,
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
return_attention_mask=True,
padding="max_length",
return_tensors="pt",
......@@ -278,83 +352,95 @@ class MMPluginMixin:
class BasePlugin(MMPluginMixin):
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
r"""
Pre-processes input messages before tokenization for VLMs.
"""
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
r"""Pre-process input messages before tokenization for VLMs."""
self._validate_input(processor, images, videos, audios)
return messages
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
input_ids: list[int],
labels: Optional[list[int]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
r"""
Pre-processes token ids after tokenization for VLMs.
"""
processor: Optional["MMProcessor"],
) -> tuple[list[int], Optional[list[int]]]:
r"""Pre-process token ids after tokenization for VLMs."""
self._validate_input(processor, images, videos, audios)
return input_ids, labels
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
r"""
Builds batched multimodal inputs for VLMs.
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
r"""Build batched multimodal inputs for VLMs.
Arguments:
images: a list of image inputs, shape (num_images,)
videos: a list of video inputs, shape (num_videos,)
audios: a list of audio inputs, shape (num_audios,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
audlens: number of audios in each sample, shape (batch_size,)
batch_ids: token ids of input samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos
"""
self._validate_input(processor, images, videos, audios)
return {}
return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class LlavaPlugin(BasePlugin):
class Gemma3Plugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
messages = deepcopy(messages)
boi_token: str = getattr(processor, "boi_token")
full_image_sequence: str = getattr(processor, "full_image_sequence")
image_str = full_image_sequence if self.expand_mm_tokens else boi_token
do_pan_and_scan: bool = getattr(processor, "image_do_pan_and_scan", False)
if do_pan_and_scan:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
if do_pan_and_scan:
image_placeholder_str = (
"Here is the original image {{image}} and here are some crops to help you see better "
+ " ".join(["{{image}}"] * mm_inputs["num_crops"][0][num_image_tokens])
)
else:
image_placeholder_str = "{{image}}"
content = content.replace(IMAGE_PLACEHOLDER, image_placeholder_str, 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token)
message["content"] = content.replace("{{image}}", image_str)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
......@@ -364,17 +450,127 @@ class LlavaPlugin(BasePlugin):
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs.pop("num_crops", None)
mm_inputs["token_type_ids"] = _get_gemma3_token_type_ids(batch_ids, processor)
return mm_inputs
@dataclass
class Llama4Plugin(BasePlugin):
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:]
num_patches_per_chunk = int(
(image_height // processor.patch_size)
* (image_width // processor.patch_size)
// processor.downsample_ratio
)
aspect_ratios = mm_inputs.pop("aspect_ratios")
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
placeholder_count = content.count(IMAGE_PLACEHOLDER)
if self.expand_mm_tokens:
prompt_splits = content.split(IMAGE_PLACEHOLDER)
new_content = []
for local_image_index, split_part in enumerate(prompt_splits):
new_content.append(split_part)
if local_image_index < placeholder_count:
tokens_for_this_image = processor._prompt_split_image(
aspect_ratios[num_image_tokens], num_patches_per_chunk
)
num_image_tokens += 1
new_content.append(tokens_for_this_image)
content = "".join(new_content)
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs.pop("aspect_ratios", None)
return mm_inputs
@dataclass
class LlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0]))
image_seqlen = (height // processor.patch_size) * (
width // processor.patch_size
) + processor.num_additional_image_tokens
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
else:
image_seqlen = 1
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@dataclass
......@@ -382,15 +578,16 @@ class LlavaNextPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"].tolist())
......@@ -402,7 +599,7 @@ class LlavaNextPlugin(BasePlugin):
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy", "default") == "default":
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
else:
image_seqlen = 1
......@@ -417,47 +614,34 @@ class LlavaNextPlugin(BasePlugin):
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class LlavaNextVideoPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"].tolist())
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy", "default") == "default":
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
else:
image_seqlen = 1
......@@ -467,11 +651,11 @@ class LlavaNextVideoPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", self.image_token)
if "pixel_values_videos" in mm_inputs:
if self.expand_mm_tokens:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
if "pixel_values_videos" in mm_inputs:
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(one_video[0])
num_frames = one_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
else:
......@@ -480,8 +664,8 @@ class LlavaNextVideoPlugin(BasePlugin):
for message in messages:
content = message["content"]
while VIDEO_PLACEHOLDER in content:
num_video_tokens += 1
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
num_video_tokens += 1
message["content"] = content.replace("{{video}}", self.video_token)
......@@ -493,37 +677,22 @@ class LlavaNextVideoPlugin(BasePlugin):
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class MiniCPMVPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
mm_inputs = {}
audio_inputs = {}
if len(images) != 0 and len(videos) != 0:
......@@ -614,21 +783,20 @@ class MiniCPMVPlugin(BasePlugin):
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
**kwargs,
) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)
)["images"]
if "valid_image_nums_ls" in kwargs:
valid_image_nums_ls = kwargs["valid_image_nums_ls"]
new_images = []
......@@ -651,15 +819,15 @@ class MiniCPMVPlugin(BasePlugin):
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
)["videos"]
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
mm_inputs.update(video_inputs)
if len(audios) != 0:
audios = self._regularize_audios(
audios,
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
)
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
)["audios"]
if "valid_audio_nums_ls" in kwargs:
valid_audio_nums_ls = kwargs["valid_audio_nums_ls"]
audios_ls = []
......@@ -673,7 +841,7 @@ class MiniCPMVPlugin(BasePlugin):
audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
audios_ls,
chunk_input=True,
sampling_rate=16000,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
)
audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
......@@ -685,15 +853,15 @@ class MiniCPMVPlugin(BasePlugin):
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
# image bound
image_bounds_list = []
......@@ -756,12 +924,12 @@ class MllamaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
......@@ -775,61 +943,24 @@ class MllamaPlugin(BasePlugin):
return messages
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
imglens: List[int],
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
Returns:
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
mm_inputs = {}
if len(images) > 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)
batch_images = []
for image_length in imglens:
batch_images.append(images[:image_length])
images = images[image_length:]
mm_inputs.update(image_processor(batch_images, return_tensors="pt"))
return mm_inputs
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
if mm_inputs:
num_tiles = mm_inputs.pop("num_tiles")
image_token_id = getattr(processor, "image_token_id")
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
image_token_id: int = getattr(processor, "image_token_id")
max_image_tiles: int = getattr(processor.image_processor, "max_image_tiles")
cross_attention_token_mask = [
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
]
......@@ -850,22 +981,22 @@ class PaliGemmaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
content = content.replace(IMAGE_PLACEHOLDER, "", 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", "")
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
......@@ -875,36 +1006,36 @@ class PaliGemmaPlugin(BasePlugin):
@override
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
input_ids: list[int],
labels: Optional[list[int]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
processor: Optional["MMProcessor"],
) -> tuple[list[int], Optional[list[int]]]:
self._validate_input(processor, images, videos, audios)
num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
input_ids = [image_token_id] * image_seqlen + input_ids
input_ids = [image_token_id] * num_images * image_seqlen + input_ids
if labels is not None:
labels = [IGNORE_INDEX] * image_seqlen + labels
labels = [IGNORE_INDEX] * num_images * image_seqlen + labels
return input_ids, labels
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
seqlens = [len(input_ids) for input_ids in batch_ids]
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
......@@ -917,37 +1048,39 @@ class PixtralPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
patch_size = getattr(processor, "patch_size")
image_token = getattr(processor, "image_token")
image_break_token = getattr(processor, "image_break_token")
image_end_token = getattr(processor, "image_end_token")
num_image_tokens = 0
messages = deepcopy(messages)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
# BC for transformers < 4.49.0
if isinstance(mm_inputs["image_sizes"], list):
image_sizes = iter(mm_inputs["image_sizes"][0])
else:
image_sizes = iter(mm_inputs["image_sizes"].tolist())
image_break_token: str = getattr(processor, "image_break_token")
image_end_token: str = getattr(processor, "image_end_token")
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if self.expand_mm_tokens:
height, width = next(image_sizes)
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
num_height_tokens = height // processor.patch_size
num_width_tokens = width // processor.patch_size
replace_tokens = [[self.image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens)
else:
replace_str = image_token
replace_str = self.image_token
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
num_image_tokens += 1
......@@ -962,17 +1095,21 @@ class PixtralPlugin(BasePlugin):
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
# ref to this commit https://github.com/huggingface/transformers/pull/35122
# after transformers 4.49.0, the `image_sizes` is mandatory as an input parameter for Pixtral VisionEncoder forwarding.
# it can be passed into `LlavaConditionalGeneration` as a parameter.
if not is_transformers_version_greater_than("4.49.0"):
mm_inputs.pop("image_sizes", None)
return mm_inputs
......@@ -982,21 +1119,22 @@ class Qwen2AudioPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
bos_token: str = getattr(processor, "audio_bos_token")
eos_token: str = getattr(processor, "audio_eos_token")
num_audio_tokens = 0
messages = deepcopy(messages)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs([], [], audios, processor)
if "feature_attention_mask" in mm_inputs:
audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
num_audio_tokens = 0
for message in messages:
content = message["content"]
while AUDIO_PLACEHOLDER in content:
......@@ -1022,15 +1160,15 @@ class Qwen2AudioPlugin(BasePlugin):
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
......@@ -1056,14 +1194,14 @@ class Qwen2VLPlugin(BasePlugin):
@override
def _regularize_videos(
self, videos: Sequence["VideoInput"], **kwargs
) -> Tuple[List[List["ImageObject"]], List[float]]:
self, videos: list["VideoInput"], **kwargs
) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
results, fps_per_video = [], []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
frames: List["ImageObject"] = []
frames: list[ImageObject] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
......@@ -1072,59 +1210,61 @@ class Qwen2VLPlugin(BasePlugin):
if len(frames) % 2 != 0: # qwen2-vl requires even number of frames
frames.append(frames[-1])
frames = self._regularize_images(frames, **kwargs)
frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames)
if video_stream.duration is None:
fps_per_video.append(2.0)
else:
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
return results, fps_per_video
return {"videos": results, "fps_per_video": fps_per_video}
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)
)["images"]
mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0:
videos, fps_per_video = self._regularize_videos(
video_data = self._regularize_videos(
videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
mm_inputs.update(image_processor(images=None, videos=videos, return_tensors="pt"))
mm_inputs["fps_per_video"] = fps_per_video
mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt"))
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
if "second_per_grid_ts" in processor.model_input_names:
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]]
return mm_inputs
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
if self.expand_mm_tokens:
......@@ -1167,26 +1307,188 @@ class Qwen2VLPlugin(BasePlugin):
return messages
class Qwen2OmniPlugin(Qwen2VLPlugin):
@override
def get_mm_inputs(
def _get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)["images"]
mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0:
video_dict = self._regularize_videos(
videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt"))
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
mm_inputs["video_second_per_grid"] = torch.tensor(
[temporal_patch_size / fps for fps in video_dict["fps_per_video"]]
)
if len(audios) != 0:
audios = self._regularize_audios(
audios,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
)["audios"]
mm_inputs.update(
feature_extractor(
audios,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
return_attention_mask=True,
padding="max_length",
return_tensors="pt",
)
)
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
return mm_inputs
@override
def process_messages(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
messages = deepcopy(messages)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
fps_per_video = mm_inputs.pop("fps_per_video", [])
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video]
else:
mm_inputs = {}
return mm_inputs
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
# get length or size from mm_inputs
if "feature_attention_mask" in mm_inputs:
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
audio_lengths = (input_lengths - 2) // 2 + 1
if mm_inputs.get("image_grid_thw", None) is not None:
image_grid_thw = mm_inputs["image_grid_thw"]
merge_length = processor.image_processor.merge_size**2
if mm_inputs.get("video_grid_thw", None) is not None:
video_grid_thw = mm_inputs["video_grid_thw"]
merge_length = processor.image_processor.merge_size**2
if use_audio_in_video:
if audio_lengths is None:
raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
if not mm_inputs.get("video_grid_thw", None):
raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
positions_list = []
for i, message in enumerate(messages): # get multimodal index when use_audio
positions = []
for special_token in [self.audio_token, self.image_token, self.video_token]:
start = 0
while True:
pos = message[i].find(special_token, start)
if pos == -1:
break
positions.append((pos, special_token))
start = pos + len(special_token)
positions_list.append(positions.sort(key=lambda x: x[0]))
for message in messages:
content = message["content"]
# separate with audio-video
while IMAGE_PLACEHOLDER in content:
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
content = content.replace(
IMAGE_PLACEHOLDER,
f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>",
1,
)
num_image_tokens += 1
if not use_audio_in_video:
while AUDIO_PLACEHOLDER in content:
audio_token_replace_length = audio_lengths[num_audio_tokens]
content = content.replace(
AUDIO_PLACEHOLDER,
f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>",
1,
)
num_audio_tokens += 1
# TODO handle video_input and use_audio_in_video
while VIDEO_PLACEHOLDER in content:
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
content = content.replace(
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
)
num_video_tokens += 1
else: # if use the audio of video # deal video token and audio token togather
while VIDEO_PLACEHOLDER in content:
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
video_t_index = (
torch.arange(video_grid_thw[num_video_tokens][0])
.view(-1, 1, 1)
.expand(
-1,
video_grid_thw[num_video_tokens][1] // self.image_processor.merge_size,
video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size,
)
.flatten()
* mm_inputs["video_second_per_grid"][num_video_tokens]
* 25 # FIXME hardcode of position_id_per_seconds=25
).long()
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
placeholder_string = ""
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
if video_chunk_index is not None:
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
if audio_chunk_index is not None:
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
num_audio_tokens += 1
num_video_tokens += 1
message["content"] = content
if len(audios) != num_audio_tokens:
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages
@dataclass
......@@ -1194,33 +1496,33 @@ class VideoLlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
num_frames = 0
has_images = "pixel_values_images" in mm_inputs
has_videos = "pixel_values_videos" in mm_inputs
if has_images or has_videos:
if self.expand_mm_tokens:
if has_images:
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values_images" in mm_inputs:
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values_images"][0]))
num_frames = 1
if has_videos:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
if "pixel_values_videos" in mm_inputs:
one_video = to_numpy_array(mm_inputs["pixel_values_videos"][0])
height, width = get_image_size(one_video[0])
num_frames = one_video.shape[0] # frame dim is always after batch dim
if "pixel_values_images" in mm_inputs or "pixel_values_videos" in mm_inputs:
image_seqlen = (height // processor.patch_size) * (
width // processor.patch_size
) + processor.num_additional_image_tokens
video_seqlen = image_seqlen * num_frames
if getattr(processor, "vision_feature_select_strategy", "default") == "default":
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
else:
image_seqlen, video_seqlen = 1, 1
......@@ -1246,24 +1548,11 @@ class VideoLlavaPlugin(BasePlugin):
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
PLUGINS = {
"base": BasePlugin,
"gemma3": Gemma3Plugin,
"llama4": Llama4Plugin,
"llava": LlavaPlugin,
"llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin,
......@@ -1272,15 +1561,14 @@ PLUGINS = {
"paligemma": PaliGemmaPlugin,
"pixtral": PixtralPlugin,
"qwen2_audio": Qwen2AudioPlugin,
"qwen2_omni": Qwen2OmniPlugin,
"qwen2_vl": Qwen2VLPlugin,
"video_llava": VideoLlavaPlugin,
}
def register_mm_plugin(name: str, plugin_class: Type["BasePlugin"]) -> None:
r"""
Registers a multimodal plugin.
"""
def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
r"""Register a multimodal plugin."""
if name in PLUGINS:
raise ValueError(f"Multimodal plugin {name} already exists.")
......@@ -1293,9 +1581,7 @@ def get_mm_plugin(
video_token: Optional[str] = None,
audio_token: Optional[str] = None,
) -> "BasePlugin":
r"""
Gets plugin for multimodal inputs.
"""
r"""Get plugin for multimodal inputs."""
if name not in PLUGINS:
raise ValueError(f"Multimodal plugin `{name}` not found.")
......
......@@ -15,9 +15,9 @@
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Sequence
from typing import Any, Literal, Optional
from transformers.utils import cached_file
from huggingface_hub import hf_hub_download
from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope, use_openmind
......@@ -25,9 +25,7 @@ from ..extras.misc import use_modelscope, use_openmind
@dataclass
class DatasetAttr:
r"""
Dataset attributes.
"""
r"""Dataset attributes."""
# basic configs
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
......@@ -68,10 +66,10 @@ class DatasetAttr:
def __repr__(self) -> str:
return self.dataset_name
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
setattr(self, key, obj.get(key, default))
def join(self, attr: Dict[str, Any]) -> None:
def join(self, attr: dict[str, Any]) -> None:
self.set_attr("formatting", attr, default="alpaca")
self.set_attr("ranking", attr, default=False)
self.set_attr("subset", attr)
......@@ -92,10 +90,8 @@ class DatasetAttr:
self.set_attr(tag, attr["tags"])
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
r"""
Gets the attributes of the datasets.
"""
def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> list["DatasetAttr"]:
r"""Get the attributes of the datasets."""
if dataset_names is None:
dataset_names = []
......@@ -103,7 +99,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_info = None
else:
if dataset_dir.startswith("REMOTE:"):
config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
config_path = hf_hub_download(repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
else:
config_path = os.path.join(dataset_dir, DATA_CONFIG)
......@@ -116,7 +112,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_info = None
dataset_list: List["DatasetAttr"] = []
dataset_list: list[DatasetAttr] = []
for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE
if use_modelscope():
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .feedback import FeedbackDatasetProcessor
from .pairwise import PairwiseDatasetProcessor
from .pretrain import PretrainDatasetProcessor
......@@ -9,9 +23,9 @@ from .unsupervised import UnsupervisedDatasetProcessor
__all__ = [
"DatasetProcessor",
"FeedbackDatasetProcessor",
"PackedSupervisedDatasetProcessor",
"PairwiseDatasetProcessor",
"PretrainDatasetProcessor",
"PackedSupervisedDatasetProcessor",
"SupervisedDatasetProcessor",
"UnsupervisedDatasetProcessor",
]
......@@ -13,7 +13,7 @@
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
......@@ -30,15 +30,15 @@ logger = logging.get_logger(__name__)
class FeedbackDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
kl_response: Sequence[Dict[str, str]],
prompt: list[dict[str, str]],
response: list[dict[str, str]],
kl_response: list[dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
) -> tuple[list[int], list[int], list[int], list[int], bool]:
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
......@@ -82,9 +82,9 @@ class FeedbackDatasetProcessor(DatasetProcessor):
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["_response"][::-1]
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions.
kl_response = [examples["_response"][-1]] + examples["_response"][:-1]
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
......@@ -121,7 +121,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
......@@ -30,14 +30,14 @@ logger = logging.get_logger(__name__)
class PairwiseDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
prompt: list[dict[str, str]],
response: list[dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int]]:
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
) -> tuple[list[int], list[int], list[int], list[int]]:
chosen_messages = self.template.mm_plugin.process_messages(
prompt + [response[0]], images, videos, audios, self.processor
)
......@@ -68,7 +68,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
......@@ -99,7 +99,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
......@@ -17,14 +17,14 @@
from dataclasses import dataclass
from itertools import chain
from typing import Any, Dict, List
from typing import Any
from .processor_utils import DatasetProcessor
@dataclass
class PretrainDatasetProcessor(DatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
......@@ -52,6 +52,6 @@ class PretrainDatasetProcessor(DatasetProcessor):
return result
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
......@@ -15,7 +15,7 @@
import bisect
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
......@@ -27,9 +27,7 @@ if TYPE_CHECKING:
@dataclass
class DatasetProcessor(ABC):
r"""
A class for data processors.
"""
r"""A class for data processors."""
template: "Template"
tokenizer: "PreTrainedTokenizer"
......@@ -37,32 +35,24 @@ class DatasetProcessor(ABC):
data_args: "DataArguments"
@abstractmethod
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
r"""
Builds model inputs from the examples.
"""
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
r"""Build model inputs from the examples."""
...
@abstractmethod
def print_data_example(self, example: Dict[str, List[int]]) -> None:
r"""
Print a data example to stdout.
"""
def print_data_example(self, example: dict[str, list[int]]) -> None:
r"""Print a data example to stdout."""
...
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
r"""
Finds the index of largest number that fits into the knapsack with the given capacity.
"""
def search_for_fit(numbers: list[int], capacity: int) -> int:
r"""Find the index of largest number that fits into the knapsack with the given capacity."""
index = bisect.bisect(numbers, capacity)
return -1 if index == 0 else (index - 1)
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
r"""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
def greedy_knapsack(numbers: list[int], capacity: int) -> list[list[int]]:
r"""Implement efficient greedy algorithm with binary search for the knapsack problem."""
numbers.sort() # sort numbers in ascending order for binary search
knapsacks = []
......@@ -83,10 +73,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
return knapsacks
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
r"""
Computes the real sequence length after truncation by the cutoff_len.
"""
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> tuple[int, int]:
r"""Compute the real sequence length after truncation by the cutoff_len."""
if target_len * 2 < cutoff_len: # truncate source
max_target_len = cutoff_len
elif source_len * 2 < cutoff_len: # truncate target
......
......@@ -14,7 +14,7 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
......@@ -32,14 +32,14 @@ logger = logging.get_logger(__name__)
class SupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
prompt: list[dict[str, str]],
response: list[dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int]]:
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
) -> tuple[list[int], list[int]]:
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
input_ids, labels = self.template.mm_plugin.process_token_ids(
[], [], images, videos, audios, self.tokenizer, self.processor
......@@ -85,7 +85,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = defaultdict(list)
......@@ -114,7 +114,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
......@@ -124,7 +124,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
@dataclass
class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
......@@ -165,7 +165,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos, packed_audios = [], [], []
packed_images, packed_videos, packed_audios, packed_position_ids = [], [], [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
......@@ -175,6 +175,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
packed_audios += batch_audios[index]
if self.data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
packed_position_ids += list(range(len(batch_input_ids[index])))
else:
packed_attention_masks += [1] * len(batch_input_ids[index])
......@@ -184,6 +185,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
packed_labels += [IGNORE_INDEX] * pad_length
if self.data_args.neat_packing:
packed_attention_masks += [0] * pad_length
packed_position_ids += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn
......@@ -196,5 +198,6 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
model_inputs["audios"].append(packed_audios or None)
model_inputs["position_ids"].append(packed_position_ids or None)
return model_inputs
......@@ -13,7 +13,7 @@
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ..data_utils import Role
......@@ -30,14 +30,14 @@ logger = logging.get_logger(__name__)
class UnsupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
prompt: list[dict[str, str]],
response: list[dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int]]:
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
) -> tuple[list[int], list[int]]:
if len(response) == 1:
messages = prompt + response
else:
......@@ -56,7 +56,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
labels = labels[:target_len]
return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
......@@ -84,7 +84,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment