Commit 7ea81099 authored by chenych's avatar chenych
Browse files

update llama4

parent 84987715
......@@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import TYPE_CHECKING, Optional, Union
from typing_extensions import override
......@@ -46,8 +46,8 @@ class Template:
format_tools: "Formatter"
format_prefix: "Formatter"
default_system: str
stop_words: List[str]
thought_words: Tuple[str, str]
stop_words: list[str]
thought_words: tuple[str, str]
efficient_eos: bool
replace_eos: bool
replace_jinja_template: bool
......@@ -56,13 +56,11 @@ class Template:
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
) -> tuple[list[int], list[int]]:
r"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = []
for encoded_ids in encoded_messages[:-1]:
......@@ -74,36 +72,28 @@ class Template:
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
) -> list[tuple[list[int], list[int]]]:
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts tool message.
"""
def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
r"""Extract tool message."""
return self.format_tools.extract(content)
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> List[int]:
r"""
Returns stop token ids.
"""
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
r"""Return stop token ids."""
stop_token_ids = {tokenizer.eos_token_id}
for token in self.stop_words:
stop_token_ids.add(tokenizer.convert_tokens_to_ids(token))
return list(stop_token_ids)
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
r"""
Converts elements to token ids.
"""
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
r"""Convert elements to token ids."""
token_ids = []
for elem in elements:
if isinstance(elem, str):
......@@ -124,14 +114,14 @@ class Template:
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str],
tools: Optional[str],
) -> List[List[int]]:
r"""
Encodes formatted inputs to pairs of token ids.
) -> list[list[int]]:
r"""Encode formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: query resp
Turn t: query resp.
"""
system = system or self.default_system
encoded_messages = []
......@@ -161,9 +151,7 @@ class Template:
@staticmethod
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
r"""
Adds or replaces eos token to the tokenizer.
"""
r"""Add or replace eos token to the tokenizer."""
is_added = tokenizer.eos_token_id is None
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
......@@ -176,9 +164,7 @@ class Template:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None:
r"""
Adds eos token and pad token to the tokenizer.
"""
r"""Add eos token and pad token to the tokenizer."""
stop_words = self.stop_words
if self.replace_eos:
if not stop_words:
......@@ -204,16 +190,12 @@ class Template:
@staticmethod
def _jinja_escape(content: str) -> str:
r"""
Escape single quotes in content.
"""
r"""Escape single quotes in content."""
return content.replace("'", r"\'")
@staticmethod
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
r"""
Converts slots to jinja template.
"""
r"""Convert slots to jinja template."""
slot_items = []
for slot in slots:
if isinstance(slot, str):
......@@ -235,9 +217,7 @@ class Template:
return " + ".join(slot_items)
def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the jinja template.
"""
r"""Return the jinja template."""
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
......@@ -265,9 +245,7 @@ class Template:
return jinja_template
def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None:
r"""
Replaces the jinja template in the tokenizer.
"""
r"""Replace the jinja template in the tokenizer."""
if tokenizer.chat_template is None or self.replace_jinja_template:
try:
tokenizer.chat_template = self._get_jinja_template(tokenizer)
......@@ -278,9 +256,7 @@ class Template:
def _convert_slots_to_ollama(
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
) -> str:
r"""
Converts slots to ollama template.
"""
r"""Convert slots to ollama template."""
slot_items = []
for slot in slots:
if isinstance(slot, str):
......@@ -302,9 +278,7 @@ class Template:
return "".join(slot_items)
def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the ollama template.
"""
r"""Return the ollama template."""
prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer)
system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System")
user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content")
......@@ -316,8 +290,7 @@ class Template:
)
def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the ollama modelfile.
r"""Return the ollama modelfile.
TODO: support function calling.
"""
......@@ -336,14 +309,16 @@ class Template:
@dataclass
class Llama2Template(Template):
r"""A template that fuse the system message to first user message."""
@override
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: str,
tools: str,
) -> List[List[int]]:
) -> list[list[int]]:
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
......@@ -402,7 +377,7 @@ class Llama2Template(Template):
return jinja_template
TEMPLATES: Dict[str, "Template"] = {}
TEMPLATES: dict[str, "Template"] = {}
def register_template(
......@@ -415,16 +390,15 @@ def register_template(
format_tools: Optional["Formatter"] = None,
format_prefix: Optional["Formatter"] = None,
default_system: str = "",
stop_words: Optional[Sequence[str]] = None,
thought_words: Optional[Tuple[str, str]] = None,
stop_words: Optional[list[str]] = None,
thought_words: Optional[tuple[str, str]] = None,
efficient_eos: bool = False,
replace_eos: bool = False,
replace_jinja_template: bool = False,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: Type["Template"] = Template,
template_class: type["Template"] = Template,
) -> None:
r"""
Registers a chat template.
r"""Register a chat template.
To add the following chat template:
```
......@@ -472,9 +446,7 @@ def register_template(
def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
r"""
Extracts a chat template from the tokenizer.
"""
r"""Extract a chat template from the tokenizer."""
def find_diff(short_str: str, long_str: str) -> str:
i, j = 0, 0
......@@ -532,9 +504,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
r"""
Gets chat template and fixes the tokenizer.
"""
r"""Get chat template and fixes the tokenizer."""
if data_args.template is None:
if isinstance(tokenizer.chat_template, str):
logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.")
......@@ -807,15 +777,17 @@ register_template(
register_template(
name="default",
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
format_user=StringFormatter(slots=["Human: {{content}}", {"eos_token"}, "\nAssistant:"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
format_system=StringFormatter(slots=["System: {{content}}\n"]),
format_system=StringFormatter(slots=["System: {{content}}", {"eos_token"}, "\n"]),
replace_jinja_template=True,
)
register_template(
name="empty",
format_assistant=StringFormatter(slots=["{{content}}"]),
replace_jinja_template=True,
)
......@@ -839,6 +811,7 @@ register_template(
name="fewshot",
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
efficient_eos=True,
replace_jinja_template=True,
)
......@@ -846,10 +819,29 @@ register_template(
name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"],
template_class=Llama2Template,
)
# copied from gemma template
register_template(
name="gemma3",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"],
mm_plugin=get_mm_plugin("gemma3", image_token="<image_soft_token>"),
template_class=Llama2Template,
)
......@@ -887,6 +879,16 @@ register_template(
)
register_template(
name="hunyuan",
format_user=StringFormatter(slots=["<|bos|>user\n{{content}}<|eos|>\n<|bos|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|eos|>\n"]),
format_system=StringFormatter(slots=["<|bos|>system\n{{content}}<|eos|>\n"]),
format_prefix=EmptyFormatter(slots=["<|bos|>"]),
stop_words=["<|eos|>"],
)
register_template(
name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
......@@ -966,6 +968,26 @@ register_template(
)
register_template(
name="llama4",
format_user=StringFormatter(
slots=["<|header_start|>user<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"]
),
format_assistant=StringFormatter(slots=["{{content}}<|eot|>"]),
format_system=StringFormatter(slots=["<|header_start|>system<|header_end|>\n\n{{content}}<|eot|>"]),
format_function=FunctionFormatter(slots=["{{content}}<|eot|>"], tool_format="llama3"),
format_observation=StringFormatter(
slots=[
"<|header_start|>ipython<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"
]
),
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot|>", "<|eom|>"],
mm_plugin=get_mm_plugin(name="llama4", image_token="<|image|>"),
)
# copied from llama3 template
register_template(
name="mllama",
......@@ -1149,7 +1171,8 @@ register_template(
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
default_system=(
"你是一个经过良好训练的AI助手,你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
"你是一个经过良好训练的AI助手,你的名字是Marco-o1."
"由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n"
"<Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。\n"
),
......@@ -1273,6 +1296,7 @@ register_template(
format_user=StringFormatter(slots=["{{content}}\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
template_class=Llama2Template,
)
......@@ -1285,7 +1309,9 @@ register_template(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"],
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
template_class=Llama2Template,
)
......@@ -1361,6 +1387,24 @@ register_template(
)
# copied from qwen template
register_template(
name="qwen2_omni",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
mm_plugin=get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
),
)
# copied from qwen template
register_template(
name="qwen2_vl",
......
......@@ -17,7 +17,7 @@ import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, NamedTuple, Tuple, Union
from typing import Any, NamedTuple, Union
from typing_extensions import override
......@@ -60,31 +60,24 @@ QWEN_TOOL_PROMPT = (
@dataclass
class ToolUtils(ABC):
"""
Base class for tool utilities.
"""
"""Base class for tool utilities."""
@staticmethod
@abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
r"""
Generates the system message describing all the available tools.
"""
def tool_formatter(tools: list[dict[str, Any]]) -> str:
r"""Generate the system message describing all the available tools."""
...
@staticmethod
@abstractmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
r"""
Generates the assistant message including all the tool calls.
"""
def function_formatter(functions: list["FunctionCall"]) -> str:
r"""Generate the assistant message including all the tool calls."""
...
@staticmethod
@abstractmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts all the function calls from the assistant message.
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
r"""Extract all the function calls from the assistant message.
It should be an inverse function of `function_formatter`.
"""
......@@ -92,13 +85,11 @@ class ToolUtils(ABC):
class DefaultToolUtils(ToolUtils):
r"""
Default tool using template.
"""
r"""Default tool using template."""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
for tool in tools:
......@@ -132,7 +123,7 @@ class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
function_text = ""
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"
......@@ -141,9 +132,9 @@ class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content)
action_match: list[tuple[str, str]] = re.findall(regex, content)
if not action_match:
return content
......@@ -161,13 +152,11 @@ class DefaultToolUtils(ToolUtils):
class GLM4ToolUtils(ToolUtils):
r"""
GLM-4 tool using template.
"""
r"""GLM-4 tool using template."""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
......@@ -178,7 +167,7 @@ class GLM4ToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("GLM-4 does not support parallel functions.")
......@@ -186,7 +175,7 @@ class GLM4ToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
if "\n" not in content:
return content
......@@ -200,15 +189,14 @@ class GLM4ToolUtils(ToolUtils):
class Llama3ToolUtils(ToolUtils):
r"""
Llama 3.x tool using template with `tools_in_user_message=False`.
r"""Llama 3.x tool using template with `tools_in_user_message=False`.
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
date = datetime.now().strftime("%d %b %Y")
tool_text = ""
for tool in tools:
......@@ -219,7 +207,7 @@ class Llama3ToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("Llama-3 does not support parallel functions.")
......@@ -227,7 +215,7 @@ class Llama3ToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try:
tool = json.loads(content.strip())
except json.JSONDecodeError:
......@@ -240,13 +228,11 @@ class Llama3ToolUtils(ToolUtils):
class MistralToolUtils(ToolUtils):
r"""
Mistral v0.3 tool using template.
"""
r"""Mistral v0.3 tool using template."""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
wrapped_tools = []
for tool in tools:
wrapped_tools.append({"type": "function", "function": tool})
......@@ -255,7 +241,7 @@ class MistralToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
......@@ -264,7 +250,7 @@ class MistralToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try:
tools = json.loads(content.strip())
except json.JSONDecodeError:
......@@ -284,13 +270,11 @@ class MistralToolUtils(ToolUtils):
class QwenToolUtils(ToolUtils):
r"""
Qwen 2.5 tool using template.
"""
r"""Qwen 2.5 tool using template."""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
wrapped_tool = {"type": "function", "function": tool}
......@@ -300,7 +284,7 @@ class QwenToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(
......@@ -311,9 +295,9 @@ class QwenToolUtils(ToolUtils):
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)", re.DOTALL)
tool_match: List[str] = re.findall(regex, content)
tool_match: list[str] = re.findall(regex, content)
if not tool_match:
return content
......
......@@ -39,7 +39,7 @@
import json
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Optional
import numpy as np
import torch
......@@ -59,7 +59,7 @@ if TYPE_CHECKING:
class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
......@@ -69,7 +69,7 @@ class Evaluator:
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
@torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]:
def batch_inference(self, batch_input: dict[str, "torch.Tensor"]) -> list[str]:
logits = self.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
......@@ -88,7 +88,7 @@ class Evaluator:
)
with open(mapping, encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f)
categorys: dict[str, dict[str, str]] = json.load(f)
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
......@@ -136,7 +136,7 @@ class Evaluator:
pbar.close()
self._save_results(category_corrects, results)
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
def _save_results(self, category_corrects: dict[str, "NDArray"], results: dict[str, dict[int, str]]) -> None:
score_info = "\n".join(
[
f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
......
......@@ -13,7 +13,6 @@
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Sequence, Tuple
from ..data import Role
from ..extras.constants import CHOICES
......@@ -25,20 +24,19 @@ class EvalTemplate:
choice: str
answer: str
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
r"""
def _parse_example(self, example: dict[str, str]) -> tuple[str, str]:
r"""Parse eval example.
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
output: a tuple of (prompt, response)
output: a tuple of (prompt, response).
"""
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example(
self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
) -> List[Dict[str, str]]:
r"""
Converts dataset examples to messages.
"""
self, target_data: dict[str, str], support_set: list[dict[str, str]], subject_name: str
) -> list[dict[str, str]]:
r"""Convert dataset examples to messages."""
messages = []
for k in range(len(support_set)):
prompt, response = self._parse_example(support_set[k])
......@@ -52,7 +50,7 @@ class EvalTemplate:
return messages
eval_templates: Dict[str, "EvalTemplate"] = {}
eval_templates: dict[str, "EvalTemplate"] = {}
def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
......
......@@ -15,7 +15,7 @@
import os
from collections import OrderedDict, defaultdict
from enum import Enum
from typing import Dict, Optional
from typing import Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
......@@ -106,6 +106,7 @@ class AttentionFunction(str, Enum):
class EngineName(str, Enum):
HF = "huggingface"
VLLM = "vllm"
SGLANG = "sglang"
class DownloadSource(str, Enum):
......@@ -122,7 +123,7 @@ class RopeScaling(str, Enum):
def register_model_group(
models: Dict[str, Dict[DownloadSource, str]],
models: dict[str, dict[DownloadSource, str]],
template: Optional[str] = None,
multimodal: bool = False,
) -> None:
......@@ -650,11 +651,51 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-2-27b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it",
},
"Gemma-3-1B": {
DownloadSource.DEFAULT: "google/gemma-3-1b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-pt",
},
"Gemma-3-1B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-1b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-it",
},
},
template="gemma",
)
register_model_group(
models={
"Gemma-3-4B": {
DownloadSource.DEFAULT: "google/gemma-3-4b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-4b-pt",
},
"Gemma-3-12B": {
DownloadSource.DEFAULT: "google/gemma-3-12b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-12b-pt",
},
"Gemma-3-27B": {
DownloadSource.DEFAULT: "google/gemma-3-27b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-pt",
},
"Gemma-3-4B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-4b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-4b-it",
},
"Gemma-3-12B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-12b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-12b-it",
},
"Gemma-3-27B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-27b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-it",
},
},
template="gemma3",
multimodal=True,
)
register_model_group(
models={
"GLM-4-9B": {
......@@ -768,6 +809,17 @@ register_model_group(
)
register_model_group(
models={
"Hunyuan-7B-Instruct": {
DownloadSource.DEFAULT: "tencent/Hunyuan-7B-Instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/Hunyuan-7B-Instruct",
},
},
template="hunyuan",
)
register_model_group(
models={
"Index-1.9B-Base": {
......@@ -1059,6 +1111,30 @@ register_model_group(
)
register_model_group(
models={
"Llama-4-Scout-17B-16E": {
DownloadSource.DEFAULT: "meta-llama/Llama-4-Scout-17B-16E",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Scout-17B-16E",
},
"Llama-4-Scout-17B-16E-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-4-Scout-17B-16E-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Scout-17B-16E-Instruct",
},
"Llama-4-Maverick-17B-128E": {
DownloadSource.DEFAULT: "meta-llama/Llama-4-Maverick-17B-128E",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Maverick-17B-128E",
},
"Llama-4-Maverick-17B-128E-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Maverick-17B-128E-Instruct",
},
},
template="llama4",
multimodal=True,
)
register_model_group(
models={
"LLaVA-1.5-7B-Chat": {
......@@ -2218,6 +2294,18 @@ register_model_group(
)
register_model_group(
models={
"Qwen2.5-Omni-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B",
}
},
template="qwen2_omni",
multimodal=True,
)
register_model_group(
models={
"Qwen2-VL-2B": {
......@@ -2294,6 +2382,10 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-7B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-7B-Instruct",
},
"Qwen2.5-VL-32B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-32B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-32B-Instruct",
},
"Qwen2.5-VL-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-72B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-72B-Instruct",
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
......@@ -26,7 +26,7 @@ import trl
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
VERSION = "0.9.2"
VERSION = "0.9.3.dev0"
def print_env() -> None:
......
# Copyright 2024 Optuna, HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 Optuna, HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py
......@@ -32,9 +32,7 @@ _default_log_level: "logging._Level" = logging.INFO
class LoggerHandler(logging.Handler):
r"""
Redirects the logging output to the logging file for LLaMA Board.
"""
r"""Redirect the logging output to the logging file for LLaMA Board."""
def __init__(self, output_dir: str) -> None:
super().__init__()
......@@ -67,9 +65,7 @@ class LoggerHandler(logging.Handler):
class _Logger(logging.Logger):
r"""
A logger that supports rank0 logging.
"""
r"""A logger that supports rank0 logging."""
def info_rank0(self, *args, **kwargs) -> None:
self.info(*args, **kwargs)
......@@ -82,9 +78,7 @@ class _Logger(logging.Logger):
def _get_default_logging_level() -> "logging._Level":
r"""
Returns the default logging level.
"""
r"""Return the default logging level."""
env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None)
if env_level_str:
if env_level_str.upper() in logging._nameToLevel:
......@@ -104,9 +98,7 @@ def _get_library_root_logger() -> "_Logger":
def _configure_library_root_logger() -> None:
r"""
Configures root logger using a stdout stream handler with an explicit format.
"""
r"""Configure root logger using a stdout stream handler with an explicit format."""
global _default_handler
with _thread_lock:
......@@ -126,9 +118,7 @@ def _configure_library_root_logger() -> None:
def get_logger(name: Optional[str] = None) -> "_Logger":
r"""
Returns a logger with the specified name. It it not supposed to be accessed externally.
"""
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
if name is None:
name = _get_library_name()
......@@ -137,17 +127,13 @@ def get_logger(name: Optional[str] = None) -> "_Logger":
def add_handler(handler: "logging.Handler") -> None:
r"""
Adds a handler to the root logger.
"""
r"""Add a handler to the root logger."""
_configure_library_root_logger()
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
r"""
Removes a handler to the root logger.
"""
r"""Remove a handler to the root logger."""
_configure_library_root_logger()
_get_library_root_logger().removeHandler(handler)
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
......@@ -17,7 +17,8 @@
import gc
import os
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union
import socket
from typing import TYPE_CHECKING, Any, Literal, Union
import torch
import torch.distributed as dist
......@@ -54,9 +55,7 @@ logger = logging.get_logger(__name__)
class AverageMeter:
r"""
Computes and stores the average and current value.
"""
r"""Compute and store the average and current value."""
def __init__(self):
self.reset()
......@@ -75,9 +74,7 @@ class AverageMeter:
def check_version(requirement: str, mandatory: bool = False) -> None:
r"""
Optionally checks the package version.
"""
r"""Optionally check the package version."""
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
......@@ -91,22 +88,18 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None:
r"""
Checks the version of the required packages.
"""
check_version("transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("datasets>=2.16.0,<=3.2.0")
check_version("accelerate>=0.34.0,<=1.2.1")
check_version("peft>=0.11.1,<=0.12.0")
r"""Check the version of the required packages."""
check_version("transformers>=4.41.2,<=4.51.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("datasets>=2.16.0,<=3.4.1")
check_version("accelerate>=0.34.0,<=1.5.2")
check_version("peft>=0.14.0,<=0.15.0")
check_version("trl>=0.8.6,<=0.9.6")
if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"):
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
r"""
Calculates effective tokens per second.
"""
def calculate_tps(dataset: list[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
r"""Calculate effective tokens per second."""
effective_token_num = 0
for data in dataset:
if stage == "sft":
......@@ -118,10 +111,8 @@ def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float],
return result / dist.get_world_size() if dist.is_initialized() else result
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
r"""Return the number of trainable parameters and number of all parameters in the model."""
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
......@@ -148,9 +139,7 @@ def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
def get_current_device() -> "torch.device":
r"""
Gets the current available device.
"""
r"""Get the current available device."""
if is_torch_xpu_available():
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_npu_available():
......@@ -166,9 +155,7 @@ def get_current_device() -> "torch.device":
def get_device_count() -> int:
r"""
Gets the number of available GPU or NPU devices.
"""
r"""Get the number of available GPU or NPU devices."""
if is_torch_xpu_available():
return torch.xpu.device_count()
elif is_torch_npu_available():
......@@ -180,18 +167,14 @@ def get_device_count() -> int:
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.
"""
r"""Get logits processor that removes NaN and Inf logits."""
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
def get_peak_memory() -> Tuple[int, int]:
r"""
Gets the peak memory usage for the current device (in Bytes).
"""
def get_peak_memory() -> tuple[int, int]:
r"""Get the peak memory usage for the current device (in Bytes)."""
if is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_cuda_available():
......@@ -201,16 +184,12 @@ def get_peak_memory() -> Tuple[int, int]:
def has_tokenized_data(path: "os.PathLike") -> bool:
r"""
Checks if the path has a tokenized dataset.
"""
r"""Check if the path has a tokenized dataset."""
return os.path.isdir(path) and len(os.listdir(path)) > 0
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
r"""Infer the optimal dtype according to the model_dtype and device compatibility."""
if _is_bf16_available and model_dtype == torch.bfloat16:
return torch.bfloat16
elif _is_fp16_available:
......@@ -220,23 +199,17 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
def is_gpu_or_npu_available() -> bool:
r"""
Checks if the GPU or NPU is available.
"""
r"""Check if the GPU or NPU is available."""
return is_torch_npu_available() or is_torch_cuda_available()
def is_env_enabled(env_var: str, default: str = "0") -> bool:
r"""
Checks if the environment variable is enabled.
"""
r"""Check if the environment variable is enabled."""
return os.getenv(env_var, default).lower() in ["true", "y", "1"]
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
r"""
Casts a torch tensor or a numpy array to a numpy array.
"""
r"""Cast a torch tensor or a numpy array to a numpy array."""
if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu()
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
......@@ -248,17 +221,13 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
def skip_check_imports() -> None:
r"""
Avoids flash attention import error in custom model files.
"""
r"""Avoid flash attention import error in custom model files."""
if not is_env_enabled("FORCE_CHECK_IMPORTS"):
transformers.dynamic_module_utils.check_imports = get_relative_imports
def torch_gc() -> None:
r"""
Collects GPU or NPU memory.
"""
r"""Collect GPU or NPU memory."""
gc.collect()
if is_torch_xpu_available():
torch.xpu.empty_cache()
......@@ -306,3 +275,20 @@ def use_openmind() -> bool:
def use_ray() -> bool:
return is_env_enabled("USE_RAY")
def find_available_port() -> int:
"""Find an available port on the local machine."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
return port
def fix_proxy(ipv6_enabled: bool) -> None:
"""Fix proxy settings for gradio ui."""
os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
if ipv6_enabled:
for name in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"):
os.environ.pop(name, None)
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
......@@ -97,3 +97,7 @@ def is_uvicorn_available():
def is_vllm_available():
return _is_package_available("vllm")
def is_sglang_available():
return _is_package_available("sglang")
......@@ -15,7 +15,7 @@
import json
import math
import os
from typing import Any, Dict, List
from typing import Any
from transformers.trainer import TRAINER_STATE_NAME
......@@ -31,10 +31,8 @@ if is_matplotlib_available():
logger = logging.get_logger(__name__)
def smooth(scalars: List[float]) -> List[float]:
r"""
EMA implementation according to TensorBoard.
"""
def smooth(scalars: list[float]) -> list[float]:
r"""EMA implementation according to TensorBoard."""
if len(scalars) == 0:
return []
......@@ -48,10 +46,8 @@ def smooth(scalars: List[float]) -> List[float]:
return smoothed
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
r"""
Plots loss curves in LlamaBoard.
"""
def gen_loss_plot(trainer_log: list[dict[str, Any]]) -> "matplotlib.figure.Figure":
r"""Plot loss curves in LlamaBoard."""
plt.close("all")
plt.switch_backend("agg")
fig = plt.figure()
......@@ -70,10 +66,8 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
return fig
def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
r"""
Plots loss curves and saves the image.
"""
def plot_loss(save_dictionary: str, keys: list[str] = ["loss"]) -> None:
r"""Plot loss curves and saves the image."""
plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
data = json.load(f)
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
......@@ -16,14 +16,12 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Literal, Optional
from typing import Any, Literal, Optional
@dataclass
class DataArguments:
r"""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
template: Optional[str] = field(
default=None,
......@@ -162,5 +160,5 @@ class DataArguments:
if self.mask_history and self.train_on_prompt:
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return asdict(self)
......@@ -21,9 +21,7 @@ from datasets import DownloadMode
@dataclass
class EvaluationArguments:
r"""
Arguments pertaining to specify the evaluation parameters.
"""
r"""Arguments pertaining to specify the evaluation parameters."""
task: str = field(
metadata={"help": "Name of the evaluation task."},
......
......@@ -13,14 +13,12 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Literal, Optional
@dataclass
class FreezeArguments:
r"""
Arguments pertaining to the freeze (partial-parameter) training.
"""
r"""Arguments pertaining to the freeze (partial-parameter) training."""
freeze_trainable_layers: int = field(
default=2,
......@@ -56,9 +54,7 @@ class FreezeArguments:
@dataclass
class LoraArguments:
r"""
Arguments pertaining to the LoRA training.
"""
r"""Arguments pertaining to the LoRA training."""
additional_target: Optional[str] = field(
default=None,
......@@ -128,9 +124,7 @@ class LoraArguments:
@dataclass
class RLHFArguments:
r"""
Arguments pertaining to the PPO, DPO and KTO training.
"""
r"""Arguments pertaining to the PPO, DPO and KTO training."""
pref_beta: float = field(
default=0.1,
......@@ -212,9 +206,7 @@ class RLHFArguments:
@dataclass
class GaloreArguments:
r"""
Arguments pertaining to the GaLore algorithm.
"""
r"""Arguments pertaining to the GaLore algorithm."""
use_galore: bool = field(
default=False,
......@@ -253,9 +245,7 @@ class GaloreArguments:
@dataclass
class ApolloArguments:
r"""
Arguments pertaining to the APOLLO algorithm.
"""
r"""Arguments pertaining to the APOLLO algorithm."""
use_apollo: bool = field(
default=False,
......@@ -306,9 +296,7 @@ class ApolloArguments:
@dataclass
class BAdamArgument:
r"""
Arguments pertaining to the BAdam optimizer.
"""
r"""Arguments pertaining to the BAdam optimizer."""
use_badam: bool = field(
default=False,
......@@ -387,15 +375,21 @@ class SwanLabArguments:
default=None,
metadata={"help": "The log directory for SwanLab."},
)
swanlab_lark_webhook_url: Optional[str] = field(
default=None,
metadata={"help": "The Lark(飞书) webhook URL for SwanLab."},
)
swanlab_lark_secret: Optional[str] = field(
default=None,
metadata={"help": "The Lark(飞书) secret for SwanLab."},
)
@dataclass
class FinetuningArguments(
SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments
):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
r"""Arguments pertaining to which techniques we are going to fine-tuning with."""
pure_bf16: bool = field(
default=False,
......@@ -452,13 +446,13 @@ class FinetuningArguments(
return [item.strip() for item in arg.split(",")]
return arg
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
self.lora_target: List[str] = split_arg(self.lora_target)
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
self.galore_target: List[str] = split_arg(self.galore_target)
self.apollo_target: List[str] = split_arg(self.apollo_target)
self.lora_target: list[str] = split_arg(self.lora_target)
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
self.galore_target: list[str] = split_arg(self.galore_target)
self.apollo_target: list[str] = split_arg(self.apollo_target)
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
......@@ -499,7 +493,7 @@ class FinetuningArguments(
if self.pissa_init:
raise ValueError("`pissa_init` is only valid for LoRA training.")
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()}
return args
......@@ -13,16 +13,14 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional
from typing import Any, Optional
from transformers import GenerationConfig
@dataclass
class GeneratingArguments:
r"""
Arguments pertaining to specify the decoding parameters.
"""
r"""Arguments pertaining to specify the decoding parameters."""
do_sample: bool = field(
default=True,
......@@ -35,7 +33,9 @@ class GeneratingArguments:
top_p: float = field(
default=0.7,
metadata={
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
"help": (
"The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
)
},
)
top_k: int = field(
......@@ -71,7 +71,7 @@ class GeneratingArguments:
metadata={"help": "Whether or not to remove special tokens in the decoding."},
)
def to_dict(self, obey_generation_config: bool = False) -> Dict[str, Any]:
def to_dict(self, obey_generation_config: bool = False) -> dict[str, Any]:
args = asdict(self)
if args.get("max_new_tokens", -1) > 0:
args.pop("max_length", None)
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
......@@ -17,7 +17,7 @@
import json
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union
from typing import Any, Literal, Optional, Union
import torch
from transformers.training_args import _convert_str_dict
......@@ -28,9 +28,7 @@ from ..extras.constants import AttentionFunction, EngineName, RopeScaling
@dataclass
class BaseModelArguments:
r"""
Arguments pertaining to the model.
"""
r"""Arguments pertaining to the model."""
model_name_or_path: Optional[str] = field(
default=None,
......@@ -184,9 +182,7 @@ class BaseModelArguments:
@dataclass
class QuantizationArguments:
r"""
Arguments pertaining to the quantization method.
"""
r"""Arguments pertaining to the quantization method."""
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
default="bitsandbytes",
......@@ -212,9 +208,7 @@ class QuantizationArguments:
@dataclass
class ProcessorArguments:
r"""
Arguments pertaining to the image processor.
"""
r"""Arguments pertaining to the image processor."""
image_max_pixels: int = field(
default=768 * 768,
......@@ -224,6 +218,14 @@ class ProcessorArguments:
default=32 * 32,
metadata={"help": "The minimum number of pixels of image inputs."},
)
image_do_pan_and_scan: bool = field(
default=False,
metadata={"help": "Use pan and scan to process image for gemma3."},
)
use_audio_in_video: bool = field(
default=False,
metadata={"help": "Whether or not to use audio in video inputs."},
)
video_max_pixels: int = field(
default=256 * 256,
metadata={"help": "The maximum number of pixels of video inputs."},
......@@ -240,13 +242,22 @@ class ProcessorArguments:
default=128,
metadata={"help": "The maximum number of sampled frames for video inputs."},
)
audio_sampling_rate: int = field(
default=16000,
metadata={"help": "The sampling rate of audio inputs."},
)
def __post_init__(self):
if self.image_max_pixels < self.image_min_pixels:
raise ValueError("`image_max_pixels` cannot be smaller than `image_min_pixels`.")
if self.video_max_pixels < self.video_min_pixels:
raise ValueError("`video_max_pixels` cannot be smaller than `video_min_pixels`.")
@dataclass
class ExportArguments:
r"""
Arguments pertaining to the model export.
"""
r"""Arguments pertaining to the model export."""
export_dir: Optional[str] = field(
default=None,
......@@ -292,16 +303,14 @@ class ExportArguments:
@dataclass
class VllmArguments:
r"""
Arguments pertaining to the vLLM worker.
"""
r"""Arguments pertaining to the vLLM worker."""
vllm_maxlen: int = field(
default=4096,
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
)
vllm_gpu_util: float = field(
default=0.9,
default=0.7,
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
)
vllm_enforce_eager: bool = field(
......@@ -323,9 +332,36 @@ class VllmArguments:
@dataclass
class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments):
r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
class SGLangArguments:
r"""Arguments pertaining to the SGLang worker."""
sglang_maxlen: int = field(
default=4096,
metadata={"help": "Maximum sequence (prompt + response) length of the SGLang engine."},
)
sglang_mem_fraction: float = field(
default=0.7,
metadata={"help": "The memory fraction (0-1) to be used for the SGLang engine."},
)
sglang_tp_size: int = field(
default=-1,
metadata={"help": "Tensor parallel size for the SGLang engine."},
)
sglang_config: Optional[Union[dict, str]] = field(
default=None,
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
)
def __post_init__(self):
if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"):
self.sglang_config = _convert_str_dict(json.loads(self.sglang_config))
@dataclass
class ModelArguments(
SGLangArguments, VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments
):
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
The class on the most right will be displayed first.
"""
......@@ -335,7 +371,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
init=False,
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
)
device_map: Optional[Union[str, Dict[str, Any]]] = field(
device_map: Optional[Union[str, dict[str, Any]]] = field(
default=None,
init=False,
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
......@@ -353,8 +389,10 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
def __post_init__(self):
BaseModelArguments.__post_init__(self)
ProcessorArguments.__post_init__(self)
ExportArguments.__post_init__(self)
VllmArguments.__post_init__(self)
SGLangArguments.__post_init__(self)
@classmethod
def copyfrom(cls, source: "Self", **kwargs) -> "Self":
......@@ -372,7 +410,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
return result
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()}
return args
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
......@@ -19,7 +19,7 @@ import json
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union
import torch
import transformers
......@@ -31,7 +31,7 @@ from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES
from ..extras.constants import CHECKPOINT_NAMES, EngineName
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
......@@ -47,17 +47,15 @@ check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]:
r"""
Gets arguments from the command line or a config file.
"""
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
r"""Get arguments from the command line or a config file."""
if args is not None:
return args
......@@ -70,8 +68,8 @@ def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[
def _parse_args(
parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False
) -> Tuple[Any]:
parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False
) -> tuple[Any]:
args = read_args(args)
if isinstance(args, dict):
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
......@@ -136,9 +134,12 @@ def _check_extra_dependencies(
if model_args.mixture_of_depths is not None:
check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == "vllm":
check_version("vllm>=0.4.3,<=0.7.3")
if model_args.infer_backend == EngineName.VLLM:
check_version("vllm>=0.4.3,<=0.8.2")
check_version("vllm", mandatory=True)
elif model_args.infer_backend == EngineName.SGLANG:
check_version("sglang>=0.4.4")
check_version("sglang", mandatory=True)
if finetuning_args.use_galore:
check_version("galore_torch", mandatory=True)
......@@ -161,31 +162,31 @@ def _check_extra_dependencies(
check_version("rouge_chinese", mandatory=True)
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments:
def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments:
parser = HfArgumentParser(RayArguments)
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
return ray_args
def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging
......@@ -364,9 +365,7 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
and training_args.resume_from_checkpoint is not None
):
logger.warning_rank0(
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
training_args.resume_from_checkpoint
)
f"Add {training_args.resume_from_checkpoint} to `adapter_name_or_path` to resume training from checkpoint."
)
# Post-process model arguments
......@@ -382,20 +381,17 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
# Log on each process the small summary
logger.info(
"Process rank: {}, world size: {}, device: {}, distributed training: {}, compute dtype: {}".format(
training_args.process_index,
training_args.world_size,
training_args.device,
training_args.parallel_mode == ParallelMode.DISTRIBUTED,
str(model_args.compute_dtype),
)
f"Process rank: {training_args.process_index}, "
f"world size: {training_args.world_size}, device: {training_args.device}, "
f"distributed training: {training_args.parallel_mode == ParallelMode.DISTRIBUTED}, "
f"compute dtype: {str(model_args.compute_dtype)}"
)
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging()
......@@ -426,7 +422,7 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging()
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from dataclasses import dataclass, field
from typing import Literal, Optional, Union
......@@ -10,9 +24,7 @@ from ..extras.misc import use_ray
@dataclass
class RayArguments:
r"""
Arguments pertaining to the Ray training.
"""
r"""Arguments pertaining to the Ray training."""
ray_run_name: Optional[str] = field(
default=None,
......@@ -43,9 +55,7 @@ class RayArguments:
@dataclass
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
r"""
Arguments pertaining to the trainer.
"""
r"""Arguments pertaining to the trainer."""
def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self)
......
......@@ -20,9 +20,9 @@ from .model_utils.valuehead import load_valuehead_params
__all__ = [
"QuantizationMethod",
"find_all_linear_modules",
"load_config",
"load_model",
"load_tokenizer",
"find_all_linear_modules",
"load_valuehead_params",
]
......@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING
import torch
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras import logging
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
......@@ -81,9 +80,8 @@ def _setup_freeze_tuning(
if finetuning_args.use_llama_pro:
if num_layers % finetuning_args.freeze_trainable_layers != 0:
raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
num_layers, finetuning_args.freeze_trainable_layers
)
f"`num_layers` {num_layers} should be "
f"divisible by `num_layer_trainable` {finetuning_args.freeze_trainable_layers}."
)
stride = num_layers // finetuning_args.freeze_trainable_layers
......@@ -178,7 +176,7 @@ def _setup_lora_tuning(
}
for adapter in adapter_to_merge:
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload()
if len(adapter_to_merge) > 0:
......@@ -263,8 +261,7 @@ def init_adapter(
finetuning_args: "FinetuningArguments",
is_trainable: bool,
) -> "PreTrainedModel":
r"""
Initializes the adapters.
r"""Initialize the adapters.
Support full-parameter, freeze and LoRA training.
......@@ -279,14 +276,14 @@ def init_adapter(
# cast trainable parameters to float32 if:
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
# 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
# 2. is_trainable and not pure_bf16 and not badam and not zero3 (zero3 already in fp32)
cast_trainable_params_to_fp32 = False
if not is_trainable:
pass
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
elif model_args.quantization_bit is None and is_deepspeed_zero3_enabled():
logger.info_rank0("DeepSpeed ZeRO3 detected, remaining trainable params in float32.")
else:
logger.info_rank0("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True
......
......@@ -13,13 +13,15 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
from typing import TYPE_CHECKING, Any, Optional, TypedDict
import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForSeq2SeqLM,
AutoModelForTextToWaveform,
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
......@@ -51,9 +53,8 @@ class TokenizerModule(TypedDict):
processor: Optional["ProcessorMixin"]
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
r"""
Gets arguments to load config/tokenizer/model.
def _get_init_kwargs(model_args: "ModelArguments") -> dict[str, Any]:
r"""Get arguments to load config/tokenizer/model.
Note: including inplace operation of model_args.
"""
......@@ -68,13 +69,11 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r"""
Loads pretrained tokenizer and optionally loads processor.
r"""Load pretrained tokenizer and optionally loads processor.
Note: including inplace operation of model_args.
"""
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
......@@ -96,7 +95,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
patch_tokenizer(tokenizer, model_args)
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, config, tokenizer, model_args)
patch_processor(processor, tokenizer, model_args)
except Exception as e:
logger.debug(f"Processor was not found: {e}.")
processor = None
......@@ -110,9 +109,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
r"""
Loads model config.
"""
r"""Load model config."""
init_kwargs = _get_init_kwargs(model_args)
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
......@@ -124,9 +121,7 @@ def load_model(
is_trainable: bool = False,
add_valuehead: bool = False,
) -> "PreTrainedModel":
r"""
Loads pretrained model.
"""
r"""Load pretrained model."""
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
......@@ -147,10 +142,14 @@ def load_model(
if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs)
else:
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
load_class = AutoModelForVision2Seq
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys():
elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
load_class = AutoModelForImageTextToText
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
load_class = AutoModelForSeq2SeqLM
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen2_5_omni
load_class = AutoModelForTextToWaveform
else:
load_class = AutoModelForCausalLM
......@@ -158,6 +157,8 @@ def load_model(
model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
else:
model = load_class.from_pretrained(**init_kwargs)
if getattr(model.config, "model_type", None) == "qwen2_5_omni":
model = model.thinker # use part of Omni model
if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args)
......@@ -194,8 +195,9 @@ def load_model(
trainable_params, all_param = count_parameters(model)
if is_trainable:
param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
param_stats = (
f"trainable params: {trainable_params:,} || "
f"all params: {all_param:,} || trainable%: {100 * trainable_params / all_param:.4f}"
)
else:
param_stats = f"all params: {all_param:,}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment