Commit 7ea81099 authored by chenych's avatar chenych
Browse files

update llama4

parent 84987715
# Copyright 2024 THUDM and the LlamaFactory team. # Copyright 2025 THUDM and the LlamaFactory team.
# #
# This code is inspired by the THUDM's ChatGLM implementation. # This code is inspired by the THUDM's ChatGLM implementation.
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py # https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
...@@ -17,13 +17,15 @@ ...@@ -17,13 +17,15 @@
import asyncio import asyncio
import os import os
from collections.abc import AsyncGenerator, Generator
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence from typing import TYPE_CHECKING, Any, Optional
from ..extras.constants import EngineName from ..extras.constants import EngineName
from ..extras.misc import torch_gc from ..extras.misc import torch_gc
from ..hparams import get_infer_args from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine from .hf_engine import HuggingfaceEngine
from .sglang_engine import SGLangEngine
from .vllm_engine import VllmEngine from .vllm_engine import VllmEngine
...@@ -38,20 +40,21 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None: ...@@ -38,20 +40,21 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
class ChatModel: class ChatModel:
r""" r"""General class for chat models. Backed by huggingface or vllm engines.
General class for chat models. Backed by huggingface or vllm engines.
Supports both sync and async methods. Supports both sync and async methods.
Sync methods: chat(), stream_chat() and get_scores(). Sync methods: chat(), stream_chat() and get_scores().
Async methods: achat(), astream_chat() and aget_scores(). Async methods: achat(), astream_chat() and aget_scores().
""" """
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args) model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
if model_args.infer_backend == EngineName.HF: if model_args.infer_backend == EngineName.HF:
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == EngineName.VLLM: elif model_args.infer_backend == EngineName.VLLM:
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args) self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == EngineName.SGLANG:
self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
else: else:
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}") raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
...@@ -61,17 +64,15 @@ class ChatModel: ...@@ -61,17 +64,15 @@ class ChatModel:
def chat( def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> list["Response"]:
r""" r"""Get a list of responses of the chat model."""
Gets a list of responses of the chat model.
"""
task = asyncio.run_coroutine_threadsafe( task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
) )
...@@ -79,32 +80,28 @@ class ChatModel: ...@@ -79,32 +80,28 @@ class ChatModel:
async def achat( async def achat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> list["Response"]:
r""" r"""Asynchronously get a list of responses of the chat model."""
Asynchronously gets a list of responses of the chat model.
"""
return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs) return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
def stream_chat( def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
r""" r"""Get the response token-by-token of the chat model."""
Gets the response token-by-token of the chat model.
"""
generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs) generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
while True: while True:
try: try:
...@@ -115,17 +112,15 @@ class ChatModel: ...@@ -115,17 +112,15 @@ class ChatModel:
async def astream_chat( async def astream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r""" r"""Asynchronously get the response token-by-token of the chat model."""
Asynchronously gets the response token-by-token of the chat model.
"""
async for new_token in self.engine.stream_chat( async for new_token in self.engine.stream_chat(
messages, system, tools, images, videos, audios, **input_kwargs messages, system, tools, images, videos, audios, **input_kwargs
): ):
...@@ -133,23 +128,19 @@ class ChatModel: ...@@ -133,23 +128,19 @@ class ChatModel:
def get_scores( def get_scores(
self, self,
batch_input: List[str], batch_input: list[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> list[float]:
r""" r"""Get a list of scores of the reward model."""
Gets a list of scores of the reward model.
"""
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop) task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
return task.result() return task.result()
async def aget_scores( async def aget_scores(
self, self,
batch_input: List[str], batch_input: list[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> list[float]:
r""" r"""Asynchronously get a list of scores of the reward model."""
Asynchronously gets a list of scores of the reward model.
"""
return await self.engine.get_scores(batch_input, **input_kwargs) return await self.engine.get_scores(batch_input, **input_kwargs)
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
import asyncio import asyncio
import concurrent.futures
import os import os
from collections.abc import AsyncGenerator
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
...@@ -76,15 +76,15 @@ class HuggingfaceEngine(BaseEngine): ...@@ -76,15 +76,15 @@ class HuggingfaceEngine(BaseEngine):
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
template: "Template", template: "Template",
generating_args: Dict[str, Any], generating_args: dict[str, Any],
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]: ) -> tuple[dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]} mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
if images is not None: if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]}) mm_input_dict.update({"images": images, "imglens": [len(images)]})
...@@ -130,7 +130,7 @@ class HuggingfaceEngine(BaseEngine): ...@@ -130,7 +130,7 @@ class HuggingfaceEngine(BaseEngine):
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None) skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
if stop is not None: if stop is not None:
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.") logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
...@@ -217,15 +217,15 @@ class HuggingfaceEngine(BaseEngine): ...@@ -217,15 +217,15 @@ class HuggingfaceEngine(BaseEngine):
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
template: "Template", template: "Template",
generating_args: Dict[str, Any], generating_args: dict[str, Any],
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[dict[str, Any]] = {},
) -> List["Response"]: ) -> list["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args( gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, model,
tokenizer, tokenizer,
...@@ -272,14 +272,14 @@ class HuggingfaceEngine(BaseEngine): ...@@ -272,14 +272,14 @@ class HuggingfaceEngine(BaseEngine):
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
template: "Template", template: "Template",
generating_args: Dict[str, Any], generating_args: dict[str, Any],
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[dict[str, Any]] = {},
) -> Callable[[], str]: ) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args( gen_kwargs, _ = HuggingfaceEngine._process_args(
model, model,
...@@ -317,12 +317,12 @@ class HuggingfaceEngine(BaseEngine): ...@@ -317,12 +317,12 @@ class HuggingfaceEngine(BaseEngine):
def _get_scores( def _get_scores(
model: "PreTrainedModelWrapper", model: "PreTrainedModelWrapper",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
batch_input: List[str], batch_input: list[str],
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[dict[str, Any]] = {},
) -> List[float]: ) -> list[float]:
max_length: Optional[int] = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
device = getattr(model.pretrained_model, "device", "cuda") device = getattr(model.pretrained_model, "device", "cuda")
inputs: Dict[str, "torch.Tensor"] = tokenizer( inputs: dict[str, torch.Tensor] = tokenizer(
batch_input, batch_input,
padding=True, padding=True,
truncation=True, truncation=True,
...@@ -330,25 +330,24 @@ class HuggingfaceEngine(BaseEngine): ...@@ -330,25 +330,24 @@ class HuggingfaceEngine(BaseEngine):
return_tensors="pt", return_tensors="pt",
add_special_tokens=False, add_special_tokens=False,
).to(device) ).to(device)
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1] values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1)) scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return scores return scores
@override @override
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> list["Response"]:
if not self.can_generate: if not self.can_generate:
raise ValueError("The current model does not support `chat`.") raise ValueError("The current model does not support `chat`.")
loop = asyncio.get_running_loop()
input_args = ( input_args = (
self.model, self.model,
self.tokenizer, self.tokenizer,
...@@ -364,24 +363,22 @@ class HuggingfaceEngine(BaseEngine): ...@@ -364,24 +363,22 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool: return await asyncio.to_thread(self._chat, *input_args)
return await loop.run_in_executor(pool, self._chat, *input_args)
@override @override
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
if not self.can_generate: if not self.can_generate:
raise ValueError("The current model does not support `stream_chat`.") raise ValueError("The current model does not support `stream_chat`.")
loop = asyncio.get_running_loop()
input_args = ( input_args = (
self.model, self.model,
self.tokenizer, self.tokenizer,
...@@ -397,25 +394,22 @@ class HuggingfaceEngine(BaseEngine): ...@@ -397,25 +394,22 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool: stream = self._stream_chat(*input_args)
stream = self._stream_chat(*input_args) while True:
while True: try:
try: yield await asyncio.to_thread(stream)
yield await loop.run_in_executor(pool, stream) except StopAsyncIteration:
except StopAsyncIteration: break
break
@override @override
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: list[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> list[float]:
if self.can_generate: if self.can_generate:
raise ValueError("Cannot get scores using an auto-regressive model.") raise ValueError("Cannot get scores using an auto-regressive model.")
loop = asyncio.get_running_loop()
input_args = (self.model, self.tokenizer, batch_input, input_kwargs) input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
async with self.semaphore: async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool: return await asyncio.to_thread(self._get_scores, *input_args)
return await loop.run_in_executor(pool, self._get_scores, *input_args)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import atexit
import json
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
import requests
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
from ..extras.misc import get_device_count, torch_gc
from ..extras.packages import is_sglang_available
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod
from .base_engine import BaseEngine, Response
if is_sglang_available():
from sglang.utils import launch_server_cmd, terminate_process, wait_for_server # type: ignore
if TYPE_CHECKING:
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
logger = logging.get_logger(__name__)
class SGLangEngine(BaseEngine):
"""Inference engine for SGLang models.
This class wraps the SGLang engine to provide a consistent interface for text generation
that matches LLaMA Factory's requirements. It uses the SGLang HTTP server approach for
better interaction and performance. The engine launches a server process and communicates
with it via HTTP requests.
For more details on the SGLang HTTP server approach, see:
https://docs.sglang.ai/backend/send_request.html
"""
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.name = EngineName.SGLANG
self.model_args = model_args
config = load_config(model_args) # may download model from ms hub
if getattr(config, "quantization_config", None): # gptq models should use float16
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
model_args.infer_dtype = "float16"
self.can_generate = finetuning_args.stage == "sft"
tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
self.generating_args = generating_args.to_dict()
launch_cmd = [
"python3 -m sglang.launch_server",
f"--model-path {model_args.model_name_or_path}",
f"--dtype {model_args.infer_dtype}",
f"--context-length {model_args.sglang_maxlen}",
f"--mem-fraction-static {model_args.sglang_mem_fraction}",
f"--tp-size {model_args.sglang_tp_size if model_args.sglang_tp_size != -1 else get_device_count() or 1}",
f"--download-dir {model_args.cache_dir}",
"--log-level error",
]
launch_cmd = " ".join(launch_cmd)
logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
try:
torch_gc()
self.server_process, port = launch_server_cmd(launch_cmd)
self.base_url = f"http://localhost:{port}"
atexit.register(self._cleanup_server)
logger.info_rank0(f"Waiting for SGLang server to be ready at {self.base_url}")
wait_for_server(self.base_url, timeout=300)
logger.info_rank0(f"SGLang server initialized successfully at {self.base_url}")
try:
response = requests.get(f"{self.base_url}/get_model_info", timeout=5)
if response.status_code == 200:
model_info = response.json()
logger.info(f"SGLang server model info: {model_info}")
except Exception as e:
logger.debug(f"Note: could not get model info: {str(e)}")
except Exception as e:
logger.error(f"Failed to start SGLang server: {str(e)}")
self._cleanup_server() # make sure to clean up any started process
raise RuntimeError(f"SGLang server initialization failed: {str(e)}.")
def _cleanup_server(self):
r"""Clean up the server process when the engine is destroyed."""
if hasattr(self, "server_process") and self.server_process:
try:
logger.info("Terminating SGLang server process")
terminate_process(self.server_process)
logger.info("SGLang server process terminated")
except Exception as e:
logger.warning(f"Error terminating SGLang server: {str(e)}")
async def _generate(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncIterator[dict[str, Any]]:
if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
messages = self.template.mm_plugin.process_messages(
messages, images or [], videos or [], audios or [], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
if num_return_sequences != 1:
raise NotImplementedError("SGLang only supports n=1.")
if "max_new_tokens" in self.generating_args:
max_tokens = self.generating_args["max_new_tokens"]
elif "max_length" in self.generating_args:
if self.generating_args["max_length"] > prompt_length:
max_tokens = self.generating_args["max_length"] - prompt_length
else:
max_tokens = 1
if max_length:
max_tokens = max_length - prompt_length if max_length > prompt_length else 1
if max_new_tokens:
max_tokens = max_new_tokens
sampling_params = {
"temperature": temperature if temperature is not None else self.generating_args["temperature"],
"top_p": (top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
"top_k": (top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0
"stop": stop,
"stop_token_ids": self.template.get_stop_token_ids(self.tokenizer),
"max_new_tokens": max_tokens,
"repetition_penalty": (
repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
)
or 1.0, # repetition_penalty must > 0
"skip_special_tokens": skip_special_tokens
if skip_special_tokens is not None
else self.generating_args["skip_special_tokens"],
}
def stream_request():
json_data = {
"input_ids": prompt_ids,
"sampling_params": sampling_params,
"stream": True,
}
response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
if response.status_code != 200:
raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")
for chunk in response.iter_lines(decode_unicode=False):
chunk = str(chunk.decode("utf-8"))
if chunk == "data: [DONE]":
break
if chunk and chunk.startswith("data:"):
yield json.loads(chunk[5:].strip("\n"))
return await asyncio.to_thread(stream_request)
@override
async def chat(
self,
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> list["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
for request_output in generator:
final_output = request_output
results = [
Response(
response_text=final_output["text"],
response_length=final_output["meta_info"]["completion_tokens"],
prompt_length=final_output["meta_info"]["prompt_tokens"],
finish_reason="stop" if final_output["meta_info"]["finish_reason"] == "stop" else "length",
)
]
return results
@override
async def stream_chat(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
for result in generator:
delta_text = result["text"][len(generated_text) :]
generated_text = result["text"]
yield delta_text
@override
async def get_scores(
self,
batch_input: list[str],
**input_kwargs,
) -> list[float]:
raise NotImplementedError("SGLang engine does not support `get_scores`.")
def __del__(self):
r"""Ensure server is cleaned up when object is deleted."""
self._cleanup_server()
try:
atexit.unregister(self._cleanup_server)
except Exception:
pass
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
import uuid import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union from collections.abc import AsyncGenerator, AsyncIterator
from typing import TYPE_CHECKING, Any, Optional, Union
from typing_extensions import override from typing_extensions import override
...@@ -53,7 +54,7 @@ class VllmEngine(BaseEngine): ...@@ -53,7 +54,7 @@ class VllmEngine(BaseEngine):
self.model_args = model_args self.model_args = model_args
config = load_config(model_args) # may download model from ms hub config = load_config(model_args) # may download model from ms hub
if getattr(config, "quantization_config", None): # gptq models should use float16 if getattr(config, "quantization_config", None): # gptq models should use float16
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "") quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto": if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
model_args.infer_dtype = "float16" model_args.infer_dtype = "float16"
...@@ -82,7 +83,7 @@ class VllmEngine(BaseEngine): ...@@ -82,7 +83,7 @@ class VllmEngine(BaseEngine):
"max_lora_rank": model_args.vllm_max_lora_rank, "max_lora_rank": model_args.vllm_max_lora_rank,
} }
if self.template.mm_plugin.__class__.__name__ != "BasePlugin": if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2} engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
if isinstance(model_args.vllm_config, dict): if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config) engine_args.update(model_args.vllm_config)
...@@ -101,33 +102,26 @@ class VllmEngine(BaseEngine): ...@@ -101,33 +102,26 @@ class VllmEngine(BaseEngine):
async def _generate( async def _generate(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = f"chatcmpl-{uuid.uuid4().hex}" request_id = f"chatcmpl-{uuid.uuid4().hex}"
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]} if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
if images is not None: messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
mm_input_dict.update({"images": images, "imglens": [len(images)]})
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
if videos is not None: if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]}) messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
if audios is not None:
mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
messages = self.template.mm_plugin.process_messages( messages = self.template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], self.processor messages, images or [], videos or [], audios or [], self.processor
) )
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"] system = system or self.generating_args["default_system"]
...@@ -143,7 +137,7 @@ class VllmEngine(BaseEngine): ...@@ -143,7 +137,7 @@ class VllmEngine(BaseEngine):
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None) skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
if length_penalty is not None: if length_penalty is not None:
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.") logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
...@@ -185,8 +179,24 @@ class VllmEngine(BaseEngine): ...@@ -185,8 +179,24 @@ class VllmEngine(BaseEngine):
images, images,
image_max_pixels=self.model_args.image_max_pixels, image_max_pixels=self.model_args.image_max_pixels,
image_min_pixels=self.model_args.image_min_pixels, image_min_pixels=self.model_args.image_min_pixels,
) )["images"]
} }
elif videos is not None:
multi_modal_data = {
"video": self.template.mm_plugin._regularize_videos(
videos,
image_max_pixels=self.model_args.video_max_pixels,
image_min_pixels=self.model_args.video_min_pixels,
video_fps=self.model_args.video_fps,
video_maxlen=self.model_args.video_maxlen,
)["videos"]
}
elif audios is not None:
audio_data = self.template.mm_plugin._regularize_audios(
audios,
sampling_rate=self.model_args.audio_sampling_rate,
)
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
else: else:
multi_modal_data = None multi_modal_data = None
...@@ -201,14 +211,14 @@ class VllmEngine(BaseEngine): ...@@ -201,14 +211,14 @@ class VllmEngine(BaseEngine):
@override @override
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> list["Response"]:
final_output = None final_output = None
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs) generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
async for request_output in generator: async for request_output in generator:
...@@ -230,12 +240,12 @@ class VllmEngine(BaseEngine): ...@@ -230,12 +240,12 @@ class VllmEngine(BaseEngine):
@override @override
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
generated_text = "" generated_text = ""
...@@ -248,7 +258,7 @@ class VllmEngine(BaseEngine): ...@@ -248,7 +258,7 @@ class VllmEngine(BaseEngine):
@override @override
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: list[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> list[float]:
raise NotImplementedError("vLLM engine does not support get_scores.") raise NotImplementedError("vLLM engine does not support `get_scores`.")
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import os import os
import random
import subprocess import subprocess
import sys import sys
from enum import Enum, unique from enum import Enum, unique
...@@ -24,7 +23,7 @@ from .chat.chat_model import run_chat ...@@ -24,7 +23,7 @@ from .chat.chat_model import run_chat
from .eval.evaluator import run_eval from .eval.evaluator import run_eval
from .extras import logging from .extras import logging
from .extras.env import VERSION, print_env from .extras.env import VERSION, print_env
from .extras.misc import get_device_count, is_env_enabled, use_ray from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
from .train.tuner import export_model, run_exp from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui from .webui.interface import run_web_demo, run_web_ui
...@@ -92,7 +91,7 @@ def main(): ...@@ -92,7 +91,7 @@ def main():
node_rank = os.getenv("NODE_RANK", "0") node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count())) nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999))) master_port = os.getenv("MASTER_PORT", str(find_available_port()))
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}") logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
if int(nnodes) > 1: if int(nnodes) > 1:
print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}") print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
......
...@@ -24,14 +24,14 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer ...@@ -24,14 +24,14 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [ __all__ = [
"TEMPLATES",
"KTODataCollatorWithPadding", "KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq", "MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding", "PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask",
"Role", "Role",
"split_dataset", "SFTDataCollatorWith4DAttentionMask",
"get_dataset",
"TEMPLATES",
"Template", "Template",
"get_dataset",
"get_template_and_fix_tokenizer", "get_template_and_fix_tokenizer",
"split_dataset",
] ]
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team. # Copyright 2025 OpenAccess AI Collective and the LlamaFactory team.
# #
# This code is inspired by the OpenAccess AI Collective's axolotl library. # This code is inspired by the OpenAccess AI Collective's axolotl library.
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py # https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence from typing import TYPE_CHECKING, Any, Literal, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -24,6 +24,7 @@ import torch.nn.functional as F ...@@ -24,6 +24,7 @@ import torch.nn.functional as F
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
from ..extras.misc import get_current_device
from ..extras.packages import is_pillow_available from ..extras.packages import is_pillow_available
...@@ -38,9 +39,10 @@ if TYPE_CHECKING: ...@@ -38,9 +39,10 @@ if TYPE_CHECKING:
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r""" r"""Expand 2d attention mask to 4d attention mask.
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking. Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g. e.g.
```python ```python
...@@ -62,24 +64,37 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype ...@@ -62,24 +64,37 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
``` ```
where `o` equals to `0.0`, `x` equals to `min_dtype`. where `o` equals to `0.0`, `x` equals to `min_dtype`.
""" """
bsz, seq_len = attention_mask_with_indices.size() _, seq_len = attention_mask_with_indices.size()
# Move to compute device if the source is CPU.
source_device = attention_mask_with_indices.device
compute_device = get_current_device() if source_device.type == "cpu" else source_device
if compute_device != source_device:
attention_mask_with_indices = attention_mask_with_indices.to(compute_device)
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len) zero_tensor = torch.tensor(0, dtype=dtype, device=compute_device)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
padding_mask = torch.where(expanded_mask != 0, 1, 0) # Create a non-padding mask.
# Create a block-diagonal mask. non_padding = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask # Create indices for comparison.
# Use the lower triangular mask to zero out the upper triangular part indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, seq_len]
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long)) indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3) # [bsz, 1, seq_len, 1]
# Create a lower triangular mask.
tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=compute_device))
attention_mask_4d = (indices == indices_t) & non_padding & tril_mask
# Invert the attention mask. # Invert the attention mask.
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype) attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype)
# Move back to original device if needed.
if compute_device != source_device:
attention_mask_4d = attention_mask_4d.to(source_device)
return attention_mask_4d return attention_mask_4d
@dataclass @dataclass
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r""" r"""Data collator that supports VLMs.
Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios. Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
""" """
...@@ -91,7 +106,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -91,7 +106,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if self.template is None: if self.template is None:
raise ValueError("Template is required for MultiModalDataCollator.") raise ValueError("Template is required for MultiModalDataCollator.")
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_audios = [], [], [] batch_images, batch_videos, batch_audios = [], [], []
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], [] batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
for feature in features: for feature in features:
...@@ -166,7 +181,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -166,7 +181,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
for i, feature in enumerate(features): for i, feature in enumerate(features):
feature["token_type_ids"] = token_type_ids[i] feature["token_type_ids"] = token_type_ids[i]
features: Dict[str, "torch.Tensor"] = super().__call__(features) features: dict[str, torch.Tensor] = super().__call__(features)
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
rope_index_kwargs = { rope_index_kwargs = {
...@@ -175,10 +190,28 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -175,10 +190,28 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"video_grid_thw": mm_inputs.get("video_grid_thw"), "video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": features["attention_mask"], "attention_mask": features["attention_mask"],
} }
if "second_per_grid_ts" in mm_inputs: if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts") rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
if "video_second_per_grid" in mm_inputs: # for qwen2omni
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs) rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(
feature_attention_mask, dim=1
) # FIXME need to get video image lengths
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
# avoid conflict
new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
features["position_ids"], features["rope_deltas"] = (
new_position_ids.clone(),
rope_deltas - delta0,
) # avoid inplace operation FIXME
else: # for qwen2vl
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask = mm_inputs.pop("cross_attention_mask") cross_attention_mask = mm_inputs.pop("cross_attention_mask")
...@@ -198,15 +231,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -198,15 +231,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
@dataclass @dataclass
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
r""" r"""Data collator for 4d attention mask."""
Data collator for 4d attention mask.
"""
block_diag_attn: bool = False block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager" attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32 compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
features = super().__call__(features) features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2": if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
...@@ -220,13 +251,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): ...@@ -220,13 +251,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
@dataclass @dataclass
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r""" r"""Data collator for pairwise data."""
Data collator for pairwise data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
r""" r"""Pad batched data to the longest sequence in the batch.
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples. the last n examples represent rejected examples.
...@@ -249,11 +277,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): ...@@ -249,11 +277,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
@dataclass @dataclass
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r""" r"""Data collator for KTO data."""
Data collator for KTO data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
target_features = [] target_features = []
kl_features = [] kl_features = []
kto_tags = [] kto_tags = []
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import os import os
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union from typing import TYPE_CHECKING, Any, Optional, Union
from ..extras import logging from ..extras import logging
from .data_utils import Role from .data_utils import Role
...@@ -26,8 +26,12 @@ if TYPE_CHECKING: ...@@ -26,8 +26,12 @@ if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments from ..hparams import DataArguments
from .mm_plugin import AudioInput, ImageInput, VideoInput
from .parser import DatasetAttr from .parser import DatasetAttr
MediaType = Union[ImageInput, VideoInput, AudioInput]
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -36,12 +40,12 @@ class DatasetConverter: ...@@ -36,12 +40,12 @@ class DatasetConverter:
dataset_attr: "DatasetAttr" dataset_attr: "DatasetAttr"
data_args: "DataArguments" data_args: "DataArguments"
def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[List[Any]]: def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]:
r""" r"""Optionally concatenate media path to media dir when loading from local disk."""
Optionally concatenates media path to media dir when loading from local disk. if medias is None:
""" return None
if not isinstance(medias, list): elif not isinstance(medias, list):
medias = [medias] if medias is not None else [] medias = [medias]
elif len(medias) == 0: elif len(medias) == 0:
return None return None
else: else:
...@@ -57,16 +61,14 @@ class DatasetConverter: ...@@ -57,16 +61,14 @@ class DatasetConverter:
return medias return medias
@abstractmethod @abstractmethod
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]: def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
r""" r"""Convert a single example in the dataset to the standard format."""
Converts a single example in the dataset to the standard format.
"""
... ...
@dataclass @dataclass
class AlpacaDatasetConverter(DatasetConverter): class AlpacaDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]: def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
prompt = [] prompt = []
if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list): if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list):
for old_prompt, old_response in example[self.dataset_attr.history]: for old_prompt, old_response in example[self.dataset_attr.history]:
...@@ -116,7 +118,7 @@ class AlpacaDatasetConverter(DatasetConverter): ...@@ -116,7 +118,7 @@ class AlpacaDatasetConverter(DatasetConverter):
@dataclass @dataclass
class SharegptDatasetConverter(DatasetConverter): class SharegptDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]: def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
tag_mapping = { tag_mapping = {
self.dataset_attr.user_tag: Role.USER.value, self.dataset_attr.user_tag: Role.USER.value,
self.dataset_attr.assistant_tag: Role.ASSISTANT.value, self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
...@@ -216,10 +218,8 @@ DATASET_CONVERTERS = { ...@@ -216,10 +218,8 @@ DATASET_CONVERTERS = {
} }
def register_dataset_converter(name: str, dataset_converter: Type["DatasetConverter"]) -> None: def register_dataset_converter(name: str, dataset_converter: type["DatasetConverter"]) -> None:
r""" r"""Register a new dataset converter."""
Register a new dataset converter.
"""
if name in DATASET_CONVERTERS: if name in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} already exists.") raise ValueError(f"Dataset converter {name} already exists.")
...@@ -227,9 +227,7 @@ def register_dataset_converter(name: str, dataset_converter: Type["DatasetConver ...@@ -227,9 +227,7 @@ def register_dataset_converter(name: str, dataset_converter: Type["DatasetConver
def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter": def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter":
r""" r"""Get a dataset converter."""
Gets a dataset converter.
"""
if name not in DATASET_CONVERTERS: if name not in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} not found.") raise ValueError(f"Dataset converter {name} not found.")
...@@ -242,17 +240,17 @@ def align_dataset( ...@@ -242,17 +240,17 @@ def align_dataset(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""Align the dataset to a specific format.
Aligned dataset: Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1) _prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) _response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..." _system: "..."
_tools: "...", _tools: "..."
_images: [], _images: []
_videos: [], _videos: []
_audios: [], _audios: []
""" """
column_names = list(next(iter(dataset)).keys()) column_names = list(next(iter(dataset)).keys())
kwargs = {} kwargs = {}
if not data_args.streaming: if not data_args.streaming:
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from enum import Enum, unique from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union from typing import TYPE_CHECKING, Optional, TypedDict, Union
from datasets import DatasetDict, concatenate_datasets, interleave_datasets from datasets import DatasetDict, concatenate_datasets, interleave_datasets
...@@ -29,7 +29,7 @@ if TYPE_CHECKING: ...@@ -29,7 +29,7 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] SLOTS = list[Union[str, set[str], dict[str, str]]]
@unique @unique
...@@ -43,15 +43,13 @@ class Role(str, Enum): ...@@ -43,15 +43,13 @@ class Role(str, Enum):
class DatasetModule(TypedDict): class DatasetModule(TypedDict):
train_dataset: Optional[Union["Dataset", "IterableDataset"]] train_dataset: Optional[Union["Dataset", "IterableDataset"]]
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]] eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]
def merge_dataset( def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int all_datasets: list[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""Merge multiple datasets to a unified dataset."""
Merges multiple datasets to a unified dataset.
"""
if len(all_datasets) == 1: if len(all_datasets) == 1:
return all_datasets[0] return all_datasets[0]
...@@ -78,14 +76,13 @@ def merge_dataset( ...@@ -78,14 +76,13 @@ def merge_dataset(
def split_dataset( def split_dataset(
dataset: Optional[Union["Dataset", "IterableDataset"]], dataset: Optional[Union["Dataset", "IterableDataset"]],
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]], eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
data_args: "DataArguments", data_args: "DataArguments",
seed: int, seed: int,
) -> "DatasetDict": ) -> "DatasetDict":
r""" r"""Split the dataset and returns a dataset dict containing train set and validation set.
Splits the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset. Support both map dataset and iterable dataset.
""" """
if eval_dataset is not None and data_args.val_size > 1e-6: if eval_dataset is not None and data_args.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.") raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
...@@ -120,10 +117,8 @@ def split_dataset( ...@@ -120,10 +117,8 @@ def split_dataset(
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule": def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
r""" r"""Convert dataset or dataset dict to dataset module."""
Converts dataset or dataset dict to dataset module. dataset_module: DatasetModule = {}
"""
dataset_module: "DatasetModule" = {}
if isinstance(dataset, DatasetDict): # dataset dict if isinstance(dataset, DatasetDict): # dataset dict
if "train" in dataset: if "train" in dataset:
dataset_module["train_dataset"] = dataset["train"] dataset_module["train_dataset"] = dataset["train"]
......
...@@ -16,7 +16,7 @@ import json ...@@ -16,7 +16,7 @@ import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional, Union from typing import Optional, Union
from typing_extensions import override from typing_extensions import override
...@@ -31,14 +31,11 @@ class Formatter(ABC): ...@@ -31,14 +31,11 @@ class Formatter(ABC):
@abstractmethod @abstractmethod
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
r""" r"""Forms a list of slots according to the inputs to encode."""
Forms a list of slots according to the inputs to encode.
"""
... ...
def extract(self, content: str) -> Union[str, List["FunctionCall"]]: def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
r""" r"""Extract a list of tuples from the response message if using tools.
Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments. Each tuple consists of function name and function arguments.
""" """
...@@ -105,7 +102,7 @@ class FunctionFormatter(StringFormatter): ...@@ -105,7 +102,7 @@ class FunctionFormatter(StringFormatter):
if thought: if thought:
content = content.replace(thought.group(0), "") content = content.replace(thought.group(0), "")
functions: List["FunctionCall"] = [] functions: list[FunctionCall] = []
try: try:
tool_calls = json.loads(content) tool_calls = json.loads(content)
if not isinstance(tool_calls, list): # parallel function call if not isinstance(tool_calls, list): # parallel function call
...@@ -141,5 +138,5 @@ class ToolFormatter(Formatter): ...@@ -141,5 +138,5 @@ class ToolFormatter(Formatter):
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
@override @override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]: def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
return self.tool_utils.tool_extractor(content) return self.tool_utils.tool_extractor(content)
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np import numpy as np
from datasets import load_dataset, load_from_disk from datasets import load_dataset, load_from_disk
...@@ -54,9 +54,7 @@ def _load_single_dataset( ...@@ -54,9 +54,7 @@ def _load_single_dataset(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""Load a single dataset and aligns it to the standard format."""
Loads a single dataset and aligns it to the standard format.
"""
logger.info_rank0(f"Loading dataset {dataset_attr}...") logger.info_rank0(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]: if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
...@@ -133,10 +131,12 @@ def _load_single_dataset( ...@@ -133,10 +131,12 @@ def _load_single_dataset(
split=dataset_attr.split, split=dataset_attr.split,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
streaming=data_args.streaming,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
trust_remote_code=model_args.trust_remote_code, trust_remote_code=model_args.trust_remote_code,
streaming=data_args.streaming and dataset_attr.load_from != "file",
) )
if data_args.streaming and dataset_attr.load_from == "file":
dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
if dataset_attr.num_samples is not None and not data_args.streaming: if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples target_num = dataset_attr.num_samples
...@@ -158,16 +158,14 @@ def _load_single_dataset( ...@@ -158,16 +158,14 @@ def _load_single_dataset(
def _get_merged_dataset( def _get_merged_dataset(
dataset_names: Optional[Sequence[str]], dataset_names: Optional[list[str]],
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
merge: bool = True, merge: bool = True,
) -> Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]: ) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
r""" r"""Return the merged datasets in the standard format."""
Returns the merged datasets in the standard format.
"""
if dataset_names is None: if dataset_names is None:
return None return None
...@@ -192,9 +190,7 @@ def _get_dataset_processor( ...@@ -192,9 +190,7 @@ def _get_dataset_processor(
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
do_generate: bool = False, do_generate: bool = False,
) -> "DatasetProcessor": ) -> "DatasetProcessor":
r""" r"""Return the corresponding dataset processor."""
Returns the corresponding dataset processor.
"""
if stage == "pt": if stage == "pt":
dataset_processor_class = PretrainDatasetProcessor dataset_processor_class = PretrainDatasetProcessor
elif stage == "sft" and not do_generate: elif stage == "sft" and not do_generate:
...@@ -236,9 +232,7 @@ def _get_preprocessed_dataset( ...@@ -236,9 +232,7 @@ def _get_preprocessed_dataset(
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False, is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]: ) -> Optional[Union["Dataset", "IterableDataset"]]:
r""" r"""Preprocesses the dataset, including format checking and tokenization."""
Preprocesses the dataset, including format checking and tokenization.
"""
if dataset is None: if dataset is None:
return None return None
...@@ -284,9 +278,7 @@ def get_dataset( ...@@ -284,9 +278,7 @@ def get_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
) -> "DatasetModule": ) -> "DatasetModule":
r""" r"""Get the train dataset and optionally gets the evaluation dataset."""
Gets the train dataset and optionally gets the evaluation dataset.
"""
# Load tokenized dataset if path exists # Load tokenized dataset if path exists
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path): if has_tokenized_data(data_args.tokenized_path):
......
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/processing_llava.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect import inspect
import math import math
import re import re
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, TypedDict, Union from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
import numpy as np import numpy as np
import torch import torch
...@@ -51,28 +68,65 @@ if TYPE_CHECKING: ...@@ -51,28 +68,65 @@ if TYPE_CHECKING:
path: Optional[str] path: Optional[str]
bytes: Optional[bytes] bytes: Optional[bytes]
ImageInput = Union[str, bytes, EncodedImage, ImageObject] ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
VideoInput = str VideoInput = Union[str, BinaryIO]
AudioInput = Union[str, NDArray] AudioInput = Union[str, BinaryIO, NDArray]
class MMProcessor(ProcessorMixin):
patch_size: int
image_seq_length: int
num_additional_image_tokens: int
vision_feature_select_strategy: Literal["default", "full"]
def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
pass
def _get_paligemma_token_type_ids( def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin" r"""Get paligemma token type ids for computing loss.
) -> List[List[int]]:
r""" It is slightly different with the original token type ids where the prompt part is 0.
Gets paligemma token type ids for computing loss.
Returns: Returns:
batch_token_type_ids: shape (batch_size, sequence_length) batch_token_type_ids: shape (batch_size, seq_length)
""" """
batch_token_type_ids = [] batch_token_type_ids = []
for imglen, seqlen in zip(imglens, seqlens): for imglen, seqlen in zip(imglens, seqlens):
image_seqlen = imglen * getattr(processor, "image_seqlen") image_seqlen = imglen * processor.image_seq_length
batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen)) batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
return batch_token_type_ids return batch_token_type_ids
def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcessor"):
r"""Get gemma3 token type ids for computing loss.
Returns:
batch_token_type_ids: shape (batch_size, seq_length)
"""
image_token_id: int = getattr(processor, "image_token_id")
batch_token_type_ids = []
for token_ids in batch_ids:
token_ids = np.array(token_ids)
token_type_ids = np.zeros_like(token_ids)
token_type_ids[token_ids == image_token_id] = 1
batch_token_type_ids.append(token_type_ids.tolist())
return batch_token_type_ids
def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]:
r"""Make nested list of images."""
batch_images = []
for imglen in imglens:
batch_images.append(images[:imglen])
images = images[imglen:]
return batch_images
@dataclass @dataclass
class MMPluginMixin: class MMPluginMixin:
image_token: Optional[str] image_token: Optional[str]
...@@ -82,16 +136,17 @@ class MMPluginMixin: ...@@ -82,16 +136,17 @@ class MMPluginMixin:
def _validate_input( def _validate_input(
self, self,
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
) -> None: ) -> None:
r""" r"""Validate if this model accepts the input modalities."""
Validates if this model accepts the input modalities. image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
""" video_processor: BaseImageProcessor = getattr(
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None) processor, "video_processor", getattr(processor, "image_processor", None)
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None) )
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
if len(images) != 0 and self.image_token is None: if len(images) != 0 and self.image_token is None:
raise ValueError( raise ValueError(
"This model does not support image input. Please check whether the correct `template` is used." "This model does not support image input. Please check whether the correct `template` is used."
...@@ -113,15 +168,16 @@ class MMPluginMixin: ...@@ -113,15 +168,16 @@ class MMPluginMixin:
if self.image_token is not None and image_processor is None: if self.image_token is not None and image_processor is None:
raise ValueError("Image processor was not found, please check and update your processor config.") raise ValueError("Image processor was not found, please check and update your processor config.")
if self.video_token is not None and video_processor is None:
raise ValueError("Video processor was not found, please check and update your processor config.")
if self.audio_token is not None and feature_extractor is None: if self.audio_token is not None and feature_extractor is None:
raise ValueError("Audio feature extractor was not found, please check and update your processor config.") raise ValueError("Audio feature extractor was not found, please check and update your processor config.")
def _preprocess_image( def _preprocess_image(
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
) -> "ImageObject": ) -> "ImageObject":
r""" r"""Pre-process a single image."""
Pre-processes a single image.
"""
if (image.width * image.height) > image_max_pixels: if (image.width * image.height) > image_max_pixels:
resize_factor = math.sqrt(image_max_pixels / (image.width * image.height)) resize_factor = math.sqrt(image_max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor) width, height = int(image.width * resize_factor), int(image.height * resize_factor)
...@@ -139,10 +195,8 @@ class MMPluginMixin: ...@@ -139,10 +195,8 @@ class MMPluginMixin:
def _get_video_sample_indices( def _get_video_sample_indices(
self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs
) -> List[int]: ) -> list[int]:
r""" r"""Compute video sample indices according to fps."""
Computes video sample indices according to fps.
"""
total_frames = video_stream.frames total_frames = video_stream.frames
if total_frames == 0: # infinite video if total_frames == 0: # infinite video
return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32) return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
...@@ -151,13 +205,11 @@ class MMPluginMixin: ...@@ -151,13 +205,11 @@ class MMPluginMixin:
sample_frames = min(total_frames, video_maxlen, sample_frames) sample_frames = min(total_frames, video_maxlen, sample_frames)
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]: def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str, list["ImageObject"]]:
r""" r"""Regularize images to avoid error. Including reading and pre-processing."""
Regularizes images to avoid error. Including reading and pre-processing.
"""
results = [] results = []
for image in images: for image in images:
if isinstance(image, str): if isinstance(image, (str, BinaryIO)):
image = Image.open(image) image = Image.open(image)
elif isinstance(image, bytes): elif isinstance(image, bytes):
image = Image.open(BytesIO(image)) image = Image.open(BytesIO(image))
...@@ -172,53 +224,52 @@ class MMPluginMixin: ...@@ -172,53 +224,52 @@ class MMPluginMixin:
results.append(self._preprocess_image(image, **kwargs)) results.append(self._preprocess_image(image, **kwargs))
return results return {"images": results}
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]: def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]:
r""" r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
Regularizes videos to avoid error. Including reading, resizing and converting.
"""
results = [] results = []
for video in videos: for video in videos:
container = av.open(video, "r") container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video") video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs) sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
frames: List["ImageObject"] = [] frames: list[ImageObject] = []
container.seek(0) container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)): for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices: if frame_idx in sample_indices:
frames.append(frame.to_image()) frames.append(frame.to_image())
frames = self._regularize_images(frames, **kwargs) frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames) results.append(frames)
return results return {"videos": results}
def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> List["NDArray"]: def _regularize_audios(
r""" self, audios: list["AudioInput"], sampling_rate: float, **kwargs
Regularizes audios to avoid error. Including reading and resampling. ) -> dict[str, Union[list["NDArray"], list[float]]]:
""" r"""Regularizes audios to avoid error. Including reading and resampling."""
results = [] results, sampling_rates = [], []
for audio in audios: for audio in audios:
if isinstance(audio, str): if isinstance(audio, (str, BinaryIO)):
audio = librosa.load(audio, sr=sampling_rate)[0] audio, sampling_rate = librosa.load(audio, sr=sampling_rate)
if not isinstance(audio, np.ndarray): if not isinstance(audio, np.ndarray):
raise ValueError(f"Expect input is a list of audios, but got {type(audio)}.") raise ValueError(f"Expect input is a list of audios, but got {type(audio)}.")
results.append(audio) results.append(audio)
sampling_rates.append(sampling_rate)
return results return {"audios": results, "sampling_rates": sampling_rates}
def _get_mm_inputs( def _get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: "ProcessorMixin", processor: "MMProcessor",
) -> Dict[str, "torch.Tensor"]: imglens: Optional[list[int]] = None,
r""" ) -> dict[str, "torch.Tensor"]:
Processes visual inputs. r"""Process visual inputs.
Returns: (llava and paligemma) Returns: (llava and paligemma)
pixel_values: tensor with shape (B, C, H, W) pixel_values: tensor with shape (B, C, H, W)
...@@ -226,44 +277,67 @@ class MMPluginMixin: ...@@ -226,44 +277,67 @@ class MMPluginMixin:
Returns: (qwen2-vl) Returns: (qwen2-vl)
pixel_values: tensor with shape (num_patches, patch_dim) pixel_values: tensor with shape (num_patches, patch_dim)
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
where num_patches == torch.prod(image_grid_thw)
Returns: (mllama)
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
It holds num_patches == torch.prod(image_grid_thw)
""" """
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
mm_inputs = {} mm_inputs = {}
if len(images) != 0: if len(images) != 0:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
images = self._regularize_images( images = self._regularize_images(
images, images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
) )["images"]
mm_inputs.update(image_processor(images, return_tensors="pt")) if imglens is not None: # if imglens are provided, make batched images
images = _make_batched_images(images, imglens)
image_processor_kwargs = {}
if getattr(processor, "image_do_pan_and_scan", False): # gemma3 image processor
image_processor_kwargs.update(
{
"do_pan_and_scan": True,
"pan_and_scan_min_crop_size": 256,
"pan_and_scan_max_num_crops": 4,
"pan_and_scan_min_ratio_to_activate": 1.2,
}
)
mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs))
if len(videos) != 0: if len(videos) != 0:
video_processor: BaseImageProcessor = getattr(
processor, "video_processor", getattr(processor, "image_processor", None)
)
videos = self._regularize_videos( videos = self._regularize_videos(
videos, videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0), video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128), video_maxlen=getattr(processor, "video_maxlen", 128),
) )["videos"]
if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt")) mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
else: # for llava_next_video else: # for llava_next_video
mm_inputs.update(video_processor(videos, return_tensors="pt")) mm_inputs.update(video_processor(videos, return_tensors="pt"))
if len(audios) != 0: if len(audios) != 0:
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
audios = self._regularize_audios( audios = self._regularize_audios(
audios, audios,
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000), sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
) )["audios"]
mm_inputs.update( mm_inputs.update(
feature_extractor( feature_extractor(
audios, audios,
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000), sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
return_attention_mask=True, return_attention_mask=True,
padding="max_length", padding="max_length",
return_tensors="pt", return_tensors="pt",
...@@ -278,83 +352,95 @@ class MMPluginMixin: ...@@ -278,83 +352,95 @@ class MMPluginMixin:
class BasePlugin(MMPluginMixin): class BasePlugin(MMPluginMixin):
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
r""" r"""Pre-process input messages before tokenization for VLMs."""
Pre-processes input messages before tokenization for VLMs.
"""
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return messages return messages
def process_token_ids( def process_token_ids(
self, self,
input_ids: List[int], input_ids: list[int],
labels: Optional[List[int]], labels: Optional[list[int]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> Tuple[List[int], Optional[List[int]]]: ) -> tuple[list[int], Optional[list[int]]]:
r""" r"""Pre-process token ids after tokenization for VLMs."""
Pre-processes token ids after tokenization for VLMs.
"""
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return input_ids, labels return input_ids, labels
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
imglens: Sequence[int], imglens: list[int],
vidlens: Sequence[int], vidlens: list[int],
audlens: Sequence[int], audlens: list[int],
batch_ids: Sequence[List[int]], batch_ids: list[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
r""" r"""Build batched multimodal inputs for VLMs.
Builds batched multimodal inputs for VLMs.
Arguments: Arguments:
images: a list of image inputs, shape (num_images,) images: a list of image inputs, shape (num_images,)
videos: a list of video inputs, shape (num_videos,) videos: a list of video inputs, shape (num_videos,)
audios: a list of audio inputs, shape (num_audios,)
imglens: number of images in each sample, shape (batch_size,) imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,) vidlens: number of videos in each sample, shape (batch_size,)
audlens: number of audios in each sample, shape (batch_size,) audlens: number of audios in each sample, shape (batch_size,)
batch_ids: token ids of input samples, shape (batch_size, seq_len) batch_ids: token ids of input samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos processor: a processor for pre-processing images and videos
""" """
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return {} return self._get_mm_inputs(images, videos, audios, processor)
@dataclass @dataclass
class LlavaPlugin(BasePlugin): class Gemma3Plugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
messages = deepcopy(messages) messages = deepcopy(messages)
boi_token: str = getattr(processor, "boi_token")
full_image_sequence: str = getattr(processor, "full_image_sequence")
image_str = full_image_sequence if self.expand_mm_tokens else boi_token
do_pan_and_scan: bool = getattr(processor, "image_do_pan_and_scan", False)
if do_pan_and_scan:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) if do_pan_and_scan:
image_placeholder_str = (
"Here is the original image {{image}} and here are some crops to help you see better "
+ " ".join(["{{image}}"] * mm_inputs["num_crops"][0][num_image_tokens])
)
else:
image_placeholder_str = "{{image}}"
content = content.replace(IMAGE_PLACEHOLDER, image_placeholder_str, 1)
num_image_tokens += 1 num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token) message["content"] = content.replace("{{image}}", image_str)
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
...@@ -364,37 +450,148 @@ class LlavaPlugin(BasePlugin): ...@@ -364,37 +450,148 @@ class LlavaPlugin(BasePlugin):
@override @override
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
imglens: Sequence[int], imglens: list[int],
vidlens: Sequence[int], vidlens: list[int],
audlens: Sequence[int], audlens: list[int],
batch_ids: Sequence[List[int]], batch_ids: list[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs.pop("num_crops", None)
mm_inputs["token_type_ids"] = _get_gemma3_token_type_ids(batch_ids, processor)
return mm_inputs
@dataclass @dataclass
class LlavaNextPlugin(BasePlugin): class Llama4Plugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:]
num_patches_per_chunk = int(
(image_height // processor.patch_size)
* (image_width // processor.patch_size)
// processor.downsample_ratio
)
aspect_ratios = mm_inputs.pop("aspect_ratios")
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages:
content = message["content"]
placeholder_count = content.count(IMAGE_PLACEHOLDER)
if self.expand_mm_tokens:
prompt_splits = content.split(IMAGE_PLACEHOLDER)
new_content = []
for local_image_index, split_part in enumerate(prompt_splits):
new_content.append(split_part)
if local_image_index < placeholder_count:
tokens_for_this_image = processor._prompt_split_image(
aspect_ratios[num_image_tokens], num_patches_per_chunk
)
num_image_tokens += 1
new_content.append(tokens_for_this_image)
content = "".join(new_content)
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs: mm_inputs.pop("aspect_ratios", None)
image_sizes = iter(mm_inputs["image_sizes"].tolist()) return mm_inputs
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
@dataclass
class LlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0]))
image_seqlen = (height // processor.patch_size) * (
width // processor.patch_size
) + processor.num_additional_image_tokens
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
else:
image_seqlen = 1
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@dataclass
class LlavaNextPlugin(BasePlugin):
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"].tolist())
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages: for message in messages:
content = message["content"] content = message["content"]
...@@ -402,7 +599,7 @@ class LlavaNextPlugin(BasePlugin): ...@@ -402,7 +599,7 @@ class LlavaNextPlugin(BasePlugin):
if self.expand_mm_tokens: if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes) orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy", "default") == "default": if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1 image_seqlen -= 1
else: else:
image_seqlen = 1 image_seqlen = 1
...@@ -417,73 +614,60 @@ class LlavaNextPlugin(BasePlugin): ...@@ -417,73 +614,60 @@ class LlavaNextPlugin(BasePlugin):
return messages return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@dataclass @dataclass
class LlavaNextVideoPlugin(BasePlugin): class LlavaNextVideoPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0 num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) if self.expand_mm_tokens:
if "pixel_values" in mm_inputs: mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_sizes = iter(mm_inputs["image_sizes"].tolist()) if "pixel_values" in mm_inputs:
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) image_sizes = iter(mm_inputs["image_sizes"].tolist())
for message in messages: height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
content = message["content"]
while IMAGE_PLACEHOLDER in content: for message in messages:
if self.expand_mm_tokens: content = message["content"]
orig_height, orig_width = next(image_sizes) while IMAGE_PLACEHOLDER in content:
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) if self.expand_mm_tokens:
if getattr(processor, "vision_feature_select_strategy", "default") == "default": orig_height, orig_width = next(image_sizes)
image_seqlen -= 1 image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
else: if processor.vision_feature_select_strategy == "default":
image_seqlen = 1 image_seqlen -= 1
else:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) image_seqlen = 1
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
message["content"] = content.replace("{{image}}", self.image_token) num_image_tokens += 1
if "pixel_values_videos" in mm_inputs: message["content"] = content.replace("{{image}}", self.image_token)
if self.expand_mm_tokens:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) if self.expand_mm_tokens:
height, width = get_image_size(pixel_values_video[0]) if "pixel_values_videos" in mm_inputs:
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(one_video[0])
num_frames = one_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
else: else:
video_seqlen = 1 video_seqlen = 1
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
num_video_tokens += 1 content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) num_video_tokens += 1
message["content"] = content.replace("{{video}}", self.video_token) message["content"] = content.replace("{{video}}", self.video_token)
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
...@@ -493,37 +677,22 @@ class LlavaNextVideoPlugin(BasePlugin): ...@@ -493,37 +677,22 @@ class LlavaNextVideoPlugin(BasePlugin):
return messages return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@dataclass @dataclass
class MiniCPMVPlugin(BasePlugin): class MiniCPMVPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: BaseImageProcessor = getattr(processor, "image_processor")
mm_inputs = {} mm_inputs = {}
audio_inputs = {} audio_inputs = {}
if len(images) != 0 and len(videos) != 0: if len(images) != 0 and len(videos) != 0:
...@@ -614,21 +783,20 @@ class MiniCPMVPlugin(BasePlugin): ...@@ -614,21 +783,20 @@ class MiniCPMVPlugin(BasePlugin):
@override @override
def _get_mm_inputs( def _get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: "ProcessorMixin", processor: "MMProcessor",
**kwargs, **kwargs,
) -> Dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: BaseImageProcessor = getattr(processor, "image_processor")
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
mm_inputs = {} mm_inputs = {}
if len(images) != 0: if len(images) != 0:
images = self._regularize_images( images = self._regularize_images(
images, images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
) )["images"]
if "valid_image_nums_ls" in kwargs: if "valid_image_nums_ls" in kwargs:
valid_image_nums_ls = kwargs["valid_image_nums_ls"] valid_image_nums_ls = kwargs["valid_image_nums_ls"]
new_images = [] new_images = []
...@@ -651,15 +819,15 @@ class MiniCPMVPlugin(BasePlugin): ...@@ -651,15 +819,15 @@ class MiniCPMVPlugin(BasePlugin):
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0), video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128), video_maxlen=getattr(processor, "video_maxlen", 128),
) )["videos"]
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt") video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
mm_inputs.update(video_inputs) mm_inputs.update(video_inputs)
if len(audios) != 0: if len(audios) != 0:
audios = self._regularize_audios( audios = self._regularize_audios(
audios, audios,
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000), sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
) )["audios"]
if "valid_audio_nums_ls" in kwargs: if "valid_audio_nums_ls" in kwargs:
valid_audio_nums_ls = kwargs["valid_audio_nums_ls"] valid_audio_nums_ls = kwargs["valid_audio_nums_ls"]
audios_ls = [] audios_ls = []
...@@ -673,7 +841,7 @@ class MiniCPMVPlugin(BasePlugin): ...@@ -673,7 +841,7 @@ class MiniCPMVPlugin(BasePlugin):
audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract( audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
audios_ls, audios_ls,
chunk_input=True, chunk_input=True,
sampling_rate=16000, sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
) )
audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens] audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens}) mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
...@@ -685,15 +853,15 @@ class MiniCPMVPlugin(BasePlugin): ...@@ -685,15 +853,15 @@ class MiniCPMVPlugin(BasePlugin):
@override @override
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
imglens: Sequence[int], imglens: list[int],
vidlens: Sequence[int], vidlens: list[int],
audlens: Sequence[int], audlens: list[int],
batch_ids: Sequence[List[int]], batch_ids: list[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
# image bound # image bound
image_bounds_list = [] image_bounds_list = []
...@@ -756,12 +924,12 @@ class MllamaPlugin(BasePlugin): ...@@ -756,12 +924,12 @@ class MllamaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
...@@ -775,61 +943,24 @@ class MllamaPlugin(BasePlugin): ...@@ -775,61 +943,24 @@ class MllamaPlugin(BasePlugin):
return messages return messages
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
imglens: List[int],
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
Returns:
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
mm_inputs = {}
if len(images) > 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)
batch_images = []
for image_length in imglens:
batch_images.append(images[:image_length])
images = images[image_length:]
mm_inputs.update(image_processor(batch_images, return_tensors="pt"))
return mm_inputs
@override @override
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
imglens: Sequence[int], imglens: list[int],
vidlens: Sequence[int], vidlens: list[int],
audlens: Sequence[int], audlens: list[int],
batch_ids: Sequence[List[int]], batch_ids: list[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens) mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
if mm_inputs: if mm_inputs:
num_tiles = mm_inputs.pop("num_tiles") num_tiles = mm_inputs.pop("num_tiles")
image_token_id = getattr(processor, "image_token_id") image_token_id: int = getattr(processor, "image_token_id")
max_image_tiles = getattr(processor.image_processor, "max_image_tiles") max_image_tiles: int = getattr(processor.image_processor, "max_image_tiles")
cross_attention_token_mask = [ cross_attention_token_mask = [
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
] ]
...@@ -850,22 +981,22 @@ class PaliGemmaPlugin(BasePlugin): ...@@ -850,22 +981,22 @@ class PaliGemmaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) content = content.replace(IMAGE_PLACEHOLDER, "", 1)
num_image_tokens += 1 num_image_tokens += 1
message["content"] = content.replace("{{image}}", "") message["content"] = content
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
...@@ -875,36 +1006,36 @@ class PaliGemmaPlugin(BasePlugin): ...@@ -875,36 +1006,36 @@ class PaliGemmaPlugin(BasePlugin):
@override @override
def process_token_ids( def process_token_ids(
self, self,
input_ids: List[int], input_ids: list[int],
labels: Optional[List[int]], labels: Optional[list[int]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> Tuple[List[int], Optional[List[int]]]: ) -> tuple[list[int], Optional[list[int]]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_images = len(images) num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
input_ids = [image_token_id] * image_seqlen + input_ids input_ids = [image_token_id] * num_images * image_seqlen + input_ids
if labels is not None: if labels is not None:
labels = [IGNORE_INDEX] * image_seqlen + labels labels = [IGNORE_INDEX] * num_images * image_seqlen + labels
return input_ids, labels return input_ids, labels
@override @override
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
imglens: Sequence[int], imglens: list[int],
vidlens: Sequence[int], vidlens: list[int],
audlens: Sequence[int], audlens: list[int],
batch_ids: Sequence[List[int]], batch_ids: list[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
seqlens = [len(input_ids) for input_ids in batch_ids] seqlens = [len(input_ids) for input_ids in batch_ids]
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
...@@ -917,37 +1048,39 @@ class PixtralPlugin(BasePlugin): ...@@ -917,37 +1048,39 @@ class PixtralPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
patch_size = getattr(processor, "patch_size")
image_token = getattr(processor, "image_token")
image_break_token = getattr(processor, "image_break_token")
image_end_token = getattr(processor, "image_end_token")
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) if self.expand_mm_tokens:
if "pixel_values" in mm_inputs: mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_sizes = iter(mm_inputs["image_sizes"].tolist()) if "pixel_values" in mm_inputs:
# BC for transformers < 4.49.0
if isinstance(mm_inputs["image_sizes"], list):
image_sizes = iter(mm_inputs["image_sizes"][0])
else:
image_sizes = iter(mm_inputs["image_sizes"].tolist())
image_break_token: str = getattr(processor, "image_break_token")
image_end_token: str = getattr(processor, "image_end_token")
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if self.expand_mm_tokens: if self.expand_mm_tokens:
height, width = next(image_sizes) height, width = next(image_sizes)
num_height_tokens = height // patch_size num_height_tokens = height // processor.patch_size
num_width_tokens = width // patch_size num_width_tokens = width // processor.patch_size
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens replace_tokens = [[self.image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
replace_tokens[-1] = image_end_token replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens) replace_str = "".join(replace_tokens)
else: else:
replace_str = image_token replace_str = self.image_token
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1) content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
num_image_tokens += 1 num_image_tokens += 1
...@@ -962,18 +1095,22 @@ class PixtralPlugin(BasePlugin): ...@@ -962,18 +1095,22 @@ class PixtralPlugin(BasePlugin):
@override @override
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
imglens: Sequence[int], imglens: list[int],
vidlens: Sequence[int], vidlens: list[int],
audlens: Sequence[int], audlens: list[int],
batch_ids: Sequence[List[int]], batch_ids: list[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs.pop("image_sizes", None) # ref to this commit https://github.com/huggingface/transformers/pull/35122
# after transformers 4.49.0, the `image_sizes` is mandatory as an input parameter for Pixtral VisionEncoder forwarding.
# it can be passed into `LlavaConditionalGeneration` as a parameter.
if not is_transformers_version_greater_than("4.49.0"):
mm_inputs.pop("image_sizes", None)
return mm_inputs return mm_inputs
...@@ -982,21 +1119,22 @@ class Qwen2AudioPlugin(BasePlugin): ...@@ -982,21 +1119,22 @@ class Qwen2AudioPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
bos_token: str = getattr(processor, "audio_bos_token") bos_token: str = getattr(processor, "audio_bos_token")
eos_token: str = getattr(processor, "audio_eos_token") eos_token: str = getattr(processor, "audio_eos_token")
num_audio_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs([], [], audios, processor) if self.expand_mm_tokens:
if "feature_attention_mask" in mm_inputs: mm_inputs = self._get_mm_inputs([], [], audios, processor)
audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist() if "feature_attention_mask" in mm_inputs:
audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
num_audio_tokens = 0
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while AUDIO_PLACEHOLDER in content: while AUDIO_PLACEHOLDER in content:
...@@ -1022,15 +1160,15 @@ class Qwen2AudioPlugin(BasePlugin): ...@@ -1022,15 +1160,15 @@ class Qwen2AudioPlugin(BasePlugin):
@override @override
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
imglens: Sequence[int], imglens: list[int],
vidlens: Sequence[int], vidlens: list[int],
audlens: Sequence[int], audlens: list[int],
batch_ids: Sequence[List[int]], batch_ids: list[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor) return self._get_mm_inputs(images, videos, audios, processor)
...@@ -1056,14 +1194,14 @@ class Qwen2VLPlugin(BasePlugin): ...@@ -1056,14 +1194,14 @@ class Qwen2VLPlugin(BasePlugin):
@override @override
def _regularize_videos( def _regularize_videos(
self, videos: Sequence["VideoInput"], **kwargs self, videos: list["VideoInput"], **kwargs
) -> Tuple[List[List["ImageObject"]], List[float]]: ) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
results, fps_per_video = [], [] results, fps_per_video = [], []
for video in videos: for video in videos:
container = av.open(video, "r") container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video") video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs) sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
frames: List["ImageObject"] = [] frames: list[ImageObject] = []
container.seek(0) container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)): for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices: if frame_idx in sample_indices:
...@@ -1072,59 +1210,61 @@ class Qwen2VLPlugin(BasePlugin): ...@@ -1072,59 +1210,61 @@ class Qwen2VLPlugin(BasePlugin):
if len(frames) % 2 != 0: # qwen2-vl requires even number of frames if len(frames) % 2 != 0: # qwen2-vl requires even number of frames
frames.append(frames[-1]) frames.append(frames[-1])
frames = self._regularize_images(frames, **kwargs) frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames) results.append(frames)
if video_stream.duration is None: if video_stream.duration is None:
fps_per_video.append(2.0) fps_per_video.append(2.0)
else: else:
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base)) fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
return results, fps_per_video return {"videos": results, "fps_per_video": fps_per_video}
@override @override
def _get_mm_inputs( def _get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: "ProcessorMixin", processor: "MMProcessor",
) -> Dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None) image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
mm_inputs = {} mm_inputs = {}
if len(images) != 0: if len(images) != 0:
images = self._regularize_images( images = self._regularize_images(
images, images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
) )["images"]
mm_inputs.update(image_processor(images, return_tensors="pt")) mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0: if len(videos) != 0:
videos, fps_per_video = self._regularize_videos( video_data = self._regularize_videos(
videos, videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0), video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128), video_maxlen=getattr(processor, "video_maxlen", 128),
) )
mm_inputs.update(image_processor(images=None, videos=videos, return_tensors="pt")) mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt"))
mm_inputs["fps_per_video"] = fps_per_video temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
if "second_per_grid_ts" in processor.model_input_names:
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]]
return mm_inputs return mm_inputs
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0 num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: BaseImageProcessor = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2 merge_length: int = getattr(image_processor, "merge_size") ** 2
if self.expand_mm_tokens: if self.expand_mm_tokens:
...@@ -1167,76 +1307,180 @@ class Qwen2VLPlugin(BasePlugin): ...@@ -1167,76 +1307,180 @@ class Qwen2VLPlugin(BasePlugin):
return messages return messages
class Qwen2OmniPlugin(Qwen2VLPlugin):
@override @override
def get_mm_inputs( def _get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
imglens: Sequence[int], processor: "MMProcessor",
vidlens: Sequence[int], ) -> dict[str, "torch.Tensor"]:
audlens: Sequence[int], image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
batch_ids: Sequence[List[int]], feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
processor: Optional["ProcessorMixin"], mm_inputs = {}
) -> Dict[str, Union[List[int], "torch.Tensor"]]: if len(images) != 0:
self._validate_input(processor, images, videos, audios) images = self._regularize_images(
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) images,
fps_per_video = mm_inputs.pop("fps_per_video", []) image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
if "second_per_grid_ts" in processor.model_input_names and fps_per_video: )["images"]
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video] mm_inputs.update(image_processor(images, return_tensors="pt"))
return mm_inputs if len(videos) != 0:
video_dict = self._regularize_videos(
videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt"))
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
mm_inputs["video_second_per_grid"] = torch.tensor(
[temporal_patch_size / fps for fps in video_dict["fps_per_video"]]
)
if len(audios) != 0:
audios = self._regularize_audios(
audios,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
)["audios"]
mm_inputs.update(
feature_extractor(
audios,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
return_attention_mask=True,
padding="max_length",
return_tensors="pt",
)
)
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
return mm_inputs
@dataclass
class VideoLlavaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) if self.expand_mm_tokens:
num_frames = 0 mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
has_images = "pixel_values_images" in mm_inputs else:
has_videos = "pixel_values_videos" in mm_inputs mm_inputs = {}
if has_images or has_videos:
if self.expand_mm_tokens:
if has_images:
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
num_frames = 1
if has_videos: num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) use_audio_in_video = getattr(processor, "use_audio_in_video", False)
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1 # get length or size from mm_inputs
video_seqlen = image_seqlen * num_frames if "feature_attention_mask" in mm_inputs:
if getattr(processor, "vision_feature_select_strategy", "default") == "default": input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
image_seqlen -= 1 audio_lengths = (input_lengths - 2) // 2 + 1
else:
image_seqlen, video_seqlen = 1, 1 if mm_inputs.get("image_grid_thw", None) is not None:
image_grid_thw = mm_inputs["image_grid_thw"]
merge_length = processor.image_processor.merge_size**2
if mm_inputs.get("video_grid_thw", None) is not None:
video_grid_thw = mm_inputs["video_grid_thw"]
merge_length = processor.image_processor.merge_size**2
if use_audio_in_video:
if audio_lengths is None:
raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
if not mm_inputs.get("video_grid_thw", None):
raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
positions_list = []
for i, message in enumerate(messages): # get multimodal index when use_audio
positions = []
for special_token in [self.audio_token, self.image_token, self.video_token]:
start = 0
while True:
pos = message[i].find(special_token, start)
if pos == -1:
break
positions.append((pos, special_token))
start = pos + len(special_token)
positions_list.append(positions.sort(key=lambda x: x[0]))
for message in messages:
content = message["content"]
# separate with audio-video
while IMAGE_PLACEHOLDER in content:
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
content = content.replace(
IMAGE_PLACEHOLDER,
f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>",
1,
)
num_image_tokens += 1
for message in messages: if not use_audio_in_video:
content = message["content"] while AUDIO_PLACEHOLDER in content:
while IMAGE_PLACEHOLDER in content: audio_token_replace_length = audio_lengths[num_audio_tokens]
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) content = content.replace(
num_image_tokens += 1 AUDIO_PLACEHOLDER,
f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>",
1,
)
num_audio_tokens += 1
# TODO handle video_input and use_audio_in_video
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
content = content.replace(
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
)
num_video_tokens += 1 num_video_tokens += 1
content = content.replace("{{image}}", self.image_token) else: # if use the audio of video # deal video token and audio token togather
message["content"] = content.replace("{{video}}", self.video_token) while VIDEO_PLACEHOLDER in content:
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
video_t_index = (
torch.arange(video_grid_thw[num_video_tokens][0])
.view(-1, 1, 1)
.expand(
-1,
video_grid_thw[num_video_tokens][1] // self.image_processor.merge_size,
video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size,
)
.flatten()
* mm_inputs["video_second_per_grid"][num_video_tokens]
* 25 # FIXME hardcode of position_id_per_seconds=25
).long()
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
placeholder_string = ""
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
if video_chunk_index is not None:
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
if audio_chunk_index is not None:
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
num_audio_tokens += 1
num_video_tokens += 1
message["content"] = content
if len(audios) != num_audio_tokens:
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
...@@ -1246,24 +1490,69 @@ class VideoLlavaPlugin(BasePlugin): ...@@ -1246,24 +1490,69 @@ class VideoLlavaPlugin(BasePlugin):
return messages return messages
@dataclass
class VideoLlavaPlugin(BasePlugin):
@override @override
def get_mm_inputs( def process_messages(
self, self,
images: Sequence["ImageInput"], messages: list[dict[str, str]],
videos: Sequence["VideoInput"], images: list["ImageInput"],
audios: Sequence["AudioInput"], videos: list["VideoInput"],
imglens: Sequence[int], audios: list["AudioInput"],
vidlens: Sequence[int], processor: Optional["MMProcessor"],
audlens: Sequence[int], ) -> list[dict[str, str]]:
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor) num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
num_frames = 0
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values_images" in mm_inputs:
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values_images"][0]))
num_frames = 1
if "pixel_values_videos" in mm_inputs:
one_video = to_numpy_array(mm_inputs["pixel_values_videos"][0])
height, width = get_image_size(one_video[0])
num_frames = one_video.shape[0] # frame dim is always after batch dim
if "pixel_values_images" in mm_inputs or "pixel_values_videos" in mm_inputs:
image_seqlen = (height // processor.patch_size) * (
width // processor.patch_size
) + processor.num_additional_image_tokens
video_seqlen = image_seqlen * num_frames
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
else:
image_seqlen, video_seqlen = 1, 1
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
num_video_tokens += 1
content = content.replace("{{image}}", self.image_token)
message["content"] = content.replace("{{video}}", self.video_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages
PLUGINS = { PLUGINS = {
"base": BasePlugin, "base": BasePlugin,
"gemma3": Gemma3Plugin,
"llama4": Llama4Plugin,
"llava": LlavaPlugin, "llava": LlavaPlugin,
"llava_next": LlavaNextPlugin, "llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin, "llava_next_video": LlavaNextVideoPlugin,
...@@ -1272,15 +1561,14 @@ PLUGINS = { ...@@ -1272,15 +1561,14 @@ PLUGINS = {
"paligemma": PaliGemmaPlugin, "paligemma": PaliGemmaPlugin,
"pixtral": PixtralPlugin, "pixtral": PixtralPlugin,
"qwen2_audio": Qwen2AudioPlugin, "qwen2_audio": Qwen2AudioPlugin,
"qwen2_omni": Qwen2OmniPlugin,
"qwen2_vl": Qwen2VLPlugin, "qwen2_vl": Qwen2VLPlugin,
"video_llava": VideoLlavaPlugin, "video_llava": VideoLlavaPlugin,
} }
def register_mm_plugin(name: str, plugin_class: Type["BasePlugin"]) -> None: def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
r""" r"""Register a multimodal plugin."""
Registers a multimodal plugin.
"""
if name in PLUGINS: if name in PLUGINS:
raise ValueError(f"Multimodal plugin {name} already exists.") raise ValueError(f"Multimodal plugin {name} already exists.")
...@@ -1293,9 +1581,7 @@ def get_mm_plugin( ...@@ -1293,9 +1581,7 @@ def get_mm_plugin(
video_token: Optional[str] = None, video_token: Optional[str] = None,
audio_token: Optional[str] = None, audio_token: Optional[str] = None,
) -> "BasePlugin": ) -> "BasePlugin":
r""" r"""Get plugin for multimodal inputs."""
Gets plugin for multimodal inputs.
"""
if name not in PLUGINS: if name not in PLUGINS:
raise ValueError(f"Multimodal plugin `{name}` not found.") raise ValueError(f"Multimodal plugin `{name}` not found.")
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Sequence from typing import Any, Literal, Optional
from transformers.utils import cached_file from huggingface_hub import hf_hub_download
from ..extras.constants import DATA_CONFIG from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope, use_openmind from ..extras.misc import use_modelscope, use_openmind
...@@ -25,9 +25,7 @@ from ..extras.misc import use_modelscope, use_openmind ...@@ -25,9 +25,7 @@ from ..extras.misc import use_modelscope, use_openmind
@dataclass @dataclass
class DatasetAttr: class DatasetAttr:
r""" r"""Dataset attributes."""
Dataset attributes.
"""
# basic configs # basic configs
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"] load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
...@@ -68,10 +66,10 @@ class DatasetAttr: ...@@ -68,10 +66,10 @@ class DatasetAttr:
def __repr__(self) -> str: def __repr__(self) -> str:
return self.dataset_name return self.dataset_name
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None: def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
setattr(self, key, obj.get(key, default)) setattr(self, key, obj.get(key, default))
def join(self, attr: Dict[str, Any]) -> None: def join(self, attr: dict[str, Any]) -> None:
self.set_attr("formatting", attr, default="alpaca") self.set_attr("formatting", attr, default="alpaca")
self.set_attr("ranking", attr, default=False) self.set_attr("ranking", attr, default=False)
self.set_attr("subset", attr) self.set_attr("subset", attr)
...@@ -92,10 +90,8 @@ class DatasetAttr: ...@@ -92,10 +90,8 @@ class DatasetAttr:
self.set_attr(tag, attr["tags"]) self.set_attr(tag, attr["tags"])
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]: def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> list["DatasetAttr"]:
r""" r"""Get the attributes of the datasets."""
Gets the attributes of the datasets.
"""
if dataset_names is None: if dataset_names is None:
dataset_names = [] dataset_names = []
...@@ -103,7 +99,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) - ...@@ -103,7 +99,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_info = None dataset_info = None
else: else:
if dataset_dir.startswith("REMOTE:"): if dataset_dir.startswith("REMOTE:"):
config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset") config_path = hf_hub_download(repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
else: else:
config_path = os.path.join(dataset_dir, DATA_CONFIG) config_path = os.path.join(dataset_dir, DATA_CONFIG)
...@@ -116,7 +112,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) - ...@@ -116,7 +112,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_info = None dataset_info = None
dataset_list: List["DatasetAttr"] = [] dataset_list: list[DatasetAttr] = []
for name in dataset_names: for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE if dataset_info is None: # dataset_dir is ONLINE
if use_modelscope(): if use_modelscope():
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .feedback import FeedbackDatasetProcessor from .feedback import FeedbackDatasetProcessor
from .pairwise import PairwiseDatasetProcessor from .pairwise import PairwiseDatasetProcessor
from .pretrain import PretrainDatasetProcessor from .pretrain import PretrainDatasetProcessor
...@@ -9,9 +23,9 @@ from .unsupervised import UnsupervisedDatasetProcessor ...@@ -9,9 +23,9 @@ from .unsupervised import UnsupervisedDatasetProcessor
__all__ = [ __all__ = [
"DatasetProcessor", "DatasetProcessor",
"FeedbackDatasetProcessor", "FeedbackDatasetProcessor",
"PackedSupervisedDatasetProcessor",
"PairwiseDatasetProcessor", "PairwiseDatasetProcessor",
"PretrainDatasetProcessor", "PretrainDatasetProcessor",
"PackedSupervisedDatasetProcessor",
"SupervisedDatasetProcessor", "SupervisedDatasetProcessor",
"UnsupervisedDatasetProcessor", "UnsupervisedDatasetProcessor",
] ]
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
...@@ -30,15 +30,15 @@ logger = logging.get_logger(__name__) ...@@ -30,15 +30,15 @@ logger = logging.get_logger(__name__)
class FeedbackDatasetProcessor(DatasetProcessor): class FeedbackDatasetProcessor(DatasetProcessor):
def _encode_data_example( def _encode_data_example(
self, self,
prompt: Sequence[Dict[str, str]], prompt: list[dict[str, str]],
response: Sequence[Dict[str, str]], response: list[dict[str, str]],
kl_response: Sequence[Dict[str, str]], kl_response: list[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int], bool]: ) -> tuple[list[int], list[int], list[int], list[int], bool]:
if response[0]["content"]: # desired example if response[0]["content"]: # desired example
kto_tag = True kto_tag = True
messages = prompt + [response[0]] messages = prompt + [response[0]]
...@@ -82,9 +82,9 @@ class FeedbackDatasetProcessor(DatasetProcessor): ...@@ -82,9 +82,9 @@ class FeedbackDatasetProcessor(DatasetProcessor):
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag return input_ids, labels, kl_input_ids, kl_labels, kto_tag
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs # Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions.
kl_response = examples["_response"][::-1] kl_response = [examples["_response"][-1]] + examples["_response"][:-1]
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
...@@ -121,7 +121,7 @@ class FeedbackDatasetProcessor(DatasetProcessor): ...@@ -121,7 +121,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
return model_inputs return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None: def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
...@@ -30,14 +30,14 @@ logger = logging.get_logger(__name__) ...@@ -30,14 +30,14 @@ logger = logging.get_logger(__name__)
class PairwiseDatasetProcessor(DatasetProcessor): class PairwiseDatasetProcessor(DatasetProcessor):
def _encode_data_example( def _encode_data_example(
self, self,
prompt: Sequence[Dict[str, str]], prompt: list[dict[str, str]],
response: Sequence[Dict[str, str]], response: list[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int]]: ) -> tuple[list[int], list[int], list[int], list[int]]:
chosen_messages = self.template.mm_plugin.process_messages( chosen_messages = self.template.mm_plugin.process_messages(
prompt + [response[0]], images, videos, audios, self.processor prompt + [response[0]], images, videos, audios, self.processor
) )
...@@ -68,7 +68,7 @@ class PairwiseDatasetProcessor(DatasetProcessor): ...@@ -68,7 +68,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>` # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
...@@ -99,7 +99,7 @@ class PairwiseDatasetProcessor(DatasetProcessor): ...@@ -99,7 +99,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
return model_inputs return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None: def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"])) valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"])) valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"])) print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
...@@ -17,14 +17,14 @@ ...@@ -17,14 +17,14 @@
from dataclasses import dataclass from dataclasses import dataclass
from itertools import chain from itertools import chain
from typing import Any, Dict, List from typing import Any
from .processor_utils import DatasetProcessor from .processor_utils import DatasetProcessor
@dataclass @dataclass
class PretrainDatasetProcessor(DatasetProcessor): class PretrainDatasetProcessor(DatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]] text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
...@@ -52,6 +52,6 @@ class PretrainDatasetProcessor(DatasetProcessor): ...@@ -52,6 +52,6 @@ class PretrainDatasetProcessor(DatasetProcessor):
return result return result
def print_data_example(self, example: Dict[str, List[int]]) -> None: def print_data_example(self, example: dict[str, list[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import bisect import bisect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -27,9 +27,7 @@ if TYPE_CHECKING: ...@@ -27,9 +27,7 @@ if TYPE_CHECKING:
@dataclass @dataclass
class DatasetProcessor(ABC): class DatasetProcessor(ABC):
r""" r"""A class for data processors."""
A class for data processors.
"""
template: "Template" template: "Template"
tokenizer: "PreTrainedTokenizer" tokenizer: "PreTrainedTokenizer"
...@@ -37,32 +35,24 @@ class DatasetProcessor(ABC): ...@@ -37,32 +35,24 @@ class DatasetProcessor(ABC):
data_args: "DataArguments" data_args: "DataArguments"
@abstractmethod @abstractmethod
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
r""" r"""Build model inputs from the examples."""
Builds model inputs from the examples.
"""
... ...
@abstractmethod @abstractmethod
def print_data_example(self, example: Dict[str, List[int]]) -> None: def print_data_example(self, example: dict[str, list[int]]) -> None:
r""" r"""Print a data example to stdout."""
Print a data example to stdout.
"""
... ...
def search_for_fit(numbers: Sequence[int], capacity: int) -> int: def search_for_fit(numbers: list[int], capacity: int) -> int:
r""" r"""Find the index of largest number that fits into the knapsack with the given capacity."""
Finds the index of largest number that fits into the knapsack with the given capacity.
"""
index = bisect.bisect(numbers, capacity) index = bisect.bisect(numbers, capacity)
return -1 if index == 0 else (index - 1) return -1 if index == 0 else (index - 1)
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: def greedy_knapsack(numbers: list[int], capacity: int) -> list[list[int]]:
r""" r"""Implement efficient greedy algorithm with binary search for the knapsack problem."""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
numbers.sort() # sort numbers in ascending order for binary search numbers.sort() # sort numbers in ascending order for binary search
knapsacks = [] knapsacks = []
...@@ -83,10 +73,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: ...@@ -83,10 +73,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
return knapsacks return knapsacks
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> tuple[int, int]:
r""" r"""Compute the real sequence length after truncation by the cutoff_len."""
Computes the real sequence length after truncation by the cutoff_len.
"""
if target_len * 2 < cutoff_len: # truncate source if target_len * 2 < cutoff_len: # truncate source
max_target_len = cutoff_len max_target_len = cutoff_len
elif source_len * 2 < cutoff_len: # truncate target elif source_len * 2 < cutoff_len: # truncate target
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
...@@ -32,14 +32,14 @@ logger = logging.get_logger(__name__) ...@@ -32,14 +32,14 @@ logger = logging.get_logger(__name__)
class SupervisedDatasetProcessor(DatasetProcessor): class SupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example( def _encode_data_example(
self, self,
prompt: Sequence[Dict[str, str]], prompt: list[dict[str, str]],
response: Sequence[Dict[str, str]], response: list[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
) -> Tuple[List[int], List[int]]: ) -> tuple[list[int], list[int]]:
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor) messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
input_ids, labels = self.template.mm_plugin.process_token_ids( input_ids, labels = self.template.mm_plugin.process_token_ids(
[], [], images, videos, audios, self.tokenizer, self.processor [], [], images, videos, audios, self.tokenizer, self.processor
...@@ -85,7 +85,7 @@ class SupervisedDatasetProcessor(DatasetProcessor): ...@@ -85,7 +85,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
return input_ids, labels return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>` # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair. # for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
...@@ -114,7 +114,7 @@ class SupervisedDatasetProcessor(DatasetProcessor): ...@@ -114,7 +114,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
return model_inputs return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None: def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
...@@ -124,7 +124,7 @@ class SupervisedDatasetProcessor(DatasetProcessor): ...@@ -124,7 +124,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
@dataclass @dataclass
class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor): class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# TODO: use `position_ids` to achieve packing # TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>` # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>` # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
...@@ -165,7 +165,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor): ...@@ -165,7 +165,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len) knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
for knapsack in knapsacks: for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], [] packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos, packed_audios = [], [], [] packed_images, packed_videos, packed_audios, packed_position_ids = [], [], [], []
for i, length in enumerate(knapsack): for i, length in enumerate(knapsack):
index = length2indexes[length].pop() index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index] packed_input_ids += batch_input_ids[index]
...@@ -175,6 +175,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor): ...@@ -175,6 +175,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
packed_audios += batch_audios[index] packed_audios += batch_audios[index]
if self.data_args.neat_packing: if self.data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
packed_position_ids += list(range(len(batch_input_ids[index])))
else: else:
packed_attention_masks += [1] * len(batch_input_ids[index]) packed_attention_masks += [1] * len(batch_input_ids[index])
...@@ -184,6 +185,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor): ...@@ -184,6 +185,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
packed_labels += [IGNORE_INDEX] * pad_length packed_labels += [IGNORE_INDEX] * pad_length
if self.data_args.neat_packing: if self.data_args.neat_packing:
packed_attention_masks += [0] * pad_length packed_attention_masks += [0] * pad_length
packed_position_ids += [0] * pad_length
else: else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn packed_attention_masks += [1] * pad_length # more efficient flash_attn
...@@ -196,5 +198,6 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor): ...@@ -196,5 +198,6 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
model_inputs["images"].append(packed_images or None) model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None) model_inputs["videos"].append(packed_videos or None)
model_inputs["audios"].append(packed_audios or None) model_inputs["audios"].append(packed_audios or None)
model_inputs["position_ids"].append(packed_position_ids or None)
return model_inputs return model_inputs
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
from ..data_utils import Role from ..data_utils import Role
...@@ -30,14 +30,14 @@ logger = logging.get_logger(__name__) ...@@ -30,14 +30,14 @@ logger = logging.get_logger(__name__)
class UnsupervisedDatasetProcessor(DatasetProcessor): class UnsupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example( def _encode_data_example(
self, self,
prompt: Sequence[Dict[str, str]], prompt: list[dict[str, str]],
response: Sequence[Dict[str, str]], response: list[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
) -> Tuple[List[int], List[int]]: ) -> tuple[list[int], list[int]]:
if len(response) == 1: if len(response) == 1:
messages = prompt + response messages = prompt + response
else: else:
...@@ -56,7 +56,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor): ...@@ -56,7 +56,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
labels = labels[:target_len] labels = labels[:target_len]
return input_ids, labels return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>` # build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
...@@ -84,7 +84,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor): ...@@ -84,7 +84,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
return model_inputs return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None: def print_data_example(self, example: dict[str, list[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"])) print("label_ids:\n{}".format(example["labels"]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment