Commit 7ea81099 authored by chenych's avatar chenych
Browse files

update llama4

parent 84987715
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, Union from typing import TYPE_CHECKING, Optional, Union
from typing_extensions import override from typing_extensions import override
...@@ -46,8 +46,8 @@ class Template: ...@@ -46,8 +46,8 @@ class Template:
format_tools: "Formatter" format_tools: "Formatter"
format_prefix: "Formatter" format_prefix: "Formatter"
default_system: str default_system: str
stop_words: List[str] stop_words: list[str]
thought_words: Tuple[str, str] thought_words: tuple[str, str]
efficient_eos: bool efficient_eos: bool
replace_eos: bool replace_eos: bool
replace_jinja_template: bool replace_jinja_template: bool
...@@ -56,13 +56,11 @@ class Template: ...@@ -56,13 +56,11 @@ class Template:
def encode_oneturn( def encode_oneturn(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
) -> Tuple[List[int], List[int]]: ) -> tuple[list[int], list[int]]:
r""" r"""Return a single pair of token ids representing prompt and response respectively."""
Returns a single pair of token ids representing prompt and response respectively.
"""
encoded_messages = self._encode(tokenizer, messages, system, tools) encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = [] prompt_ids = []
for encoded_ids in encoded_messages[:-1]: for encoded_ids in encoded_messages[:-1]:
...@@ -74,36 +72,28 @@ class Template: ...@@ -74,36 +72,28 @@ class Template:
def encode_multiturn( def encode_multiturn(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
) -> List[Tuple[List[int], List[int]]]: ) -> list[tuple[list[int], list[int]]]:
r""" r"""Return multiple pairs of token ids representing prompts and responses respectively."""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
encoded_messages = self._encode(tokenizer, messages, system, tools) encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, List["FunctionCall"]]: def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
r""" r"""Extract tool message."""
Extracts tool message.
"""
return self.format_tools.extract(content) return self.format_tools.extract(content)
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> List[int]: def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
r""" r"""Return stop token ids."""
Returns stop token ids.
"""
stop_token_ids = {tokenizer.eos_token_id} stop_token_ids = {tokenizer.eos_token_id}
for token in self.stop_words: for token in self.stop_words:
stop_token_ids.add(tokenizer.convert_tokens_to_ids(token)) stop_token_ids.add(tokenizer.convert_tokens_to_ids(token))
return list(stop_token_ids) return list(stop_token_ids)
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]: def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
r""" r"""Convert elements to token ids."""
Converts elements to token ids.
"""
token_ids = [] token_ids = []
for elem in elements: for elem in elements:
if isinstance(elem, str): if isinstance(elem, str):
...@@ -124,14 +114,14 @@ class Template: ...@@ -124,14 +114,14 @@ class Template:
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
) -> List[List[int]]: ) -> list[list[int]]:
r""" r"""Encode formatted inputs to pairs of token ids.
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp Turn 0: prefix + system + query resp
Turn t: query resp Turn t: query resp.
""" """
system = system or self.default_system system = system or self.default_system
encoded_messages = [] encoded_messages = []
...@@ -161,9 +151,7 @@ class Template: ...@@ -161,9 +151,7 @@ class Template:
@staticmethod @staticmethod
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None: def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
r""" r"""Add or replace eos token to the tokenizer."""
Adds or replaces eos token to the tokenizer.
"""
is_added = tokenizer.eos_token_id is None is_added = tokenizer.eos_token_id is None
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
...@@ -176,9 +164,7 @@ class Template: ...@@ -176,9 +164,7 @@ class Template:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.") logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None: def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None:
r""" r"""Add eos token and pad token to the tokenizer."""
Adds eos token and pad token to the tokenizer.
"""
stop_words = self.stop_words stop_words = self.stop_words
if self.replace_eos: if self.replace_eos:
if not stop_words: if not stop_words:
...@@ -204,16 +190,12 @@ class Template: ...@@ -204,16 +190,12 @@ class Template:
@staticmethod @staticmethod
def _jinja_escape(content: str) -> str: def _jinja_escape(content: str) -> str:
r""" r"""Escape single quotes in content."""
Escape single quotes in content.
"""
return content.replace("'", r"\'") return content.replace("'", r"\'")
@staticmethod @staticmethod
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str: def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
r""" r"""Convert slots to jinja template."""
Converts slots to jinja template.
"""
slot_items = [] slot_items = []
for slot in slots: for slot in slots:
if isinstance(slot, str): if isinstance(slot, str):
...@@ -235,9 +217,7 @@ class Template: ...@@ -235,9 +217,7 @@ class Template:
return " + ".join(slot_items) return " + ".join(slot_items)
def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str: def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
r""" r"""Return the jinja template."""
Returns the jinja template.
"""
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer) prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message") system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer) user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
...@@ -265,9 +245,7 @@ class Template: ...@@ -265,9 +245,7 @@ class Template:
return jinja_template return jinja_template
def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None: def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None:
r""" r"""Replace the jinja template in the tokenizer."""
Replaces the jinja template in the tokenizer.
"""
if tokenizer.chat_template is None or self.replace_jinja_template: if tokenizer.chat_template is None or self.replace_jinja_template:
try: try:
tokenizer.chat_template = self._get_jinja_template(tokenizer) tokenizer.chat_template = self._get_jinja_template(tokenizer)
...@@ -278,9 +256,7 @@ class Template: ...@@ -278,9 +256,7 @@ class Template:
def _convert_slots_to_ollama( def _convert_slots_to_ollama(
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content" slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
) -> str: ) -> str:
r""" r"""Convert slots to ollama template."""
Converts slots to ollama template.
"""
slot_items = [] slot_items = []
for slot in slots: for slot in slots:
if isinstance(slot, str): if isinstance(slot, str):
...@@ -302,9 +278,7 @@ class Template: ...@@ -302,9 +278,7 @@ class Template:
return "".join(slot_items) return "".join(slot_items)
def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str: def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str:
r""" r"""Return the ollama template."""
Returns the ollama template.
"""
prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer) prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer)
system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System") system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System")
user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content") user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content")
...@@ -316,8 +290,7 @@ class Template: ...@@ -316,8 +290,7 @@ class Template:
) )
def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str: def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str:
r""" r"""Return the ollama modelfile.
Returns the ollama modelfile.
TODO: support function calling. TODO: support function calling.
""" """
...@@ -336,14 +309,16 @@ class Template: ...@@ -336,14 +309,16 @@ class Template:
@dataclass @dataclass
class Llama2Template(Template): class Llama2Template(Template):
r"""A template that fuse the system message to first user message."""
@override @override
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: str, system: str,
tools: str, tools: str,
) -> List[List[int]]: ) -> list[list[int]]:
system = system or self.default_system system = system or self.default_system
encoded_messages = [] encoded_messages = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
...@@ -402,7 +377,7 @@ class Llama2Template(Template): ...@@ -402,7 +377,7 @@ class Llama2Template(Template):
return jinja_template return jinja_template
TEMPLATES: Dict[str, "Template"] = {} TEMPLATES: dict[str, "Template"] = {}
def register_template( def register_template(
...@@ -415,16 +390,15 @@ def register_template( ...@@ -415,16 +390,15 @@ def register_template(
format_tools: Optional["Formatter"] = None, format_tools: Optional["Formatter"] = None,
format_prefix: Optional["Formatter"] = None, format_prefix: Optional["Formatter"] = None,
default_system: str = "", default_system: str = "",
stop_words: Optional[Sequence[str]] = None, stop_words: Optional[list[str]] = None,
thought_words: Optional[Tuple[str, str]] = None, thought_words: Optional[tuple[str, str]] = None,
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
replace_jinja_template: bool = False, replace_jinja_template: bool = False,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: Type["Template"] = Template, template_class: type["Template"] = Template,
) -> None: ) -> None:
r""" r"""Register a chat template.
Registers a chat template.
To add the following chat template: To add the following chat template:
``` ```
...@@ -472,9 +446,7 @@ def register_template( ...@@ -472,9 +446,7 @@ def register_template(
def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
r""" r"""Extract a chat template from the tokenizer."""
Extracts a chat template from the tokenizer.
"""
def find_diff(short_str: str, long_str: str) -> str: def find_diff(short_str: str, long_str: str) -> str:
i, j = 0, 0 i, j = 0, 0
...@@ -532,9 +504,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": ...@@ -532,9 +504,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template": def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
r""" r"""Get chat template and fixes the tokenizer."""
Gets chat template and fixes the tokenizer.
"""
if data_args.template is None: if data_args.template is None:
if isinstance(tokenizer.chat_template, str): if isinstance(tokenizer.chat_template, str):
logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.") logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.")
...@@ -807,15 +777,17 @@ register_template( ...@@ -807,15 +777,17 @@ register_template(
register_template( register_template(
name="default", name="default",
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]), format_user=StringFormatter(slots=["Human: {{content}}", {"eos_token"}, "\nAssistant:"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]), format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
format_system=StringFormatter(slots=["System: {{content}}\n"]), format_system=StringFormatter(slots=["System: {{content}}", {"eos_token"}, "\n"]),
replace_jinja_template=True,
) )
register_template( register_template(
name="empty", name="empty",
format_assistant=StringFormatter(slots=["{{content}}"]), format_assistant=StringFormatter(slots=["{{content}}"]),
replace_jinja_template=True,
) )
...@@ -839,6 +811,7 @@ register_template( ...@@ -839,6 +811,7 @@ register_template(
name="fewshot", name="fewshot",
format_assistant=StringFormatter(slots=["{{content}}\n\n"]), format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
efficient_eos=True, efficient_eos=True,
replace_jinja_template=True,
) )
...@@ -846,10 +819,29 @@ register_template( ...@@ -846,10 +819,29 @@ register_template(
name="gemma", name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]), format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]), format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"],
template_class=Llama2Template,
)
# copied from gemma template
register_template(
name="gemma3",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"] slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
), ),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"],
mm_plugin=get_mm_plugin("gemma3", image_token="<image_soft_token>"),
template_class=Llama2Template,
) )
...@@ -887,6 +879,16 @@ register_template( ...@@ -887,6 +879,16 @@ register_template(
) )
register_template(
name="hunyuan",
format_user=StringFormatter(slots=["<|bos|>user\n{{content}}<|eos|>\n<|bos|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|eos|>\n"]),
format_system=StringFormatter(slots=["<|bos|>system\n{{content}}<|eos|>\n"]),
format_prefix=EmptyFormatter(slots=["<|bos|>"]),
stop_words=["<|eos|>"],
)
register_template( register_template(
name="intern", name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]), format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
...@@ -966,6 +968,26 @@ register_template( ...@@ -966,6 +968,26 @@ register_template(
) )
register_template(
name="llama4",
format_user=StringFormatter(
slots=["<|header_start|>user<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"]
),
format_assistant=StringFormatter(slots=["{{content}}<|eot|>"]),
format_system=StringFormatter(slots=["<|header_start|>system<|header_end|>\n\n{{content}}<|eot|>"]),
format_function=FunctionFormatter(slots=["{{content}}<|eot|>"], tool_format="llama3"),
format_observation=StringFormatter(
slots=[
"<|header_start|>ipython<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"
]
),
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot|>", "<|eom|>"],
mm_plugin=get_mm_plugin(name="llama4", image_token="<|image|>"),
)
# copied from llama3 template # copied from llama3 template
register_template( register_template(
name="mllama", name="mllama",
...@@ -1149,7 +1171,8 @@ register_template( ...@@ -1149,7 +1171,8 @@ register_template(
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
default_system=( default_system=(
"你是一个经过良好训练的AI助手,你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n" "你是一个经过良好训练的AI助手,你的名字是Marco-o1."
"由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n" "当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n"
"<Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。\n" "<Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。\n"
), ),
...@@ -1273,6 +1296,7 @@ register_template( ...@@ -1273,6 +1296,7 @@ register_template(
format_user=StringFormatter(slots=["{{content}}\n"]), format_user=StringFormatter(slots=["{{content}}\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"), mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
template_class=Llama2Template,
) )
...@@ -1285,7 +1309,9 @@ register_template( ...@@ -1285,7 +1309,9 @@ register_template(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"] slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
), ),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"],
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"), mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
template_class=Llama2Template,
) )
...@@ -1361,6 +1387,24 @@ register_template( ...@@ -1361,6 +1387,24 @@ register_template(
) )
# copied from qwen template
register_template(
name="qwen2_omni",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
mm_plugin=get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
),
)
# copied from qwen template # copied from qwen template
register_template( register_template(
name="qwen2_vl", name="qwen2_vl",
......
...@@ -17,7 +17,7 @@ import re ...@@ -17,7 +17,7 @@ import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, NamedTuple, Tuple, Union from typing import Any, NamedTuple, Union
from typing_extensions import override from typing_extensions import override
...@@ -60,31 +60,24 @@ QWEN_TOOL_PROMPT = ( ...@@ -60,31 +60,24 @@ QWEN_TOOL_PROMPT = (
@dataclass @dataclass
class ToolUtils(ABC): class ToolUtils(ABC):
""" """Base class for tool utilities."""
Base class for tool utilities.
"""
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
r""" r"""Generate the system message describing all the available tools."""
Generates the system message describing all the available tools.
"""
... ...
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def function_formatter(functions: List["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
r""" r"""Generate the assistant message including all the tool calls."""
Generates the assistant message including all the tool calls.
"""
... ...
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
r""" r"""Extract all the function calls from the assistant message.
Extracts all the function calls from the assistant message.
It should be an inverse function of `function_formatter`. It should be an inverse function of `function_formatter`.
""" """
...@@ -92,13 +85,11 @@ class ToolUtils(ABC): ...@@ -92,13 +85,11 @@ class ToolUtils(ABC):
class DefaultToolUtils(ToolUtils): class DefaultToolUtils(ToolUtils):
r""" r"""Default tool using template."""
Default tool using template.
"""
@override @override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
tool_names = [] tool_names = []
for tool in tools: for tool in tools:
...@@ -132,7 +123,7 @@ class DefaultToolUtils(ToolUtils): ...@@ -132,7 +123,7 @@ class DefaultToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
function_text = "" function_text = ""
for name, arguments in functions: for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n" function_text += f"Action: {name}\nAction Input: {arguments}\n"
...@@ -141,9 +132,9 @@ class DefaultToolUtils(ToolUtils): ...@@ -141,9 +132,9 @@ class DefaultToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL) regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content) action_match: list[tuple[str, str]] = re.findall(regex, content)
if not action_match: if not action_match:
return content return content
...@@ -161,13 +152,11 @@ class DefaultToolUtils(ToolUtils): ...@@ -161,13 +152,11 @@ class DefaultToolUtils(ToolUtils):
class GLM4ToolUtils(ToolUtils): class GLM4ToolUtils(ToolUtils):
r""" r"""GLM-4 tool using template."""
GLM-4 tool using template.
"""
@override @override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
...@@ -178,7 +167,7 @@ class GLM4ToolUtils(ToolUtils): ...@@ -178,7 +167,7 @@ class GLM4ToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1: if len(functions) > 1:
raise ValueError("GLM-4 does not support parallel functions.") raise ValueError("GLM-4 does not support parallel functions.")
...@@ -186,7 +175,7 @@ class GLM4ToolUtils(ToolUtils): ...@@ -186,7 +175,7 @@ class GLM4ToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
if "\n" not in content: if "\n" not in content:
return content return content
...@@ -200,15 +189,14 @@ class GLM4ToolUtils(ToolUtils): ...@@ -200,15 +189,14 @@ class GLM4ToolUtils(ToolUtils):
class Llama3ToolUtils(ToolUtils): class Llama3ToolUtils(ToolUtils):
r""" r"""Llama 3.x tool using template with `tools_in_user_message=False`.
Llama 3.x tool using template with `tools_in_user_message=False`.
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
""" """
@override @override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
date = datetime.now().strftime("%d %b %Y") date = datetime.now().strftime("%d %b %Y")
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
...@@ -219,7 +207,7 @@ class Llama3ToolUtils(ToolUtils): ...@@ -219,7 +207,7 @@ class Llama3ToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1: if len(functions) > 1:
raise ValueError("Llama-3 does not support parallel functions.") raise ValueError("Llama-3 does not support parallel functions.")
...@@ -227,7 +215,7 @@ class Llama3ToolUtils(ToolUtils): ...@@ -227,7 +215,7 @@ class Llama3ToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try: try:
tool = json.loads(content.strip()) tool = json.loads(content.strip())
except json.JSONDecodeError: except json.JSONDecodeError:
...@@ -240,13 +228,11 @@ class Llama3ToolUtils(ToolUtils): ...@@ -240,13 +228,11 @@ class Llama3ToolUtils(ToolUtils):
class MistralToolUtils(ToolUtils): class MistralToolUtils(ToolUtils):
r""" r"""Mistral v0.3 tool using template."""
Mistral v0.3 tool using template.
"""
@override @override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
wrapped_tools = [] wrapped_tools = []
for tool in tools: for tool in tools:
wrapped_tools.append({"type": "function", "function": tool}) wrapped_tools.append({"type": "function", "function": tool})
...@@ -255,7 +241,7 @@ class MistralToolUtils(ToolUtils): ...@@ -255,7 +241,7 @@ class MistralToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = [] function_texts = []
for name, arguments in functions: for name, arguments in functions:
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}') function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
...@@ -264,7 +250,7 @@ class MistralToolUtils(ToolUtils): ...@@ -264,7 +250,7 @@ class MistralToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try: try:
tools = json.loads(content.strip()) tools = json.loads(content.strip())
except json.JSONDecodeError: except json.JSONDecodeError:
...@@ -284,13 +270,11 @@ class MistralToolUtils(ToolUtils): ...@@ -284,13 +270,11 @@ class MistralToolUtils(ToolUtils):
class QwenToolUtils(ToolUtils): class QwenToolUtils(ToolUtils):
r""" r"""Qwen 2.5 tool using template."""
Qwen 2.5 tool using template.
"""
@override @override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
wrapped_tool = {"type": "function", "function": tool} wrapped_tool = {"type": "function", "function": tool}
...@@ -300,7 +284,7 @@ class QwenToolUtils(ToolUtils): ...@@ -300,7 +284,7 @@ class QwenToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: List["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = [] function_texts = []
for name, arguments in functions: for name, arguments in functions:
function_texts.append( function_texts.append(
...@@ -311,9 +295,9 @@ class QwenToolUtils(ToolUtils): ...@@ -311,9 +295,9 @@ class QwenToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)", re.DOTALL) regex = re.compile(r"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)", re.DOTALL)
tool_match: List[str] = re.findall(regex, content) tool_match: list[str] = re.findall(regex, content)
if not tool_match: if not tool_match:
return content return content
......
...@@ -39,7 +39,7 @@ ...@@ -39,7 +39,7 @@
import json import json
import os import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -59,7 +59,7 @@ if TYPE_CHECKING: ...@@ -59,7 +59,7 @@ if TYPE_CHECKING:
class Evaluator: class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args) self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"] self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
...@@ -69,7 +69,7 @@ class Evaluator: ...@@ -69,7 +69,7 @@ class Evaluator:
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES] self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
@torch.inference_mode() @torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]: def batch_inference(self, batch_input: dict[str, "torch.Tensor"]) -> list[str]:
logits = self.model(**batch_input).logits logits = self.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1) lengths = torch.sum(batch_input["attention_mask"], dim=-1)
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0) word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
...@@ -88,7 +88,7 @@ class Evaluator: ...@@ -88,7 +88,7 @@ class Evaluator:
) )
with open(mapping, encoding="utf-8") as f: with open(mapping, encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f) categorys: dict[str, dict[str, str]] = json.load(f)
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS} category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0) pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
...@@ -136,7 +136,7 @@ class Evaluator: ...@@ -136,7 +136,7 @@ class Evaluator:
pbar.close() pbar.close()
self._save_results(category_corrects, results) self._save_results(category_corrects, results)
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None: def _save_results(self, category_corrects: dict[str, "NDArray"], results: dict[str, dict[int, str]]) -> None:
score_info = "\n".join( score_info = "\n".join(
[ [
f"{category_name:>15}: {100 * np.mean(category_correct):.2f}" f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Sequence, Tuple
from ..data import Role from ..data import Role
from ..extras.constants import CHOICES from ..extras.constants import CHOICES
...@@ -25,20 +24,19 @@ class EvalTemplate: ...@@ -25,20 +24,19 @@ class EvalTemplate:
choice: str choice: str
answer: str answer: str
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]: def _parse_example(self, example: dict[str, str]) -> tuple[str, str]:
r""" r"""Parse eval example.
input: a dict with keys {"question", "A", "B", "C", "D", "answer"} input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
output: a tuple of (prompt, response) output: a tuple of (prompt, response).
""" """
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example] candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"] return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example( def format_example(
self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str self, target_data: dict[str, str], support_set: list[dict[str, str]], subject_name: str
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
r""" r"""Convert dataset examples to messages."""
Converts dataset examples to messages.
"""
messages = [] messages = []
for k in range(len(support_set)): for k in range(len(support_set)):
prompt, response = self._parse_example(support_set[k]) prompt, response = self._parse_example(support_set[k])
...@@ -52,7 +50,7 @@ class EvalTemplate: ...@@ -52,7 +50,7 @@ class EvalTemplate:
return messages return messages
eval_templates: Dict[str, "EvalTemplate"] = {} eval_templates: dict[str, "EvalTemplate"] = {}
def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None: def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import os import os
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from enum import Enum from enum import Enum
from typing import Dict, Optional from typing import Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
...@@ -106,6 +106,7 @@ class AttentionFunction(str, Enum): ...@@ -106,6 +106,7 @@ class AttentionFunction(str, Enum):
class EngineName(str, Enum): class EngineName(str, Enum):
HF = "huggingface" HF = "huggingface"
VLLM = "vllm" VLLM = "vllm"
SGLANG = "sglang"
class DownloadSource(str, Enum): class DownloadSource(str, Enum):
...@@ -122,7 +123,7 @@ class RopeScaling(str, Enum): ...@@ -122,7 +123,7 @@ class RopeScaling(str, Enum):
def register_model_group( def register_model_group(
models: Dict[str, Dict[DownloadSource, str]], models: dict[str, dict[DownloadSource, str]],
template: Optional[str] = None, template: Optional[str] = None,
multimodal: bool = False, multimodal: bool = False,
) -> None: ) -> None:
...@@ -650,11 +651,51 @@ register_model_group( ...@@ -650,11 +651,51 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-2-27b-it", DownloadSource.DEFAULT: "google/gemma-2-27b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it",
}, },
"Gemma-3-1B": {
DownloadSource.DEFAULT: "google/gemma-3-1b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-pt",
},
"Gemma-3-1B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-1b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-it",
},
}, },
template="gemma", template="gemma",
) )
register_model_group(
models={
"Gemma-3-4B": {
DownloadSource.DEFAULT: "google/gemma-3-4b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-4b-pt",
},
"Gemma-3-12B": {
DownloadSource.DEFAULT: "google/gemma-3-12b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-12b-pt",
},
"Gemma-3-27B": {
DownloadSource.DEFAULT: "google/gemma-3-27b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-pt",
},
"Gemma-3-4B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-4b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-4b-it",
},
"Gemma-3-12B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-12b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-12b-it",
},
"Gemma-3-27B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-27b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-it",
},
},
template="gemma3",
multimodal=True,
)
register_model_group( register_model_group(
models={ models={
"GLM-4-9B": { "GLM-4-9B": {
...@@ -768,6 +809,17 @@ register_model_group( ...@@ -768,6 +809,17 @@ register_model_group(
) )
register_model_group(
models={
"Hunyuan-7B-Instruct": {
DownloadSource.DEFAULT: "tencent/Hunyuan-7B-Instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/Hunyuan-7B-Instruct",
},
},
template="hunyuan",
)
register_model_group( register_model_group(
models={ models={
"Index-1.9B-Base": { "Index-1.9B-Base": {
...@@ -1059,6 +1111,30 @@ register_model_group( ...@@ -1059,6 +1111,30 @@ register_model_group(
) )
register_model_group(
models={
"Llama-4-Scout-17B-16E": {
DownloadSource.DEFAULT: "meta-llama/Llama-4-Scout-17B-16E",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Scout-17B-16E",
},
"Llama-4-Scout-17B-16E-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-4-Scout-17B-16E-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Scout-17B-16E-Instruct",
},
"Llama-4-Maverick-17B-128E": {
DownloadSource.DEFAULT: "meta-llama/Llama-4-Maverick-17B-128E",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Maverick-17B-128E",
},
"Llama-4-Maverick-17B-128E-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Maverick-17B-128E-Instruct",
},
},
template="llama4",
multimodal=True,
)
register_model_group( register_model_group(
models={ models={
"LLaVA-1.5-7B-Chat": { "LLaVA-1.5-7B-Chat": {
...@@ -2218,6 +2294,18 @@ register_model_group( ...@@ -2218,6 +2294,18 @@ register_model_group(
) )
register_model_group(
models={
"Qwen2.5-Omni-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B",
}
},
template="qwen2_omni",
multimodal=True,
)
register_model_group( register_model_group(
models={ models={
"Qwen2-VL-2B": { "Qwen2-VL-2B": {
...@@ -2294,6 +2382,10 @@ register_model_group( ...@@ -2294,6 +2382,10 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-7B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-7B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-7B-Instruct", DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-7B-Instruct",
}, },
"Qwen2.5-VL-32B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-32B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-32B-Instruct",
},
"Qwen2.5-VL-72B-Instruct": { "Qwen2.5-VL-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-72B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-72B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-72B-Instruct", DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-72B-Instruct",
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
...@@ -26,7 +26,7 @@ import trl ...@@ -26,7 +26,7 @@ import trl
from transformers.utils import is_torch_cuda_available, is_torch_npu_available from transformers.utils import is_torch_cuda_available, is_torch_npu_available
VERSION = "0.9.2" VERSION = "0.9.3.dev0"
def print_env() -> None: def print_env() -> None:
......
# Copyright 2024 Optuna, HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 Optuna, HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py
...@@ -32,9 +32,7 @@ _default_log_level: "logging._Level" = logging.INFO ...@@ -32,9 +32,7 @@ _default_log_level: "logging._Level" = logging.INFO
class LoggerHandler(logging.Handler): class LoggerHandler(logging.Handler):
r""" r"""Redirect the logging output to the logging file for LLaMA Board."""
Redirects the logging output to the logging file for LLaMA Board.
"""
def __init__(self, output_dir: str) -> None: def __init__(self, output_dir: str) -> None:
super().__init__() super().__init__()
...@@ -67,9 +65,7 @@ class LoggerHandler(logging.Handler): ...@@ -67,9 +65,7 @@ class LoggerHandler(logging.Handler):
class _Logger(logging.Logger): class _Logger(logging.Logger):
r""" r"""A logger that supports rank0 logging."""
A logger that supports rank0 logging.
"""
def info_rank0(self, *args, **kwargs) -> None: def info_rank0(self, *args, **kwargs) -> None:
self.info(*args, **kwargs) self.info(*args, **kwargs)
...@@ -82,9 +78,7 @@ class _Logger(logging.Logger): ...@@ -82,9 +78,7 @@ class _Logger(logging.Logger):
def _get_default_logging_level() -> "logging._Level": def _get_default_logging_level() -> "logging._Level":
r""" r"""Return the default logging level."""
Returns the default logging level.
"""
env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None) env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None)
if env_level_str: if env_level_str:
if env_level_str.upper() in logging._nameToLevel: if env_level_str.upper() in logging._nameToLevel:
...@@ -104,9 +98,7 @@ def _get_library_root_logger() -> "_Logger": ...@@ -104,9 +98,7 @@ def _get_library_root_logger() -> "_Logger":
def _configure_library_root_logger() -> None: def _configure_library_root_logger() -> None:
r""" r"""Configure root logger using a stdout stream handler with an explicit format."""
Configures root logger using a stdout stream handler with an explicit format.
"""
global _default_handler global _default_handler
with _thread_lock: with _thread_lock:
...@@ -126,9 +118,7 @@ def _configure_library_root_logger() -> None: ...@@ -126,9 +118,7 @@ def _configure_library_root_logger() -> None:
def get_logger(name: Optional[str] = None) -> "_Logger": def get_logger(name: Optional[str] = None) -> "_Logger":
r""" r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
Returns a logger with the specified name. It it not supposed to be accessed externally.
"""
if name is None: if name is None:
name = _get_library_name() name = _get_library_name()
...@@ -137,17 +127,13 @@ def get_logger(name: Optional[str] = None) -> "_Logger": ...@@ -137,17 +127,13 @@ def get_logger(name: Optional[str] = None) -> "_Logger":
def add_handler(handler: "logging.Handler") -> None: def add_handler(handler: "logging.Handler") -> None:
r""" r"""Add a handler to the root logger."""
Adds a handler to the root logger.
"""
_configure_library_root_logger() _configure_library_root_logger()
_get_library_root_logger().addHandler(handler) _get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None: def remove_handler(handler: logging.Handler) -> None:
r""" r"""Remove a handler to the root logger."""
Removes a handler to the root logger.
"""
_configure_library_root_logger() _configure_library_root_logger()
_get_library_root_logger().removeHandler(handler) _get_library_root_logger().removeHandler(handler)
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's PEFT library. # This code is inspired by the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py # https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
import gc import gc
import os import os
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union import socket
from typing import TYPE_CHECKING, Any, Literal, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -54,9 +55,7 @@ logger = logging.get_logger(__name__) ...@@ -54,9 +55,7 @@ logger = logging.get_logger(__name__)
class AverageMeter: class AverageMeter:
r""" r"""Compute and store the average and current value."""
Computes and stores the average and current value.
"""
def __init__(self): def __init__(self):
self.reset() self.reset()
...@@ -75,9 +74,7 @@ class AverageMeter: ...@@ -75,9 +74,7 @@ class AverageMeter:
def check_version(requirement: str, mandatory: bool = False) -> None: def check_version(requirement: str, mandatory: bool = False) -> None:
r""" r"""Optionally check the package version."""
Optionally checks the package version.
"""
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory: if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.") logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return return
...@@ -91,22 +88,18 @@ def check_version(requirement: str, mandatory: bool = False) -> None: ...@@ -91,22 +88,18 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None: def check_dependencies() -> None:
r""" r"""Check the version of the required packages."""
Checks the version of the required packages. check_version("transformers>=4.41.2,<=4.51.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
""" check_version("datasets>=2.16.0,<=3.4.1")
check_version("transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0") check_version("accelerate>=0.34.0,<=1.5.2")
check_version("datasets>=2.16.0,<=3.2.0") check_version("peft>=0.14.0,<=0.15.0")
check_version("accelerate>=0.34.0,<=1.2.1")
check_version("peft>=0.11.1,<=0.12.0")
check_version("trl>=0.8.6,<=0.9.6") check_version("trl>=0.8.6,<=0.9.6")
if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"): if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"):
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.") logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float: def calculate_tps(dataset: list[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
r""" r"""Calculate effective tokens per second."""
Calculates effective tokens per second.
"""
effective_token_num = 0 effective_token_num = 0
for data in dataset: for data in dataset:
if stage == "sft": if stage == "sft":
...@@ -118,10 +111,8 @@ def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], ...@@ -118,10 +111,8 @@ def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float],
return result / dist.get_world_size() if dist.is_initialized() else result return result / dist.get_world_size() if dist.is_initialized() else result
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]: def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
r""" r"""Return the number of trainable parameters and number of all parameters in the model."""
Returns the number of trainable parameters and number of all parameters in the model.
"""
trainable_params, all_param = 0, 0 trainable_params, all_param = 0, 0
for param in model.parameters(): for param in model.parameters():
num_params = param.numel() num_params = param.numel()
...@@ -148,9 +139,7 @@ def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]: ...@@ -148,9 +139,7 @@ def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
def get_current_device() -> "torch.device": def get_current_device() -> "torch.device":
r""" r"""Get the current available device."""
Gets the current available device.
"""
if is_torch_xpu_available(): if is_torch_xpu_available():
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0")) device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_npu_available(): elif is_torch_npu_available():
...@@ -166,9 +155,7 @@ def get_current_device() -> "torch.device": ...@@ -166,9 +155,7 @@ def get_current_device() -> "torch.device":
def get_device_count() -> int: def get_device_count() -> int:
r""" r"""Get the number of available GPU or NPU devices."""
Gets the number of available GPU or NPU devices.
"""
if is_torch_xpu_available(): if is_torch_xpu_available():
return torch.xpu.device_count() return torch.xpu.device_count()
elif is_torch_npu_available(): elif is_torch_npu_available():
...@@ -180,18 +167,14 @@ def get_device_count() -> int: ...@@ -180,18 +167,14 @@ def get_device_count() -> int:
def get_logits_processor() -> "LogitsProcessorList": def get_logits_processor() -> "LogitsProcessorList":
r""" r"""Get logits processor that removes NaN and Inf logits."""
Gets logits processor that removes NaN and Inf logits.
"""
logits_processor = LogitsProcessorList() logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor()) logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor return logits_processor
def get_peak_memory() -> Tuple[int, int]: def get_peak_memory() -> tuple[int, int]:
r""" r"""Get the peak memory usage for the current device (in Bytes)."""
Gets the peak memory usage for the current device (in Bytes).
"""
if is_torch_npu_available(): if is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved() return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_cuda_available(): elif is_torch_cuda_available():
...@@ -201,16 +184,12 @@ def get_peak_memory() -> Tuple[int, int]: ...@@ -201,16 +184,12 @@ def get_peak_memory() -> Tuple[int, int]:
def has_tokenized_data(path: "os.PathLike") -> bool: def has_tokenized_data(path: "os.PathLike") -> bool:
r""" r"""Check if the path has a tokenized dataset."""
Checks if the path has a tokenized dataset.
"""
return os.path.isdir(path) and len(os.listdir(path)) > 0 return os.path.isdir(path) and len(os.listdir(path)) > 0
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype": def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
r""" r"""Infer the optimal dtype according to the model_dtype and device compatibility."""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
if _is_bf16_available and model_dtype == torch.bfloat16: if _is_bf16_available and model_dtype == torch.bfloat16:
return torch.bfloat16 return torch.bfloat16
elif _is_fp16_available: elif _is_fp16_available:
...@@ -220,23 +199,17 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype": ...@@ -220,23 +199,17 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
def is_gpu_or_npu_available() -> bool: def is_gpu_or_npu_available() -> bool:
r""" r"""Check if the GPU or NPU is available."""
Checks if the GPU or NPU is available.
"""
return is_torch_npu_available() or is_torch_cuda_available() return is_torch_npu_available() or is_torch_cuda_available()
def is_env_enabled(env_var: str, default: str = "0") -> bool: def is_env_enabled(env_var: str, default: str = "0") -> bool:
r""" r"""Check if the environment variable is enabled."""
Checks if the environment variable is enabled.
"""
return os.getenv(env_var, default).lower() in ["true", "y", "1"] return os.getenv(env_var, default).lower() in ["true", "y", "1"]
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray": def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
r""" r"""Cast a torch tensor or a numpy array to a numpy array."""
Casts a torch tensor or a numpy array to a numpy array.
"""
if isinstance(inputs, torch.Tensor): if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu() inputs = inputs.cpu()
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4 if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
...@@ -248,17 +221,13 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray": ...@@ -248,17 +221,13 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
def skip_check_imports() -> None: def skip_check_imports() -> None:
r""" r"""Avoid flash attention import error in custom model files."""
Avoids flash attention import error in custom model files.
"""
if not is_env_enabled("FORCE_CHECK_IMPORTS"): if not is_env_enabled("FORCE_CHECK_IMPORTS"):
transformers.dynamic_module_utils.check_imports = get_relative_imports transformers.dynamic_module_utils.check_imports = get_relative_imports
def torch_gc() -> None: def torch_gc() -> None:
r""" r"""Collect GPU or NPU memory."""
Collects GPU or NPU memory.
"""
gc.collect() gc.collect()
if is_torch_xpu_available(): if is_torch_xpu_available():
torch.xpu.empty_cache() torch.xpu.empty_cache()
...@@ -306,3 +275,20 @@ def use_openmind() -> bool: ...@@ -306,3 +275,20 @@ def use_openmind() -> bool:
def use_ray() -> bool: def use_ray() -> bool:
return is_env_enabled("USE_RAY") return is_env_enabled("USE_RAY")
def find_available_port() -> int:
"""Find an available port on the local machine."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
return port
def fix_proxy(ipv6_enabled: bool) -> None:
"""Fix proxy settings for gradio ui."""
os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
if ipv6_enabled:
for name in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"):
os.environ.pop(name, None)
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
...@@ -97,3 +97,7 @@ def is_uvicorn_available(): ...@@ -97,3 +97,7 @@ def is_uvicorn_available():
def is_vllm_available(): def is_vllm_available():
return _is_package_available("vllm") return _is_package_available("vllm")
def is_sglang_available():
return _is_package_available("sglang")
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import json import json
import math import math
import os import os
from typing import Any, Dict, List from typing import Any
from transformers.trainer import TRAINER_STATE_NAME from transformers.trainer import TRAINER_STATE_NAME
...@@ -31,10 +31,8 @@ if is_matplotlib_available(): ...@@ -31,10 +31,8 @@ if is_matplotlib_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def smooth(scalars: List[float]) -> List[float]: def smooth(scalars: list[float]) -> list[float]:
r""" r"""EMA implementation according to TensorBoard."""
EMA implementation according to TensorBoard.
"""
if len(scalars) == 0: if len(scalars) == 0:
return [] return []
...@@ -48,10 +46,8 @@ def smooth(scalars: List[float]) -> List[float]: ...@@ -48,10 +46,8 @@ def smooth(scalars: List[float]) -> List[float]:
return smoothed return smoothed
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure": def gen_loss_plot(trainer_log: list[dict[str, Any]]) -> "matplotlib.figure.Figure":
r""" r"""Plot loss curves in LlamaBoard."""
Plots loss curves in LlamaBoard.
"""
plt.close("all") plt.close("all")
plt.switch_backend("agg") plt.switch_backend("agg")
fig = plt.figure() fig = plt.figure()
...@@ -70,10 +66,8 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur ...@@ -70,10 +66,8 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
return fig return fig
def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None: def plot_loss(save_dictionary: str, keys: list[str] = ["loss"]) -> None:
r""" r"""Plot loss curves and saves the image."""
Plots loss curves and saves the image.
"""
plt.switch_backend("agg") plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f: with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
...@@ -16,14 +16,12 @@ ...@@ -16,14 +16,12 @@
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Literal, Optional from typing import Any, Literal, Optional
@dataclass @dataclass
class DataArguments: class DataArguments:
r""" r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
template: Optional[str] = field( template: Optional[str] = field(
default=None, default=None,
...@@ -162,5 +160,5 @@ class DataArguments: ...@@ -162,5 +160,5 @@ class DataArguments:
if self.mask_history and self.train_on_prompt: if self.mask_history and self.train_on_prompt:
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.") raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
...@@ -21,9 +21,7 @@ from datasets import DownloadMode ...@@ -21,9 +21,7 @@ from datasets import DownloadMode
@dataclass @dataclass
class EvaluationArguments: class EvaluationArguments:
r""" r"""Arguments pertaining to specify the evaluation parameters."""
Arguments pertaining to specify the evaluation parameters.
"""
task: str = field( task: str = field(
metadata={"help": "Name of the evaluation task."}, metadata={"help": "Name of the evaluation task."},
......
...@@ -13,14 +13,12 @@ ...@@ -13,14 +13,12 @@
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Literal, Optional from typing import Any, Literal, Optional
@dataclass @dataclass
class FreezeArguments: class FreezeArguments:
r""" r"""Arguments pertaining to the freeze (partial-parameter) training."""
Arguments pertaining to the freeze (partial-parameter) training.
"""
freeze_trainable_layers: int = field( freeze_trainable_layers: int = field(
default=2, default=2,
...@@ -56,9 +54,7 @@ class FreezeArguments: ...@@ -56,9 +54,7 @@ class FreezeArguments:
@dataclass @dataclass
class LoraArguments: class LoraArguments:
r""" r"""Arguments pertaining to the LoRA training."""
Arguments pertaining to the LoRA training.
"""
additional_target: Optional[str] = field( additional_target: Optional[str] = field(
default=None, default=None,
...@@ -128,9 +124,7 @@ class LoraArguments: ...@@ -128,9 +124,7 @@ class LoraArguments:
@dataclass @dataclass
class RLHFArguments: class RLHFArguments:
r""" r"""Arguments pertaining to the PPO, DPO and KTO training."""
Arguments pertaining to the PPO, DPO and KTO training.
"""
pref_beta: float = field( pref_beta: float = field(
default=0.1, default=0.1,
...@@ -212,9 +206,7 @@ class RLHFArguments: ...@@ -212,9 +206,7 @@ class RLHFArguments:
@dataclass @dataclass
class GaloreArguments: class GaloreArguments:
r""" r"""Arguments pertaining to the GaLore algorithm."""
Arguments pertaining to the GaLore algorithm.
"""
use_galore: bool = field( use_galore: bool = field(
default=False, default=False,
...@@ -253,9 +245,7 @@ class GaloreArguments: ...@@ -253,9 +245,7 @@ class GaloreArguments:
@dataclass @dataclass
class ApolloArguments: class ApolloArguments:
r""" r"""Arguments pertaining to the APOLLO algorithm."""
Arguments pertaining to the APOLLO algorithm.
"""
use_apollo: bool = field( use_apollo: bool = field(
default=False, default=False,
...@@ -306,9 +296,7 @@ class ApolloArguments: ...@@ -306,9 +296,7 @@ class ApolloArguments:
@dataclass @dataclass
class BAdamArgument: class BAdamArgument:
r""" r"""Arguments pertaining to the BAdam optimizer."""
Arguments pertaining to the BAdam optimizer.
"""
use_badam: bool = field( use_badam: bool = field(
default=False, default=False,
...@@ -387,15 +375,21 @@ class SwanLabArguments: ...@@ -387,15 +375,21 @@ class SwanLabArguments:
default=None, default=None,
metadata={"help": "The log directory for SwanLab."}, metadata={"help": "The log directory for SwanLab."},
) )
swanlab_lark_webhook_url: Optional[str] = field(
default=None,
metadata={"help": "The Lark(飞书) webhook URL for SwanLab."},
)
swanlab_lark_secret: Optional[str] = field(
default=None,
metadata={"help": "The Lark(飞书) secret for SwanLab."},
)
@dataclass @dataclass
class FinetuningArguments( class FinetuningArguments(
SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments
): ):
r""" r"""Arguments pertaining to which techniques we are going to fine-tuning with."""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
pure_bf16: bool = field( pure_bf16: bool = field(
default=False, default=False,
...@@ -452,13 +446,13 @@ class FinetuningArguments( ...@@ -452,13 +446,13 @@ class FinetuningArguments(
return [item.strip() for item in arg.split(",")] return [item.strip() for item in arg.split(",")]
return arg return arg
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules) self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules) self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2 self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
self.lora_target: List[str] = split_arg(self.lora_target) self.lora_target: list[str] = split_arg(self.lora_target)
self.additional_target: Optional[List[str]] = split_arg(self.additional_target) self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
self.galore_target: List[str] = split_arg(self.galore_target) self.galore_target: list[str] = split_arg(self.galore_target)
self.apollo_target: List[str] = split_arg(self.apollo_target) self.apollo_target: list[str] = split_arg(self.apollo_target)
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"] self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
...@@ -499,7 +493,7 @@ class FinetuningArguments( ...@@ -499,7 +493,7 @@ class FinetuningArguments(
if self.pissa_init: if self.pissa_init:
raise ValueError("`pissa_init` is only valid for LoRA training.") raise ValueError("`pissa_init` is only valid for LoRA training.")
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
args = asdict(self) args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()} args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()}
return args return args
...@@ -13,16 +13,14 @@ ...@@ -13,16 +13,14 @@
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional from typing import Any, Optional
from transformers import GenerationConfig from transformers import GenerationConfig
@dataclass @dataclass
class GeneratingArguments: class GeneratingArguments:
r""" r"""Arguments pertaining to specify the decoding parameters."""
Arguments pertaining to specify the decoding parameters.
"""
do_sample: bool = field( do_sample: bool = field(
default=True, default=True,
...@@ -35,7 +33,9 @@ class GeneratingArguments: ...@@ -35,7 +33,9 @@ class GeneratingArguments:
top_p: float = field( top_p: float = field(
default=0.7, default=0.7,
metadata={ metadata={
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept." "help": (
"The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
)
}, },
) )
top_k: int = field( top_k: int = field(
...@@ -71,7 +71,7 @@ class GeneratingArguments: ...@@ -71,7 +71,7 @@ class GeneratingArguments:
metadata={"help": "Whether or not to remove special tokens in the decoding."}, metadata={"help": "Whether or not to remove special tokens in the decoding."},
) )
def to_dict(self, obey_generation_config: bool = False) -> Dict[str, Any]: def to_dict(self, obey_generation_config: bool = False) -> dict[str, Any]:
args = asdict(self) args = asdict(self)
if args.get("max_new_tokens", -1) > 0: if args.get("max_new_tokens", -1) > 0:
args.pop("max_length", None) args.pop("max_length", None)
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import json import json
from dataclasses import asdict, dataclass, field, fields from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union from typing import Any, Literal, Optional, Union
import torch import torch
from transformers.training_args import _convert_str_dict from transformers.training_args import _convert_str_dict
...@@ -28,9 +28,7 @@ from ..extras.constants import AttentionFunction, EngineName, RopeScaling ...@@ -28,9 +28,7 @@ from ..extras.constants import AttentionFunction, EngineName, RopeScaling
@dataclass @dataclass
class BaseModelArguments: class BaseModelArguments:
r""" r"""Arguments pertaining to the model."""
Arguments pertaining to the model.
"""
model_name_or_path: Optional[str] = field( model_name_or_path: Optional[str] = field(
default=None, default=None,
...@@ -184,9 +182,7 @@ class BaseModelArguments: ...@@ -184,9 +182,7 @@ class BaseModelArguments:
@dataclass @dataclass
class QuantizationArguments: class QuantizationArguments:
r""" r"""Arguments pertaining to the quantization method."""
Arguments pertaining to the quantization method.
"""
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field( quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
default="bitsandbytes", default="bitsandbytes",
...@@ -212,9 +208,7 @@ class QuantizationArguments: ...@@ -212,9 +208,7 @@ class QuantizationArguments:
@dataclass @dataclass
class ProcessorArguments: class ProcessorArguments:
r""" r"""Arguments pertaining to the image processor."""
Arguments pertaining to the image processor.
"""
image_max_pixels: int = field( image_max_pixels: int = field(
default=768 * 768, default=768 * 768,
...@@ -224,6 +218,14 @@ class ProcessorArguments: ...@@ -224,6 +218,14 @@ class ProcessorArguments:
default=32 * 32, default=32 * 32,
metadata={"help": "The minimum number of pixels of image inputs."}, metadata={"help": "The minimum number of pixels of image inputs."},
) )
image_do_pan_and_scan: bool = field(
default=False,
metadata={"help": "Use pan and scan to process image for gemma3."},
)
use_audio_in_video: bool = field(
default=False,
metadata={"help": "Whether or not to use audio in video inputs."},
)
video_max_pixels: int = field( video_max_pixels: int = field(
default=256 * 256, default=256 * 256,
metadata={"help": "The maximum number of pixels of video inputs."}, metadata={"help": "The maximum number of pixels of video inputs."},
...@@ -240,13 +242,22 @@ class ProcessorArguments: ...@@ -240,13 +242,22 @@ class ProcessorArguments:
default=128, default=128,
metadata={"help": "The maximum number of sampled frames for video inputs."}, metadata={"help": "The maximum number of sampled frames for video inputs."},
) )
audio_sampling_rate: int = field(
default=16000,
metadata={"help": "The sampling rate of audio inputs."},
)
def __post_init__(self):
if self.image_max_pixels < self.image_min_pixels:
raise ValueError("`image_max_pixels` cannot be smaller than `image_min_pixels`.")
if self.video_max_pixels < self.video_min_pixels:
raise ValueError("`video_max_pixels` cannot be smaller than `video_min_pixels`.")
@dataclass @dataclass
class ExportArguments: class ExportArguments:
r""" r"""Arguments pertaining to the model export."""
Arguments pertaining to the model export.
"""
export_dir: Optional[str] = field( export_dir: Optional[str] = field(
default=None, default=None,
...@@ -292,16 +303,14 @@ class ExportArguments: ...@@ -292,16 +303,14 @@ class ExportArguments:
@dataclass @dataclass
class VllmArguments: class VllmArguments:
r""" r"""Arguments pertaining to the vLLM worker."""
Arguments pertaining to the vLLM worker.
"""
vllm_maxlen: int = field( vllm_maxlen: int = field(
default=4096, default=4096,
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."}, metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
) )
vllm_gpu_util: float = field( vllm_gpu_util: float = field(
default=0.9, default=0.7,
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."}, metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
) )
vllm_enforce_eager: bool = field( vllm_enforce_eager: bool = field(
...@@ -323,9 +332,36 @@ class VllmArguments: ...@@ -323,9 +332,36 @@ class VllmArguments:
@dataclass @dataclass
class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments): class SGLangArguments:
r""" r"""Arguments pertaining to the SGLang worker."""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
sglang_maxlen: int = field(
default=4096,
metadata={"help": "Maximum sequence (prompt + response) length of the SGLang engine."},
)
sglang_mem_fraction: float = field(
default=0.7,
metadata={"help": "The memory fraction (0-1) to be used for the SGLang engine."},
)
sglang_tp_size: int = field(
default=-1,
metadata={"help": "Tensor parallel size for the SGLang engine."},
)
sglang_config: Optional[Union[dict, str]] = field(
default=None,
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
)
def __post_init__(self):
if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"):
self.sglang_config = _convert_str_dict(json.loads(self.sglang_config))
@dataclass
class ModelArguments(
SGLangArguments, VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments
):
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
The class on the most right will be displayed first. The class on the most right will be displayed first.
""" """
...@@ -335,7 +371,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz ...@@ -335,7 +371,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
init=False, init=False,
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."}, metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
) )
device_map: Optional[Union[str, Dict[str, Any]]] = field( device_map: Optional[Union[str, dict[str, Any]]] = field(
default=None, default=None,
init=False, init=False,
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."}, metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
...@@ -353,8 +389,10 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz ...@@ -353,8 +389,10 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
def __post_init__(self): def __post_init__(self):
BaseModelArguments.__post_init__(self) BaseModelArguments.__post_init__(self)
ProcessorArguments.__post_init__(self)
ExportArguments.__post_init__(self) ExportArguments.__post_init__(self)
VllmArguments.__post_init__(self) VllmArguments.__post_init__(self)
SGLangArguments.__post_init__(self)
@classmethod @classmethod
def copyfrom(cls, source: "Self", **kwargs) -> "Self": def copyfrom(cls, source: "Self", **kwargs) -> "Self":
...@@ -372,7 +410,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz ...@@ -372,7 +410,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
return result return result
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
args = asdict(self) args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()} args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()}
return args return args
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
...@@ -19,7 +19,7 @@ import json ...@@ -19,7 +19,7 @@ import json
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Optional, Union
import torch import torch
import transformers import transformers
...@@ -31,7 +31,7 @@ from transformers.training_args import ParallelMode ...@@ -31,7 +31,7 @@ from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from ..extras import logging from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES from ..extras.constants import CHECKPOINT_NAMES, EngineName
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
...@@ -47,17 +47,15 @@ check_dependencies() ...@@ -47,17 +47,15 @@ check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]: def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
r""" r"""Get arguments from the command line or a config file."""
Gets arguments from the command line or a config file.
"""
if args is not None: if args is not None:
return args return args
...@@ -70,8 +68,8 @@ def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[ ...@@ -70,8 +68,8 @@ def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[
def _parse_args( def _parse_args(
parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False
) -> Tuple[Any]: ) -> tuple[Any]:
args = read_args(args) args = read_args(args)
if isinstance(args, dict): if isinstance(args, dict):
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys) return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
...@@ -136,9 +134,12 @@ def _check_extra_dependencies( ...@@ -136,9 +134,12 @@ def _check_extra_dependencies(
if model_args.mixture_of_depths is not None: if model_args.mixture_of_depths is not None:
check_version("mixture-of-depth>=1.1.6", mandatory=True) check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == "vllm": if model_args.infer_backend == EngineName.VLLM:
check_version("vllm>=0.4.3,<=0.7.3") check_version("vllm>=0.4.3,<=0.8.2")
check_version("vllm", mandatory=True) check_version("vllm", mandatory=True)
elif model_args.infer_backend == EngineName.SGLANG:
check_version("sglang>=0.4.4")
check_version("sglang", mandatory=True)
if finetuning_args.use_galore: if finetuning_args.use_galore:
check_version("galore_torch", mandatory=True) check_version("galore_torch", mandatory=True)
...@@ -161,31 +162,31 @@ def _check_extra_dependencies( ...@@ -161,31 +162,31 @@ def _check_extra_dependencies(
check_version("rouge_chinese", mandatory=True) check_version("rouge_chinese", mandatory=True)
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS: def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS) parser = HfArgumentParser(_TRAIN_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS: def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS) parser = HfArgumentParser(_INFER_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS: def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS) parser = HfArgumentParser(_EVAL_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments: def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments:
parser = HfArgumentParser(RayArguments) parser = HfArgumentParser(RayArguments)
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True) (ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
return ray_args return ray_args
def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS: def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging # Setup logging
...@@ -364,9 +365,7 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _ ...@@ -364,9 +365,7 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
and training_args.resume_from_checkpoint is not None and training_args.resume_from_checkpoint is not None
): ):
logger.warning_rank0( logger.warning_rank0(
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format( f"Add {training_args.resume_from_checkpoint} to `adapter_name_or_path` to resume training from checkpoint."
training_args.resume_from_checkpoint
)
) )
# Post-process model arguments # Post-process model arguments
...@@ -382,20 +381,17 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _ ...@@ -382,20 +381,17 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
# Log on each process the small summary # Log on each process the small summary
logger.info( logger.info(
"Process rank: {}, world size: {}, device: {}, distributed training: {}, compute dtype: {}".format( f"Process rank: {training_args.process_index}, "
training_args.process_index, f"world size: {training_args.world_size}, device: {training_args.device}, "
training_args.world_size, f"distributed training: {training_args.parallel_mode == ParallelMode.DISTRIBUTED}, "
training_args.device, f"compute dtype: {str(model_args.compute_dtype)}"
training_args.parallel_mode == ParallelMode.DISTRIBUTED,
str(model_args.compute_dtype),
)
) )
transformers.set_seed(training_args.seed) transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS: def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging() _set_transformers_logging()
...@@ -426,7 +422,7 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _ ...@@ -426,7 +422,7 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS: def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging() _set_transformers_logging()
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
...@@ -10,9 +24,7 @@ from ..extras.misc import use_ray ...@@ -10,9 +24,7 @@ from ..extras.misc import use_ray
@dataclass @dataclass
class RayArguments: class RayArguments:
r""" r"""Arguments pertaining to the Ray training."""
Arguments pertaining to the Ray training.
"""
ray_run_name: Optional[str] = field( ray_run_name: Optional[str] = field(
default=None, default=None,
...@@ -43,9 +55,7 @@ class RayArguments: ...@@ -43,9 +55,7 @@ class RayArguments:
@dataclass @dataclass
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments): class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
r""" r"""Arguments pertaining to the trainer."""
Arguments pertaining to the trainer.
"""
def __post_init__(self): def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self) Seq2SeqTrainingArguments.__post_init__(self)
......
...@@ -20,9 +20,9 @@ from .model_utils.valuehead import load_valuehead_params ...@@ -20,9 +20,9 @@ from .model_utils.valuehead import load_valuehead_params
__all__ = [ __all__ = [
"QuantizationMethod", "QuantizationMethod",
"find_all_linear_modules",
"load_config", "load_config",
"load_model", "load_model",
"load_tokenizer", "load_tokenizer",
"find_all_linear_modules",
"load_valuehead_params", "load_valuehead_params",
] ]
...@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING ...@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING
import torch import torch
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras import logging from ..extras import logging
from .model_utils.misc import find_all_linear_modules, find_expanded_modules from .model_utils.misc import find_all_linear_modules, find_expanded_modules
...@@ -81,9 +80,8 @@ def _setup_freeze_tuning( ...@@ -81,9 +80,8 @@ def _setup_freeze_tuning(
if finetuning_args.use_llama_pro: if finetuning_args.use_llama_pro:
if num_layers % finetuning_args.freeze_trainable_layers != 0: if num_layers % finetuning_args.freeze_trainable_layers != 0:
raise ValueError( raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format( f"`num_layers` {num_layers} should be "
num_layers, finetuning_args.freeze_trainable_layers f"divisible by `num_layer_trainable` {finetuning_args.freeze_trainable_layers}."
)
) )
stride = num_layers // finetuning_args.freeze_trainable_layers stride = num_layers // finetuning_args.freeze_trainable_layers
...@@ -178,7 +176,7 @@ def _setup_lora_tuning( ...@@ -178,7 +176,7 @@ def _setup_lora_tuning(
} }
for adapter in adapter_to_merge: for adapter in adapter_to_merge:
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs) model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload() model = model.merge_and_unload()
if len(adapter_to_merge) > 0: if len(adapter_to_merge) > 0:
...@@ -263,8 +261,7 @@ def init_adapter( ...@@ -263,8 +261,7 @@ def init_adapter(
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: bool, is_trainable: bool,
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r""" r"""Initialize the adapters.
Initializes the adapters.
Support full-parameter, freeze and LoRA training. Support full-parameter, freeze and LoRA training.
...@@ -279,14 +276,14 @@ def init_adapter( ...@@ -279,14 +276,14 @@ def init_adapter(
# cast trainable parameters to float32 if: # cast trainable parameters to float32 if:
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora) # 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
# 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32) # 2. is_trainable and not pure_bf16 and not badam and not zero3 (zero3 already in fp32)
cast_trainable_params_to_fp32 = False cast_trainable_params_to_fp32 = False
if not is_trainable: if not is_trainable:
pass pass
elif finetuning_args.pure_bf16 or finetuning_args.use_badam: elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.") logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()): elif model_args.quantization_bit is None and is_deepspeed_zero3_enabled():
logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.") logger.info_rank0("DeepSpeed ZeRO3 detected, remaining trainable params in float32.")
else: else:
logger.info_rank0("Upcasting trainable params to float32.") logger.info_rank0("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True cast_trainable_params_to_fp32 = True
......
...@@ -13,13 +13,15 @@ ...@@ -13,13 +13,15 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict from typing import TYPE_CHECKING, Any, Optional, TypedDict
import torch import torch
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoModelForTextToWaveform,
AutoModelForVision2Seq, AutoModelForVision2Seq,
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
...@@ -51,9 +53,8 @@ class TokenizerModule(TypedDict): ...@@ -51,9 +53,8 @@ class TokenizerModule(TypedDict):
processor: Optional["ProcessorMixin"] processor: Optional["ProcessorMixin"]
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: def _get_init_kwargs(model_args: "ModelArguments") -> dict[str, Any]:
r""" r"""Get arguments to load config/tokenizer/model.
Gets arguments to load config/tokenizer/model.
Note: including inplace operation of model_args. Note: including inplace operation of model_args.
""" """
...@@ -68,13 +69,11 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: ...@@ -68,13 +69,11 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r""" r"""Load pretrained tokenizer and optionally loads processor.
Loads pretrained tokenizer and optionally loads processor.
Note: including inplace operation of model_args. Note: including inplace operation of model_args.
""" """
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
try: try:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
...@@ -96,7 +95,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": ...@@ -96,7 +95,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
patch_tokenizer(tokenizer, model_args) patch_tokenizer(tokenizer, model_args)
try: try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, config, tokenizer, model_args) patch_processor(processor, tokenizer, model_args)
except Exception as e: except Exception as e:
logger.debug(f"Processor was not found: {e}.") logger.debug(f"Processor was not found: {e}.")
processor = None processor = None
...@@ -110,9 +109,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": ...@@ -110,9 +109,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
def load_config(model_args: "ModelArguments") -> "PretrainedConfig": def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
r""" r"""Load model config."""
Loads model config.
"""
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
...@@ -124,9 +121,7 @@ def load_model( ...@@ -124,9 +121,7 @@ def load_model(
is_trainable: bool = False, is_trainable: bool = False,
add_valuehead: bool = False, add_valuehead: bool = False,
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r""" r"""Load pretrained model."""
Loads pretrained model.
"""
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args) config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
...@@ -147,10 +142,14 @@ def load_model( ...@@ -147,10 +142,14 @@ def load_model(
if model_args.mixture_of_depths == "load": if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs) model = load_mod_pretrained_model(**init_kwargs)
else: else:
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
load_class = AutoModelForVision2Seq load_class = AutoModelForVision2Seq
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
load_class = AutoModelForImageTextToText
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
load_class = AutoModelForSeq2SeqLM load_class = AutoModelForSeq2SeqLM
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen2_5_omni
load_class = AutoModelForTextToWaveform
else: else:
load_class = AutoModelForCausalLM load_class = AutoModelForCausalLM
...@@ -158,6 +157,8 @@ def load_model( ...@@ -158,6 +157,8 @@ def load_model(
model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code) model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
else: else:
model = load_class.from_pretrained(**init_kwargs) model = load_class.from_pretrained(**init_kwargs)
if getattr(model.config, "model_type", None) == "qwen2_5_omni":
model = model.thinker # use part of Omni model
if model_args.mixture_of_depths == "convert": if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args) model = convert_pretrained_model_to_mod(model, config, model_args)
...@@ -194,8 +195,9 @@ def load_model( ...@@ -194,8 +195,9 @@ def load_model(
trainable_params, all_param = count_parameters(model) trainable_params, all_param = count_parameters(model)
if is_trainable: if is_trainable:
param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format( param_stats = (
trainable_params, all_param, 100 * trainable_params / all_param f"trainable params: {trainable_params:,} || "
f"all params: {all_param:,} || trainable%: {100 * trainable_params / all_param:.4f}"
) )
else: else:
param_stats = f"all params: {all_param:,}" param_stats = f"all params: {all_param:,}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment