Commit 7ea81099 authored by chenych's avatar chenych
Browse files

update llama4

parent 84987715
# Copyright 2024 THUDM and the LlamaFactory team.
# Copyright 2025 THUDM and the LlamaFactory team.
#
# This code is inspired by the THUDM's ChatGLM implementation.
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
......@@ -17,13 +17,15 @@
import asyncio
import os
from collections.abc import AsyncGenerator, Generator
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
from typing import TYPE_CHECKING, Any, Optional
from ..extras.constants import EngineName
from ..extras.misc import torch_gc
from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine
from .sglang_engine import SGLangEngine
from .vllm_engine import VllmEngine
......@@ -38,20 +40,21 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
class ChatModel:
r"""
General class for chat models. Backed by huggingface or vllm engines.
r"""General class for chat models. Backed by huggingface or vllm engines.
Supports both sync and async methods.
Sync methods: chat(), stream_chat() and get_scores().
Async methods: achat(), astream_chat() and aget_scores().
"""
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
if model_args.infer_backend == EngineName.HF:
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == EngineName.VLLM:
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == EngineName.SGLANG:
self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
else:
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
......@@ -61,17 +64,15 @@ class ChatModel:
def chat(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
) -> list["Response"]:
r"""Get a list of responses of the chat model."""
task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
)
......@@ -79,32 +80,28 @@ class ChatModel:
async def achat(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Asynchronously gets a list of responses of the chat model.
"""
) -> list["Response"]:
r"""Asynchronously get a list of responses of the chat model."""
return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
def stream_chat(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> Generator[str, None, None]:
r"""
Gets the response token-by-token of the chat model.
"""
r"""Get the response token-by-token of the chat model."""
generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
while True:
try:
......@@ -115,17 +112,15 @@ class ChatModel:
async def astream_chat(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
r"""
Asynchronously gets the response token-by-token of the chat model.
"""
r"""Asynchronously get the response token-by-token of the chat model."""
async for new_token in self.engine.stream_chat(
messages, system, tools, images, videos, audios, **input_kwargs
):
......@@ -133,23 +128,19 @@ class ChatModel:
def get_scores(
self,
batch_input: List[str],
batch_input: list[str],
**input_kwargs,
) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
) -> list[float]:
r"""Get a list of scores of the reward model."""
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
return task.result()
async def aget_scores(
self,
batch_input: List[str],
batch_input: list[str],
**input_kwargs,
) -> List[float]:
r"""
Asynchronously gets a list of scores of the reward model.
"""
) -> list[float]:
r"""Asynchronously get a list of scores of the reward model."""
return await self.engine.get_scores(batch_input, **input_kwargs)
......
......@@ -13,10 +13,10 @@
# limitations under the License.
import asyncio
import concurrent.futures
import os
from collections.abc import AsyncGenerator
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
from transformers import GenerationConfig, TextIteratorStreamer
......@@ -76,15 +76,15 @@ class HuggingfaceEngine(BaseEngine):
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
generating_args: dict[str, Any],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
input_kwargs: Optional[dict[str, Any]] = {},
) -> tuple[dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]})
......@@ -130,7 +130,7 @@ class HuggingfaceEngine(BaseEngine):
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
if stop is not None:
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
......@@ -217,15 +217,15 @@ class HuggingfaceEngine(BaseEngine):
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
generating_args: dict[str, Any],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
input_kwargs: Optional[dict[str, Any]] = {},
) -> list["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model,
tokenizer,
......@@ -272,14 +272,14 @@ class HuggingfaceEngine(BaseEngine):
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
generating_args: dict[str, Any],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
input_kwargs: Optional[dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
model,
......@@ -317,12 +317,12 @@ class HuggingfaceEngine(BaseEngine):
def _get_scores(
model: "PreTrainedModelWrapper",
tokenizer: "PreTrainedTokenizer",
batch_input: List[str],
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List[float]:
batch_input: list[str],
input_kwargs: Optional[dict[str, Any]] = {},
) -> list[float]:
max_length: Optional[int] = input_kwargs.pop("max_length", None)
device = getattr(model.pretrained_model, "device", "cuda")
inputs: Dict[str, "torch.Tensor"] = tokenizer(
inputs: dict[str, torch.Tensor] = tokenizer(
batch_input,
padding=True,
truncation=True,
......@@ -330,25 +330,24 @@ class HuggingfaceEngine(BaseEngine):
return_tensors="pt",
add_special_tokens=False,
).to(device)
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return scores
@override
async def chat(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
) -> list["Response"]:
if not self.can_generate:
raise ValueError("The current model does not support `chat`.")
loop = asyncio.get_running_loop()
input_args = (
self.model,
self.tokenizer,
......@@ -364,24 +363,22 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs,
)
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._chat, *input_args)
return await asyncio.to_thread(self._chat, *input_args)
@override
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
raise ValueError("The current model does not support `stream_chat`.")
loop = asyncio.get_running_loop()
input_args = (
self.model,
self.tokenizer,
......@@ -397,25 +394,22 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs,
)
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
stream = self._stream_chat(*input_args)
while True:
try:
yield await loop.run_in_executor(pool, stream)
yield await asyncio.to_thread(stream)
except StopAsyncIteration:
break
@override
async def get_scores(
self,
batch_input: List[str],
batch_input: list[str],
**input_kwargs,
) -> List[float]:
) -> list[float]:
if self.can_generate:
raise ValueError("Cannot get scores using an auto-regressive model.")
loop = asyncio.get_running_loop()
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._get_scores, *input_args)
return await asyncio.to_thread(self._get_scores, *input_args)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import atexit
import json
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
import requests
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
from ..extras.misc import get_device_count, torch_gc
from ..extras.packages import is_sglang_available
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod
from .base_engine import BaseEngine, Response
if is_sglang_available():
from sglang.utils import launch_server_cmd, terminate_process, wait_for_server # type: ignore
if TYPE_CHECKING:
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
logger = logging.get_logger(__name__)
class SGLangEngine(BaseEngine):
"""Inference engine for SGLang models.
This class wraps the SGLang engine to provide a consistent interface for text generation
that matches LLaMA Factory's requirements. It uses the SGLang HTTP server approach for
better interaction and performance. The engine launches a server process and communicates
with it via HTTP requests.
For more details on the SGLang HTTP server approach, see:
https://docs.sglang.ai/backend/send_request.html
"""
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.name = EngineName.SGLANG
self.model_args = model_args
config = load_config(model_args) # may download model from ms hub
if getattr(config, "quantization_config", None): # gptq models should use float16
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
model_args.infer_dtype = "float16"
self.can_generate = finetuning_args.stage == "sft"
tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
self.generating_args = generating_args.to_dict()
launch_cmd = [
"python3 -m sglang.launch_server",
f"--model-path {model_args.model_name_or_path}",
f"--dtype {model_args.infer_dtype}",
f"--context-length {model_args.sglang_maxlen}",
f"--mem-fraction-static {model_args.sglang_mem_fraction}",
f"--tp-size {model_args.sglang_tp_size if model_args.sglang_tp_size != -1 else get_device_count() or 1}",
f"--download-dir {model_args.cache_dir}",
"--log-level error",
]
launch_cmd = " ".join(launch_cmd)
logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
try:
torch_gc()
self.server_process, port = launch_server_cmd(launch_cmd)
self.base_url = f"http://localhost:{port}"
atexit.register(self._cleanup_server)
logger.info_rank0(f"Waiting for SGLang server to be ready at {self.base_url}")
wait_for_server(self.base_url, timeout=300)
logger.info_rank0(f"SGLang server initialized successfully at {self.base_url}")
try:
response = requests.get(f"{self.base_url}/get_model_info", timeout=5)
if response.status_code == 200:
model_info = response.json()
logger.info(f"SGLang server model info: {model_info}")
except Exception as e:
logger.debug(f"Note: could not get model info: {str(e)}")
except Exception as e:
logger.error(f"Failed to start SGLang server: {str(e)}")
self._cleanup_server() # make sure to clean up any started process
raise RuntimeError(f"SGLang server initialization failed: {str(e)}.")
def _cleanup_server(self):
r"""Clean up the server process when the engine is destroyed."""
if hasattr(self, "server_process") and self.server_process:
try:
logger.info("Terminating SGLang server process")
terminate_process(self.server_process)
logger.info("SGLang server process terminated")
except Exception as e:
logger.warning(f"Error terminating SGLang server: {str(e)}")
async def _generate(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncIterator[dict[str, Any]]:
if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
messages = self.template.mm_plugin.process_messages(
messages, images or [], videos or [], audios or [], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
if num_return_sequences != 1:
raise NotImplementedError("SGLang only supports n=1.")
if "max_new_tokens" in self.generating_args:
max_tokens = self.generating_args["max_new_tokens"]
elif "max_length" in self.generating_args:
if self.generating_args["max_length"] > prompt_length:
max_tokens = self.generating_args["max_length"] - prompt_length
else:
max_tokens = 1
if max_length:
max_tokens = max_length - prompt_length if max_length > prompt_length else 1
if max_new_tokens:
max_tokens = max_new_tokens
sampling_params = {
"temperature": temperature if temperature is not None else self.generating_args["temperature"],
"top_p": (top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
"top_k": (top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0
"stop": stop,
"stop_token_ids": self.template.get_stop_token_ids(self.tokenizer),
"max_new_tokens": max_tokens,
"repetition_penalty": (
repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
)
or 1.0, # repetition_penalty must > 0
"skip_special_tokens": skip_special_tokens
if skip_special_tokens is not None
else self.generating_args["skip_special_tokens"],
}
def stream_request():
json_data = {
"input_ids": prompt_ids,
"sampling_params": sampling_params,
"stream": True,
}
response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
if response.status_code != 200:
raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")
for chunk in response.iter_lines(decode_unicode=False):
chunk = str(chunk.decode("utf-8"))
if chunk == "data: [DONE]":
break
if chunk and chunk.startswith("data:"):
yield json.loads(chunk[5:].strip("\n"))
return await asyncio.to_thread(stream_request)
@override
async def chat(
self,
messages: Sequence[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> list["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
for request_output in generator:
final_output = request_output
results = [
Response(
response_text=final_output["text"],
response_length=final_output["meta_info"]["completion_tokens"],
prompt_length=final_output["meta_info"]["prompt_tokens"],
finish_reason="stop" if final_output["meta_info"]["finish_reason"] == "stop" else "length",
)
]
return results
@override
async def stream_chat(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
for result in generator:
delta_text = result["text"][len(generated_text) :]
generated_text = result["text"]
yield delta_text
@override
async def get_scores(
self,
batch_input: list[str],
**input_kwargs,
) -> list[float]:
raise NotImplementedError("SGLang engine does not support `get_scores`.")
def __del__(self):
r"""Ensure server is cleaned up when object is deleted."""
self._cleanup_server()
try:
atexit.unregister(self._cleanup_server)
except Exception:
pass
......@@ -13,7 +13,8 @@
# limitations under the License.
import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from collections.abc import AsyncGenerator, AsyncIterator
from typing import TYPE_CHECKING, Any, Optional, Union
from typing_extensions import override
......@@ -53,7 +54,7 @@ class VllmEngine(BaseEngine):
self.model_args = model_args
config = load_config(model_args) # may download model from ms hub
if getattr(config, "quantization_config", None): # gptq models should use float16
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
model_args.infer_dtype = "float16"
......@@ -82,7 +83,7 @@ class VllmEngine(BaseEngine):
"max_lora_rank": model_args.vllm_max_lora_rank,
}
if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config)
......@@ -101,33 +102,26 @@ class VllmEngine(BaseEngine):
async def _generate(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = f"chatcmpl-{uuid.uuid4().hex}"
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]})
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if videos is not None:
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
if audios is not None:
mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
messages = self.template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], self.processor
messages, images or [], videos or [], audios or [], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
......@@ -143,7 +137,7 @@ class VllmEngine(BaseEngine):
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
if length_penalty is not None:
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
......@@ -185,8 +179,24 @@ class VllmEngine(BaseEngine):
images,
image_max_pixels=self.model_args.image_max_pixels,
image_min_pixels=self.model_args.image_min_pixels,
)
)["images"]
}
elif videos is not None:
multi_modal_data = {
"video": self.template.mm_plugin._regularize_videos(
videos,
image_max_pixels=self.model_args.video_max_pixels,
image_min_pixels=self.model_args.video_min_pixels,
video_fps=self.model_args.video_fps,
video_maxlen=self.model_args.video_maxlen,
)["videos"]
}
elif audios is not None:
audio_data = self.template.mm_plugin._regularize_audios(
audios,
sampling_rate=self.model_args.audio_sampling_rate,
)
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
else:
multi_modal_data = None
......@@ -201,14 +211,14 @@ class VllmEngine(BaseEngine):
@override
async def chat(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
) -> list["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
async for request_output in generator:
......@@ -230,12 +240,12 @@ class VllmEngine(BaseEngine):
@override
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
......@@ -248,7 +258,7 @@ class VllmEngine(BaseEngine):
@override
async def get_scores(
self,
batch_input: List[str],
batch_input: list[str],
**input_kwargs,
) -> List[float]:
raise NotImplementedError("vLLM engine does not support get_scores.")
) -> list[float]:
raise NotImplementedError("vLLM engine does not support `get_scores`.")
......@@ -13,7 +13,6 @@
# limitations under the License.
import os
import random
import subprocess
import sys
from enum import Enum, unique
......@@ -24,7 +23,7 @@ from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import get_device_count, is_env_enabled, use_ray
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
......@@ -92,7 +91,7 @@ def main():
node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
if int(nnodes) > 1:
print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
......
......@@ -24,14 +24,14 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [
"TEMPLATES",
"KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask",
"Role",
"split_dataset",
"get_dataset",
"TEMPLATES",
"SFTDataCollatorWith4DAttentionMask",
"Template",
"get_dataset",
"get_template_and_fix_tokenizer",
"split_dataset",
]
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
# Copyright 2025 OpenAccess AI Collective and the LlamaFactory team.
#
# This code is inspired by the OpenAccess AI Collective's axolotl library.
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
......@@ -16,7 +16,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
from typing import TYPE_CHECKING, Any, Literal, Optional
import numpy as np
import torch
......@@ -24,6 +24,7 @@ import torch.nn.functional as F
from transformers import DataCollatorForSeq2Seq
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
from ..extras.misc import get_current_device
from ..extras.packages import is_pillow_available
......@@ -38,9 +39,10 @@ if TYPE_CHECKING:
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
r"""Expand 2d attention mask to 4d attention mask.
Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g.
```python
......@@ -62,24 +64,37 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
```
where `o` equals to `0.0`, `x` equals to `min_dtype`.
"""
bsz, seq_len = attention_mask_with_indices.size()
_, seq_len = attention_mask_with_indices.size()
# Move to compute device if the source is CPU.
source_device = attention_mask_with_indices.device
compute_device = get_current_device() if source_device.type == "cpu" else source_device
if compute_device != source_device:
attention_mask_with_indices = attention_mask_with_indices.to(compute_device)
min_dtype = torch.finfo(dtype).min
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
padding_mask = torch.where(expanded_mask != 0, 1, 0)
# Create a block-diagonal mask.
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
# Use the lower triangular mask to zero out the upper triangular part
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
zero_tensor = torch.tensor(0, dtype=dtype, device=compute_device)
# Create a non-padding mask.
non_padding = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
# Create indices for comparison.
indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, seq_len]
indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3) # [bsz, 1, seq_len, 1]
# Create a lower triangular mask.
tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=compute_device))
attention_mask_4d = (indices == indices_t) & non_padding & tril_mask
# Invert the attention mask.
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype)
# Move back to original device if needed.
if compute_device != source_device:
attention_mask_4d = attention_mask_4d.to(source_device)
return attention_mask_4d
@dataclass
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r"""
Data collator that supports VLMs.
r"""Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
"""
......@@ -91,7 +106,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if self.template is None:
raise ValueError("Template is required for MultiModalDataCollator.")
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_audios = [], [], []
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
for feature in features:
......@@ -166,7 +181,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
for i, feature in enumerate(features):
feature["token_type_ids"] = token_type_ids[i]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
features: dict[str, torch.Tensor] = super().__call__(features)
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
rope_index_kwargs = {
......@@ -175,9 +190,27 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": features["attention_mask"],
}
if "second_per_grid_ts" in mm_inputs:
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
if "video_second_per_grid" in mm_inputs: # for qwen2omni
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(
feature_attention_mask, dim=1
) # FIXME need to get video image lengths
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
# avoid conflict
new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
features["position_ids"], features["rope_deltas"] = (
new_position_ids.clone(),
rope_deltas - delta0,
) # avoid inplace operation FIXME
else: # for qwen2vl
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
......@@ -198,15 +231,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
@dataclass
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for 4d attention mask.
"""
r"""Data collator for 4d attention mask."""
block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
......@@ -220,13 +251,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
@dataclass
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
r"""Data collator for pairwise data."""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
r"""
Pads batched data to the longest sequence in the batch.
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
r"""Pad batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
......@@ -249,11 +277,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
@dataclass
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""
r"""Data collator for KTO data."""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
target_features = []
kl_features = []
kto_tags = []
......
......@@ -15,7 +15,7 @@
import os
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from ..extras import logging
from .data_utils import Role
......@@ -26,8 +26,12 @@ if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments
from .mm_plugin import AudioInput, ImageInput, VideoInput
from .parser import DatasetAttr
MediaType = Union[ImageInput, VideoInput, AudioInput]
logger = logging.get_logger(__name__)
......@@ -36,12 +40,12 @@ class DatasetConverter:
dataset_attr: "DatasetAttr"
data_args: "DataArguments"
def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[List[Any]]:
r"""
Optionally concatenates media path to media dir when loading from local disk.
"""
if not isinstance(medias, list):
medias = [medias] if medias is not None else []
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]:
r"""Optionally concatenate media path to media dir when loading from local disk."""
if medias is None:
return None
elif not isinstance(medias, list):
medias = [medias]
elif len(medias) == 0:
return None
else:
......@@ -57,16 +61,14 @@ class DatasetConverter:
return medias
@abstractmethod
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
r"""
Converts a single example in the dataset to the standard format.
"""
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
r"""Convert a single example in the dataset to the standard format."""
...
@dataclass
class AlpacaDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
prompt = []
if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list):
for old_prompt, old_response in example[self.dataset_attr.history]:
......@@ -116,7 +118,7 @@ class AlpacaDatasetConverter(DatasetConverter):
@dataclass
class SharegptDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
tag_mapping = {
self.dataset_attr.user_tag: Role.USER.value,
self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
......@@ -216,10 +218,8 @@ DATASET_CONVERTERS = {
}
def register_dataset_converter(name: str, dataset_converter: Type["DatasetConverter"]) -> None:
r"""
Register a new dataset converter.
"""
def register_dataset_converter(name: str, dataset_converter: type["DatasetConverter"]) -> None:
r"""Register a new dataset converter."""
if name in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} already exists.")
......@@ -227,9 +227,7 @@ def register_dataset_converter(name: str, dataset_converter: Type["DatasetConver
def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter":
r"""
Gets a dataset converter.
"""
r"""Get a dataset converter."""
if name not in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} not found.")
......@@ -242,17 +240,17 @@ def align_dataset(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
r"""Align the dataset to a specific format.
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_tools: "...",
_images: [],
_videos: [],
_audios: [],
_tools: "..."
_images: []
_videos: []
_audios: []
"""
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union
from typing import TYPE_CHECKING, Optional, TypedDict, Union
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
......@@ -29,7 +29,7 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
SLOTS = list[Union[str, set[str], dict[str, str]]]
@unique
......@@ -43,15 +43,13 @@ class Role(str, Enum):
class DatasetModule(TypedDict):
train_dataset: Optional[Union["Dataset", "IterableDataset"]]
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
all_datasets: list[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
) -> Union["Dataset", "IterableDataset"]:
r"""
Merges multiple datasets to a unified dataset.
"""
r"""Merge multiple datasets to a unified dataset."""
if len(all_datasets) == 1:
return all_datasets[0]
......@@ -78,14 +76,13 @@ def merge_dataset(
def split_dataset(
dataset: Optional[Union["Dataset", "IterableDataset"]],
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]],
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
data_args: "DataArguments",
seed: int,
) -> "DatasetDict":
r"""
Splits the dataset and returns a dataset dict containing train set and validation set.
r"""Split the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset.
Support both map dataset and iterable dataset.
"""
if eval_dataset is not None and data_args.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
......@@ -120,10 +117,8 @@ def split_dataset(
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
r"""
Converts dataset or dataset dict to dataset module.
"""
dataset_module: "DatasetModule" = {}
r"""Convert dataset or dataset dict to dataset module."""
dataset_module: DatasetModule = {}
if isinstance(dataset, DatasetDict): # dataset dict
if "train" in dataset:
dataset_module["train_dataset"] = dataset["train"]
......
......@@ -16,7 +16,7 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Optional, Union
from typing import Optional, Union
from typing_extensions import override
......@@ -31,14 +31,11 @@ class Formatter(ABC):
@abstractmethod
def apply(self, **kwargs) -> SLOTS:
r"""
Forms a list of slots according to the inputs to encode.
"""
r"""Forms a list of slots according to the inputs to encode."""
...
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extract a list of tuples from the response message if using tools.
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
r"""Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments.
"""
......@@ -105,7 +102,7 @@ class FunctionFormatter(StringFormatter):
if thought:
content = content.replace(thought.group(0), "")
functions: List["FunctionCall"] = []
functions: list[FunctionCall] = []
try:
tool_calls = json.loads(content)
if not isinstance(tool_calls, list): # parallel function call
......@@ -141,5 +138,5 @@ class ToolFormatter(Formatter):
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
@override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
return self.tool_utils.tool_extractor(content)
......@@ -13,7 +13,7 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np
from datasets import load_dataset, load_from_disk
......@@ -54,9 +54,7 @@ def _load_single_dataset(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
Loads a single dataset and aligns it to the standard format.
"""
r"""Load a single dataset and aligns it to the standard format."""
logger.info_rank0(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
......@@ -133,10 +131,12 @@ def _load_single_dataset(
split=dataset_attr.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=data_args.streaming,
num_proc=data_args.preprocessing_num_workers,
trust_remote_code=model_args.trust_remote_code,
streaming=data_args.streaming and dataset_attr.load_from != "file",
)
if data_args.streaming and dataset_attr.load_from == "file":
dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples
......@@ -158,16 +158,14 @@ def _load_single_dataset(
def _get_merged_dataset(
dataset_names: Optional[Sequence[str]],
dataset_names: Optional[list[str]],
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
merge: bool = True,
) -> Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]:
r"""
Returns the merged datasets in the standard format.
"""
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
r"""Return the merged datasets in the standard format."""
if dataset_names is None:
return None
......@@ -192,9 +190,7 @@ def _get_dataset_processor(
processor: Optional["ProcessorMixin"],
do_generate: bool = False,
) -> "DatasetProcessor":
r"""
Returns the corresponding dataset processor.
"""
r"""Return the corresponding dataset processor."""
if stage == "pt":
dataset_processor_class = PretrainDatasetProcessor
elif stage == "sft" and not do_generate:
......@@ -236,9 +232,7 @@ def _get_preprocessed_dataset(
processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Preprocesses the dataset, including format checking and tokenization.
"""
r"""Preprocesses the dataset, including format checking and tokenization."""
if dataset is None:
return None
......@@ -284,9 +278,7 @@ def get_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> "DatasetModule":
r"""
Gets the train dataset and optionally gets the evaluation dataset.
"""
r"""Get the train dataset and optionally gets the evaluation dataset."""
# Load tokenized dataset if path exists
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
......
This diff is collapsed.
......@@ -15,9 +15,9 @@
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Sequence
from typing import Any, Literal, Optional
from transformers.utils import cached_file
from huggingface_hub import hf_hub_download
from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope, use_openmind
......@@ -25,9 +25,7 @@ from ..extras.misc import use_modelscope, use_openmind
@dataclass
class DatasetAttr:
r"""
Dataset attributes.
"""
r"""Dataset attributes."""
# basic configs
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
......@@ -68,10 +66,10 @@ class DatasetAttr:
def __repr__(self) -> str:
return self.dataset_name
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
setattr(self, key, obj.get(key, default))
def join(self, attr: Dict[str, Any]) -> None:
def join(self, attr: dict[str, Any]) -> None:
self.set_attr("formatting", attr, default="alpaca")
self.set_attr("ranking", attr, default=False)
self.set_attr("subset", attr)
......@@ -92,10 +90,8 @@ class DatasetAttr:
self.set_attr(tag, attr["tags"])
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
r"""
Gets the attributes of the datasets.
"""
def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> list["DatasetAttr"]:
r"""Get the attributes of the datasets."""
if dataset_names is None:
dataset_names = []
......@@ -103,7 +99,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_info = None
else:
if dataset_dir.startswith("REMOTE:"):
config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
config_path = hf_hub_download(repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
else:
config_path = os.path.join(dataset_dir, DATA_CONFIG)
......@@ -116,7 +112,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_info = None
dataset_list: List["DatasetAttr"] = []
dataset_list: list[DatasetAttr] = []
for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE
if use_modelscope():
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .feedback import FeedbackDatasetProcessor
from .pairwise import PairwiseDatasetProcessor
from .pretrain import PretrainDatasetProcessor
......@@ -9,9 +23,9 @@ from .unsupervised import UnsupervisedDatasetProcessor
__all__ = [
"DatasetProcessor",
"FeedbackDatasetProcessor",
"PackedSupervisedDatasetProcessor",
"PairwiseDatasetProcessor",
"PretrainDatasetProcessor",
"PackedSupervisedDatasetProcessor",
"SupervisedDatasetProcessor",
"UnsupervisedDatasetProcessor",
]
......@@ -13,7 +13,7 @@
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
......@@ -30,15 +30,15 @@ logger = logging.get_logger(__name__)
class FeedbackDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
kl_response: Sequence[Dict[str, str]],
prompt: list[dict[str, str]],
response: list[dict[str, str]],
kl_response: list[dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
) -> tuple[list[int], list[int], list[int], list[int], bool]:
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
......@@ -82,9 +82,9 @@ class FeedbackDatasetProcessor(DatasetProcessor):
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["_response"][::-1]
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions.
kl_response = [examples["_response"][-1]] + examples["_response"][:-1]
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
......@@ -121,7 +121,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
......@@ -30,14 +30,14 @@ logger = logging.get_logger(__name__)
class PairwiseDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
prompt: list[dict[str, str]],
response: list[dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int]]:
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
) -> tuple[list[int], list[int], list[int], list[int]]:
chosen_messages = self.template.mm_plugin.process_messages(
prompt + [response[0]], images, videos, audios, self.processor
)
......@@ -68,7 +68,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
......@@ -99,7 +99,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
......@@ -17,14 +17,14 @@
from dataclasses import dataclass
from itertools import chain
from typing import Any, Dict, List
from typing import Any
from .processor_utils import DatasetProcessor
@dataclass
class PretrainDatasetProcessor(DatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
......@@ -52,6 +52,6 @@ class PretrainDatasetProcessor(DatasetProcessor):
return result
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
......@@ -15,7 +15,7 @@
import bisect
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
......@@ -27,9 +27,7 @@ if TYPE_CHECKING:
@dataclass
class DatasetProcessor(ABC):
r"""
A class for data processors.
"""
r"""A class for data processors."""
template: "Template"
tokenizer: "PreTrainedTokenizer"
......@@ -37,32 +35,24 @@ class DatasetProcessor(ABC):
data_args: "DataArguments"
@abstractmethod
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
r"""
Builds model inputs from the examples.
"""
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
r"""Build model inputs from the examples."""
...
@abstractmethod
def print_data_example(self, example: Dict[str, List[int]]) -> None:
r"""
Print a data example to stdout.
"""
def print_data_example(self, example: dict[str, list[int]]) -> None:
r"""Print a data example to stdout."""
...
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
r"""
Finds the index of largest number that fits into the knapsack with the given capacity.
"""
def search_for_fit(numbers: list[int], capacity: int) -> int:
r"""Find the index of largest number that fits into the knapsack with the given capacity."""
index = bisect.bisect(numbers, capacity)
return -1 if index == 0 else (index - 1)
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
r"""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
def greedy_knapsack(numbers: list[int], capacity: int) -> list[list[int]]:
r"""Implement efficient greedy algorithm with binary search for the knapsack problem."""
numbers.sort() # sort numbers in ascending order for binary search
knapsacks = []
......@@ -83,10 +73,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
return knapsacks
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
r"""
Computes the real sequence length after truncation by the cutoff_len.
"""
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> tuple[int, int]:
r"""Compute the real sequence length after truncation by the cutoff_len."""
if target_len * 2 < cutoff_len: # truncate source
max_target_len = cutoff_len
elif source_len * 2 < cutoff_len: # truncate target
......
......@@ -14,7 +14,7 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
......@@ -32,14 +32,14 @@ logger = logging.get_logger(__name__)
class SupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
prompt: list[dict[str, str]],
response: list[dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int]]:
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
) -> tuple[list[int], list[int]]:
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
input_ids, labels = self.template.mm_plugin.process_token_ids(
[], [], images, videos, audios, self.tokenizer, self.processor
......@@ -85,7 +85,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = defaultdict(list)
......@@ -114,7 +114,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
......@@ -124,7 +124,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
@dataclass
class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
......@@ -165,7 +165,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos, packed_audios = [], [], []
packed_images, packed_videos, packed_audios, packed_position_ids = [], [], [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
......@@ -175,6 +175,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
packed_audios += batch_audios[index]
if self.data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
packed_position_ids += list(range(len(batch_input_ids[index])))
else:
packed_attention_masks += [1] * len(batch_input_ids[index])
......@@ -184,6 +185,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
packed_labels += [IGNORE_INDEX] * pad_length
if self.data_args.neat_packing:
packed_attention_masks += [0] * pad_length
packed_position_ids += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn
......@@ -196,5 +198,6 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
model_inputs["audios"].append(packed_audios or None)
model_inputs["position_ids"].append(packed_position_ids or None)
return model_inputs
......@@ -13,7 +13,7 @@
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ..data_utils import Role
......@@ -30,14 +30,14 @@ logger = logging.get_logger(__name__)
class UnsupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
prompt: list[dict[str, str]],
response: list[dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int]]:
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
) -> tuple[list[int], list[int]]:
if len(response) == 1:
messages = prompt + response
else:
......@@ -56,7 +56,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
labels = labels[:target_len]
return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
......@@ -84,7 +84,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
def print_data_example(self, example: dict[str, list[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment