Commit ca625f43 authored by shihm's avatar shihm
Browse files

uodata

parent 7164651d
......@@ -24,9 +24,6 @@ from typing import TYPE_CHECKING, Any, Optional
from ..extras.constants import EngineName
from ..extras.misc import torch_gc
from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine
from .sglang_engine import SGLangEngine
from .vllm_engine import VllmEngine
if TYPE_CHECKING:
......@@ -49,12 +46,41 @@ class ChatModel:
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
if model_args.infer_backend == EngineName.HF:
from .hf_engine import HuggingfaceEngine
self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == EngineName.VLLM:
self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
try:
from .vllm_engine import VllmEngine
self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
except ImportError as e:
raise ImportError(
"vLLM not install, you may need to run `pip install vllm`\n"
"or try to use HuggingFace backend: --infer_backend huggingface"
) from e
elif model_args.infer_backend == EngineName.SGLANG:
self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
try:
from .sglang_engine import SGLangEngine
self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
except ImportError as e:
raise ImportError(
"SGLang not install, you may need to run `pip install sglang[all]`\n"
"or try to use HuggingFace backend: --infer_backend huggingface"
) from e
elif model_args.infer_backend == EngineName.KT:
try:
from .kt_engine import KTransformersEngine
self.engine: BaseEngine = KTransformersEngine(model_args, data_args, finetuning_args, generating_args)
except ImportError as e:
raise ImportError(
"KTransformers not install, you may need to run `pip install ktransformers`\n"
"or try to use HuggingFace backend: --infer_backend huggingface"
) from e
else:
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
......
......@@ -14,9 +14,9 @@
import asyncio
import os
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable
from threading import Thread
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
from transformers import GenerationConfig, TextIteratorStreamer
......
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
import platform
from collections.abc import AsyncGenerator
from threading import Thread
from typing import TYPE_CHECKING, Any, Optional
import torch
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import EngineName
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from trl import PreTrainedModelWrapper
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
from ktransformers.server.config.config import Config
from ktransformers.util.utils import (
get_compute_capability,
prefill_and_generate_capture,
)
from ktransformers.util.vendors import GPUVendor, device_manager
logger = logging.get_logger(__name__)
class KTransformersEngine(BaseEngine):
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.name = EngineName.KT
self.can_generate = finetuning_args.stage == "sft"
tok_mod = load_tokenizer(model_args)
self.tokenizer = tok_mod["tokenizer"]
self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
self.generating_args = generating_args.to_dict()
self.max_new_tokens = model_args.kt_maxlen
self.use_cuda_graph = model_args.kt_use_cuda_graph
self.mode = model_args.kt_mode
self.force_think = model_args.kt_force_think
self.chunk_size = model_args.chunk_size
try:
asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@staticmethod
@torch.inference_mode()
def _get_scores(
model: "PreTrainedModelWrapper",
tokenizer: "PreTrainedTokenizer",
batch_input: list[str],
input_kwargs: Optional[dict[str, Any]] = {},
) -> list[float]:
max_length: Optional[int] = input_kwargs.pop("max_length", None)
device = getattr(model.pretrained_model, "device", "cuda")
inputs = tokenizer(
batch_input,
padding=True,
truncation=True,
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
return_tensors="pt",
add_special_tokens=False,
).to(device)
values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return scores
async def _generate(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
paired = messages + [{"role": "assistant", "content": ""}]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired, system, tools)
prompt_len = len(prompt_ids)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
if "max_new_tokens" in self.generating_args:
max_tokens = int(self.generating_args["max_new_tokens"])
elif "max_length" in self.generating_args:
gl = int(self.generating_args["max_length"])
max_tokens = gl - prompt_len if gl > prompt_len else 1
else:
max_tokens = self.max_new_tokens or 256
if max_length is not None:
max_tokens = max(max_length - prompt_len, 1)
if max_new_tokens is not None:
max_tokens = int(max_new_tokens)
max_tokens = max(1, int(max_tokens))
if self.mode == "long_context":
max_len_cfg = Config().long_context_config["max_seq_len"]
need = prompt_len + max_tokens
assert max_len_cfg > need, f"please set max_seq_len > {need} in ~/.ktransformers/config.yaml"
device = next(self.model.parameters()).device
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
if self.force_think:
think = torch.tensor(
[self.tokenizer.encode("<think>\n", add_special_tokens=False)], dtype=torch.long, device=device
)
input_tensor = torch.cat([input_tensor, think], dim=1)
use_flashinfer = (
platform.system() != "Windows"
and getattr(self.model.config, "architectures", [""])[0]
in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
and flashinfer_enabled
and get_compute_capability() >= 8
and device_manager.gpu_vendor == GPUVendor.NVIDIA
)
def make_gen():
if use_flashinfer:
return prefill_and_generate_capture(
self.model,
self.tokenizer,
input_tensor,
max_tokens,
self.use_cuda_graph,
mode=self.mode,
force_think=self.force_think,
chunk_size=self.chunk_size,
use_flashinfer_mla=True,
num_heads=self.model.config.num_attention_heads,
head_dim_ckv=getattr(self.model.config, "kv_lora_rank", 0),
head_dim_kpe=getattr(self.model.config, "qk_rope_head_dim", 0),
q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0)
+ getattr(self.model.config, "qk_nope_head_dim", 0),
echo_stream=False,
)
else:
return prefill_and_generate_capture(
self.model,
self.tokenizer,
input_tensor,
max_tokens,
self.use_cuda_graph,
mode=self.mode,
force_think=self.force_think,
chunk_size=self.chunk_size,
echo_stream=False,
)
loop = asyncio.get_running_loop()
q: asyncio.Queue[Optional[str]] = asyncio.Queue()
def producer():
try:
gen = make_gen()
if hasattr(gen, "__aiter__"):
async def drain_async():
async for t in gen:
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
asyncio.run(drain_async())
elif hasattr(gen, "__iter__"):
for t in gen:
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
else:
loop.call_soon_threadsafe(q.put_nowait, gen if isinstance(gen, str) else str(gen))
finally:
loop.call_soon_threadsafe(q.put_nowait, None)
Thread(target=producer, daemon=True).start()
while True:
item = await q.get()
if item is None:
break
yield item
@override
async def chat(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> list["Response"]:
if not self.can_generate:
raise ValueError("The current model does not support `chat`.")
async with self.semaphore:
produced = ""
final_text = ""
async for t in self._generate(messages, system, tools, **input_kwargs):
delta = t
produced = produced + delta
if delta:
final_text += delta
prompt_ids, _ = self.template.encode_oneturn(
self.tokenizer, messages + [{"role": "assistant", "content": ""}], system, tools
)
return [
Response(
response_text=final_text,
response_length=len(self.tokenizer.encode(final_text, add_special_tokens=False)),
prompt_length=len(prompt_ids),
finish_reason="stop",
)
]
@override
async def stream_chat(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
raise ValueError("The current model does not support `stream_chat`.")
async with self.semaphore:
produced = ""
async for t in self._generate(messages, system, tools, **input_kwargs):
delta = t[len(produced) :] if t.startswith(produced) else t
produced = t
if delta:
yield delta
@override
async def get_scores(
self,
batch_input: list[str],
**input_kwargs,
) -> list[float]:
if self.can_generate:
raise ValueError("Cannot get scores using an auto-regressive model.")
args = (self.model, self.tokenizer, batch_input, input_kwargs)
async with self.semaphore:
return await asyncio.to_thread(self._get_scores, *args)
......@@ -16,6 +16,7 @@ import uuid
from collections.abc import AsyncGenerator, AsyncIterator
from typing import TYPE_CHECKING, Any, Optional, Union
from packaging import version
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
......@@ -77,11 +78,18 @@ class VllmEngine(BaseEngine):
"tensor_parallel_size": get_device_count() or 1,
"gpu_memory_utilization": model_args.vllm_gpu_util,
"disable_log_stats": True,
"disable_log_requests": True,
"enforce_eager": model_args.vllm_enforce_eager,
"enable_lora": model_args.adapter_name_or_path is not None,
"max_lora_rank": model_args.vllm_max_lora_rank,
}
import vllm
if version.parse(vllm.__version__) <= version.parse("0.10.0"):
engine_args["disable_log_requests"] = True
else:
engine_args["enable_log_requests"] = False
if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
......
......@@ -12,145 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
from copy import deepcopy
from functools import partial
USAGE = (
"-" * 70
+ "\n"
+ "| Usage: |\n"
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
+ "| llamafactory-cli eval -h: evaluate models |\n"
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
+ "| llamafactory-cli train -h: train models |\n"
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
+ "| llamafactory-cli webui: launch LlamaBoard |\n"
+ "| llamafactory-cli version: show version info |\n"
+ "-" * 70
)
def main():
from . import launcher
from .api.app import run_api
from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
logger = logging.get_logger(__name__)
WELCOME = (
"-" * 58
+ "\n"
+ f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION))
+ "|\n|"
+ " " * 56
+ "|\n"
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ "-" * 58
)
COMMAND_MAP = {
"api": run_api,
"chat": run_chat,
"env": print_env,
"eval": run_eval,
"export": export_model,
"train": run_exp,
"webchat": run_web_demo,
"webui": run_web_ui,
"version": partial(print, WELCOME),
"help": partial(print, USAGE),
}
from .extras.misc import is_env_enabled
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
# launch distributed training
nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
if int(nnodes) > 1:
logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
# elastic launch support
max_restarts = os.getenv("MAX_RESTARTS", "0")
rdzv_id = os.getenv("RDZV_ID")
min_nnodes = os.getenv("MIN_NNODES")
max_nnodes = os.getenv("MAX_NNODES")
env = deepcopy(os.environ)
if is_env_enabled("OPTIM_TORCH", "1"):
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
if rdzv_id is not None:
# launch elastic job with fault tolerant support when possible
# see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
rdzv_nnodes = nnodes
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
if min_nnodes is not None and max_nnodes is not None:
rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
process = subprocess.run(
(
"torchrun --nnodes {rdzv_nnodes} --nproc-per-node {nproc_per_node} "
"--rdzv-id {rdzv_id} --rdzv-backend c10d --rdzv-endpoint {master_addr}:{master_port} "
"--max-restarts {max_restarts} {file_name} {args}"
)
.format(
rdzv_nnodes=rdzv_nnodes,
nproc_per_node=nproc_per_node,
rdzv_id=rdzv_id,
master_addr=master_addr,
master_port=master_port,
max_restarts=max_restarts,
file_name=launcher.__file__,
args=" ".join(sys.argv[1:]),
)
.split(),
env=env,
check=True,
)
else:
# NOTE: DO NOT USE shell=True to avoid security risk
process = subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
)
.format(
nnodes=nnodes,
node_rank=node_rank,
nproc_per_node=nproc_per_node,
master_addr=master_addr,
master_port=master_port,
file_name=launcher.__file__,
args=" ".join(sys.argv[1:]),
)
.split(),
env=env,
check=True,
)
sys.exit(process.returncode)
elif command in COMMAND_MAP:
COMMAND_MAP[command]()
if is_env_enabled("USE_V1"):
from .v1 import launcher
else:
print(f"Unknown command: {command}.\n{USAGE}")
from . import launcher
launcher.launch()
if __name__ == "__main__":
......
......@@ -194,7 +194,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2.5 omni
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None: # FIXME: need to get video image lengths
......@@ -205,16 +205,25 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
dim=-1
).unsqueeze(-1)
else: # for qwen2vl
else: # for qwen vl
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
if (
self.model is not None
and getattr(self.model.config, "model_type", None)
in ["glm4v", "qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"]
in [
"glm4v",
"Keye",
"qwen2_vl",
"qwen2_5_vl",
"qwen2_5_omni_thinker",
"qwen3_omni_moe_thinker",
"qwen3_vl",
"qwen3_vl_moe",
]
and ("position_ids" not in features or features["position_ids"].dim() != 3)
):
raise ValueError("Qwen2-VL/Qwen2.5-Omni model requires 3D position ids for mrope.")
raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.")
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
......
......@@ -11,11 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Union
from ..extras import logging
from .data_utils import Role
......@@ -40,7 +40,7 @@ class DatasetConverter:
dataset_attr: "DatasetAttr"
data_args: "DataArguments"
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]:
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> list["MediaType"] | None:
r"""Optionally concatenate media path to media dir when loading from local disk."""
if medias is None:
return None
......@@ -227,9 +227,150 @@ class SharegptDatasetConverter(DatasetConverter):
return output
@dataclass
class OpenAIDatasetConverter(DatasetConverter):
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
tag_mapping = {
self.dataset_attr.user_tag: Role.USER.value,
self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
self.dataset_attr.observation_tag: Role.OBSERVATION.value,
self.dataset_attr.function_tag: Role.FUNCTION.value,
self.dataset_attr.system_tag: Role.SYSTEM.value,
}
messages = example[self.dataset_attr.messages]
if (
self.dataset_attr.system_tag
and len(messages) != 0
and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag
):
system = messages[0][self.dataset_attr.content_tag]
messages = messages[1:]
else:
system = example.get(self.dataset_attr.system, "") if self.dataset_attr.system else ""
aligned_messages = []
tool_responses = []
broken_data = False
for turn_idx, message in enumerate(messages):
role = message[self.dataset_attr.role_tag]
content = message[self.dataset_attr.content_tag]
if role in [self.dataset_attr.assistant_tag, self.dataset_attr.function_tag]:
if "tool_calls" in message and len(message["tool_calls"]) > 0:
tool_calls_list = [tool["function"] for tool in message["tool_calls"]]
content = json.dumps(tool_calls_list, ensure_ascii=False)
role = self.dataset_attr.function_tag
if role == self.dataset_attr.observation_tag:
tool_responses.append(content)
continue
elif len(tool_responses) > 0:
_content = "\n</tool_response>\n<tool_response>\n".join(tool_responses)
aligned_messages.append(
{
"role": Role.OBSERVATION.value,
"content": _content,
}
)
tool_responses = []
aligned_messages.append(
{
"role": tag_mapping[role],
"content": content,
}
)
odd_tags = (Role.USER.value, Role.OBSERVATION.value)
even_tags = (Role.ASSISTANT.value, Role.FUNCTION.value)
accept_tags = (odd_tags, even_tags)
for turn_idx, message in enumerate(aligned_messages):
if message["role"] not in accept_tags[turn_idx % 2]:
logger.warning_rank0(f"Invalid role tag in {messages}.")
broken_data = True
break
if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
self.dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning_rank0(f"Invalid message count in {messages}.")
broken_data = True
if broken_data:
logger.warning_rank0("Skipping this abnormal example.")
prompt, response = [], []
elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if example[self.dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
self.dataset_attr.ranking
and isinstance(example[self.dataset_attr.chosen], dict)
and isinstance(example[self.dataset_attr.rejected], dict)
): # pairwise example
chosen = example[self.dataset_attr.chosen]
rejected = example[self.dataset_attr.rejected]
if (
chosen[self.dataset_attr.role_tag] not in accept_tags[-1]
or rejected[self.dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
broken_data = True
prompt = aligned_messages
response = [
{
"role": tag_mapping[chosen[self.dataset_attr.role_tag]],
"content": chosen[self.dataset_attr.content_tag],
},
{
"role": tag_mapping[rejected[self.dataset_attr.role_tag]],
"content": rejected[self.dataset_attr.content_tag],
},
]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
tools = example.get(self.dataset_attr.tools, "") if self.dataset_attr.tools else ""
if isinstance(tools, dict) or isinstance(tools, list):
tools = json.dumps(tools, ensure_ascii=False)
short_system_prompt = "detailed thinking off"
if not system:
if not tools:
system = short_system_prompt
else:
pass
else:
if not tools:
if "detailed thinking on" in system or "detailed thinking off" in system:
pass
else:
system += "\n" + short_system_prompt
else:
system += "\n"
output = {
"_prompt": prompt,
"_response": response,
"_system": system,
"_tools": tools,
"_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
"_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
"_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
}
return output
DATASET_CONVERTERS = {
"alpaca": AlpacaDatasetConverter,
"sharegpt": SharegptDatasetConverter,
"openai": OpenAIDatasetConverter,
}
......
......@@ -81,41 +81,48 @@ def split_dataset(
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
data_args: "DataArguments",
seed: int,
) -> "DatasetDict":
r"""Split the dataset and returns a dataset dict containing train set and validation set.
) -> tuple[dict, dict]:
r"""Split the dataset and returns two dicts containing train set and validation set.
Support both map dataset and iterable dataset.
Returns:
train_dict: Dictionary containing training data with key "train"
eval_dict: Dictionary containing evaluation data with keys "validation" or "validation_{name}"
"""
if eval_dataset is not None and data_args.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
dataset_dict = {}
# the train and eval better to in dict dtype and separately return for cpode clearly and good handle outside
train_dict, eval_dict = {}, {}
if dataset is not None:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
if data_args.val_size > 1e-6:
if data_args.streaming:
dataset_dict["validation"] = dataset.take(int(data_args.val_size))
dataset_dict["train"] = dataset.skip(int(data_args.val_size))
eval_dict["validation"] = dataset.take(int(data_args.val_size))
train_dict["train"] = dataset.skip(int(data_args.val_size))
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset_dict = dataset.train_test_split(test_size=val_size, seed=seed)
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
dataset_dict = {"train": dataset["train"], "validation": dataset["test"]}
split_result = dataset.train_test_split(test_size=val_size, seed=seed)
train_dict["train"] = split_result["train"]
eval_dict["validation"] = split_result["test"]
else:
dataset_dict["train"] = dataset
train_dict["train"] = dataset
if eval_dataset is not None:
if isinstance(eval_dataset, dict):
dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()})
for name, data in eval_dataset.items():
eval_dict[f"validation_{name}"] = data
else:
if data_args.streaming:
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
dataset_dict["validation"] = eval_dataset
eval_dict["validation"] = eval_dataset
return DatasetDict(dataset_dict)
return train_dict, eval_dict
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
......
......@@ -16,7 +16,6 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Optional, Union
from typing_extensions import override
......@@ -27,14 +26,14 @@ from .tool_utils import FunctionCall, get_tool_utils
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Optional[str] = None
tool_format: str | None = None
@abstractmethod
def apply(self, **kwargs) -> SLOTS:
r"""Forms a list of slots according to the inputs to encode."""
...
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
def extract(self, content: str) -> str | list["FunctionCall"]:
r"""Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments.
......@@ -97,28 +96,46 @@ class FunctionFormatter(StringFormatter):
@override
def apply(self, **kwargs) -> SLOTS:
content: str = kwargs.pop("content")
regex = re.compile(r"<think>(.*)</think>", re.DOTALL)
thought = re.search(regex, content)
if thought:
content = content.replace(thought.group(0), "")
functions: list[FunctionCall] = []
try:
tool_calls = json.loads(content)
if not isinstance(tool_calls, list): # parallel function call
tool_calls = [tool_calls]
for tool_call in tool_calls:
functions.append(
FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))
)
except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string
function_str = self.tool_utils.function_formatter(functions)
if thought:
function_str = thought.group(0) + function_str
thought_words = kwargs.pop("thought_words", None)
tool_call_words = kwargs.pop("tool_call_words", None)
def _parse_functions(json_content: str) -> list["FunctionCall"]:
try:
tool_calls = json.loads(json_content)
if not isinstance(tool_calls, list): # parallel function call
tool_calls = [tool_calls]
return [FunctionCall(tc["name"], json.dumps(tc["arguments"], ensure_ascii=False)) for tc in tool_calls]
except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.")
tool_call_match = None
if tool_call_words and len(tool_call_words) == 2:
tool_call_regex = re.compile(
rf"{re.escape(tool_call_words[0])}(.*?){re.escape(tool_call_words[1])}", re.DOTALL
)
tool_call_match = re.search(tool_call_regex, content)
if tool_call_match is None:
thought_match = None
if thought_words and len(thought_words) == 2:
regex = re.compile(rf"{re.escape(thought_words[0])}(.*?){re.escape(thought_words[1])}", re.DOTALL)
thought_match = re.search(regex, content)
if thought_match:
json_part = content.replace(thought_match.group(0), "")
else:
json_part = content
functions = _parse_functions(json_part)
function_str = self.tool_utils.function_formatter(functions)
if thought_match:
function_str = thought_match.group(0) + function_str
else:
thought_content = content.replace(tool_call_match.group(0), "")
functions = _parse_functions(tool_call_match.group(1))
function_str = self.tool_utils.function_formatter(functions)
function_str = thought_content + function_str
return super().apply(content=function_str)
......@@ -138,5 +155,5 @@ class ToolFormatter(Formatter):
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
@override
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
def extract(self, content: str) -> str | list["FunctionCall"]:
return self.tool_utils.tool_extractor(content)
......@@ -16,7 +16,7 @@ import os
from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np
from datasets import Dataset, load_dataset, load_from_disk
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from ..extras import logging
from ..extras.constants import FILEEXT2TYPE
......@@ -137,7 +137,6 @@ def _load_single_dataset(
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
num_proc=data_args.preprocessing_num_workers,
trust_remote_code=model_args.trust_remote_code,
streaming=data_args.streaming and dataset_attr.load_from != "file",
)
if data_args.streaming and dataset_attr.load_from == "file":
......@@ -163,13 +162,13 @@ def _load_single_dataset(
def _get_merged_dataset(
dataset_names: Optional[list[str]],
dataset_names: list[str] | None,
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
return_dict: bool = False,
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
) -> Union["Dataset", "IterableDataset", dict[str, "Dataset"]] | None:
r"""Return the merged datasets in the standard format."""
if dataset_names is None:
return None
......@@ -228,7 +227,7 @@ def _get_dataset_processor(
def _get_preprocessed_dataset(
dataset: Optional[Union["Dataset", "IterableDataset"]],
dataset: Union["Dataset", "IterableDataset"] | None,
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
......@@ -236,7 +235,7 @@ def _get_preprocessed_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]:
) -> Union["Dataset", "IterableDataset"] | None:
r"""Preprocesses the dataset, including format checking and tokenization."""
if dataset is None:
return None
......@@ -312,20 +311,22 @@ def get_dataset(
)
with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)):
dataset = _get_preprocessed_dataset(
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
)
if isinstance(eval_dataset, dict):
for eval_name, eval_data in eval_dataset.items():
eval_dataset[eval_name] = _get_preprocessed_dataset(
eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
)
else:
eval_dataset = _get_preprocessed_dataset(
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
# move front to make sure eval_dataset(if contain or split) can preprocessed appropriately
train_dict, eval_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed)
if "train" in train_dict:
train_dict["train"] = _get_preprocessed_dataset(
train_dict["train"], data_args, training_args, stage, template, tokenizer, processor, is_eval=False
)
dataset_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed)
for key in eval_dict:
eval_dict[key] = _get_preprocessed_dataset(
eval_dict[key], data_args, training_args, stage, template, tokenizer, processor, is_eval=True
)
# Combine train and eval dictionaries
dataset_dict = DatasetDict({**train_dict, **eval_dict})
if data_args.tokenized_path is not None: # save tokenized dataset to disk
if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path)
......
......@@ -22,10 +22,11 @@ import re
from copy import deepcopy
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
import numpy as np
import torch
import torchaudio
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
from transformers.models.mllama.processing_mllama import (
convert_sparse_cross_attention_mask_to_dense,
......@@ -34,16 +35,7 @@ from transformers.models.mllama.processing_mllama import (
from typing_extensions import override
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import (
is_librosa_available,
is_pillow_available,
is_pyav_available,
is_transformers_version_greater_than,
)
if is_librosa_available():
import librosa
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
if is_pillow_available():
......@@ -68,15 +60,28 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
from transformers.image_processing_utils import BaseImageProcessor
from transformers.video_processing_utils import BaseVideoProcessor
class EncodedImage(TypedDict):
path: Optional[str]
bytes: Optional[bytes]
path: str | None
bytes: bytes | None
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
AudioInput = Union[str, BinaryIO, NDArray]
class RegularizedImageOutput(TypedDict):
images: list[ImageObject]
class RegularizedVideoOutput(TypedDict):
videos: list[list[ImageObject]]
durations: list[float]
fps_per_video: NotRequired[list[float]]
class RegularizedAudioOutput(TypedDict):
audios: list[NDArray]
sampling_rates: list[float]
class MMProcessor(ProcessorMixin):
patch_size: int
image_seq_length: int
......@@ -134,14 +139,14 @@ def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> lis
def _check_video_is_nested_images(video: "VideoInput") -> bool:
r"""Check if the video is nested images."""
return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict)) for frame in video)
return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict, ImageObject)) for frame in video)
@dataclass
class MMPluginMixin:
image_token: Optional[str]
video_token: Optional[str]
audio_token: Optional[str]
image_token: str | None
video_token: str | None
audio_token: str | None
expand_mm_tokens: bool = True
def _validate_input(
......@@ -244,7 +249,7 @@ class MMPluginMixin:
sample_frames = min(total_frames, video_maxlen, sample_frames)
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str, list["ImageObject"]]:
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
r"""Regularize images to avoid error. Including reading and pre-processing."""
results = []
for image in images:
......@@ -265,9 +270,10 @@ class MMPluginMixin:
return {"images": results}
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]:
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
results = []
durations = []
for video in videos:
frames: list[ImageObject] = []
if _check_video_is_nested_images(video):
......@@ -275,6 +281,7 @@ class MMPluginMixin:
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
raise ValueError("Invalid image found in video frames.")
frames = video
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
......@@ -284,19 +291,31 @@ class MMPluginMixin:
if frame_idx in sample_indices:
frames.append(frame.to_image())
if video_stream.duration is None:
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else:
durations.append(float(video_stream.duration * video_stream.time_base))
frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames)
return {"videos": results}
return {"videos": results, "durations": durations}
def _regularize_audios(
self, audios: list["AudioInput"], sampling_rate: float, **kwargs
) -> dict[str, Union[list["NDArray"], list[float]]]:
) -> "RegularizedAudioOutput":
r"""Regularizes audios to avoid error. Including reading and resampling."""
results, sampling_rates = [], []
for audio in audios:
if not isinstance(audio, np.ndarray):
audio, sampling_rate = librosa.load(audio, sr=sampling_rate)
audio, sr = torchaudio.load(audio)
if audio.shape[0] > 1:
audio = audio.mean(dim=0, keepdim=True)
if sr != sampling_rate:
audio = torchaudio.functional.resample(audio, sr, sampling_rate)
audio = audio.squeeze(0).numpy()
results.append(audio)
sampling_rates.append(sampling_rate)
......@@ -309,7 +328,7 @@ class MMPluginMixin:
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
imglens: Optional[list[int]] = None,
imglens: list[int] | None = None,
) -> dict[str, "torch.Tensor"]:
r"""Process visual inputs.
......@@ -407,13 +426,13 @@ class BasePlugin(MMPluginMixin):
def process_token_ids(
self,
input_ids: list[int],
labels: Optional[list[int]],
labels: list[int] | None,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["MMProcessor"],
) -> tuple[list[int], Optional[list[int]]]:
) -> tuple[list[int], list[int] | None]:
r"""Pre-process token ids after tokenization for VLMs."""
self._validate_input(processor, images, videos, audios)
return input_ids, labels
......@@ -446,6 +465,57 @@ class BasePlugin(MMPluginMixin):
return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class ErnieVLPlugin(BasePlugin):
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
else:
image_grid_thw = [None] * len(images)
video_grid_thw = [None] * len(videos)
image_idx, video_idx = 0, 0
for message in messages:
content = message["content"]
image_token = self.image_token or "<|IMAGE_PLACEHOLDER|>"
video_token = self.video_token or "<|VIDEO_PLACEHOLDER|>"
while IMAGE_PLACEHOLDER in content:
image_seqlen = image_grid_thw[image_idx].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER,
f"Picture {image_idx + 1}:<|IMAGE_START|>{image_token * image_seqlen}<|IMAGE_END|>",
1,
)
image_idx += 1
while VIDEO_PLACEHOLDER in content:
video_seqlen = video_grid_thw[video_idx].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
VIDEO_PLACEHOLDER,
f"Video {video_idx + 1}:<|VIDEO_START|>{video_token * video_seqlen}<|VIDEO_END|>",
1,
)
video_idx += 1
message["content"] = content
return messages
@dataclass
class Gemma3Plugin(BasePlugin):
@override
......@@ -1235,13 +1305,13 @@ class PaliGemmaPlugin(BasePlugin):
def process_token_ids(
self,
input_ids: list[int],
labels: Optional[list[int]],
labels: list[int] | None,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["MMProcessor"],
) -> tuple[list[int], Optional[list[int]]]:
) -> tuple[list[int], list[int] | None]:
self._validate_input(processor, images, videos, audios)
num_images = len(images)
image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
......@@ -1397,6 +1467,9 @@ class Qwen2AudioPlugin(BasePlugin):
@dataclass
class Qwen2VLPlugin(BasePlugin):
vision_bos_token: str = "<|vision_start|>"
vision_eos_token: str = "<|vision_end|>"
@override
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
image = super()._preprocess_image(image, **kwargs)
......@@ -1415,10 +1488,8 @@ class Qwen2VLPlugin(BasePlugin):
return image
@override
def _regularize_videos(
self, videos: list["VideoInput"], **kwargs
) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
results, fps_per_video = [], []
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
results, fps_per_video, durations = [], [], []
for video in videos:
frames: list[ImageObject] = []
if _check_video_is_nested_images(video):
......@@ -1428,6 +1499,7 @@ class Qwen2VLPlugin(BasePlugin):
frames = video
fps_per_video.append(kwargs.get("video_fps", 2.0))
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
......@@ -1439,8 +1511,10 @@ class Qwen2VLPlugin(BasePlugin):
if video_stream.duration is None:
fps_per_video.append(kwargs.get("video_fps", 2.0))
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else:
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
durations.append(float(video_stream.duration * video_stream.time_base))
if len(frames) % 2 != 0:
frames.append(frames[-1])
......@@ -1448,7 +1522,7 @@ class Qwen2VLPlugin(BasePlugin):
frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames)
return {"videos": results, "fps_per_video": fps_per_video}
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations}
@override
def _get_mm_inputs(
......@@ -1459,6 +1533,7 @@ class Qwen2VLPlugin(BasePlugin):
processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
......@@ -1476,7 +1551,7 @@ class Qwen2VLPlugin(BasePlugin):
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt"))
mm_inputs.update(video_processor(videos=video_data["videos"], return_tensors="pt"))
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
if "second_per_grid_ts" in processor.model_input_names:
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]]
......@@ -1512,15 +1587,142 @@ class Qwen2VLPlugin(BasePlugin):
while IMAGE_PLACEHOLDER in content:
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
IMAGE_PLACEHOLDER,
f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
1,
)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
VIDEO_PLACEHOLDER,
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}",
1,
)
num_video_tokens += 1
message["content"] = content
return messages
@dataclass
class Qwen3VLPlugin(Qwen2VLPlugin):
@override
def _get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseImageProcessor = getattr(processor, "video_processor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)["images"]
mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
video_metadata = [
{"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)}
for video, duration in zip(videos["videos"], videos["durations"])
]
mm_inputs.update(
video_processor(
videos=videos["videos"],
video_metadata=video_metadata,
fps=getattr(processor, "video_fps", 2.0),
return_metadata=True,
)
)
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
if "second_per_grid_ts" in processor.model_input_names:
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in videos["fps_per_video"]]
return mm_inputs
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
video_processor: BaseImageProcessor = getattr(processor, "video_processor")
image_merge_length: int = getattr(image_processor, "merge_size") ** 2
video_merge_length: int = getattr(video_processor, "merge_size") ** 2
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
video_metadata = mm_inputs.get("video_metadata", {})
else:
image_grid_thw = [None] * len(images)
video_grid_thw = [None] * len(videos)
num_frames = 0
timestamps = [0]
for idx, message in enumerate(messages):
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_seqlen = (
image_grid_thw[num_image_tokens].prod() // image_merge_length if self.expand_mm_tokens else 1
)
content = content.replace(
IMAGE_PLACEHOLDER,
f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
1,
)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
if self.expand_mm_tokens:
metadata = video_metadata[idx]
timestamps = processor._calculate_timestamps(
metadata.frames_indices,
metadata.fps,
video_processor.merge_size,
)
video_structure = ""
for frame_index in range(num_frames):
video_seqlen = (
video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
if self.expand_mm_tokens
else 1
)
timestamp_sec = timestamps[frame_index]
frame_structure = (
f"<{timestamp_sec:.1f} seconds>"
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
)
video_structure += frame_structure
else:
video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
num_video_tokens += 1
message["content"] = content
......@@ -1559,7 +1761,8 @@ class GLM4VPlugin(Qwen2VLPlugin):
)
# prepare video metadata
video_metadata = [
{"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"]
{"fps": 2, "duration": duration, "total_frames": len(video)}
for video, duration in zip(video_data["videos"], video_data["durations"])
]
mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))
......@@ -1630,6 +1833,9 @@ class GLM4VPlugin(Qwen2VLPlugin):
)
video_structure += frame_structure
if not self.expand_mm_tokens:
video_structure = self.video_token
content = content.replace(VIDEO_PLACEHOLDER, f"<|begin_of_video|>{video_structure}<|end_of_video|>", 1)
num_video_tokens += 1
......@@ -1655,7 +1861,11 @@ class GLM4VPlugin(Qwen2VLPlugin):
return mm_inputs
@dataclass
class Qwen2OmniPlugin(Qwen2VLPlugin):
audio_bos_token: str = "<|audio_start|>"
audio_eos_token: str = "<|audio_end|>"
@override
def _get_mm_inputs(
self,
......@@ -1665,6 +1875,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
mm_inputs = {}
if len(images) != 0:
......@@ -1683,7 +1894,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt"))
mm_inputs.update(video_processor(videos=video_dict["videos"], return_tensors="pt"))
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
mm_inputs["video_second_per_grid"] = torch.tensor(
[temporal_patch_size / fps for fps in video_dict["fps_per_video"]]
......@@ -1729,8 +1940,14 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
if "feature_attention_mask" in mm_inputs:
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
audio_lengths = (input_lengths - 2) // 2 + 1
if processor.__class__.__name__ == "Qwen3OmniMoeProcessor": # for qwen3omni
input_lengths = mm_inputs["feature_attention_mask"].sum(-1)
input_lengths_leave = input_lengths % 100
feature_lengths = (input_lengths_leave - 1) // 2 + 1
audio_lengths = ((feature_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
else:
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
audio_lengths = (input_lengths - 2) // 2 + 1
else:
mm_inputs = {}
image_grid_thw = [None] * len(images)
......@@ -1742,7 +1959,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
while IMAGE_PLACEHOLDER in content:
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1
IMAGE_PLACEHOLDER,
f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
1,
)
num_image_tokens += 1
......@@ -1779,7 +1998,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
placeholder_string = ""
placeholder_string += "<|vision_bos|>" + "<|audio_bos|>"
placeholder_string += self.vision_bos_token + self.audio_bos_token
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
......@@ -1789,7 +2008,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
if audio_chunk_index is not None:
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
placeholder_string += self.audio_eos_token + self.vision_eos_token
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
num_audio_tokens += 1
......@@ -1798,7 +2017,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
while AUDIO_PLACEHOLDER in content:
audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
content = content.replace(
AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1
AUDIO_PLACEHOLDER,
f"{self.audio_bos_token}{self.audio_token * audio_seqlen}{self.audio_eos_token}",
1,
)
num_audio_tokens += 1
......@@ -1807,7 +2028,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
)
content = content.replace(
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1
VIDEO_PLACEHOLDER,
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}",
1,
)
num_video_tokens += 1
......@@ -1871,6 +2094,7 @@ class VideoLlavaPlugin(BasePlugin):
PLUGINS = {
"base": BasePlugin,
"ernie_vl": ErnieVLPlugin,
"gemma3": Gemma3Plugin,
"glm4v": GLM4VPlugin,
"gemma3n": Gemma3nPlugin,
......@@ -1887,6 +2111,7 @@ PLUGINS = {
"qwen2_audio": Qwen2AudioPlugin,
"qwen2_omni": Qwen2OmniPlugin,
"qwen2_vl": Qwen2VLPlugin,
"qwen3_vl": Qwen3VLPlugin,
"video_llava": VideoLlavaPlugin,
}
......@@ -1901,12 +2126,13 @@ def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
def get_mm_plugin(
name: str,
image_token: Optional[str] = None,
video_token: Optional[str] = None,
audio_token: Optional[str] = None,
image_token: str | None = None,
video_token: str | None = None,
audio_token: str | None = None,
**kwargs,
) -> "BasePlugin":
r"""Get plugin for multimodal inputs."""
if name not in PLUGINS:
raise ValueError(f"Multimodal plugin `{name}` not found.")
return PLUGINS[name](image_token, video_token, audio_token)
return PLUGINS[name](image_token, video_token, audio_token, **kwargs)
......@@ -15,7 +15,7 @@
import json
import os
from dataclasses import dataclass
from typing import Any, Literal, Optional
from typing import Any, Literal
from huggingface_hub import hf_hub_download
......@@ -30,43 +30,43 @@ class DatasetAttr:
# basic configs
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
dataset_name: str
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca"
ranking: bool = False
# extra configs
subset: Optional[str] = None
subset: str | None = None
split: str = "train"
folder: Optional[str] = None
num_samples: Optional[int] = None
folder: str | None = None
num_samples: int | None = None
# common columns
system: Optional[str] = None
tools: Optional[str] = None
images: Optional[str] = None
videos: Optional[str] = None
audios: Optional[str] = None
system: str | None = None
tools: str | None = None
images: str | None = None
videos: str | None = None
audios: str | None = None
# dpo columns
chosen: Optional[str] = None
rejected: Optional[str] = None
kto_tag: Optional[str] = None
chosen: str | None = None
rejected: str | None = None
kto_tag: str | None = None
# alpaca columns
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
prompt: str | None = "instruction"
query: str | None = "input"
response: str | None = "output"
history: str | None = None
# sharegpt columns
messages: Optional[str] = "conversations"
messages: str | None = "conversations"
# sharegpt tags
role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human"
assistant_tag: Optional[str] = "gpt"
observation_tag: Optional[str] = "observation"
function_tag: Optional[str] = "function_call"
system_tag: Optional[str] = "system"
role_tag: str | None = "from"
content_tag: str | None = "value"
user_tag: str | None = "human"
assistant_tag: str | None = "gpt"
observation_tag: str | None = "observation"
function_tag: str | None = "function_call"
system_tag: str | None = "system"
def __repr__(self) -> str:
return self.dataset_name
def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
def set_attr(self, key: str, obj: dict[str, Any], default: Any | None = None) -> None:
setattr(self, key, obj.get(key, default))
def join(self, attr: dict[str, Any]) -> None:
......@@ -90,12 +90,14 @@ class DatasetAttr:
self.set_attr(tag, attr["tags"])
def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> list["DatasetAttr"]:
def get_dataset_list(dataset_names: list[str] | None, dataset_dir: str | dict) -> list["DatasetAttr"]:
r"""Get the attributes of the datasets."""
if dataset_names is None:
dataset_names = []
if dataset_dir == "ONLINE":
if isinstance(dataset_dir, dict):
dataset_info = dataset_dir
elif dataset_dir == "ONLINE":
dataset_info = None
else:
if dataset_dir.startswith("REMOTE:"):
......
......@@ -62,7 +62,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
if self.data_args.train_on_prompt:
source_label = source_ids
elif self.template.efficient_eos:
elif self.template.efficient_eos and turn_idx != 0:
source_label = [self.tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else:
source_label = [IGNORE_INDEX] * source_len
......
......@@ -49,6 +49,7 @@ class Template:
default_system: str
stop_words: list[str]
thought_words: tuple[str, str]
tool_call_words: tuple[str, str]
efficient_eos: bool
replace_eos: bool
replace_jinja_template: bool
......@@ -96,7 +97,7 @@ class Template:
def add_thought(self, content: str = "") -> str:
r"""Add empty thought to assistant message."""
return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content
return f"{self.thought_words[0]}{self.thought_words[1]}" + content
def remove_thought(self, content: str) -> str:
r"""Remove thought from assistant message."""
......@@ -156,7 +157,9 @@ class Template:
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=message["content"])
elements += self.format_function.apply(
content=message["content"], thought_words=self.thought_words, tool_call_words=self.tool_call_words
)
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
......@@ -199,9 +202,12 @@ class Template:
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
if stop_words:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
try:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
except TypeError:
num_added_tokens = tokenizer.add_special_tokens(dict(additional_special_tokens=stop_words))
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
......@@ -416,8 +422,8 @@ class ReasoningTemplate(Template):
prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
if (
self.thought_words[0] not in messages[-1]["content"]
and self.thought_words[1] not in messages[-1]["content"]
self.thought_words[0].strip() not in messages[-1]["content"]
and self.thought_words[1].strip() not in messages[-1]["content"]
): # add empty cot
if not self.enable_thinking: # do not compute loss
prompt_ids += self.get_thought_word_ids(tokenizer)
......@@ -442,8 +448,8 @@ class ReasoningTemplate(Template):
encoded_messages = self._encode(tokenizer, messages, system, tools)
for i in range(0, len(messages), 2):
if (
self.thought_words[0] not in messages[i + 1]["content"]
and self.thought_words[1] not in messages[i + 1]["content"]
self.thought_words[0].strip() not in messages[i + 1]["content"]
and self.thought_words[1].strip() not in messages[i + 1]["content"]
): # add empty cot
if not self.enable_thinking: # do not compute loss
encoded_messages[i] += self.get_thought_word_ids(tokenizer)
......@@ -468,6 +474,7 @@ def register_template(
default_system: str = "",
stop_words: Optional[list[str]] = None,
thought_words: Optional[tuple[str, str]] = None,
tool_call_words: Optional[tuple[str, str]] = None,
efficient_eos: bool = False,
replace_eos: bool = False,
replace_jinja_template: bool = False,
......@@ -518,7 +525,8 @@ def register_template(
format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system,
stop_words=stop_words or [],
thought_words=thought_words or ("<think>", "</think>"),
thought_words=thought_words or ("<think>\n", "\n</think>\n\n"),
tool_call_words=tool_call_words or ("<tool_call>", "</tool_call>"),
efficient_eos=efficient_eos,
replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template,
......@@ -579,7 +587,8 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
format_prefix=EmptyFormatter(slots=[prefix]) if prefix else EmptyFormatter(),
default_system=default_system,
stop_words=[],
thought_words=("<think>", "</think>"),
thought_words=("<think>\n", "\n</think>\n\n"),
tool_call_words=("<tool_call>", "</tool_call>"),
efficient_eos=False,
replace_eos=False,
replace_jinja_template=False,
......@@ -616,7 +625,14 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
logger.info_rank0(f"Using default system message: {data_args.default_system}.")
template.default_system = data_args.default_system
template.enable_thinking = data_args.enable_thinking
if isinstance(template, ReasoningTemplate):
logger.warning_rank0(
"You are using reasoning template, "
"please add `_nothink` suffix if the model is not a reasoning model. "
"e.g., qwen3_vl_nothink"
)
template.enable_thinking = data_args.enable_thinking
template.fix_special_tokens(tokenizer)
template.fix_jinja_template(tokenizer)
return template
......@@ -679,6 +695,23 @@ register_template(
)
register_template(
name="bailing_v2",
format_user=StringFormatter(slots=["<role>HUMAN</role>{{content}}<|role_end|><role>ASSISTANT</role>"]),
format_system=StringFormatter(slots=["<role>SYSTEM</role>{{content}}<|role_end|>"]),
format_assistant=StringFormatter(slots=["{{content}}<|role_end|>"]),
format_observation=StringFormatter(
slots=[
"<role>OBSERVATION</role>\n<tool_response>\n{{content}}\n</tool_response><|role_end|><role>ASSISTANT</role>"
]
),
format_function=FunctionFormatter(slots=["{{content}}<|role_end|>"], tool_format="ling"),
format_tools=ToolFormatter(tool_format="ling"),
stop_words=["<|endoftext|>"],
efficient_eos=True,
)
register_template(
name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
......@@ -894,12 +927,64 @@ register_template(
)
register_template(
name="dots_ocr",
format_user=StringFormatter(slots=["<|user|>{{content}}<|endofuser|><|assistant|>"]),
format_assistant=StringFormatter(slots=["{{content}}<|endofassistant|>"]),
format_system=StringFormatter(slots=["<|system|>{{content}}<|endofsystem|>\n"]),
stop_words=["<|endofassistant|>"],
efficient_eos=True,
mm_plugin=get_mm_plugin(
name="qwen2_vl",
image_token="<|imgpad|>",
video_token="<|vidpad|>",
vision_bos_token="<|img|>",
vision_eos_token="<|endofimg|>",
),
)
register_template(
name="empty",
format_assistant=StringFormatter(slots=["{{content}}"]),
)
# copied from chatml template
register_template(
name="ernie",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n\n<|im_start|>assistant\n"]),
default_system="<global_setting>\nthink_mode=True\n</global_setting>",
stop_words=["<|im_end|>"],
)
register_template(
name="ernie_nothink",
format_user=StringFormatter(slots=["User: {{content}}\nAssistant: "]),
format_assistant=StringFormatter(slots=["{{content}}<|end_of_sentence|>"]),
format_system=StringFormatter(slots=["{{content}}\n"]),
format_prefix=EmptyFormatter(slots=["<|begin_of_sentence|>"]),
stop_words=["<|end_of_sentence|>"],
)
register_template(
name="ernie_vl",
format_user=StringFormatter(slots=["User: {{content}}"]),
format_assistant=StringFormatter(slots=["\nAssistant: {{content}}<|end_of_sentence|>"]),
format_system=StringFormatter(slots=["{{content}}\n"]),
stop_words=["<|end_of_sentence|>"],
replace_eos=True,
replace_jinja_template=True,
template_class=ReasoningTemplate,
mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|IMAGE_PLACEHOLDER|>", video_token="<|VIDEO_PLACEHOLDER|>"),
)
register_template(
name="exaone",
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
......@@ -1014,6 +1099,22 @@ register_template(
)
# copied from glm4 template
register_template(
name="glm4_moe",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4_moe"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
template_class=ReasoningTemplate,
)
# copied from glm4 template
register_template(
name="glm4v",
......@@ -1031,6 +1132,23 @@ register_template(
)
# copied from glm4 template
register_template(
name="glm4_5v",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4_moe"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
stop_words=["<|user|>", "<|observation|>", "</answer>"],
efficient_eos=True,
mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"),
template_class=ReasoningTemplate,
)
# copied from glm4 template
register_template(
name="glmz1",
......@@ -1047,6 +1165,18 @@ register_template(
)
register_template(
name="gpt_oss",
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
default_system="You are ChatGPT, a large language model trained by OpenAI.",
thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"),
efficient_eos=True,
template_class=ReasoningTemplate,
)
register_template(
name="granite3",
format_user=StringFormatter(
......@@ -1071,6 +1201,25 @@ register_template(
)
register_template(
name="granite4",
format_user=StringFormatter(
slots=[
"<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>"
]
),
format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]),
format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|end_of_text|>\n"], tool_format="default"),
format_observation=StringFormatter(
slots=["<|start_of_role|>tool<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="default"),
stop_words=["<|end_of_text|>"],
default_system="You are Granite, developed by IBM. You are a helpful AI assistant.",
)
register_template(
name="index",
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
......@@ -1081,10 +1230,10 @@ register_template(
register_template(
name="hunyuan",
format_user=StringFormatter(slots=["<|bos|>user\n{{content}}<|eos|>\n<|bos|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|eos|>\n"]),
format_system=StringFormatter(slots=["<|bos|>system\n{{content}}<|eos|>\n"]),
format_prefix=EmptyFormatter(slots=["<|bos|>"]),
format_user=StringFormatter(slots=["{{content}}<|extra_0|>"]),
format_assistant=StringFormatter(slots=["{{content}}<|eos|>"]),
format_system=StringFormatter(slots=["{{content}}<|extra_4|>"]),
format_prefix=EmptyFormatter(slots=["<|startoftext|>"]),
stop_words=["<|eos|>"],
)
......@@ -1137,6 +1286,35 @@ register_template(
)
register_template(
name="intern_s1",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
mm_plugin=get_mm_plugin(name="intern_vl", image_token="<image>", video_token="<video>"),
)
# copied from qwen template
register_template(
name="keye_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
template_class=ReasoningTemplate,
)
register_template(
name="kimi_vl",
format_user=StringFormatter(
......@@ -1432,6 +1610,26 @@ register_template(
template_class=ReasoningTemplate,
)
# copied from qwen template
register_template(
name="mimo_v2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are MiMo, a helpful AI assistant engineered by Xiaomi.",
stop_words=["<|im_end|>"],
replace_eos=True,
thought_words=("<think>", "</think>"),
template_class=ReasoningTemplate,
)
# copied from qwen2vl
register_template(
name="mimo_vl",
......@@ -1470,11 +1668,48 @@ register_template(
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
stop_words=["<|im_end|>"],
default_system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
default_system="You are a helpful assistant. You can accept audio and text input and output voice and text.",
mm_plugin=get_mm_plugin(name="minicpm_v", image_token="<image>", video_token="<video>", audio_token="<audio>"),
)
register_template(
name="minimax1",
format_user=StringFormatter(
slots=[
"<beginning_of_sentence>user name=user\n{{content}}<end_of_sentence>\n<beginning_of_sentence>ai name=assistant\n"
]
),
format_assistant=StringFormatter(slots=["{{content}}<end_of_sentence>\n"]),
format_system=StringFormatter(
slots=["<beginning_of_sentence>system ai_setting=assistant\n{{content}}<end_of_sentence>\n"]
),
format_function=FunctionFormatter(slots=["{{content}}<end_of_sentence>\n"], tool_format="minimax1"),
format_observation=StringFormatter(
slots=[
"<beginning_of_sentence>tool name=tools\n{{content}}<end_of_sentence>\n<beginning_of_sentence>ai name=assistant\n"
]
),
format_tools=ToolFormatter(tool_format="minimax1"),
default_system="You are a helpful assistant.",
stop_words=["<end_of_sentence>"],
)
register_template(
name="minimax2",
format_user=StringFormatter(slots=["]~b]user\n{{content}}[e~[\n]~b]ai\n"]),
format_assistant=StringFormatter(slots=["{{content}}[e~[\n"]),
format_system=StringFormatter(slots=["]~!b[]~b]system\n{{content}}[e~[\n"]),
format_function=FunctionFormatter(slots=["{{content}}[e~[\n"], tool_format="minimax2"),
format_observation=StringFormatter(slots=["]~b]tool\n<response>{{content}}</response>[e~[\n]~b]ai\n"]),
format_tools=ToolFormatter(tool_format="minimax2"),
default_system="You are a helpful assistant. Your name is MiniMax-M2.1 and is built by MiniMax.",
stop_words=["[e~["],
template_class=ReasoningTemplate,
)
# mistral tokenizer v3 tekken
register_template(
name="ministral",
......@@ -1515,6 +1750,19 @@ register_template(
)
register_template(
name="ministral3",
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
template_class=Llama2Template,
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
)
register_template(
name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
......@@ -1669,6 +1917,22 @@ register_template(
)
# copied from qwen template
register_template(
name="qwen3_nothink",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
stop_words=["<|im_end|>"],
replace_eos=True,
)
# copied from chatml template
register_template(
name="qwen2_audio",
......@@ -1697,10 +1961,55 @@ register_template(
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
name="qwen2_omni",
image_token="<|IMAGE|>",
video_token="<|VIDEO|>",
audio_token="<|AUDIO|>",
vision_bos_token="<|vision_bos|>",
vision_eos_token="<|vision_eos|>",
audio_bos_token="<|audio_bos|>",
audio_eos_token="<|audio_eos|>",
),
)
register_template(
name="qwen3_omni",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(
name="qwen2_omni", image_token="<|image_pad|>", video_token="<|video_pad|>", audio_token="<|audio_pad|>"
),
template_class=ReasoningTemplate,
)
register_template(
name="qwen3_omni_nothink",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(
name="qwen2_omni", image_token="<|image_pad|>", video_token="<|video_pad|>", audio_token="<|audio_pad|>"
),
)
# copied from qwen template
register_template(
name="qwen2_vl",
......@@ -1719,6 +2028,41 @@ register_template(
)
# copied from qwen template
register_template(
name="qwen3_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
template_class=ReasoningTemplate,
)
# copied from qwen template
register_template(
name="qwen3_vl_nothink",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
register_template(
name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
......@@ -1746,6 +2090,20 @@ register_template(
)
# copied from seed_coder
register_template(
name="seed_oss",
format_user=StringFormatter(
slots=[{"bos_token"}, "user\n{{content}}", {"eos_token"}, {"bos_token"}, "assistant\n"]
),
format_system=StringFormatter(slots=[{"bos_token"}, "system\n{{content}}", {"eos_token"}]),
format_function=FunctionFormatter(slots=[{"bos_token"}, "\n{{content}}", {"eos_token"}], tool_format="seed_oss"),
format_tools=ToolFormatter(tool_format="seed_oss"),
template_class=ReasoningTemplate,
thought_words=("<seed:think>", "</seed:think>"),
)
# copied from llama3 template
register_template(
name="skywork_o1",
......
......@@ -42,6 +42,18 @@ GLM4_TOOL_PROMPT = (
"你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{tool_text}"
)
GLM4_MOE_TOOL_PROMPT = (
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
"\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:"
"\n<tool_call>{{function-name}}"
"\n<arg_key>{{arg-key-1}}</arg_key>"
"\n<arg_value>{{arg-value-1}}</arg_value>"
"\n<arg_key>{{arg-key-2}}</arg_key>"
"\n<arg_value>{{arg-value-2}}</arg_value>"
"\n...\n</tool_call>\n"
)
LLAMA3_TOOL_PROMPT = (
"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
......@@ -49,6 +61,21 @@ LLAMA3_TOOL_PROMPT = (
"Do not use variables.\n\n{tool_text}"
)
MINIMAX_M1_TOOL_PROMPT = (
"You are provided with these tools:\n<tools>\n{tool_text}</tools>\n\n"
"If you need to call tools, please respond with <tool_calls></tool_calls> XML tags, and provide tool-name and "
"json-object of arguments, following the format below:\n<tool_calls>\n"
"""{{"name": <tool-name-1>, "arguments": <args-json-object-1>}}\n...\n</tool_calls>"""
)
MINIMAX_M2_TOOL_PROMPT = (
"\n\n# Tools\n\nYou may call one or more tools to assist with the user query.\n"
"Here are the tools available in JSONSchema format:\n\n<tools>\n{tool_text}</tools>\n\n"
"When making tool calls, use XML format to invoke tools and pass parameters:\n"
"""\n<minimax:tool_call>\n<invoke name="tool-name-1">\n<parameter name="param-key-1">param-value-1</parameter>\n"""
"""<parameter name="param-key-2">param-value-2</parameter>\n...\n</invoke>\n</minimax:tool_call>"""
)
QWEN_TOOL_PROMPT = (
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
......@@ -57,6 +84,23 @@ QWEN_TOOL_PROMPT = (
""""arguments": <args-json-object>}}\n</tool_call>"""
)
SEED_TOOL_PROMPT = (
"system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query."
"Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing "
"any task, you must decide how to call them based on the descriptions and parameters of these tools.{tool_text}\n"
"工具调用请遵循如下格式:\n<seed:tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>value_1"
"</parameter>\n<parameter=example_parameter_2>This is the value for the second parameter\nthat can span\nmultiple "
"lines</parameter>\n</function>\n</seed:tool_call>\n"
)
LING_TOOL_PROMPT = (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
"\n</tools>\n\nFor each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:\n<tool_call>\n{{"name": <function-name>, """
""""arguments": <args-json-object>}}\n</tool_call>"""
)
@dataclass
class ToolUtils(ABC):
......@@ -224,6 +268,109 @@ class Llama3ToolUtils(ToolUtils):
return content
class MiniMaxM1ToolUtils(ToolUtils):
r"""MiniMax-M1 tool using template."""
@override
@staticmethod
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool = tool.get("function", "") if tool.get("type") == "function" else tool
tool_text += json.dumps(tool, ensure_ascii=False) + "\n"
return MINIMAX_M1_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for func in functions:
name, arguments = func.name, json.loads(func.arguments)
function_texts.append(json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False))
return "<tool_calls>\n" + "\n".join(function_texts) + "\n</tool_calls>"
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"<tool_calls>\s*(.+?)\s*</tool_calls>", re.DOTALL)
tool_match = re.search(regex, content)
if not tool_match:
return content
tool_calls_content = tool_match.group(1)
results = []
for line in tool_calls_content.split("\n"):
line = line.strip()
if not line:
continue
try:
tool_call = json.loads(line)
results.append(FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
except json.JSONDecodeError:
continue
return results
class MiniMaxM2ToolUtils(ToolUtils):
r"""MiniMax-M2 tool using template."""
@override
@staticmethod
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool = tool.get("function", "") if tool.get("type") == "function" else tool
tool_text += "<tool>" + json.dumps(tool, ensure_ascii=False) + "</tool>\n"
return MINIMAX_M2_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for func in functions:
name, arguments = func.name, json.loads(func.arguments)
prompt = f'<invoke name="{name}">'
for key, value in arguments.items():
prompt += f'\n<parameter name="{key}">'
if not isinstance(value, str):
value = json.dumps(value, ensure_ascii=False)
prompt += value + "</parameter>"
prompt += "\n</invoke>"
function_texts.append(prompt)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"<minimax:tool_call>\s*(.+?)\s*</minimax:tool_call>", re.DOTALL)
tool_match = re.search(regex, content)
if not tool_match:
return content
tool_calls_content = tool_match.group(1)
invoke_regex = re.compile(r"<invoke name=\"(.*?)\">(.*?)</invoke>", re.DOTALL)
results = []
for func_name, params_block in re.findall(invoke_regex, tool_calls_content):
args_dict = {}
param_pattern = re.compile(r"<parameter name=\"(.*?)\">(.*?)</parameter>", re.DOTALL)
for key, raw_value in re.findall(param_pattern, params_block):
value = raw_value.strip()
try:
parsed_value = json.loads(value)
except json.JSONDecodeError:
parsed_value = raw_value
args_dict[key] = parsed_value
results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False)))
return results
class MistralToolUtils(ToolUtils):
r"""Mistral v0.3 tool using template."""
......@@ -303,12 +450,113 @@ class QwenToolUtils(ToolUtils):
return results
class GLM4MOEToolUtils(QwenToolUtils):
r"""GLM-4-MOE tool using template."""
@override
@staticmethod
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
return GLM4_MOE_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_json = [
{"func_name": name, "func_key_values": json.loads(arguments)} for name, arguments in functions
]
function_texts = []
for func in function_json:
prompt = "\n<tool_call>" + func["func_name"]
for key, value in func["func_key_values"].items():
prompt += "\n<arg_key>" + key + "</arg_key>"
if not isinstance(value, str):
value = json.dumps(value, ensure_ascii=False)
prompt += "\n<arg_value>" + value + "</arg_value>"
function_texts.append(prompt)
return "\n".join(function_texts)
class SeedToolUtils(ToolUtils):
r"""Seed tool using template."""
@override
@staticmethod
def tool_formatter(tools: list[dict[str, Any]]) -> str:
return SEED_TOOL_PROMPT.format(tool_text="\n" + json.dumps(tools, ensure_ascii=False))
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_json = [
{"func_name": name, "func_key_values": json.loads(arguments)} for name, arguments in functions
]
function_texts = []
for func in function_json:
prompt = "\n<seed:tool_call>\n<function=" + func["func_name"]
for key, value in func["func_key_values"].items():
prompt += "\n<parameter=" + key + ">"
if not isinstance(value, str):
value = json.dumps(value, ensure_ascii=False)
prompt += value + "</parameter>"
prompt += "\n</function>\n</seed:tool_call>"
function_texts.append(prompt)
return "\n".join(function_texts)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
results = []
regex = re.compile(
r"<seed:tool_call>\s*<function=\s*([^\s<]+)\s*(.*?)\s*</function>\s*</seed:tool_call>", re.DOTALL
)
for func_name, params_block in re.findall(regex, content):
args_dict = {}
param_pattern = re.compile(r"<parameter=(.*?)>(.*?)</parameter>", re.DOTALL)
for key, raw_value in re.findall(param_pattern, params_block.strip()):
value = raw_value.strip()
try:
parsed_value = json.loads(value)
except json.JSONDecodeError:
parsed_value = raw_value
args_dict[key] = parsed_value
results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False)))
return results
class LingToolUtils(QwenToolUtils):
r"""Ling v2 tool using template."""
@override
@staticmethod
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
return LING_TOOL_PROMPT.format(tool_text=tool_text) + "\n" + "detailed thinking off"
TOOLS = {
"default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(),
"llama3": Llama3ToolUtils(),
"minimax1": MiniMaxM1ToolUtils(),
"minimax2": MiniMaxM2ToolUtils(),
"mistral": MistralToolUtils(),
"qwen": QwenToolUtils(),
"glm4_moe": GLM4MOEToolUtils(),
"seed_oss": SeedToolUtils(),
"ling": LingToolUtils(),
}
......
......@@ -15,7 +15,6 @@
import os
from collections import OrderedDict, defaultdict
from enum import Enum, unique
from typing import Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
......@@ -56,13 +55,27 @@ LAYERNORM_NAMES = {"norm", "ln"}
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
METHODS = ["full", "freeze", "lora"]
MCA_SUPPORTED_MODELS = {
"deepseek_v3",
"llama",
"mistral",
"mixtral",
"qwen2",
"qwen2_vl",
"qwen2_5_vl",
"qwen3_vl",
"qwen3",
"qwen3_moe",
"qwen3_next",
}
METHODS = ["full", "freeze", "lora", "oft"]
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
MULTIMODAL_SUPPORTED_MODELS = set()
PEFT_METHODS = {"lora"}
PEFT_METHODS = {"lora", "oft"}
RUNNING_LOG = "running_log.txt"
......@@ -101,12 +114,14 @@ class AttentionFunction(str, Enum):
DISABLED = "disabled"
SDPA = "sdpa"
FA2 = "fa2"
FA3 = "fa3"
class EngineName(str, Enum):
HF = "huggingface"
VLLM = "vllm"
SGLANG = "sglang"
KT = "ktransformers"
class DownloadSource(str, Enum):
......@@ -126,6 +141,8 @@ class QuantizationMethod(str, Enum):
QUANTO = "quanto"
EETQ = "eetq"
HQQ = "hqq"
MXFP4 = "mxfp4"
FP8 = "fp8"
class RopeScaling(str, Enum):
......@@ -137,13 +154,13 @@ class RopeScaling(str, Enum):
def register_model_group(
models: dict[str, dict[DownloadSource, str]],
template: Optional[str] = None,
template: str | None = None,
multimodal: bool = False,
) -> None:
for name, path in models.items():
SUPPORTED_MODELS[name] = path
if template is not None and (
any(suffix in name for suffix in ("-Chat", "-Distill", "-Instruct")) or multimodal
any(suffix in name for suffix in ("-Chat", "-Distill", "-Instruct", "-Thinking")) or multimodal
):
DEFAULT_TEMPLATE[name] = template
......@@ -276,7 +293,7 @@ register_model_group(
register_model_group(
models={
"ChatGLM2-6B-Chat": {
DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
DownloadSource.DEFAULT: "zai-org/chatglm2-6b",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
}
},
......@@ -287,11 +304,11 @@ register_model_group(
register_model_group(
models={
"ChatGLM3-6B-Base": {
DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
DownloadSource.DEFAULT: "zai-org/chatglm3-6b-base",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base",
},
"ChatGLM3-6B-Chat": {
DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
DownloadSource.DEFAULT: "zai-org/chatglm3-6b",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
},
},
......@@ -333,7 +350,7 @@ register_model_group(
register_model_group(
models={
"CodeGeeX4-9B-Chat": {
DownloadSource.DEFAULT: "THUDM/codegeex4-all-9b",
DownloadSource.DEFAULT: "zai-org/codegeex4-all-9b",
DownloadSource.MODELSCOPE: "ZhipuAI/codegeex4-all-9b",
},
},
......@@ -600,6 +617,68 @@ register_model_group(
)
register_model_group(
models={
"dots.ocr": {
DownloadSource.DEFAULT: "rednote-hilab/dots.ocr",
DownloadSource.MODELSCOPE: "rednote-hilab/dots.ocr",
},
},
template="dots_ocr",
multimodal=True,
)
register_model_group(
models={
"ERNIE-4.5-21B-A3B-Thinking": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-21B-A3B-Thinking",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-21B-A3B-Thinking",
},
},
template="ernie",
)
register_model_group(
models={
"ERNIE-4.5-0.3B-PT": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-0.3B-PT",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-0.3B-PT",
},
"ERNIE-4.5-21B-A3B-PT": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-21B-A3B-PT",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-21B-A3B-PT",
},
"ERNIE-4.5-300B-A47B-PT": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-300B-A47B-PT",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-300B-A47B-PT",
},
},
template="ernie_nothink",
)
register_model_group(
models={
"ERNIE-4.5-VL-28B-A3B-PT": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-28B-A3B-PT",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-28B-A3B-PT",
},
"ERNIE-4.5-VL-28B-A3B-Thinking": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-28B-A3B-Thinking",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-28B-A3B-Thinking",
},
"ERNIE-4.5-VL-424B-A47B-Base-PT": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-424B-A47B-PT",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-424B-A47B-PT",
},
},
template="ernie_vl",
multimodal=True,
)
register_model_group(
models={
"EXAONE-3.0-7.8B-Instruct": {
......@@ -644,6 +723,7 @@ register_model_group(
template="falcon",
)
register_model_group(
models={
"Falcon-H1-0.5B-Base": {
......@@ -756,10 +836,18 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-2-27b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it",
},
"Gemma-3-270M": {
DownloadSource.DEFAULT: "google/gemma-3-270m",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-270m",
},
"Gemma-3-1B": {
DownloadSource.DEFAULT: "google/gemma-3-1b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-pt",
},
"Gemma-3-270M-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-270m-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-270m-it",
},
"Gemma-3-1B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-1b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-it",
......@@ -807,6 +895,10 @@ register_model_group(
DownloadSource.DEFAULT: "google/medgemma-4b-it",
DownloadSource.MODELSCOPE: "google/medgemma-4b-it",
},
"MedGemma-27B-Instruct": {
DownloadSource.DEFAULT: "google/medgemma-27b-text-it",
DownloadSource.MODELSCOPE: "google/medgemma-27b-text-it",
},
},
template="gemma3",
multimodal=True,
......@@ -840,28 +932,28 @@ register_model_group(
register_model_group(
models={
"GLM-4-9B": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b",
DownloadSource.DEFAULT: "zai-org/glm-4-9b",
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b",
},
"GLM-4-9B-Chat": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
DownloadSource.DEFAULT: "zai-org/glm-4-9b-chat",
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
DownloadSource.OPENMIND: "LlamaFactory/glm-4-9b-chat",
},
"GLM-4-9B-1M-Chat": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
DownloadSource.DEFAULT: "zai-org/glm-4-9b-chat-1m",
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat-1m",
},
"GLM-4-0414-9B-Chat": {
DownloadSource.DEFAULT: "THUDM/GLM-4-9B-0414",
DownloadSource.DEFAULT: "zai-org/GLM-4-9B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-9B-0414",
},
"GLM-4-0414-32B-Base": {
DownloadSource.DEFAULT: "THUDM/GLM-4-32B-Base-0414",
DownloadSource.DEFAULT: "zai-org/GLM-4-32B-Base-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-Base-0414",
},
"GLM-4-0414-32B-Chat": {
DownloadSource.DEFAULT: "THUDM/GLM-4-32B-0414",
DownloadSource.DEFAULT: "zai-org/GLM-4-32B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414",
},
},
......@@ -872,11 +964,11 @@ register_model_group(
register_model_group(
models={
"GLM-4.1V-9B-Base": {
DownloadSource.DEFAULT: "THUDM/GLM-4.1V-9B-Base",
DownloadSource.DEFAULT: "zai-org/GLM-4.1V-9B-Base",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.1V-9B-Base",
},
"GLM-4.1V-9B-Thinking": {
DownloadSource.DEFAULT: "THUDM/GLM-4.1V-9B-Thinking",
DownloadSource.DEFAULT: "zai-org/GLM-4.1V-9B-Thinking",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.1V-9B-Thinking",
},
},
......@@ -885,14 +977,57 @@ register_model_group(
)
register_model_group(
models={
"GLM-4.5-Air-Base": {
DownloadSource.DEFAULT: "zai-org/GLM-4.5-Air-Base",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.5-Air-Base",
},
"GLM-4.5-Base": {
DownloadSource.DEFAULT: "zai-org/GLM-4.5-Base",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.5-Base",
},
"GLM-4.5-Air-Thinking": {
DownloadSource.DEFAULT: "zai-org/GLM-4.5-Air",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.5-Air",
},
"GLM-4.5-Thinking": {
DownloadSource.DEFAULT: "zai-org/GLM-4.5",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.5",
},
},
template="glm4_moe",
)
register_model_group(
models={
"GLM-4.5V-Air-Thinking": {
DownloadSource.DEFAULT: "zai-org/GLM-4.5V",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.5V",
},
"GLM-4.6V": {
DownloadSource.DEFAULT: "zai-org/GLM-4.6V",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.6V",
},
"GLM-4.6V-Flash": {
DownloadSource.DEFAULT: "zai-org/GLM-4.6V-Flash",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.6V-Flash",
},
},
template="glm4_5v",
multimodal=True,
)
register_model_group(
models={
"GLM-Z1-0414-9B-Chat": {
DownloadSource.DEFAULT: "THUDM/GLM-Z1-9B-0414",
DownloadSource.DEFAULT: "zai-org/GLM-Z1-9B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-9B-0414",
},
"GLM-Z1-0414-32B-Chat": {
DownloadSource.DEFAULT: "THUDM/GLM-Z1-32B-0414",
DownloadSource.DEFAULT: "zai-org/GLM-Z1-32B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-32B-0414",
},
},
......@@ -922,6 +1057,55 @@ register_model_group(
)
register_model_group(
models={
"GPT-OSS-20B-Thinking": {
DownloadSource.DEFAULT: "openai/gpt-oss-20b",
DownloadSource.MODELSCOPE: "openai/gpt-oss-20b",
},
"GPT-OSS-120B-Thinking": {
DownloadSource.DEFAULT: "openai/gpt-oss-120b",
DownloadSource.MODELSCOPE: "openai/gpt-oss-120b",
},
},
template="gpt_oss",
)
register_model_group(
models={
"MiniMax-Text-01-Instruct": {
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-Text-01-hf",
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-Text-01",
},
"MiniMax-M1-40k-Thinking": {
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M1-40k-hf",
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M1-40k-hf",
},
"MiniMax-M1-80k-Thinking": {
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M1-80k-hf",
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M1-80k-hf",
},
},
template="minimax1",
)
register_model_group(
models={
"MiniMax-M2-Thinking": {
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M2",
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M2",
},
"MiniMax-M2.1-Thinking": {
DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M2.1",
DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M2.1",
},
},
template="minimax2",
)
register_model_group(
models={
"Granite-3.0-1B-A400M-Base": {
......@@ -1029,12 +1213,27 @@ register_model_group(
)
register_model_group(
models={
"Granite-4.0-tiny-preview": {
DownloadSource.DEFAULT: "ibm-granite/granite-4.0-tiny-preview",
DownloadSource.MODELSCOPE: "ibm-granite/granite-4.0-tiny-preview",
},
},
template="granite4",
)
register_model_group(
models={
"Hunyuan-7B-Instruct": {
DownloadSource.DEFAULT: "tencent/Hunyuan-7B-Instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/Hunyuan-7B-Instruct",
},
"Hunyuan-MT-7B-Instruct": {
DownloadSource.DEFAULT: "tencent/Hunyuan-MT-7B",
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-MT-7B",
},
},
template="hunyuan",
)
......@@ -1185,12 +1384,52 @@ register_model_group(
DownloadSource.DEFAULT: "OpenGVLab/InternVL3-78B-hf",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3-78B-hf",
},
"InternVL3.5-1B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3_5-1B-HF",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3_5-1B-HF",
},
"InternVL3.5-2B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3_5-2B-HF",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3_5-2B-HF",
},
"InternVL3.5-4B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3_5-4B-HF",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3_5-4B-HF",
},
"InternVL3.5-8B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3_5-8B-HF",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3_5-8B-HF",
},
"InternVL3.5-14B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3_5-14B-HF",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3_5-14B-HF",
},
"InternVL3.5-30B-A3B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3_5-30B-A3B-HF",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3_5-30B-A3B-HF",
},
"InternVL3.5-38B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3_5-38B-HF",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3_5-38B-HF",
},
},
template="intern_vl",
multimodal=True,
)
register_model_group(
models={
"Intern-S1-mini": {
DownloadSource.DEFAULT: "internlm/Intern-S1-mini",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/Intern-S1-mini",
}
},
template="intern_s1",
multimodal=True,
)
register_model_group(
models={
"Jamba-v0.1": {
......@@ -1201,6 +1440,18 @@ register_model_group(
)
register_model_group(
models={
"Keye-VL-8B-Chat": {
DownloadSource.DEFAULT: "Kwai-Keye/Keye-VL-8B-Preview",
DownloadSource.MODELSCOPE: "Kwai-Keye/Keye-VL-8B-Preview",
},
},
template="keye_vl",
multimodal=True,
)
register_model_group(
models={
"Kimi-Dev-72B-Instruct": {
......@@ -1589,20 +1840,51 @@ register_model_group(
register_model_group(
models={
"MiMo-7B-VL-Instruct": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT",
"MiMo-V2-Flash-Base": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-V2-Flash-Base",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-V2-Flash-Base",
},
"MiMo-V2-Flash": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-V2-Flash",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-V2-Flash",
},
},
template="mimo_v2",
)
register_model_group(
models={
"MiMo-7B-VL-RL": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-RL",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL",
},
"MiMo-VL-7B-RL-2508": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-RL-2508",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL-2508",
},
},
template="mimo_vl",
multimodal=True,
)
register_model_group(
models={
"MiMo-7B-VL-Instruct": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT",
},
"MiMo-VL-7B-SFT-2508": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
},
},
template="qwen2_vl",
multimodal=True,
)
register_model_group(
models={
"MiniCPM-2B-SFT-Chat": {
......@@ -1640,6 +1922,10 @@ register_model_group(
DownloadSource.DEFAULT: "openbmb/MiniCPM4-8B",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM4-8B",
},
"MiniCPM4.1-8B-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM4.1-8B",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM4.1-8B",
},
},
template="cpm4",
)
......@@ -1647,7 +1933,7 @@ register_model_group(
register_model_group(
models={
"MiniCPM-o-2_6": {
"MiniCPM-o-2.6": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-o-2_6",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-o-2_6",
},
......@@ -1659,7 +1945,7 @@ register_model_group(
register_model_group(
models={
"MiniCPM-V-2_6": {
"MiniCPM-V-2.6": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-2_6",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-2_6",
},
......@@ -1669,6 +1955,30 @@ register_model_group(
)
register_model_group(
models={
"MiniCPM-V-4": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-4",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-4",
},
},
template="minicpm_v",
multimodal=True,
)
register_model_group(
models={
"MiniCPM-V-4.5": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-4_5",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-4_5",
},
},
template="minicpm_v",
multimodal=True,
)
register_model_group(
models={
"Ministral-8B-Instruct-2410": {
......@@ -1718,6 +2028,37 @@ register_model_group(
template="mistral",
)
register_model_group(
models={
"Ministral-3-3B-Base-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-3B-Base-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-3B-Base-2512",
},
"Ministral-3-8B-Base-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-8B-Base-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-8B-Base-2512",
},
"Ministral-3-14B-Base-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-14B-Base-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-14B-Base-2512",
},
"Ministral-3-3B-Instruct-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-3B-Instruct-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-3B-Instruct-2512",
},
"Ministral-3-8B-Instruct-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-8B-Instruct-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-8B-Instruct-2512",
},
"Ministral-3-14B-Instruct-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-14B-Instruct-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-14B-Instruct-2512",
},
},
template="ministral3",
multimodal=True,
)
register_model_group(
models={
......@@ -1777,6 +2118,37 @@ register_model_group(
)
register_model_group(
models={
"MobileLLM-R1-140M-Base": {
DownloadSource.DEFAULT: "facebook/MobileLLM-R1-140M-base",
DownloadSource.MODELSCOPE: "facebook/MobileLLM-R1-140M-base",
},
"MobileLLM-R1-360M-Base": {
DownloadSource.DEFAULT: "facebook/MobileLLM-R1-360M-base",
DownloadSource.MODELSCOPE: "facebook/MobileLLM-R1-360M-base",
},
"MobileLLM-R1-950M-Base": {
DownloadSource.DEFAULT: "facebook/MobileLLM-R1-950M-base",
DownloadSource.MODELSCOPE: "facebook/MobileLLM-R1-950M-base",
},
"MobileLLM-R1-140M-Instruct": {
DownloadSource.DEFAULT: "facebook/MobileLLM-R1-140M",
DownloadSource.MODELSCOPE: "facebook/MobileLLM-R1-140M",
},
"MobileLLM-R1-360M-Instruct": {
DownloadSource.DEFAULT: "facebook/MobileLLM-R1-360M",
DownloadSource.MODELSCOPE: "facebook/MobileLLM-R1-360M",
},
"MobileLLM-R1-950M-Instruct": {
DownloadSource.DEFAULT: "facebook/MobileLLM-R1-950M",
DownloadSource.MODELSCOPE: "facebook/MobileLLM-R1-950M",
},
},
template="llama3",
)
register_model_group(
models={
"Moonlight-16B-A3B": {
......@@ -2669,75 +3041,114 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-Base",
},
"Qwen3-0.6B-Instruct": {
"Qwen3-0.6B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B",
},
"Qwen3-1.7B-Instruct": {
"Qwen3-1.7B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B",
},
"Qwen3-4B-Instruct": {
"Qwen3-4B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-4B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B",
},
"Qwen3-8B-Instruct": {
"Qwen3-4B-Thinking-2507": {
DownloadSource.DEFAULT: "Qwen/Qwen3-4B-Thinking-2507",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B-Thinking-2507",
},
"Qwen3-8B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-8B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B",
},
"Qwen3-14B-Instruct": {
"Qwen3-14B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-14B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B",
},
"Qwen3-32B-Instruct": {
"Qwen3-32B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-32B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-32B",
},
"Qwen3-30B-A3B-Instruct": {
"Qwen3-30B-A3B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B",
},
"Qwen3-235B-A22B-Instruct": {
"Qwen3-30B-A3B-Thinking-2507": {
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-Thinking-2507",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-Thinking-2507",
},
"Qwen3-235B-A22B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B",
},
"Qwen3-0.6B-Instruct-GPTQ-Int8": {
"Qwen3-235B-A22B-Thinking-2507": {
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B-Thinking-2507",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B-Thinking-2507",
},
"Qwen3-0.6B-Thinking-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B-GPTQ-Int8",
},
"Qwen3-1.7B-Instruct-GPTQ-Int8": {
"Qwen3-1.7B-Thinking-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B-GPTQ-Int8",
},
"Qwen3-4B-Instruct-AWQ": {
"Qwen3-4B-Thinking-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-4B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B-AWQ",
},
"Qwen3-8B-Instruct-AWQ": {
"Qwen3-8B-Thinking-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-8B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B-AWQ",
},
"Qwen3-14B-Instruct-AWQ": {
"Qwen3-14B-Thinking-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-14B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B-AWQ",
},
"Qwen3-32B-Instruct-AWQ": {
"Qwen3-32B-Thinking-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-32B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-32B-AWQ",
},
"Qwen3-30B-A3B-Instruct-GPTQ-Int4": {
"Qwen3-30B-A3B-Thinking-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-GPTQ-Int4",
},
"Qwen3-235B-A22B-Instruct-GPTQ-Int4": {
"Qwen3-235B-A22B-Thinking-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B-GPTQ-Int4",
},
"Qwen/Qwen3-Next-80B-A3B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-Next-80B-A3B-Thinking",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-Next-80B-A3B-Thinking",
},
},
template="qwen3",
)
register_model_group(
models={
"Qwen3-4B-Instruct-2507": {
DownloadSource.DEFAULT: "Qwen/Qwen3-4B-Instruct-2507",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B-Instruct-2507",
},
"Qwen3-30B-A3B-Instruct-2507": {
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-Instruct-2507",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-Instruct-2507",
},
"Qwen3-235B-A22B-Instruct-2507": {
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B-Instruct-2507",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B-Instruct-2507",
},
"Qwen3-Next-80B-A3B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-Next-80B-A3B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-Next-80B-A3B-Instruct",
},
},
template="qwen3_nothink",
)
register_model_group(
models={
"Qwen2-Audio-7B": {
......@@ -2778,6 +3189,34 @@ register_model_group(
)
register_model_group(
models={
"Qwen3-Omni-30B-A3B-Captioner": {
DownloadSource.DEFAULT: "Qwen/Qwen3-Omni-30B-A3B-Captioner",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-Omni-30B-A3B-Captioner",
},
"Qwen3-Omni-30B-A3B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-Omni-30B-A3B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-Omni-30B-A3B-Instruct",
},
},
template="qwen3_omni_nothink",
multimodal=True,
)
register_model_group(
models={
"Qwen3-Omni-30B-A3B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-Omni-30B-A3B-Thinking",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-Omni-30B-A3B-Thinking",
},
},
template="qwen3_omni",
multimodal=True,
)
register_model_group(
models={
"Qwen2-VL-2B": {
......@@ -2880,22 +3319,108 @@ register_model_group(
)
register_model_group(
models={
"Qwen3-VL-2B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-2B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-2B-Instruct",
},
"Qwen3-VL-4B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-4B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-4B-Instruct",
},
"Qwen3-VL-8B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-8B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-8B-Instruct",
},
"Qwen3-VL-32B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-32B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-32B-Instruct",
},
"Qwen3-VL-30B-A3B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-30B-A3B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-30B-A3B-Instruct",
},
"Qwen3-VL-235B-A22B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-235B-A22B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-235B-A22B-Instruct",
},
},
template="qwen3_vl_nothink",
multimodal=True,
)
register_model_group(
models={
"Qwen3-VL-2B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-2B-Thinking",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-2B-Thinking",
},
"Qwen3-VL-4B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-4B-Thinking",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-4B-Thinking",
},
"Qwen3-VL-8B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-8B-Thinking",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-8B-Thinking",
},
"Qwen3-VL-32B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-32B-Thinking",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-32B-Thinking",
},
"Qwen3-VL-30B-A3B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-30B-A3B-Thinking",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-30B-A3B-Thinking",
},
"Qwen3-VL-235B-A22B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-235B-A22B-Thinking",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-235B-A22B-Thinking",
},
},
template="qwen3_vl",
multimodal=True,
)
register_model_group(
models={
"Seed-Coder-8B-Base": {
DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Base",
DownloadSource.MODELSCOPE: "ByteDance-Seed/Seed-Coder-8B-Base",
},
"Seed-Coder-8B-Instruct": {
DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Instruct",
DownloadSource.MODELSCOPE: "ByteDance-Seed/Seed-Coder-8B-Instruct",
},
"Seed-Coder-8B-Instruct-Reasoning": {
"Seed-Coder-8B-Thinking": {
DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Reasoning-bf16",
DownloadSource.MODELSCOPE: "ByteDance-Seed/Seed-Coder-8B-Reasoning-bf16",
},
},
template="seed_coder",
)
register_model_group(
models={
"Seed-OSS-36B-Base": {
DownloadSource.DEFAULT: "ByteDance-Seed/Seed-OSS-36B-Base",
DownloadSource.MODELSCOPE: "ByteDance-Seed/Seed-OSS-36B-Base",
},
"Seed-OSS-36B-Base-woSyn": {
DownloadSource.DEFAULT: "ByteDance-Seed/Seed-OSS-36B-Base-woSyn",
DownloadSource.MODELSCOPE: "ByteDance-Seed/Seed-OSS-36B-Base-woSyn",
},
"Seed-OSS-36B-Instruct": {
DownloadSource.DEFAULT: "ByteDance-Seed/Seed-OSS-36B-Instruct",
DownloadSource.MODELSCOPE: "ByteDance-Seed/Seed-OSS-36B-Instruct",
},
},
template="seed_oss",
)
register_model_group(
models={
"Skywork-13B-Base": {
......@@ -3057,6 +3582,17 @@ register_model_group(
)
register_model_group(
models={
"VibeThinker-1.5B": {
DownloadSource.DEFAULT: "WeiboAI/VibeThinker-1.5B",
DownloadSource.MODELSCOPE: "WeiboAI/VibeThinker-1.5B",
},
},
template="qwen3",
)
register_model_group(
models={
"Vicuna-v1.5-7B-Chat": {
......
......@@ -15,33 +15,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import platform
import accelerate
import datasets
import peft
import torch
import transformers
import trl
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
from collections import OrderedDict
VERSION = "0.9.4.dev0"
VERSION = "0.9.4"
def print_env() -> None:
info = {
"`llamafactory` version": VERSION,
"Platform": platform.platform(),
"Python version": platform.python_version(),
"PyTorch version": torch.__version__,
"Transformers version": transformers.__version__,
"Datasets version": datasets.__version__,
"Accelerate version": accelerate.__version__,
"PEFT version": peft.__version__,
"TRL version": trl.__version__,
}
import os
import platform
import accelerate
import datasets
import peft
import torch
import transformers
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
info = OrderedDict(
{
"`llamafactory` version": VERSION,
"Platform": platform.platform(),
"Python version": platform.python_version(),
"PyTorch version": torch.__version__,
"Transformers version": transformers.__version__,
"Datasets version": datasets.__version__,
"Accelerate version": accelerate.__version__,
"PEFT version": peft.__version__,
}
)
if is_torch_cuda_available():
info["PyTorch version"] += " (GPU)"
......@@ -54,6 +57,13 @@ def print_env() -> None:
info["NPU type"] = torch.npu.get_device_name()
info["CANN version"] = torch.version.cann
try:
import trl # type: ignore
info["TRL version"] = trl.__version__
except Exception:
pass
try:
import deepspeed # type: ignore
......
......@@ -117,7 +117,7 @@ def _configure_library_root_logger() -> None:
library_root_logger.propagate = False
def get_logger(name: Optional[str] = None) -> "_Logger":
def get_logger(name: str | None = None) -> "_Logger":
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
if name is None:
name = _get_library_name()
......
......@@ -18,7 +18,7 @@
import gc
import os
import socket
from typing import TYPE_CHECKING, Any, Literal, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
import torch
import torch.distributed as dist
......@@ -94,11 +94,11 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None:
r"""Check the version of the required packages."""
check_version("transformers>=4.49.0,<=4.52.4,!=4.52.0")
check_version("datasets>=2.16.0,<=3.6.0")
check_version("accelerate>=1.3.0,<=1.7.0")
check_version("peft>=0.14.0,<=0.15.2")
check_version("trl>=0.8.6,<=0.9.6")
check_version("transformers>=4.51.0,<=4.57.1")
check_version("datasets>=2.16.0,<=4.0.0")
check_version("accelerate>=1.3.0,<=1.11.0")
check_version("peft>=0.14.0,<=0.17.1")
check_version("trl>=0.18.0,<=0.24.0")
def calculate_tps(dataset: list[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
......@@ -211,9 +211,9 @@ def has_tokenized_data(path: "os.PathLike") -> bool:
return os.path.isdir(path) and len(os.listdir(path)) > 0
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
def infer_optim_dtype(model_dtype: Optional["torch.dtype"]) -> "torch.dtype":
r"""Infer the optimal dtype according to the model_dtype and device compatibility."""
if _is_bf16_available and model_dtype == torch.bfloat16:
if _is_bf16_available and (model_dtype == torch.bfloat16 or model_dtype is None):
return torch.bfloat16
elif _is_fp16_available:
return torch.float16
......@@ -313,6 +313,10 @@ def use_ray() -> bool:
return is_env_enabled("USE_RAY")
def use_kt() -> bool:
return is_env_enabled("USE_KT")
def find_available_port() -> int:
r"""Find an available port on the local machine."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
......@@ -328,3 +332,7 @@ def fix_proxy(ipv6_enabled: bool = False) -> None:
if ipv6_enabled:
os.environ.pop("http_proxy", None)
os.environ.pop("HTTP_PROXY", None)
os.environ.pop("https_proxy", None)
os.environ.pop("HTTPS_PROXY", None)
os.environ.pop("all_proxy", None)
os.environ.pop("ALL_PROXY", None)
......@@ -58,6 +58,10 @@ def is_apollo_available():
return _is_package_available("apollo_torch")
def is_jieba_available():
return _is_package_available("jieba")
def is_gradio_available():
return _is_package_available("gradio")
......@@ -66,6 +70,10 @@ def is_matplotlib_available():
return _is_package_available("matplotlib")
def is_mcore_adapter_available():
return _is_package_available("mcore_adapter")
def is_pillow_available():
return _is_package_available("PIL")
......@@ -74,6 +82,10 @@ def is_ray_available():
return _is_package_available("ray")
def is_kt_available():
return _is_package_available("ktransformers")
def is_requests_available():
return _is_package_available("requests")
......@@ -82,6 +94,14 @@ def is_rouge_available():
return _is_package_available("rouge_chinese")
def is_safetensors_available():
return _is_package_available("safetensors")
def is_sglang_available():
return _is_package_available("sglang")
def is_starlette_available():
return _is_package_available("sse_starlette")
......@@ -91,13 +111,14 @@ def is_transformers_version_greater_than(content: str):
return _get_package_version("transformers") >= version.parse(content)
@lru_cache
def is_torch_version_greater_than(content: str):
return _get_package_version("torch") >= version.parse(content)
def is_uvicorn_available():
return _is_package_available("uvicorn")
def is_vllm_available():
return _is_package_available("vllm")
def is_sglang_available():
return _is_package_available("sglang")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment