Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
7ea81099
Commit
7ea81099
authored
Apr 07, 2025
by
chenych
Browse files
update llama4
parent
84987715
Changes
139
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
485 additions
and
364 deletions
+485
-364
src/llamafactory/data/template.py
src/llamafactory/data/template.py
+117
-73
src/llamafactory/data/tool_utils.py
src/llamafactory/data/tool_utils.py
+30
-46
src/llamafactory/eval/evaluator.py
src/llamafactory/eval/evaluator.py
+5
-5
src/llamafactory/eval/template.py
src/llamafactory/eval/template.py
+8
-10
src/llamafactory/extras/constants.py
src/llamafactory/extras/constants.py
+94
-2
src/llamafactory/extras/env.py
src/llamafactory/extras/env.py
+2
-2
src/llamafactory/extras/logging.py
src/llamafactory/extras/logging.py
+8
-22
src/llamafactory/extras/misc.py
src/llamafactory/extras/misc.py
+43
-57
src/llamafactory/extras/packages.py
src/llamafactory/extras/packages.py
+5
-1
src/llamafactory/extras/ploting.py
src/llamafactory/extras/ploting.py
+7
-13
src/llamafactory/hparams/data_args.py
src/llamafactory/hparams/data_args.py
+4
-6
src/llamafactory/hparams/evaluation_args.py
src/llamafactory/hparams/evaluation_args.py
+1
-3
src/llamafactory/hparams/finetuning_args.py
src/llamafactory/hparams/finetuning_args.py
+23
-29
src/llamafactory/hparams/generating_args.py
src/llamafactory/hparams/generating_args.py
+6
-6
src/llamafactory/hparams/model_args.py
src/llamafactory/hparams/model_args.py
+61
-23
src/llamafactory/hparams/parser.py
src/llamafactory/hparams/parser.py
+27
-31
src/llamafactory/hparams/training_args.py
src/llamafactory/hparams/training_args.py
+16
-6
src/llamafactory/model/__init__.py
src/llamafactory/model/__init__.py
+1
-1
src/llamafactory/model/adapter.py
src/llamafactory/model/adapter.py
+7
-10
src/llamafactory/model/loader.py
src/llamafactory/model/loader.py
+20
-18
No files found.
src/llamafactory/data/template.py
View file @
7ea81099
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing_extensions
import
override
...
...
@@ -46,8 +46,8 @@ class Template:
format_tools
:
"Formatter"
format_prefix
:
"Formatter"
default_system
:
str
stop_words
:
L
ist
[
str
]
thought_words
:
T
uple
[
str
,
str
]
stop_words
:
l
ist
[
str
]
thought_words
:
t
uple
[
str
,
str
]
efficient_eos
:
bool
replace_eos
:
bool
replace_jinja_template
:
bool
...
...
@@ -56,13 +56,11 @@ class Template:
def
encode_oneturn
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
r
"""
Returns a single pair of token ids representing prompt and response respectively.
"""
)
->
tuple
[
list
[
int
],
list
[
int
]]:
r
"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
prompt_ids
=
[]
for
encoded_ids
in
encoded_messages
[:
-
1
]:
...
...
@@ -74,36 +72,28 @@ class Template:
def
encode_multiturn
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
r
"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
)
->
list
[
tuple
[
list
[
int
],
list
[
int
]]]:
r
"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
return
[(
encoded_messages
[
i
],
encoded_messages
[
i
+
1
])
for
i
in
range
(
0
,
len
(
encoded_messages
),
2
)]
def
extract_tool
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
r
"""
Extracts tool message.
"""
def
extract_tool
(
self
,
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
r
"""Extract tool message."""
return
self
.
format_tools
.
extract
(
content
)
def
get_stop_token_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
List
[
int
]:
r
"""
Returns stop token ids.
"""
def
get_stop_token_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
list
[
int
]:
r
"""Return stop token ids."""
stop_token_ids
=
{
tokenizer
.
eos_token_id
}
for
token
in
self
.
stop_words
:
stop_token_ids
.
add
(
tokenizer
.
convert_tokens_to_ids
(
token
))
return
list
(
stop_token_ids
)
def
_convert_elements_to_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
elements
:
"SLOTS"
)
->
List
[
int
]:
r
"""
Converts elements to token ids.
"""
def
_convert_elements_to_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
elements
:
"SLOTS"
)
->
list
[
int
]:
r
"""Convert elements to token ids."""
token_ids
=
[]
for
elem
in
elements
:
if
isinstance
(
elem
,
str
):
...
...
@@ -124,14 +114,14 @@ class Template:
def
_encode
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
)
->
L
ist
[
L
ist
[
int
]]:
r
"""
Encodes formatted inputs to pairs of token ids.
)
->
l
ist
[
l
ist
[
int
]]:
r
"""
Encode formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: query resp
Turn t: query resp
.
"""
system
=
system
or
self
.
default_system
encoded_messages
=
[]
...
...
@@ -161,9 +151,7 @@ class Template:
@
staticmethod
def
_add_or_replace_eos_token
(
tokenizer
:
"PreTrainedTokenizer"
,
eos_token
:
str
)
->
None
:
r
"""
Adds or replaces eos token to the tokenizer.
"""
r
"""Add or replace eos token to the tokenizer."""
is_added
=
tokenizer
.
eos_token_id
is
None
num_added_tokens
=
tokenizer
.
add_special_tokens
({
"eos_token"
:
eos_token
})
...
...
@@ -176,9 +164,7 @@ class Template:
logger
.
warning_rank0
(
"New tokens have been added, make sure `resize_vocab` is True."
)
def
fix_special_tokens
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
r
"""
Adds eos token and pad token to the tokenizer.
"""
r
"""Add eos token and pad token to the tokenizer."""
stop_words
=
self
.
stop_words
if
self
.
replace_eos
:
if
not
stop_words
:
...
...
@@ -204,16 +190,12 @@ class Template:
@
staticmethod
def
_jinja_escape
(
content
:
str
)
->
str
:
r
"""
Escape single quotes in content.
"""
r
"""Escape single quotes in content."""
return
content
.
replace
(
"'"
,
r
"\'"
)
@
staticmethod
def
_convert_slots_to_jinja
(
slots
:
"SLOTS"
,
tokenizer
:
"PreTrainedTokenizer"
,
placeholder
:
str
=
"content"
)
->
str
:
r
"""
Converts slots to jinja template.
"""
r
"""Convert slots to jinja template."""
slot_items
=
[]
for
slot
in
slots
:
if
isinstance
(
slot
,
str
):
...
...
@@ -235,9 +217,7 @@ class Template:
return
" + "
.
join
(
slot_items
)
def
_get_jinja_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
r
"""
Returns the jinja template.
"""
r
"""Return the jinja template."""
prefix
=
self
.
_convert_slots_to_jinja
(
self
.
format_prefix
.
apply
(),
tokenizer
)
system
=
self
.
_convert_slots_to_jinja
(
self
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
"system_message"
)
user
=
self
.
_convert_slots_to_jinja
(
self
.
format_user
.
apply
(),
tokenizer
)
...
...
@@ -265,9 +245,7 @@ class Template:
return
jinja_template
def
fix_jinja_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
r
"""
Replaces the jinja template in the tokenizer.
"""
r
"""Replace the jinja template in the tokenizer."""
if
tokenizer
.
chat_template
is
None
or
self
.
replace_jinja_template
:
try
:
tokenizer
.
chat_template
=
self
.
_get_jinja_template
(
tokenizer
)
...
...
@@ -278,9 +256,7 @@ class Template:
def
_convert_slots_to_ollama
(
slots
:
"SLOTS"
,
tokenizer
:
"PreTrainedTokenizer"
,
placeholder
:
str
=
"content"
)
->
str
:
r
"""
Converts slots to ollama template.
"""
r
"""Convert slots to ollama template."""
slot_items
=
[]
for
slot
in
slots
:
if
isinstance
(
slot
,
str
):
...
...
@@ -302,9 +278,7 @@ class Template:
return
""
.
join
(
slot_items
)
def
_get_ollama_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
r
"""
Returns the ollama template.
"""
r
"""Return the ollama template."""
prefix
=
self
.
_convert_slots_to_ollama
(
self
.
format_prefix
.
apply
(),
tokenizer
)
system
=
self
.
_convert_slots_to_ollama
(
self
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
".System"
)
user
=
self
.
_convert_slots_to_ollama
(
self
.
format_user
.
apply
(),
tokenizer
,
placeholder
=
".Content"
)
...
...
@@ -316,8 +290,7 @@ class Template:
)
def
get_ollama_modelfile
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
r
"""
Returns the ollama modelfile.
r
"""Return the ollama modelfile.
TODO: support function calling.
"""
...
...
@@ -336,14 +309,16 @@ class Template:
@
dataclass
class
Llama2Template
(
Template
):
r
"""A template that fuse the system message to first user message."""
@
override
def
_encode
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
str
,
tools
:
str
,
)
->
L
ist
[
L
ist
[
int
]]:
)
->
l
ist
[
l
ist
[
int
]]:
system
=
system
or
self
.
default_system
encoded_messages
=
[]
for
i
,
message
in
enumerate
(
messages
):
...
...
@@ -402,7 +377,7 @@ class Llama2Template(Template):
return
jinja_template
TEMPLATES
:
D
ict
[
str
,
"Template"
]
=
{}
TEMPLATES
:
d
ict
[
str
,
"Template"
]
=
{}
def
register_template
(
...
...
@@ -415,16 +390,15 @@ def register_template(
format_tools
:
Optional
[
"Formatter"
]
=
None
,
format_prefix
:
Optional
[
"Formatter"
]
=
None
,
default_system
:
str
=
""
,
stop_words
:
Optional
[
Sequence
[
str
]]
=
None
,
thought_words
:
Optional
[
T
uple
[
str
,
str
]]
=
None
,
stop_words
:
Optional
[
list
[
str
]]
=
None
,
thought_words
:
Optional
[
t
uple
[
str
,
str
]]
=
None
,
efficient_eos
:
bool
=
False
,
replace_eos
:
bool
=
False
,
replace_jinja_template
:
bool
=
False
,
mm_plugin
:
"BasePlugin"
=
get_mm_plugin
(
name
=
"base"
),
template_class
:
T
ype
[
"Template"
]
=
Template
,
template_class
:
t
ype
[
"Template"
]
=
Template
,
)
->
None
:
r
"""
Registers a chat template.
r
"""Register a chat template.
To add the following chat template:
```
...
...
@@ -472,9 +446,7 @@ def register_template(
def
parse_template
(
tokenizer
:
"PreTrainedTokenizer"
)
->
"Template"
:
r
"""
Extracts a chat template from the tokenizer.
"""
r
"""Extract a chat template from the tokenizer."""
def
find_diff
(
short_str
:
str
,
long_str
:
str
)
->
str
:
i
,
j
=
0
,
0
...
...
@@ -532,9 +504,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
def
get_template_and_fix_tokenizer
(
tokenizer
:
"PreTrainedTokenizer"
,
data_args
:
"DataArguments"
)
->
"Template"
:
r
"""
Gets chat template and fixes the tokenizer.
"""
r
"""Get chat template and fixes the tokenizer."""
if
data_args
.
template
is
None
:
if
isinstance
(
tokenizer
.
chat_template
,
str
):
logger
.
warning_rank0
(
"`template` was not specified, try parsing the chat template from the tokenizer."
)
...
...
@@ -807,15 +777,17 @@ register_template(
register_template
(
name
=
"default"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n
Assistant:"
]),
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
"
,
{
"eos_token"
},
"
\n
Assistant:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"System: {{content}}
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"System: {{content}}"
,
{
"eos_token"
},
"
\n
"
]),
replace_jinja_template
=
True
,
)
register_template
(
name
=
"empty"
,
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
]),
replace_jinja_template
=
True
,
)
...
...
@@ -839,6 +811,7 @@ register_template(
name
=
"fewshot"
,
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
efficient_eos
=
True
,
replace_jinja_template
=
True
,
)
...
...
@@ -846,10 +819,29 @@ register_template(
name
=
"gemma"
,
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_turn>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<start_of_turn>tool
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
template_class
=
Llama2Template
,
)
# copied from gemma template
register_template
(
name
=
"gemma3"
,
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_turn>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<start_of_turn>tool
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
mm_plugin
=
get_mm_plugin
(
"gemma3"
,
image_token
=
"<image_soft_token>"
),
template_class
=
Llama2Template
,
)
...
...
@@ -887,6 +879,16 @@ register_template(
)
register_template
(
name
=
"hunyuan"
,
format_user
=
StringFormatter
(
slots
=
[
"<|bos|>user
\n
{{content}}<|eos|>
\n
<|bos|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|eos|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|bos|>system
\n
{{content}}<|eos|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[
"<|bos|>"
]),
stop_words
=
[
"<|eos|>"
],
)
register_template
(
name
=
"intern"
,
format_user
=
StringFormatter
(
slots
=
[
"<|User|>:{{content}}
\n
<|Bot|>:"
]),
...
...
@@ -966,6 +968,26 @@ register_template(
)
register_template
(
name
=
"llama4"
,
format_user
=
StringFormatter
(
slots
=
[
"<|header_start|>user<|header_end|>
\n\n
{{content}}<|eot|><|header_start|>assistant<|header_end|>
\n\n
"
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|eot|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|header_start|>system<|header_end|>
\n\n
{{content}}<|eot|>"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot|>"
],
tool_format
=
"llama3"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|header_start|>ipython<|header_end|>
\n\n
{{content}}<|eot|><|header_start|>assistant<|header_end|>
\n\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot|>"
,
"<|eom|>"
],
mm_plugin
=
get_mm_plugin
(
name
=
"llama4"
,
image_token
=
"<|image|>"
),
)
# copied from llama3 template
register_template
(
name
=
"mllama"
,
...
...
@@ -1149,7 +1171,8 @@ register_template(
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
default_system
=
(
"你是一个经过良好训练的AI助手,你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.
\n
## 重要!!!!!
\n
"
"你是一个经过良好训练的AI助手,你的名字是Marco-o1."
"由阿里国际数字商业集团的AI Business创造.
\n
## 重要!!!!!
\n
"
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。
\n
"
"<Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。
\n
"
),
...
...
@@ -1273,6 +1296,7 @@ register_template(
format_user
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
),
template_class
=
Llama2Template
,
)
...
...
@@ -1285,7 +1309,9 @@ register_template(
slots
=
[
"<start_of_turn>tool
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
mm_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
),
template_class
=
Llama2Template
,
)
...
...
@@ -1361,6 +1387,24 @@ register_template(
)
# copied from qwen template
register_template
(
name
=
"qwen2_omni"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_omni"
,
audio_token
=
"<|AUDIO|>"
,
image_token
=
"<|IMAGE|>"
,
video_token
=
"<|VIDEO|>"
),
)
# copied from qwen template
register_template
(
name
=
"qwen2_vl"
,
...
...
src/llamafactory/data/tool_utils.py
View file @
7ea81099
...
...
@@ -17,7 +17,7 @@ import re
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
datetime
import
datetime
from
typing
import
Any
,
Dict
,
List
,
NamedTuple
,
Tuple
,
Union
from
typing
import
Any
,
Named
Tuple
,
Union
from
typing_extensions
import
override
...
...
@@ -60,31 +60,24 @@ QWEN_TOOL_PROMPT = (
@
dataclass
class
ToolUtils
(
ABC
):
"""
Base class for tool utilities.
"""
"""Base class for tool utilities."""
@
staticmethod
@
abstractmethod
def
tool_formatter
(
tools
:
List
[
Dict
[
str
,
Any
]])
->
str
:
r
"""
Generates the system message describing all the available tools.
"""
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
r
"""Generate the system message describing all the available tools."""
...
@
staticmethod
@
abstractmethod
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
str
:
r
"""
Generates the assistant message including all the tool calls.
"""
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
r
"""Generate the assistant message including all the tool calls."""
...
@
staticmethod
@
abstractmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
r
"""
Extracts all the function calls from the assistant message.
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
r
"""Extract all the function calls from the assistant message.
It should be an inverse function of `function_formatter`.
"""
...
...
@@ -92,13 +85,11 @@ class ToolUtils(ABC):
class
DefaultToolUtils
(
ToolUtils
):
r
"""
Default tool using template.
"""
r
"""Default tool using template."""
@
override
@
staticmethod
def
tool_formatter
(
tools
:
L
ist
[
D
ict
[
str
,
Any
]])
->
str
:
def
tool_formatter
(
tools
:
l
ist
[
d
ict
[
str
,
Any
]])
->
str
:
tool_text
=
""
tool_names
=
[]
for
tool
in
tools
:
...
...
@@ -132,7 +123,7 @@ class DefaultToolUtils(ToolUtils):
@
override
@
staticmethod
def
function_formatter
(
functions
:
L
ist
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
l
ist
[
"FunctionCall"
])
->
str
:
function_text
=
""
for
name
,
arguments
in
functions
:
function_text
+=
f
"Action:
{
name
}
\n
Action Input:
{
arguments
}
\n
"
...
...
@@ -141,9 +132,9 @@ class DefaultToolUtils(ToolUtils):
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
L
ist
[
"FunctionCall"
]]:
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
l
ist
[
"FunctionCall"
]]:
regex
=
re
.
compile
(
r
"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)"
,
re
.
DOTALL
)
action_match
:
L
ist
[
T
uple
[
str
,
str
]]
=
re
.
findall
(
regex
,
content
)
action_match
:
l
ist
[
t
uple
[
str
,
str
]]
=
re
.
findall
(
regex
,
content
)
if
not
action_match
:
return
content
...
...
@@ -161,13 +152,11 @@ class DefaultToolUtils(ToolUtils):
class
GLM4ToolUtils
(
ToolUtils
):
r
"""
GLM-4 tool using template.
"""
r
"""GLM-4 tool using template."""
@
override
@
staticmethod
def
tool_formatter
(
tools
:
L
ist
[
D
ict
[
str
,
Any
]])
->
str
:
def
tool_formatter
(
tools
:
l
ist
[
d
ict
[
str
,
Any
]])
->
str
:
tool_text
=
""
for
tool
in
tools
:
tool_text
+=
"
\n\n
## {name}
\n\n
{body}
\n
在调用上述函数时,请使用 Json 格式表示调用的参数。"
.
format
(
...
...
@@ -178,7 +167,7 @@ class GLM4ToolUtils(ToolUtils):
@
override
@
staticmethod
def
function_formatter
(
functions
:
L
ist
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
l
ist
[
"FunctionCall"
])
->
str
:
if
len
(
functions
)
>
1
:
raise
ValueError
(
"GLM-4 does not support parallel functions."
)
...
...
@@ -186,7 +175,7 @@ class GLM4ToolUtils(ToolUtils):
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
L
ist
[
"FunctionCall"
]]:
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
l
ist
[
"FunctionCall"
]]:
if
"
\n
"
not
in
content
:
return
content
...
...
@@ -200,15 +189,14 @@ class GLM4ToolUtils(ToolUtils):
class
Llama3ToolUtils
(
ToolUtils
):
r
"""
Llama 3.x tool using template with `tools_in_user_message=False`.
r
"""Llama 3.x tool using template with `tools_in_user_message=False`.
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
"""
@
override
@
staticmethod
def
tool_formatter
(
tools
:
L
ist
[
D
ict
[
str
,
Any
]])
->
str
:
def
tool_formatter
(
tools
:
l
ist
[
d
ict
[
str
,
Any
]])
->
str
:
date
=
datetime
.
now
().
strftime
(
"%d %b %Y"
)
tool_text
=
""
for
tool
in
tools
:
...
...
@@ -219,7 +207,7 @@ class Llama3ToolUtils(ToolUtils):
@
override
@
staticmethod
def
function_formatter
(
functions
:
L
ist
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
l
ist
[
"FunctionCall"
])
->
str
:
if
len
(
functions
)
>
1
:
raise
ValueError
(
"Llama-3 does not support parallel functions."
)
...
...
@@ -227,7 +215,7 @@ class Llama3ToolUtils(ToolUtils):
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
L
ist
[
"FunctionCall"
]]:
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
l
ist
[
"FunctionCall"
]]:
try
:
tool
=
json
.
loads
(
content
.
strip
())
except
json
.
JSONDecodeError
:
...
...
@@ -240,13 +228,11 @@ class Llama3ToolUtils(ToolUtils):
class
MistralToolUtils
(
ToolUtils
):
r
"""
Mistral v0.3 tool using template.
"""
r
"""Mistral v0.3 tool using template."""
@
override
@
staticmethod
def
tool_formatter
(
tools
:
L
ist
[
D
ict
[
str
,
Any
]])
->
str
:
def
tool_formatter
(
tools
:
l
ist
[
d
ict
[
str
,
Any
]])
->
str
:
wrapped_tools
=
[]
for
tool
in
tools
:
wrapped_tools
.
append
({
"type"
:
"function"
,
"function"
:
tool
})
...
...
@@ -255,7 +241,7 @@ class MistralToolUtils(ToolUtils):
@
override
@
staticmethod
def
function_formatter
(
functions
:
L
ist
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
l
ist
[
"FunctionCall"
])
->
str
:
function_texts
=
[]
for
name
,
arguments
in
functions
:
function_texts
.
append
(
f
'{{"name": "
{
name
}
", "arguments":
{
arguments
}
}}'
)
...
...
@@ -264,7 +250,7 @@ class MistralToolUtils(ToolUtils):
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
L
ist
[
"FunctionCall"
]]:
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
l
ist
[
"FunctionCall"
]]:
try
:
tools
=
json
.
loads
(
content
.
strip
())
except
json
.
JSONDecodeError
:
...
...
@@ -284,13 +270,11 @@ class MistralToolUtils(ToolUtils):
class
QwenToolUtils
(
ToolUtils
):
r
"""
Qwen 2.5 tool using template.
"""
r
"""Qwen 2.5 tool using template."""
@
override
@
staticmethod
def
tool_formatter
(
tools
:
L
ist
[
D
ict
[
str
,
Any
]])
->
str
:
def
tool_formatter
(
tools
:
l
ist
[
d
ict
[
str
,
Any
]])
->
str
:
tool_text
=
""
for
tool
in
tools
:
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
tool
}
...
...
@@ -300,7 +284,7 @@ class QwenToolUtils(ToolUtils):
@
override
@
staticmethod
def
function_formatter
(
functions
:
L
ist
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
l
ist
[
"FunctionCall"
])
->
str
:
function_texts
=
[]
for
name
,
arguments
in
functions
:
function_texts
.
append
(
...
...
@@ -311,9 +295,9 @@ class QwenToolUtils(ToolUtils):
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
L
ist
[
"FunctionCall"
]]:
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
l
ist
[
"FunctionCall"
]]:
regex
=
re
.
compile
(
r
"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)"
,
re
.
DOTALL
)
tool_match
:
L
ist
[
str
]
=
re
.
findall
(
regex
,
content
)
tool_match
:
l
ist
[
str
]
=
re
.
findall
(
regex
,
content
)
if
not
tool_match
:
return
content
...
...
src/llamafactory/eval/evaluator.py
View file @
7ea81099
...
...
@@ -39,7 +39,7 @@
import
json
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
numpy
as
np
import
torch
...
...
@@ -59,7 +59,7 @@ if TYPE_CHECKING:
class
Evaluator
:
def
__init__
(
self
,
args
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
)
->
None
:
def
__init__
(
self
,
args
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
)
->
None
:
self
.
model_args
,
self
.
data_args
,
self
.
eval_args
,
finetuning_args
=
get_eval_args
(
args
)
self
.
tokenizer
=
load_tokenizer
(
self
.
model_args
)[
"tokenizer"
]
self
.
tokenizer
.
padding_side
=
"right"
# avoid overflow issue in batched inference for llama2
...
...
@@ -69,7 +69,7 @@ class Evaluator:
self
.
choice_inputs
=
[
self
.
tokenizer
.
encode
(
ch
,
add_special_tokens
=
False
)[
-
1
]
for
ch
in
CHOICES
]
@
torch
.
inference_mode
()
def
batch_inference
(
self
,
batch_input
:
D
ict
[
str
,
"torch.Tensor"
])
->
L
ist
[
str
]:
def
batch_inference
(
self
,
batch_input
:
d
ict
[
str
,
"torch.Tensor"
])
->
l
ist
[
str
]:
logits
=
self
.
model
(
**
batch_input
).
logits
lengths
=
torch
.
sum
(
batch_input
[
"attention_mask"
],
dim
=-
1
)
word_probs
=
torch
.
stack
([
logits
[
i
,
lengths
[
i
]
-
1
]
for
i
in
range
(
len
(
lengths
))],
dim
=
0
)
...
...
@@ -88,7 +88,7 @@ class Evaluator:
)
with
open
(
mapping
,
encoding
=
"utf-8"
)
as
f
:
categorys
:
D
ict
[
str
,
D
ict
[
str
,
str
]]
=
json
.
load
(
f
)
categorys
:
d
ict
[
str
,
d
ict
[
str
,
str
]]
=
json
.
load
(
f
)
category_corrects
=
{
subj
:
np
.
array
([],
dtype
=
"bool"
)
for
subj
in
SUBJECTS
}
pbar
=
tqdm
(
categorys
.
keys
(),
desc
=
"Processing subjects"
,
position
=
0
)
...
...
@@ -136,7 +136,7 @@ class Evaluator:
pbar
.
close
()
self
.
_save_results
(
category_corrects
,
results
)
def
_save_results
(
self
,
category_corrects
:
D
ict
[
str
,
"NDArray"
],
results
:
D
ict
[
str
,
D
ict
[
int
,
str
]])
->
None
:
def
_save_results
(
self
,
category_corrects
:
d
ict
[
str
,
"NDArray"
],
results
:
d
ict
[
str
,
d
ict
[
int
,
str
]])
->
None
:
score_info
=
"
\n
"
.
join
(
[
f
"
{
category_name
:
>
15
}
:
{
100
*
np
.
mean
(
category_correct
):.
2
f
}
"
...
...
src/llamafactory/eval/template.py
View file @
7ea81099
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Sequence
,
Tuple
from
..data
import
Role
from
..extras.constants
import
CHOICES
...
...
@@ -25,20 +24,19 @@ class EvalTemplate:
choice
:
str
answer
:
str
def
_parse_example
(
self
,
example
:
Dict
[
str
,
str
])
->
Tuple
[
str
,
str
]:
r
"""
def
_parse_example
(
self
,
example
:
dict
[
str
,
str
])
->
tuple
[
str
,
str
]:
r
"""Parse eval example.
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
output: a tuple of (prompt, response)
output: a tuple of (prompt, response)
.
"""
candidates
=
[
self
.
choice
.
format
(
choice
=
ch
,
content
=
example
[
ch
])
for
ch
in
CHOICES
if
ch
in
example
]
return
""
.
join
([
example
[
"question"
]]
+
candidates
+
[
self
.
answer
]),
example
[
"answer"
]
def
format_example
(
self
,
target_data
:
Dict
[
str
,
str
],
support_set
:
Sequence
[
Dict
[
str
,
str
]],
subject_name
:
str
)
->
List
[
Dict
[
str
,
str
]]:
r
"""
Converts dataset examples to messages.
"""
self
,
target_data
:
dict
[
str
,
str
],
support_set
:
list
[
dict
[
str
,
str
]],
subject_name
:
str
)
->
list
[
dict
[
str
,
str
]]:
r
"""Convert dataset examples to messages."""
messages
=
[]
for
k
in
range
(
len
(
support_set
)):
prompt
,
response
=
self
.
_parse_example
(
support_set
[
k
])
...
...
@@ -52,7 +50,7 @@ class EvalTemplate:
return
messages
eval_templates
:
D
ict
[
str
,
"EvalTemplate"
]
=
{}
eval_templates
:
d
ict
[
str
,
"EvalTemplate"
]
=
{}
def
_register_eval_template
(
name
:
str
,
system
:
str
,
choice
:
str
,
answer
:
str
)
->
None
:
...
...
src/llamafactory/extras/constants.py
View file @
7ea81099
...
...
@@ -15,7 +15,7 @@
import
os
from
collections
import
OrderedDict
,
defaultdict
from
enum
import
Enum
from
typing
import
Dict
,
Optional
from
typing
import
Optional
from
peft.utils
import
SAFETENSORS_WEIGHTS_NAME
as
SAFE_ADAPTER_WEIGHTS_NAME
from
peft.utils
import
WEIGHTS_NAME
as
ADAPTER_WEIGHTS_NAME
...
...
@@ -106,6 +106,7 @@ class AttentionFunction(str, Enum):
class
EngineName
(
str
,
Enum
):
HF
=
"huggingface"
VLLM
=
"vllm"
SGLANG
=
"sglang"
class
DownloadSource
(
str
,
Enum
):
...
...
@@ -122,7 +123,7 @@ class RopeScaling(str, Enum):
def
register_model_group
(
models
:
D
ict
[
str
,
D
ict
[
DownloadSource
,
str
]],
models
:
d
ict
[
str
,
d
ict
[
DownloadSource
,
str
]],
template
:
Optional
[
str
]
=
None
,
multimodal
:
bool
=
False
,
)
->
None
:
...
...
@@ -650,11 +651,51 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"google/gemma-2-27b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-2-27b-it"
,
},
"Gemma-3-1B"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-1b-pt"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-1b-pt"
,
},
"Gemma-3-1B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-1b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-1b-it"
,
},
},
template
=
"gemma"
,
)
register_model_group
(
models
=
{
"Gemma-3-4B"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-4b-pt"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-4b-pt"
,
},
"Gemma-3-12B"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-12b-pt"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-12b-pt"
,
},
"Gemma-3-27B"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-27b-pt"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-27b-pt"
,
},
"Gemma-3-4B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-4b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-4b-it"
,
},
"Gemma-3-12B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-12b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-12b-it"
,
},
"Gemma-3-27B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-27b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-27b-it"
,
},
},
template
=
"gemma3"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"GLM-4-9B"
:
{
...
...
@@ -768,6 +809,17 @@ register_model_group(
)
register_model_group
(
models
=
{
"Hunyuan-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"tencent/Hunyuan-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Hunyuan-7B-Instruct"
,
},
},
template
=
"hunyuan"
,
)
register_model_group
(
models
=
{
"Index-1.9B-Base"
:
{
...
...
@@ -1059,6 +1111,30 @@ register_model_group(
)
register_model_group
(
models
=
{
"Llama-4-Scout-17B-16E"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-4-Scout-17B-16E"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Llama-4-Scout-17B-16E"
,
},
"Llama-4-Scout-17B-16E-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Llama-4-Scout-17B-16E-Instruct"
,
},
"Llama-4-Maverick-17B-128E"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-4-Maverick-17B-128E"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Llama-4-Maverick-17B-128E"
,
},
"Llama-4-Maverick-17B-128E-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-4-Maverick-17B-128E-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Llama-4-Maverick-17B-128E-Instruct"
,
},
},
template
=
"llama4"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"LLaVA-1.5-7B-Chat"
:
{
...
...
@@ -2218,6 +2294,18 @@ register_model_group(
)
register_model_group
(
models
=
{
"Qwen2.5-Omni-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-7B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-7B"
,
}
},
template
=
"qwen2_omni"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"Qwen2-VL-2B"
:
{
...
...
@@ -2294,6 +2382,10 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-VL-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-VL-7B-Instruct"
,
},
"Qwen2.5-VL-32B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-VL-32B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-VL-32B-Instruct"
,
},
"Qwen2.5-VL-72B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-VL-72B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-VL-72B-Instruct"
,
...
...
src/llamafactory/extras/env.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
...
...
@@ -26,7 +26,7 @@ import trl
from
transformers.utils
import
is_torch_cuda_available
,
is_torch_npu_available
VERSION
=
"0.9.
2
"
VERSION
=
"0.9.
3.dev0
"
def
print_env
()
->
None
:
...
...
src/llamafactory/extras/logging.py
View file @
7ea81099
# Copyright 202
4
Optuna, HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
Optuna, HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py
...
...
@@ -32,9 +32,7 @@ _default_log_level: "logging._Level" = logging.INFO
class
LoggerHandler
(
logging
.
Handler
):
r
"""
Redirects the logging output to the logging file for LLaMA Board.
"""
r
"""Redirect the logging output to the logging file for LLaMA Board."""
def
__init__
(
self
,
output_dir
:
str
)
->
None
:
super
().
__init__
()
...
...
@@ -67,9 +65,7 @@ class LoggerHandler(logging.Handler):
class
_Logger
(
logging
.
Logger
):
r
"""
A logger that supports rank0 logging.
"""
r
"""A logger that supports rank0 logging."""
def
info_rank0
(
self
,
*
args
,
**
kwargs
)
->
None
:
self
.
info
(
*
args
,
**
kwargs
)
...
...
@@ -82,9 +78,7 @@ class _Logger(logging.Logger):
def
_get_default_logging_level
()
->
"logging._Level"
:
r
"""
Returns the default logging level.
"""
r
"""Return the default logging level."""
env_level_str
=
os
.
environ
.
get
(
"LLAMAFACTORY_VERBOSITY"
,
None
)
if
env_level_str
:
if
env_level_str
.
upper
()
in
logging
.
_nameToLevel
:
...
...
@@ -104,9 +98,7 @@ def _get_library_root_logger() -> "_Logger":
def
_configure_library_root_logger
()
->
None
:
r
"""
Configures root logger using a stdout stream handler with an explicit format.
"""
r
"""Configure root logger using a stdout stream handler with an explicit format."""
global
_default_handler
with
_thread_lock
:
...
...
@@ -126,9 +118,7 @@ def _configure_library_root_logger() -> None:
def
get_logger
(
name
:
Optional
[
str
]
=
None
)
->
"_Logger"
:
r
"""
Returns a logger with the specified name. It it not supposed to be accessed externally.
"""
r
"""Return a logger with the specified name. It it not supposed to be accessed externally."""
if
name
is
None
:
name
=
_get_library_name
()
...
...
@@ -137,17 +127,13 @@ def get_logger(name: Optional[str] = None) -> "_Logger":
def
add_handler
(
handler
:
"logging.Handler"
)
->
None
:
r
"""
Adds a handler to the root logger.
"""
r
"""Add a handler to the root logger."""
_configure_library_root_logger
()
_get_library_root_logger
().
addHandler
(
handler
)
def
remove_handler
(
handler
:
logging
.
Handler
)
->
None
:
r
"""
Removes a handler to the root logger.
"""
r
"""Remove a handler to the root logger."""
_configure_library_root_logger
()
_get_library_root_logger
().
removeHandler
(
handler
)
...
...
src/llamafactory/extras/misc.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
...
...
@@ -17,7 +17,8 @@
import
gc
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Literal
,
Sequence
,
Tuple
,
Union
import
socket
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
Union
import
torch
import
torch.distributed
as
dist
...
...
@@ -54,9 +55,7 @@ logger = logging.get_logger(__name__)
class
AverageMeter
:
r
"""
Computes and stores the average and current value.
"""
r
"""Compute and store the average and current value."""
def
__init__
(
self
):
self
.
reset
()
...
...
@@ -75,9 +74,7 @@ class AverageMeter:
def
check_version
(
requirement
:
str
,
mandatory
:
bool
=
False
)
->
None
:
r
"""
Optionally checks the package version.
"""
r
"""Optionally check the package version."""
if
is_env_enabled
(
"DISABLE_VERSION_CHECK"
)
and
not
mandatory
:
logger
.
warning_rank0_once
(
"Version checking has been disabled, may lead to unexpected behaviors."
)
return
...
...
@@ -91,22 +88,18 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def
check_dependencies
()
->
None
:
r
"""
Checks the version of the required packages.
"""
check_version
(
"transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0"
)
check_version
(
"datasets>=2.16.0,<=3.2.0"
)
check_version
(
"accelerate>=0.34.0,<=1.2.1"
)
check_version
(
"peft>=0.11.1,<=0.12.0"
)
r
"""Check the version of the required packages."""
check_version
(
"transformers>=4.41.2,<=4.51.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0"
)
check_version
(
"datasets>=2.16.0,<=3.4.1"
)
check_version
(
"accelerate>=0.34.0,<=1.5.2"
)
check_version
(
"peft>=0.14.0,<=0.15.0"
)
check_version
(
"trl>=0.8.6,<=0.9.6"
)
if
is_transformers_version_greater_than
(
"4.46.0"
)
and
not
is_transformers_version_greater_than
(
"4.48.1"
):
logger
.
warning_rank0_once
(
"There are known bugs in transformers v4.46.0-v4.48.0, please use other versions."
)
def
calculate_tps
(
dataset
:
Sequence
[
Dict
[
str
,
Any
]],
metrics
:
Dict
[
str
,
float
],
stage
:
Literal
[
"sft"
,
"rm"
])
->
float
:
r
"""
Calculates effective tokens per second.
"""
def
calculate_tps
(
dataset
:
list
[
dict
[
str
,
Any
]],
metrics
:
dict
[
str
,
float
],
stage
:
Literal
[
"sft"
,
"rm"
])
->
float
:
r
"""Calculate effective tokens per second."""
effective_token_num
=
0
for
data
in
dataset
:
if
stage
==
"sft"
:
...
...
@@ -118,10 +111,8 @@ def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float],
return
result
/
dist
.
get_world_size
()
if
dist
.
is_initialized
()
else
result
def
count_parameters
(
model
:
"torch.nn.Module"
)
->
Tuple
[
int
,
int
]:
r
"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
def
count_parameters
(
model
:
"torch.nn.Module"
)
->
tuple
[
int
,
int
]:
r
"""Return the number of trainable parameters and number of all parameters in the model."""
trainable_params
,
all_param
=
0
,
0
for
param
in
model
.
parameters
():
num_params
=
param
.
numel
()
...
...
@@ -148,9 +139,7 @@ def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
def
get_current_device
()
->
"torch.device"
:
r
"""
Gets the current available device.
"""
r
"""Get the current available device."""
if
is_torch_xpu_available
():
device
=
"xpu:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
elif
is_torch_npu_available
():
...
...
@@ -166,9 +155,7 @@ def get_current_device() -> "torch.device":
def
get_device_count
()
->
int
:
r
"""
Gets the number of available GPU or NPU devices.
"""
r
"""Get the number of available GPU or NPU devices."""
if
is_torch_xpu_available
():
return
torch
.
xpu
.
device_count
()
elif
is_torch_npu_available
():
...
...
@@ -180,18 +167,14 @@ def get_device_count() -> int:
def
get_logits_processor
()
->
"LogitsProcessorList"
:
r
"""
Gets logits processor that removes NaN and Inf logits.
"""
r
"""Get logits processor that removes NaN and Inf logits."""
logits_processor
=
LogitsProcessorList
()
logits_processor
.
append
(
InfNanRemoveLogitsProcessor
())
return
logits_processor
def
get_peak_memory
()
->
Tuple
[
int
,
int
]:
r
"""
Gets the peak memory usage for the current device (in Bytes).
"""
def
get_peak_memory
()
->
tuple
[
int
,
int
]:
r
"""Get the peak memory usage for the current device (in Bytes)."""
if
is_torch_npu_available
():
return
torch
.
npu
.
max_memory_allocated
(),
torch
.
npu
.
max_memory_reserved
()
elif
is_torch_cuda_available
():
...
...
@@ -201,16 +184,12 @@ def get_peak_memory() -> Tuple[int, int]:
def
has_tokenized_data
(
path
:
"os.PathLike"
)
->
bool
:
r
"""
Checks if the path has a tokenized dataset.
"""
r
"""Check if the path has a tokenized dataset."""
return
os
.
path
.
isdir
(
path
)
and
len
(
os
.
listdir
(
path
))
>
0
def
infer_optim_dtype
(
model_dtype
:
"torch.dtype"
)
->
"torch.dtype"
:
r
"""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
r
"""Infer the optimal dtype according to the model_dtype and device compatibility."""
if
_is_bf16_available
and
model_dtype
==
torch
.
bfloat16
:
return
torch
.
bfloat16
elif
_is_fp16_available
:
...
...
@@ -220,23 +199,17 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
def
is_gpu_or_npu_available
()
->
bool
:
r
"""
Checks if the GPU or NPU is available.
"""
r
"""Check if the GPU or NPU is available."""
return
is_torch_npu_available
()
or
is_torch_cuda_available
()
def
is_env_enabled
(
env_var
:
str
,
default
:
str
=
"0"
)
->
bool
:
r
"""
Checks if the environment variable is enabled.
"""
r
"""Check if the environment variable is enabled."""
return
os
.
getenv
(
env_var
,
default
).
lower
()
in
[
"true"
,
"y"
,
"1"
]
def
numpify
(
inputs
:
Union
[
"NDArray"
,
"torch.Tensor"
])
->
"NDArray"
:
r
"""
Casts a torch tensor or a numpy array to a numpy array.
"""
r
"""Cast a torch tensor or a numpy array to a numpy array."""
if
isinstance
(
inputs
,
torch
.
Tensor
):
inputs
=
inputs
.
cpu
()
if
inputs
.
dtype
==
torch
.
bfloat16
:
# numpy does not support bfloat16 until 1.21.4
...
...
@@ -248,17 +221,13 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
def
skip_check_imports
()
->
None
:
r
"""
Avoids flash attention import error in custom model files.
"""
r
"""Avoid flash attention import error in custom model files."""
if
not
is_env_enabled
(
"FORCE_CHECK_IMPORTS"
):
transformers
.
dynamic_module_utils
.
check_imports
=
get_relative_imports
def
torch_gc
()
->
None
:
r
"""
Collects GPU or NPU memory.
"""
r
"""Collect GPU or NPU memory."""
gc
.
collect
()
if
is_torch_xpu_available
():
torch
.
xpu
.
empty_cache
()
...
...
@@ -306,3 +275,20 @@ def use_openmind() -> bool:
def
use_ray
()
->
bool
:
return
is_env_enabled
(
"USE_RAY"
)
def
find_available_port
()
->
int
:
"""Find an available port on the local machine."""
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
bind
((
""
,
0
))
port
=
sock
.
getsockname
()[
1
]
sock
.
close
()
return
port
def
fix_proxy
(
ipv6_enabled
:
bool
)
->
None
:
"""Fix proxy settings for gradio ui."""
os
.
environ
[
"no_proxy"
]
=
"localhost,127.0.0.1,0.0.0.0"
if
ipv6_enabled
:
for
name
in
(
"http_proxy"
,
"https_proxy"
,
"HTTP_PROXY"
,
"HTTPS_PROXY"
):
os
.
environ
.
pop
(
name
,
None
)
src/llamafactory/extras/packages.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
...
...
@@ -97,3 +97,7 @@ def is_uvicorn_available():
def
is_vllm_available
():
return
_is_package_available
(
"vllm"
)
def
is_sglang_available
():
return
_is_package_available
(
"sglang"
)
src/llamafactory/extras/ploting.py
View file @
7ea81099
...
...
@@ -15,7 +15,7 @@
import
json
import
math
import
os
from
typing
import
Any
,
Dict
,
List
from
typing
import
Any
from
transformers.trainer
import
TRAINER_STATE_NAME
...
...
@@ -31,10 +31,8 @@ if is_matplotlib_available():
logger
=
logging
.
get_logger
(
__name__
)
def
smooth
(
scalars
:
List
[
float
])
->
List
[
float
]:
r
"""
EMA implementation according to TensorBoard.
"""
def
smooth
(
scalars
:
list
[
float
])
->
list
[
float
]:
r
"""EMA implementation according to TensorBoard."""
if
len
(
scalars
)
==
0
:
return
[]
...
...
@@ -48,10 +46,8 @@ def smooth(scalars: List[float]) -> List[float]:
return
smoothed
def
gen_loss_plot
(
trainer_log
:
List
[
Dict
[
str
,
Any
]])
->
"matplotlib.figure.Figure"
:
r
"""
Plots loss curves in LlamaBoard.
"""
def
gen_loss_plot
(
trainer_log
:
list
[
dict
[
str
,
Any
]])
->
"matplotlib.figure.Figure"
:
r
"""Plot loss curves in LlamaBoard."""
plt
.
close
(
"all"
)
plt
.
switch_backend
(
"agg"
)
fig
=
plt
.
figure
()
...
...
@@ -70,10 +66,8 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
return
fig
def
plot_loss
(
save_dictionary
:
str
,
keys
:
List
[
str
]
=
[
"loss"
])
->
None
:
r
"""
Plots loss curves and saves the image.
"""
def
plot_loss
(
save_dictionary
:
str
,
keys
:
list
[
str
]
=
[
"loss"
])
->
None
:
r
"""Plot loss curves and saves the image."""
plt
.
switch_backend
(
"agg"
)
with
open
(
os
.
path
.
join
(
save_dictionary
,
TRAINER_STATE_NAME
),
encoding
=
"utf-8"
)
as
f
:
data
=
json
.
load
(
f
)
...
...
src/llamafactory/hparams/data_args.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
...
...
@@ -16,14 +16,12 @@
# limitations under the License.
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Dict
,
Literal
,
Optional
from
typing
import
Any
,
Literal
,
Optional
@
dataclass
class
DataArguments
:
r
"""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
r
"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
template
:
Optional
[
str
]
=
field
(
default
=
None
,
...
...
@@ -162,5 +160,5 @@ class DataArguments:
if
self
.
mask_history
and
self
.
train_on_prompt
:
raise
ValueError
(
"`mask_history` is incompatible with `train_on_prompt`."
)
def
to_dict
(
self
)
->
D
ict
[
str
,
Any
]:
def
to_dict
(
self
)
->
d
ict
[
str
,
Any
]:
return
asdict
(
self
)
src/llamafactory/hparams/evaluation_args.py
View file @
7ea81099
...
...
@@ -21,9 +21,7 @@ from datasets import DownloadMode
@
dataclass
class
EvaluationArguments
:
r
"""
Arguments pertaining to specify the evaluation parameters.
"""
r
"""Arguments pertaining to specify the evaluation parameters."""
task
:
str
=
field
(
metadata
=
{
"help"
:
"Name of the evaluation task."
},
...
...
src/llamafactory/hparams/finetuning_args.py
View file @
7ea81099
...
...
@@ -13,14 +13,12 @@
# limitations under the License.
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
from
typing
import
Any
,
Literal
,
Optional
@
dataclass
class
FreezeArguments
:
r
"""
Arguments pertaining to the freeze (partial-parameter) training.
"""
r
"""Arguments pertaining to the freeze (partial-parameter) training."""
freeze_trainable_layers
:
int
=
field
(
default
=
2
,
...
...
@@ -56,9 +54,7 @@ class FreezeArguments:
@
dataclass
class
LoraArguments
:
r
"""
Arguments pertaining to the LoRA training.
"""
r
"""Arguments pertaining to the LoRA training."""
additional_target
:
Optional
[
str
]
=
field
(
default
=
None
,
...
...
@@ -128,9 +124,7 @@ class LoraArguments:
@
dataclass
class
RLHFArguments
:
r
"""
Arguments pertaining to the PPO, DPO and KTO training.
"""
r
"""Arguments pertaining to the PPO, DPO and KTO training."""
pref_beta
:
float
=
field
(
default
=
0.1
,
...
...
@@ -212,9 +206,7 @@ class RLHFArguments:
@
dataclass
class
GaloreArguments
:
r
"""
Arguments pertaining to the GaLore algorithm.
"""
r
"""Arguments pertaining to the GaLore algorithm."""
use_galore
:
bool
=
field
(
default
=
False
,
...
...
@@ -253,9 +245,7 @@ class GaloreArguments:
@
dataclass
class
ApolloArguments
:
r
"""
Arguments pertaining to the APOLLO algorithm.
"""
r
"""Arguments pertaining to the APOLLO algorithm."""
use_apollo
:
bool
=
field
(
default
=
False
,
...
...
@@ -306,9 +296,7 @@ class ApolloArguments:
@
dataclass
class
BAdamArgument
:
r
"""
Arguments pertaining to the BAdam optimizer.
"""
r
"""Arguments pertaining to the BAdam optimizer."""
use_badam
:
bool
=
field
(
default
=
False
,
...
...
@@ -387,15 +375,21 @@ class SwanLabArguments:
default
=
None
,
metadata
=
{
"help"
:
"The log directory for SwanLab."
},
)
swanlab_lark_webhook_url
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The Lark(飞书) webhook URL for SwanLab."
},
)
swanlab_lark_secret
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The Lark(飞书) secret for SwanLab."
},
)
@
dataclass
class
FinetuningArguments
(
SwanLabArguments
,
BAdamArgument
,
ApolloArguments
,
GaloreArguments
,
RLHFArguments
,
LoraArguments
,
FreezeArguments
):
r
"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
r
"""Arguments pertaining to which techniques we are going to fine-tuning with."""
pure_bf16
:
bool
=
field
(
default
=
False
,
...
...
@@ -452,13 +446,13 @@ class FinetuningArguments(
return
[
item
.
strip
()
for
item
in
arg
.
split
(
","
)]
return
arg
self
.
freeze_trainable_modules
:
L
ist
[
str
]
=
split_arg
(
self
.
freeze_trainable_modules
)
self
.
freeze_extra_modules
:
Optional
[
L
ist
[
str
]]
=
split_arg
(
self
.
freeze_extra_modules
)
self
.
freeze_trainable_modules
:
l
ist
[
str
]
=
split_arg
(
self
.
freeze_trainable_modules
)
self
.
freeze_extra_modules
:
Optional
[
l
ist
[
str
]]
=
split_arg
(
self
.
freeze_extra_modules
)
self
.
lora_alpha
:
int
=
self
.
lora_alpha
or
self
.
lora_rank
*
2
self
.
lora_target
:
L
ist
[
str
]
=
split_arg
(
self
.
lora_target
)
self
.
additional_target
:
Optional
[
L
ist
[
str
]]
=
split_arg
(
self
.
additional_target
)
self
.
galore_target
:
L
ist
[
str
]
=
split_arg
(
self
.
galore_target
)
self
.
apollo_target
:
L
ist
[
str
]
=
split_arg
(
self
.
apollo_target
)
self
.
lora_target
:
l
ist
[
str
]
=
split_arg
(
self
.
lora_target
)
self
.
additional_target
:
Optional
[
l
ist
[
str
]]
=
split_arg
(
self
.
additional_target
)
self
.
galore_target
:
l
ist
[
str
]
=
split_arg
(
self
.
galore_target
)
self
.
apollo_target
:
l
ist
[
str
]
=
split_arg
(
self
.
apollo_target
)
self
.
use_ref_model
=
self
.
stage
==
"dpo"
and
self
.
pref_loss
not
in
[
"orpo"
,
"simpo"
]
assert
self
.
finetuning_type
in
[
"lora"
,
"freeze"
,
"full"
],
"Invalid fine-tuning method."
...
...
@@ -499,7 +493,7 @@ class FinetuningArguments(
if
self
.
pissa_init
:
raise
ValueError
(
"`pissa_init` is only valid for LoRA training."
)
def
to_dict
(
self
)
->
D
ict
[
str
,
Any
]:
def
to_dict
(
self
)
->
d
ict
[
str
,
Any
]:
args
=
asdict
(
self
)
args
=
{
k
:
f
"<
{
k
.
upper
()
}
>"
if
k
.
endswith
(
"api_key"
)
else
v
for
k
,
v
in
args
.
items
()}
return
args
src/llamafactory/hparams/generating_args.py
View file @
7ea81099
...
...
@@ -13,16 +13,14 @@
# limitations under the License.
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
Optional
from
transformers
import
GenerationConfig
@
dataclass
class
GeneratingArguments
:
r
"""
Arguments pertaining to specify the decoding parameters.
"""
r
"""Arguments pertaining to specify the decoding parameters."""
do_sample
:
bool
=
field
(
default
=
True
,
...
...
@@ -35,7 +33,9 @@ class GeneratingArguments:
top_p
:
float
=
field
(
default
=
0.7
,
metadata
=
{
"help"
:
"The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
"help"
:
(
"The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
)
},
)
top_k
:
int
=
field
(
...
...
@@ -71,7 +71,7 @@ class GeneratingArguments:
metadata
=
{
"help"
:
"Whether or not to remove special tokens in the decoding."
},
)
def
to_dict
(
self
,
obey_generation_config
:
bool
=
False
)
->
D
ict
[
str
,
Any
]:
def
to_dict
(
self
,
obey_generation_config
:
bool
=
False
)
->
d
ict
[
str
,
Any
]:
args
=
asdict
(
self
)
if
args
.
get
(
"max_new_tokens"
,
-
1
)
>
0
:
args
.
pop
(
"max_length"
,
None
)
...
...
src/llamafactory/hparams/model_args.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
...
...
@@ -17,7 +17,7 @@
import
json
from
dataclasses
import
asdict
,
dataclass
,
field
,
fields
from
typing
import
Any
,
Dict
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
Literal
,
Optional
,
Union
import
torch
from
transformers.training_args
import
_convert_str_dict
...
...
@@ -28,9 +28,7 @@ from ..extras.constants import AttentionFunction, EngineName, RopeScaling
@
dataclass
class
BaseModelArguments
:
r
"""
Arguments pertaining to the model.
"""
r
"""Arguments pertaining to the model."""
model_name_or_path
:
Optional
[
str
]
=
field
(
default
=
None
,
...
...
@@ -184,9 +182,7 @@ class BaseModelArguments:
@
dataclass
class
QuantizationArguments
:
r
"""
Arguments pertaining to the quantization method.
"""
r
"""Arguments pertaining to the quantization method."""
quantization_method
:
Literal
[
"bitsandbytes"
,
"hqq"
,
"eetq"
]
=
field
(
default
=
"bitsandbytes"
,
...
...
@@ -212,9 +208,7 @@ class QuantizationArguments:
@
dataclass
class
ProcessorArguments
:
r
"""
Arguments pertaining to the image processor.
"""
r
"""Arguments pertaining to the image processor."""
image_max_pixels
:
int
=
field
(
default
=
768
*
768
,
...
...
@@ -224,6 +218,14 @@ class ProcessorArguments:
default
=
32
*
32
,
metadata
=
{
"help"
:
"The minimum number of pixels of image inputs."
},
)
image_do_pan_and_scan
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Use pan and scan to process image for gemma3."
},
)
use_audio_in_video
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use audio in video inputs."
},
)
video_max_pixels
:
int
=
field
(
default
=
256
*
256
,
metadata
=
{
"help"
:
"The maximum number of pixels of video inputs."
},
...
...
@@ -240,13 +242,22 @@ class ProcessorArguments:
default
=
128
,
metadata
=
{
"help"
:
"The maximum number of sampled frames for video inputs."
},
)
audio_sampling_rate
:
int
=
field
(
default
=
16000
,
metadata
=
{
"help"
:
"The sampling rate of audio inputs."
},
)
def
__post_init__
(
self
):
if
self
.
image_max_pixels
<
self
.
image_min_pixels
:
raise
ValueError
(
"`image_max_pixels` cannot be smaller than `image_min_pixels`."
)
if
self
.
video_max_pixels
<
self
.
video_min_pixels
:
raise
ValueError
(
"`video_max_pixels` cannot be smaller than `video_min_pixels`."
)
@
dataclass
class
ExportArguments
:
r
"""
Arguments pertaining to the model export.
"""
r
"""Arguments pertaining to the model export."""
export_dir
:
Optional
[
str
]
=
field
(
default
=
None
,
...
...
@@ -292,16 +303,14 @@ class ExportArguments:
@
dataclass
class
VllmArguments
:
r
"""
Arguments pertaining to the vLLM worker.
"""
r
"""Arguments pertaining to the vLLM worker."""
vllm_maxlen
:
int
=
field
(
default
=
4096
,
metadata
=
{
"help"
:
"Maximum sequence (prompt + response) length of the vLLM engine."
},
)
vllm_gpu_util
:
float
=
field
(
default
=
0.
9
,
default
=
0.
7
,
metadata
=
{
"help"
:
"The fraction of GPU memory in (0,1) to be used for the vLLM engine."
},
)
vllm_enforce_eager
:
bool
=
field
(
...
...
@@ -323,9 +332,36 @@ class VllmArguments:
@
dataclass
class
ModelArguments
(
VllmArguments
,
ExportArguments
,
ProcessorArguments
,
QuantizationArguments
,
BaseModelArguments
):
r
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
class
SGLangArguments
:
r
"""Arguments pertaining to the SGLang worker."""
sglang_maxlen
:
int
=
field
(
default
=
4096
,
metadata
=
{
"help"
:
"Maximum sequence (prompt + response) length of the SGLang engine."
},
)
sglang_mem_fraction
:
float
=
field
(
default
=
0.7
,
metadata
=
{
"help"
:
"The memory fraction (0-1) to be used for the SGLang engine."
},
)
sglang_tp_size
:
int
=
field
(
default
=-
1
,
metadata
=
{
"help"
:
"Tensor parallel size for the SGLang engine."
},
)
sglang_config
:
Optional
[
Union
[
dict
,
str
]]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Config to initialize the SGLang engine. Please use JSON strings."
},
)
def
__post_init__
(
self
):
if
isinstance
(
self
.
sglang_config
,
str
)
and
self
.
sglang_config
.
startswith
(
"{"
):
self
.
sglang_config
=
_convert_str_dict
(
json
.
loads
(
self
.
sglang_config
))
@
dataclass
class
ModelArguments
(
SGLangArguments
,
VllmArguments
,
ExportArguments
,
ProcessorArguments
,
QuantizationArguments
,
BaseModelArguments
):
r
"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
The class on the most right will be displayed first.
"""
...
...
@@ -335,7 +371,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
init
=
False
,
metadata
=
{
"help"
:
"Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."
},
)
device_map
:
Optional
[
Union
[
str
,
D
ict
[
str
,
Any
]]]
=
field
(
device_map
:
Optional
[
Union
[
str
,
d
ict
[
str
,
Any
]]]
=
field
(
default
=
None
,
init
=
False
,
metadata
=
{
"help"
:
"Device map for model placement, derived from training stage. Do not specify it."
},
...
...
@@ -353,8 +389,10 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
def
__post_init__
(
self
):
BaseModelArguments
.
__post_init__
(
self
)
ProcessorArguments
.
__post_init__
(
self
)
ExportArguments
.
__post_init__
(
self
)
VllmArguments
.
__post_init__
(
self
)
SGLangArguments
.
__post_init__
(
self
)
@
classmethod
def
copyfrom
(
cls
,
source
:
"Self"
,
**
kwargs
)
->
"Self"
:
...
...
@@ -372,7 +410,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
return
result
def
to_dict
(
self
)
->
D
ict
[
str
,
Any
]:
def
to_dict
(
self
)
->
d
ict
[
str
,
Any
]:
args
=
asdict
(
self
)
args
=
{
k
:
f
"<
{
k
.
upper
()
}
>"
if
k
.
endswith
(
"token"
)
else
v
for
k
,
v
in
args
.
items
()}
return
args
src/llamafactory/hparams/parser.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
...
...
@@ -19,7 +19,7 @@ import json
import
os
import
sys
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Optional
,
Union
import
torch
import
transformers
...
...
@@ -31,7 +31,7 @@ from transformers.training_args import ParallelMode
from
transformers.utils
import
is_torch_bf16_gpu_available
,
is_torch_npu_available
from
..extras
import
logging
from
..extras.constants
import
CHECKPOINT_NAMES
from
..extras.constants
import
CHECKPOINT_NAMES
,
EngineName
from
..extras.misc
import
check_dependencies
,
check_version
,
get_current_device
,
is_env_enabled
from
.data_args
import
DataArguments
from
.evaluation_args
import
EvaluationArguments
...
...
@@ -47,17 +47,15 @@ check_dependencies()
_TRAIN_ARGS
=
[
ModelArguments
,
DataArguments
,
TrainingArguments
,
FinetuningArguments
,
GeneratingArguments
]
_TRAIN_CLS
=
T
uple
[
ModelArguments
,
DataArguments
,
TrainingArguments
,
FinetuningArguments
,
GeneratingArguments
]
_TRAIN_CLS
=
t
uple
[
ModelArguments
,
DataArguments
,
TrainingArguments
,
FinetuningArguments
,
GeneratingArguments
]
_INFER_ARGS
=
[
ModelArguments
,
DataArguments
,
FinetuningArguments
,
GeneratingArguments
]
_INFER_CLS
=
T
uple
[
ModelArguments
,
DataArguments
,
FinetuningArguments
,
GeneratingArguments
]
_INFER_CLS
=
t
uple
[
ModelArguments
,
DataArguments
,
FinetuningArguments
,
GeneratingArguments
]
_EVAL_ARGS
=
[
ModelArguments
,
DataArguments
,
EvaluationArguments
,
FinetuningArguments
]
_EVAL_CLS
=
T
uple
[
ModelArguments
,
DataArguments
,
EvaluationArguments
,
FinetuningArguments
]
_EVAL_CLS
=
t
uple
[
ModelArguments
,
DataArguments
,
EvaluationArguments
,
FinetuningArguments
]
def
read_args
(
args
:
Optional
[
Union
[
Dict
[
str
,
Any
],
List
[
str
]]]
=
None
)
->
Union
[
Dict
[
str
,
Any
],
List
[
str
]]:
r
"""
Gets arguments from the command line or a config file.
"""
def
read_args
(
args
:
Optional
[
Union
[
dict
[
str
,
Any
],
list
[
str
]]]
=
None
)
->
Union
[
dict
[
str
,
Any
],
list
[
str
]]:
r
"""Get arguments from the command line or a config file."""
if
args
is
not
None
:
return
args
...
...
@@ -70,8 +68,8 @@ def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[
def
_parse_args
(
parser
:
"HfArgumentParser"
,
args
:
Optional
[
Union
[
D
ict
[
str
,
Any
],
L
ist
[
str
]]]
=
None
,
allow_extra_keys
:
bool
=
False
)
->
T
uple
[
Any
]:
parser
:
"HfArgumentParser"
,
args
:
Optional
[
Union
[
d
ict
[
str
,
Any
],
l
ist
[
str
]]]
=
None
,
allow_extra_keys
:
bool
=
False
)
->
t
uple
[
Any
]:
args
=
read_args
(
args
)
if
isinstance
(
args
,
dict
):
return
parser
.
parse_dict
(
args
,
allow_extra_keys
=
allow_extra_keys
)
...
...
@@ -136,9 +134,12 @@ def _check_extra_dependencies(
if
model_args
.
mixture_of_depths
is
not
None
:
check_version
(
"mixture-of-depth>=1.1.6"
,
mandatory
=
True
)
if
model_args
.
infer_backend
==
"vllm"
:
check_version
(
"vllm>=0.4.3,<=0.
7.3
"
)
if
model_args
.
infer_backend
==
EngineName
.
VLLM
:
check_version
(
"vllm>=0.4.3,<=0.
8.2
"
)
check_version
(
"vllm"
,
mandatory
=
True
)
elif
model_args
.
infer_backend
==
EngineName
.
SGLANG
:
check_version
(
"sglang>=0.4.4"
)
check_version
(
"sglang"
,
mandatory
=
True
)
if
finetuning_args
.
use_galore
:
check_version
(
"galore_torch"
,
mandatory
=
True
)
...
...
@@ -161,31 +162,31 @@ def _check_extra_dependencies(
check_version
(
"rouge_chinese"
,
mandatory
=
True
)
def
_parse_train_args
(
args
:
Optional
[
Union
[
D
ict
[
str
,
Any
],
L
ist
[
str
]]]
=
None
)
->
_TRAIN_CLS
:
def
_parse_train_args
(
args
:
Optional
[
Union
[
d
ict
[
str
,
Any
],
l
ist
[
str
]]]
=
None
)
->
_TRAIN_CLS
:
parser
=
HfArgumentParser
(
_TRAIN_ARGS
)
allow_extra_keys
=
is_env_enabled
(
"ALLOW_EXTRA_ARGS"
)
return
_parse_args
(
parser
,
args
,
allow_extra_keys
=
allow_extra_keys
)
def
_parse_infer_args
(
args
:
Optional
[
Union
[
D
ict
[
str
,
Any
],
L
ist
[
str
]]]
=
None
)
->
_INFER_CLS
:
def
_parse_infer_args
(
args
:
Optional
[
Union
[
d
ict
[
str
,
Any
],
l
ist
[
str
]]]
=
None
)
->
_INFER_CLS
:
parser
=
HfArgumentParser
(
_INFER_ARGS
)
allow_extra_keys
=
is_env_enabled
(
"ALLOW_EXTRA_ARGS"
)
return
_parse_args
(
parser
,
args
,
allow_extra_keys
=
allow_extra_keys
)
def
_parse_eval_args
(
args
:
Optional
[
Union
[
D
ict
[
str
,
Any
],
L
ist
[
str
]]]
=
None
)
->
_EVAL_CLS
:
def
_parse_eval_args
(
args
:
Optional
[
Union
[
d
ict
[
str
,
Any
],
l
ist
[
str
]]]
=
None
)
->
_EVAL_CLS
:
parser
=
HfArgumentParser
(
_EVAL_ARGS
)
allow_extra_keys
=
is_env_enabled
(
"ALLOW_EXTRA_ARGS"
)
return
_parse_args
(
parser
,
args
,
allow_extra_keys
=
allow_extra_keys
)
def
get_ray_args
(
args
:
Optional
[
Union
[
D
ict
[
str
,
Any
],
L
ist
[
str
]]]
=
None
)
->
RayArguments
:
def
get_ray_args
(
args
:
Optional
[
Union
[
d
ict
[
str
,
Any
],
l
ist
[
str
]]]
=
None
)
->
RayArguments
:
parser
=
HfArgumentParser
(
RayArguments
)
(
ray_args
,)
=
_parse_args
(
parser
,
args
,
allow_extra_keys
=
True
)
return
ray_args
def
get_train_args
(
args
:
Optional
[
Union
[
D
ict
[
str
,
Any
],
L
ist
[
str
]]]
=
None
)
->
_TRAIN_CLS
:
def
get_train_args
(
args
:
Optional
[
Union
[
d
ict
[
str
,
Any
],
l
ist
[
str
]]]
=
None
)
->
_TRAIN_CLS
:
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
=
_parse_train_args
(
args
)
# Setup logging
...
...
@@ -364,9 +365,7 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
and
training_args
.
resume_from_checkpoint
is
not
None
):
logger
.
warning_rank0
(
"Add {} to `adapter_name_or_path` to resume training from checkpoint."
.
format
(
training_args
.
resume_from_checkpoint
)
f
"Add
{
training_args
.
resume_from_checkpoint
}
to `adapter_name_or_path` to resume training from checkpoint."
)
# Post-process model arguments
...
...
@@ -382,20 +381,17 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
# Log on each process the small summary
logger
.
info
(
"Process rank: {}, world size: {}, device: {}, distributed training: {}, compute dtype: {}"
.
format
(
training_args
.
process_index
,
training_args
.
world_size
,
training_args
.
device
,
training_args
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
,
str
(
model_args
.
compute_dtype
),
)
f
"Process rank:
{
training_args
.
process_index
}
, "
f
"world size:
{
training_args
.
world_size
}
, device:
{
training_args
.
device
}
, "
f
"distributed training:
{
training_args
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
}
, "
f
"compute dtype:
{
str
(
model_args
.
compute_dtype
)
}
"
)
transformers
.
set_seed
(
training_args
.
seed
)
return
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
def
get_infer_args
(
args
:
Optional
[
Union
[
D
ict
[
str
,
Any
],
L
ist
[
str
]]]
=
None
)
->
_INFER_CLS
:
def
get_infer_args
(
args
:
Optional
[
Union
[
d
ict
[
str
,
Any
],
l
ist
[
str
]]]
=
None
)
->
_INFER_CLS
:
model_args
,
data_args
,
finetuning_args
,
generating_args
=
_parse_infer_args
(
args
)
_set_transformers_logging
()
...
...
@@ -426,7 +422,7 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
return
model_args
,
data_args
,
finetuning_args
,
generating_args
def
get_eval_args
(
args
:
Optional
[
Union
[
D
ict
[
str
,
Any
],
L
ist
[
str
]]]
=
None
)
->
_EVAL_CLS
:
def
get_eval_args
(
args
:
Optional
[
Union
[
d
ict
[
str
,
Any
],
l
ist
[
str
]]]
=
None
)
->
_EVAL_CLS
:
model_args
,
data_args
,
eval_args
,
finetuning_args
=
_parse_eval_args
(
args
)
_set_transformers_logging
()
...
...
src/llamafactory/hparams/training_args.py
View file @
7ea81099
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
from
dataclasses
import
dataclass
,
field
from
typing
import
Literal
,
Optional
,
Union
...
...
@@ -10,9 +24,7 @@ from ..extras.misc import use_ray
@
dataclass
class
RayArguments
:
r
"""
Arguments pertaining to the Ray training.
"""
r
"""Arguments pertaining to the Ray training."""
ray_run_name
:
Optional
[
str
]
=
field
(
default
=
None
,
...
...
@@ -43,9 +55,7 @@ class RayArguments:
@
dataclass
class
TrainingArguments
(
RayArguments
,
Seq2SeqTrainingArguments
):
r
"""
Arguments pertaining to the trainer.
"""
r
"""Arguments pertaining to the trainer."""
def
__post_init__
(
self
):
Seq2SeqTrainingArguments
.
__post_init__
(
self
)
...
...
src/llamafactory/model/__init__.py
View file @
7ea81099
...
...
@@ -20,9 +20,9 @@ from .model_utils.valuehead import load_valuehead_params
__all__
=
[
"QuantizationMethod"
,
"find_all_linear_modules"
,
"load_config"
,
"load_model"
,
"load_tokenizer"
,
"find_all_linear_modules"
,
"load_valuehead_params"
,
]
src/llamafactory/model/adapter.py
View file @
7ea81099
...
...
@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING
import
torch
from
peft
import
LoraConfig
,
LoraModel
,
PeftModel
,
TaskType
,
get_peft_model
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.modeling_utils
import
is_fsdp_enabled
from
..extras
import
logging
from
.model_utils.misc
import
find_all_linear_modules
,
find_expanded_modules
...
...
@@ -81,9 +80,8 @@ def _setup_freeze_tuning(
if
finetuning_args
.
use_llama_pro
:
if
num_layers
%
finetuning_args
.
freeze_trainable_layers
!=
0
:
raise
ValueError
(
"`num_layers` {} should be divisible by `num_layer_trainable` {}."
.
format
(
num_layers
,
finetuning_args
.
freeze_trainable_layers
)
f
"`num_layers`
{
num_layers
}
should be "
f
"divisible by `num_layer_trainable`
{
finetuning_args
.
freeze_trainable_layers
}
."
)
stride
=
num_layers
//
finetuning_args
.
freeze_trainable_layers
...
...
@@ -178,7 +176,7 @@ def _setup_lora_tuning(
}
for
adapter
in
adapter_to_merge
:
model
:
"
LoraModel
"
=
PeftModel
.
from_pretrained
(
model
,
adapter
,
**
init_kwargs
)
model
:
LoraModel
=
PeftModel
.
from_pretrained
(
model
,
adapter
,
**
init_kwargs
)
model
=
model
.
merge_and_unload
()
if
len
(
adapter_to_merge
)
>
0
:
...
...
@@ -263,8 +261,7 @@ def init_adapter(
finetuning_args
:
"FinetuningArguments"
,
is_trainable
:
bool
,
)
->
"PreTrainedModel"
:
r
"""
Initializes the adapters.
r
"""Initialize the adapters.
Support full-parameter, freeze and LoRA training.
...
...
@@ -279,14 +276,14 @@ def init_adapter(
# cast trainable parameters to float32 if:
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
# 2. is_trainable and not pure_bf16 and not badam and not zero3
and not fsdp (zero3 or fsdp
already in fp32)
# 2. is_trainable and not pure_bf16 and not badam and not zero3
(zero3
already in fp32)
cast_trainable_params_to_fp32
=
False
if
not
is_trainable
:
pass
elif
finetuning_args
.
pure_bf16
or
finetuning_args
.
use_badam
:
logger
.
info_rank0
(
"Pure bf16 / BAdam detected, remaining trainable params in half precision."
)
elif
model_args
.
quantization_bit
is
None
and
(
is_deepspeed_zero3_enabled
()
or
is_fsdp_enabled
())
:
logger
.
info_rank0
(
"
ZeRO3 / FSDP
detected, remaining trainable params in float32."
)
elif
model_args
.
quantization_bit
is
None
and
is_deepspeed_zero3_enabled
():
logger
.
info_rank0
(
"
DeepSpeed ZeRO3
detected, remaining trainable params in float32."
)
else
:
logger
.
info_rank0
(
"Upcasting trainable params to float32."
)
cast_trainable_params_to_fp32
=
True
...
...
src/llamafactory/model/loader.py
View file @
7ea81099
...
...
@@ -13,13 +13,15 @@
# limitations under the License.
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
,
TypedDict
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
TypedDict
import
torch
from
transformers
import
(
AutoConfig
,
AutoModelForCausalLM
,
AutoModelForImageTextToText
,
AutoModelForSeq2SeqLM
,
AutoModelForTextToWaveform
,
AutoModelForVision2Seq
,
AutoProcessor
,
AutoTokenizer
,
...
...
@@ -51,9 +53,8 @@ class TokenizerModule(TypedDict):
processor
:
Optional
[
"ProcessorMixin"
]
def
_get_init_kwargs
(
model_args
:
"ModelArguments"
)
->
Dict
[
str
,
Any
]:
r
"""
Gets arguments to load config/tokenizer/model.
def
_get_init_kwargs
(
model_args
:
"ModelArguments"
)
->
dict
[
str
,
Any
]:
r
"""Get arguments to load config/tokenizer/model.
Note: including inplace operation of model_args.
"""
...
...
@@ -68,13 +69,11 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
def
load_tokenizer
(
model_args
:
"ModelArguments"
)
->
"TokenizerModule"
:
r
"""
Loads pretrained tokenizer and optionally loads processor.
r
"""Load pretrained tokenizer and optionally loads processor.
Note: including inplace operation of model_args.
"""
init_kwargs
=
_get_init_kwargs
(
model_args
)
config
=
load_config
(
model_args
)
try
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_args
.
model_name_or_path
,
...
...
@@ -96,7 +95,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
patch_tokenizer
(
tokenizer
,
model_args
)
try
:
processor
=
AutoProcessor
.
from_pretrained
(
model_args
.
model_name_or_path
,
**
init_kwargs
)
patch_processor
(
processor
,
config
,
tokenizer
,
model_args
)
patch_processor
(
processor
,
tokenizer
,
model_args
)
except
Exception
as
e
:
logger
.
debug
(
f
"Processor was not found:
{
e
}
."
)
processor
=
None
...
...
@@ -110,9 +109,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
def
load_config
(
model_args
:
"ModelArguments"
)
->
"PretrainedConfig"
:
r
"""
Loads model config.
"""
r
"""Load model config."""
init_kwargs
=
_get_init_kwargs
(
model_args
)
return
AutoConfig
.
from_pretrained
(
model_args
.
model_name_or_path
,
**
init_kwargs
)
...
...
@@ -124,9 +121,7 @@ def load_model(
is_trainable
:
bool
=
False
,
add_valuehead
:
bool
=
False
,
)
->
"PreTrainedModel"
:
r
"""
Loads pretrained model.
"""
r
"""Load pretrained model."""
init_kwargs
=
_get_init_kwargs
(
model_args
)
config
=
load_config
(
model_args
)
patch_config
(
config
,
tokenizer
,
model_args
,
init_kwargs
,
is_trainable
)
...
...
@@ -147,10 +142,14 @@ def load_model(
if
model_args
.
mixture_of_depths
==
"load"
:
model
=
load_mod_pretrained_model
(
**
init_kwargs
)
else
:
if
type
(
config
)
in
AutoModelForVision2Seq
.
_model_mapping
.
keys
():
#
assume built-in models
if
type
(
config
)
in
AutoModelForVision2Seq
.
_model_mapping
.
keys
():
#
image-text
load_class
=
AutoModelForVision2Seq
elif
type
(
config
)
in
AutoModelForSeq2SeqLM
.
_model_mapping
.
keys
():
elif
type
(
config
)
in
AutoModelForImageTextToText
.
_model_mapping
.
keys
():
# image-text
load_class
=
AutoModelForImageTextToText
elif
type
(
config
)
in
AutoModelForSeq2SeqLM
.
_model_mapping
.
keys
():
# audio-text
load_class
=
AutoModelForSeq2SeqLM
elif
type
(
config
)
in
AutoModelForTextToWaveform
.
_model_mapping
.
keys
():
# audio hack for qwen2_5_omni
load_class
=
AutoModelForTextToWaveform
else
:
load_class
=
AutoModelForCausalLM
...
...
@@ -158,6 +157,8 @@ def load_model(
model
=
load_class
.
from_config
(
config
,
trust_remote_code
=
model_args
.
trust_remote_code
)
else
:
model
=
load_class
.
from_pretrained
(
**
init_kwargs
)
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"qwen2_5_omni"
:
model
=
model
.
thinker
# use part of Omni model
if
model_args
.
mixture_of_depths
==
"convert"
:
model
=
convert_pretrained_model_to_mod
(
model
,
config
,
model_args
)
...
...
@@ -194,8 +195,9 @@ def load_model(
trainable_params
,
all_param
=
count_parameters
(
model
)
if
is_trainable
:
param_stats
=
"trainable params: {:,} || all params: {:,} || trainable%: {:.4f}"
.
format
(
trainable_params
,
all_param
,
100
*
trainable_params
/
all_param
param_stats
=
(
f
"trainable params:
{
trainable_params
:,
}
|| "
f
"all params:
{
all_param
:,
}
|| trainable%:
{
100
*
trainable_params
/
all_param
:.
4
f
}
"
)
else
:
param_stats
=
f
"all params:
{
all_param
:,
}
"
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment