Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
7ea81099
Commit
7ea81099
authored
Apr 07, 2025
by
chenych
Browse files
update llama4
parent
84987715
Changes
139
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
485 additions
and
364 deletions
+485
-364
src/llamafactory/data/template.py
src/llamafactory/data/template.py
+117
-73
src/llamafactory/data/tool_utils.py
src/llamafactory/data/tool_utils.py
+30
-46
src/llamafactory/eval/evaluator.py
src/llamafactory/eval/evaluator.py
+5
-5
src/llamafactory/eval/template.py
src/llamafactory/eval/template.py
+8
-10
src/llamafactory/extras/constants.py
src/llamafactory/extras/constants.py
+94
-2
src/llamafactory/extras/env.py
src/llamafactory/extras/env.py
+2
-2
src/llamafactory/extras/logging.py
src/llamafactory/extras/logging.py
+8
-22
src/llamafactory/extras/misc.py
src/llamafactory/extras/misc.py
+43
-57
src/llamafactory/extras/packages.py
src/llamafactory/extras/packages.py
+5
-1
src/llamafactory/extras/ploting.py
src/llamafactory/extras/ploting.py
+7
-13
src/llamafactory/hparams/data_args.py
src/llamafactory/hparams/data_args.py
+4
-6
src/llamafactory/hparams/evaluation_args.py
src/llamafactory/hparams/evaluation_args.py
+1
-3
src/llamafactory/hparams/finetuning_args.py
src/llamafactory/hparams/finetuning_args.py
+23
-29
src/llamafactory/hparams/generating_args.py
src/llamafactory/hparams/generating_args.py
+6
-6
src/llamafactory/hparams/model_args.py
src/llamafactory/hparams/model_args.py
+61
-23
src/llamafactory/hparams/parser.py
src/llamafactory/hparams/parser.py
+27
-31
src/llamafactory/hparams/training_args.py
src/llamafactory/hparams/training_args.py
+16
-6
src/llamafactory/model/__init__.py
src/llamafactory/model/__init__.py
+1
-1
src/llamafactory/model/adapter.py
src/llamafactory/model/adapter.py
+7
-10
src/llamafactory/model/loader.py
src/llamafactory/model/loader.py
+20
-18
No files found.
src/llamafactory/data/template.py
View file @
7ea81099
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing_extensions
import
override
from
typing_extensions
import
override
...
@@ -46,8 +46,8 @@ class Template:
...
@@ -46,8 +46,8 @@ class Template:
format_tools
:
"Formatter"
format_tools
:
"Formatter"
format_prefix
:
"Formatter"
format_prefix
:
"Formatter"
default_system
:
str
default_system
:
str
stop_words
:
L
ist
[
str
]
stop_words
:
l
ist
[
str
]
thought_words
:
T
uple
[
str
,
str
]
thought_words
:
t
uple
[
str
,
str
]
efficient_eos
:
bool
efficient_eos
:
bool
replace_eos
:
bool
replace_eos
:
bool
replace_jinja_template
:
bool
replace_jinja_template
:
bool
...
@@ -56,13 +56,11 @@ class Template:
...
@@ -56,13 +56,11 @@ class Template:
def
encode_oneturn
(
def
encode_oneturn
(
self
,
self
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
)
->
tuple
[
list
[
int
],
list
[
int
]]:
r
"""
r
"""Return a single pair of token ids representing prompt and response respectively."""
Returns a single pair of token ids representing prompt and response respectively.
"""
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
prompt_ids
=
[]
prompt_ids
=
[]
for
encoded_ids
in
encoded_messages
[:
-
1
]:
for
encoded_ids
in
encoded_messages
[:
-
1
]:
...
@@ -74,36 +72,28 @@ class Template:
...
@@ -74,36 +72,28 @@ class Template:
def
encode_multiturn
(
def
encode_multiturn
(
self
,
self
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
)
->
list
[
tuple
[
list
[
int
],
list
[
int
]]]:
r
"""
r
"""Return multiple pairs of token ids representing prompts and responses respectively."""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
return
[(
encoded_messages
[
i
],
encoded_messages
[
i
+
1
])
for
i
in
range
(
0
,
len
(
encoded_messages
),
2
)]
return
[(
encoded_messages
[
i
],
encoded_messages
[
i
+
1
])
for
i
in
range
(
0
,
len
(
encoded_messages
),
2
)]
def
extract_tool
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
def
extract_tool
(
self
,
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
r
"""
r
"""Extract tool message."""
Extracts tool message.
"""
return
self
.
format_tools
.
extract
(
content
)
return
self
.
format_tools
.
extract
(
content
)
def
get_stop_token_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
List
[
int
]:
def
get_stop_token_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
list
[
int
]:
r
"""
r
"""Return stop token ids."""
Returns stop token ids.
"""
stop_token_ids
=
{
tokenizer
.
eos_token_id
}
stop_token_ids
=
{
tokenizer
.
eos_token_id
}
for
token
in
self
.
stop_words
:
for
token
in
self
.
stop_words
:
stop_token_ids
.
add
(
tokenizer
.
convert_tokens_to_ids
(
token
))
stop_token_ids
.
add
(
tokenizer
.
convert_tokens_to_ids
(
token
))
return
list
(
stop_token_ids
)
return
list
(
stop_token_ids
)
def
_convert_elements_to_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
elements
:
"SLOTS"
)
->
List
[
int
]:
def
_convert_elements_to_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
elements
:
"SLOTS"
)
->
list
[
int
]:
r
"""
r
"""Convert elements to token ids."""
Converts elements to token ids.
"""
token_ids
=
[]
token_ids
=
[]
for
elem
in
elements
:
for
elem
in
elements
:
if
isinstance
(
elem
,
str
):
if
isinstance
(
elem
,
str
):
...
@@ -124,14 +114,14 @@ class Template:
...
@@ -124,14 +114,14 @@ class Template:
def
_encode
(
def
_encode
(
self
,
self
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
tools
:
Optional
[
str
],
)
->
L
ist
[
L
ist
[
int
]]:
)
->
l
ist
[
l
ist
[
int
]]:
r
"""
r
"""
Encode formatted inputs to pairs of token ids.
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn 0: prefix + system + query resp
Turn t: query resp
Turn t: query resp
.
"""
"""
system
=
system
or
self
.
default_system
system
=
system
or
self
.
default_system
encoded_messages
=
[]
encoded_messages
=
[]
...
@@ -161,9 +151,7 @@ class Template:
...
@@ -161,9 +151,7 @@ class Template:
@
staticmethod
@
staticmethod
def
_add_or_replace_eos_token
(
tokenizer
:
"PreTrainedTokenizer"
,
eos_token
:
str
)
->
None
:
def
_add_or_replace_eos_token
(
tokenizer
:
"PreTrainedTokenizer"
,
eos_token
:
str
)
->
None
:
r
"""
r
"""Add or replace eos token to the tokenizer."""
Adds or replaces eos token to the tokenizer.
"""
is_added
=
tokenizer
.
eos_token_id
is
None
is_added
=
tokenizer
.
eos_token_id
is
None
num_added_tokens
=
tokenizer
.
add_special_tokens
({
"eos_token"
:
eos_token
})
num_added_tokens
=
tokenizer
.
add_special_tokens
({
"eos_token"
:
eos_token
})
...
@@ -176,9 +164,7 @@ class Template:
...
@@ -176,9 +164,7 @@ class Template:
logger
.
warning_rank0
(
"New tokens have been added, make sure `resize_vocab` is True."
)
logger
.
warning_rank0
(
"New tokens have been added, make sure `resize_vocab` is True."
)
def
fix_special_tokens
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
def
fix_special_tokens
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
r
"""
r
"""Add eos token and pad token to the tokenizer."""
Adds eos token and pad token to the tokenizer.
"""
stop_words
=
self
.
stop_words
stop_words
=
self
.
stop_words
if
self
.
replace_eos
:
if
self
.
replace_eos
:
if
not
stop_words
:
if
not
stop_words
:
...
@@ -204,16 +190,12 @@ class Template:
...
@@ -204,16 +190,12 @@ class Template:
@
staticmethod
@
staticmethod
def
_jinja_escape
(
content
:
str
)
->
str
:
def
_jinja_escape
(
content
:
str
)
->
str
:
r
"""
r
"""Escape single quotes in content."""
Escape single quotes in content.
"""
return
content
.
replace
(
"'"
,
r
"\'"
)
return
content
.
replace
(
"'"
,
r
"\'"
)
@
staticmethod
@
staticmethod
def
_convert_slots_to_jinja
(
slots
:
"SLOTS"
,
tokenizer
:
"PreTrainedTokenizer"
,
placeholder
:
str
=
"content"
)
->
str
:
def
_convert_slots_to_jinja
(
slots
:
"SLOTS"
,
tokenizer
:
"PreTrainedTokenizer"
,
placeholder
:
str
=
"content"
)
->
str
:
r
"""
r
"""Convert slots to jinja template."""
Converts slots to jinja template.
"""
slot_items
=
[]
slot_items
=
[]
for
slot
in
slots
:
for
slot
in
slots
:
if
isinstance
(
slot
,
str
):
if
isinstance
(
slot
,
str
):
...
@@ -235,9 +217,7 @@ class Template:
...
@@ -235,9 +217,7 @@ class Template:
return
" + "
.
join
(
slot_items
)
return
" + "
.
join
(
slot_items
)
def
_get_jinja_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
def
_get_jinja_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
r
"""
r
"""Return the jinja template."""
Returns the jinja template.
"""
prefix
=
self
.
_convert_slots_to_jinja
(
self
.
format_prefix
.
apply
(),
tokenizer
)
prefix
=
self
.
_convert_slots_to_jinja
(
self
.
format_prefix
.
apply
(),
tokenizer
)
system
=
self
.
_convert_slots_to_jinja
(
self
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
"system_message"
)
system
=
self
.
_convert_slots_to_jinja
(
self
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
"system_message"
)
user
=
self
.
_convert_slots_to_jinja
(
self
.
format_user
.
apply
(),
tokenizer
)
user
=
self
.
_convert_slots_to_jinja
(
self
.
format_user
.
apply
(),
tokenizer
)
...
@@ -265,9 +245,7 @@ class Template:
...
@@ -265,9 +245,7 @@ class Template:
return
jinja_template
return
jinja_template
def
fix_jinja_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
def
fix_jinja_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
r
"""
r
"""Replace the jinja template in the tokenizer."""
Replaces the jinja template in the tokenizer.
"""
if
tokenizer
.
chat_template
is
None
or
self
.
replace_jinja_template
:
if
tokenizer
.
chat_template
is
None
or
self
.
replace_jinja_template
:
try
:
try
:
tokenizer
.
chat_template
=
self
.
_get_jinja_template
(
tokenizer
)
tokenizer
.
chat_template
=
self
.
_get_jinja_template
(
tokenizer
)
...
@@ -278,9 +256,7 @@ class Template:
...
@@ -278,9 +256,7 @@ class Template:
def
_convert_slots_to_ollama
(
def
_convert_slots_to_ollama
(
slots
:
"SLOTS"
,
tokenizer
:
"PreTrainedTokenizer"
,
placeholder
:
str
=
"content"
slots
:
"SLOTS"
,
tokenizer
:
"PreTrainedTokenizer"
,
placeholder
:
str
=
"content"
)
->
str
:
)
->
str
:
r
"""
r
"""Convert slots to ollama template."""
Converts slots to ollama template.
"""
slot_items
=
[]
slot_items
=
[]
for
slot
in
slots
:
for
slot
in
slots
:
if
isinstance
(
slot
,
str
):
if
isinstance
(
slot
,
str
):
...
@@ -302,9 +278,7 @@ class Template:
...
@@ -302,9 +278,7 @@ class Template:
return
""
.
join
(
slot_items
)
return
""
.
join
(
slot_items
)
def
_get_ollama_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
def
_get_ollama_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
r
"""
r
"""Return the ollama template."""
Returns the ollama template.
"""
prefix
=
self
.
_convert_slots_to_ollama
(
self
.
format_prefix
.
apply
(),
tokenizer
)
prefix
=
self
.
_convert_slots_to_ollama
(
self
.
format_prefix
.
apply
(),
tokenizer
)
system
=
self
.
_convert_slots_to_ollama
(
self
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
".System"
)
system
=
self
.
_convert_slots_to_ollama
(
self
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
".System"
)
user
=
self
.
_convert_slots_to_ollama
(
self
.
format_user
.
apply
(),
tokenizer
,
placeholder
=
".Content"
)
user
=
self
.
_convert_slots_to_ollama
(
self
.
format_user
.
apply
(),
tokenizer
,
placeholder
=
".Content"
)
...
@@ -316,8 +290,7 @@ class Template:
...
@@ -316,8 +290,7 @@ class Template:
)
)
def
get_ollama_modelfile
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
def
get_ollama_modelfile
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
r
"""
r
"""Return the ollama modelfile.
Returns the ollama modelfile.
TODO: support function calling.
TODO: support function calling.
"""
"""
...
@@ -336,14 +309,16 @@ class Template:
...
@@ -336,14 +309,16 @@ class Template:
@
dataclass
@
dataclass
class
Llama2Template
(
Template
):
class
Llama2Template
(
Template
):
r
"""A template that fuse the system message to first user message."""
@
override
@
override
def
_encode
(
def
_encode
(
self
,
self
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
str
,
system
:
str
,
tools
:
str
,
tools
:
str
,
)
->
L
ist
[
L
ist
[
int
]]:
)
->
l
ist
[
l
ist
[
int
]]:
system
=
system
or
self
.
default_system
system
=
system
or
self
.
default_system
encoded_messages
=
[]
encoded_messages
=
[]
for
i
,
message
in
enumerate
(
messages
):
for
i
,
message
in
enumerate
(
messages
):
...
@@ -402,7 +377,7 @@ class Llama2Template(Template):
...
@@ -402,7 +377,7 @@ class Llama2Template(Template):
return
jinja_template
return
jinja_template
TEMPLATES
:
D
ict
[
str
,
"Template"
]
=
{}
TEMPLATES
:
d
ict
[
str
,
"Template"
]
=
{}
def
register_template
(
def
register_template
(
...
@@ -415,16 +390,15 @@ def register_template(
...
@@ -415,16 +390,15 @@ def register_template(
format_tools
:
Optional
[
"Formatter"
]
=
None
,
format_tools
:
Optional
[
"Formatter"
]
=
None
,
format_prefix
:
Optional
[
"Formatter"
]
=
None
,
format_prefix
:
Optional
[
"Formatter"
]
=
None
,
default_system
:
str
=
""
,
default_system
:
str
=
""
,
stop_words
:
Optional
[
Sequence
[
str
]]
=
None
,
stop_words
:
Optional
[
list
[
str
]]
=
None
,
thought_words
:
Optional
[
T
uple
[
str
,
str
]]
=
None
,
thought_words
:
Optional
[
t
uple
[
str
,
str
]]
=
None
,
efficient_eos
:
bool
=
False
,
efficient_eos
:
bool
=
False
,
replace_eos
:
bool
=
False
,
replace_eos
:
bool
=
False
,
replace_jinja_template
:
bool
=
False
,
replace_jinja_template
:
bool
=
False
,
mm_plugin
:
"BasePlugin"
=
get_mm_plugin
(
name
=
"base"
),
mm_plugin
:
"BasePlugin"
=
get_mm_plugin
(
name
=
"base"
),
template_class
:
T
ype
[
"Template"
]
=
Template
,
template_class
:
t
ype
[
"Template"
]
=
Template
,
)
->
None
:
)
->
None
:
r
"""
r
"""Register a chat template.
Registers a chat template.
To add the following chat template:
To add the following chat template:
```
```
...
@@ -472,9 +446,7 @@ def register_template(
...
@@ -472,9 +446,7 @@ def register_template(
def
parse_template
(
tokenizer
:
"PreTrainedTokenizer"
)
->
"Template"
:
def
parse_template
(
tokenizer
:
"PreTrainedTokenizer"
)
->
"Template"
:
r
"""
r
"""Extract a chat template from the tokenizer."""
Extracts a chat template from the tokenizer.
"""
def
find_diff
(
short_str
:
str
,
long_str
:
str
)
->
str
:
def
find_diff
(
short_str
:
str
,
long_str
:
str
)
->
str
:
i
,
j
=
0
,
0
i
,
j
=
0
,
0
...
@@ -532,9 +504,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
...
@@ -532,9 +504,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
def
get_template_and_fix_tokenizer
(
tokenizer
:
"PreTrainedTokenizer"
,
data_args
:
"DataArguments"
)
->
"Template"
:
def
get_template_and_fix_tokenizer
(
tokenizer
:
"PreTrainedTokenizer"
,
data_args
:
"DataArguments"
)
->
"Template"
:
r
"""
r
"""Get chat template and fixes the tokenizer."""
Gets chat template and fixes the tokenizer.
"""
if
data_args
.
template
is
None
:
if
data_args
.
template
is
None
:
if
isinstance
(
tokenizer
.
chat_template
,
str
):
if
isinstance
(
tokenizer
.
chat_template
,
str
):
logger
.
warning_rank0
(
"`template` was not specified, try parsing the chat template from the tokenizer."
)
logger
.
warning_rank0
(
"`template` was not specified, try parsing the chat template from the tokenizer."
)
...
@@ -807,15 +777,17 @@ register_template(
...
@@ -807,15 +777,17 @@ register_template(
register_template
(
register_template
(
name
=
"default"
,
name
=
"default"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n
Assistant:"
]),
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
"
,
{
"eos_token"
},
"
\n
Assistant:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"System: {{content}}
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"System: {{content}}"
,
{
"eos_token"
},
"
\n
"
]),
replace_jinja_template
=
True
,
)
)
register_template
(
register_template
(
name
=
"empty"
,
name
=
"empty"
,
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
]),
replace_jinja_template
=
True
,
)
)
...
@@ -839,6 +811,7 @@ register_template(
...
@@ -839,6 +811,7 @@ register_template(
name
=
"fewshot"
,
name
=
"fewshot"
,
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
efficient_eos
=
True
,
efficient_eos
=
True
,
replace_jinja_template
=
True
,
)
)
...
@@ -846,10 +819,29 @@ register_template(
...
@@ -846,10 +819,29 @@ register_template(
name
=
"gemma"
,
name
=
"gemma"
,
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_turn>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_turn>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<start_of_turn>tool
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
template_class
=
Llama2Template
,
)
# copied from gemma template
register_template
(
name
=
"gemma3"
,
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_turn>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_observation
=
StringFormatter
(
format_observation
=
StringFormatter
(
slots
=
[
"<start_of_turn>tool
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]
slots
=
[
"<start_of_turn>tool
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]
),
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
mm_plugin
=
get_mm_plugin
(
"gemma3"
,
image_token
=
"<image_soft_token>"
),
template_class
=
Llama2Template
,
)
)
...
@@ -887,6 +879,16 @@ register_template(
...
@@ -887,6 +879,16 @@ register_template(
)
)
register_template
(
name
=
"hunyuan"
,
format_user
=
StringFormatter
(
slots
=
[
"<|bos|>user
\n
{{content}}<|eos|>
\n
<|bos|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|eos|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|bos|>system
\n
{{content}}<|eos|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[
"<|bos|>"
]),
stop_words
=
[
"<|eos|>"
],
)
register_template
(
register_template
(
name
=
"intern"
,
name
=
"intern"
,
format_user
=
StringFormatter
(
slots
=
[
"<|User|>:{{content}}
\n
<|Bot|>:"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|User|>:{{content}}
\n
<|Bot|>:"
]),
...
@@ -966,6 +968,26 @@ register_template(
...
@@ -966,6 +968,26 @@ register_template(
)
)
register_template
(
name
=
"llama4"
,
format_user
=
StringFormatter
(
slots
=
[
"<|header_start|>user<|header_end|>
\n\n
{{content}}<|eot|><|header_start|>assistant<|header_end|>
\n\n
"
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|eot|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|header_start|>system<|header_end|>
\n\n
{{content}}<|eot|>"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot|>"
],
tool_format
=
"llama3"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|header_start|>ipython<|header_end|>
\n\n
{{content}}<|eot|><|header_start|>assistant<|header_end|>
\n\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot|>"
,
"<|eom|>"
],
mm_plugin
=
get_mm_plugin
(
name
=
"llama4"
,
image_token
=
"<|image|>"
),
)
# copied from llama3 template
# copied from llama3 template
register_template
(
register_template
(
name
=
"mllama"
,
name
=
"mllama"
,
...
@@ -1149,7 +1171,8 @@ register_template(
...
@@ -1149,7 +1171,8 @@ register_template(
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
default_system
=
(
default_system
=
(
"你是一个经过良好训练的AI助手,你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.
\n
## 重要!!!!!
\n
"
"你是一个经过良好训练的AI助手,你的名字是Marco-o1."
"由阿里国际数字商业集团的AI Business创造.
\n
## 重要!!!!!
\n
"
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。
\n
"
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。
\n
"
"<Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。
\n
"
"<Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。
\n
"
),
),
...
@@ -1273,6 +1296,7 @@ register_template(
...
@@ -1273,6 +1296,7 @@ register_template(
format_user
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
),
template_class
=
Llama2Template
,
)
)
...
@@ -1285,7 +1309,9 @@ register_template(
...
@@ -1285,7 +1309,9 @@ register_template(
slots
=
[
"<start_of_turn>tool
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]
slots
=
[
"<start_of_turn>tool
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]
),
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
mm_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
),
template_class
=
Llama2Template
,
)
)
...
@@ -1361,6 +1387,24 @@ register_template(
...
@@ -1361,6 +1387,24 @@ register_template(
)
)
# copied from qwen template
register_template
(
name
=
"qwen2_omni"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_omni"
,
audio_token
=
"<|AUDIO|>"
,
image_token
=
"<|IMAGE|>"
,
video_token
=
"<|VIDEO|>"
),
)
# copied from qwen template
# copied from qwen template
register_template
(
register_template
(
name
=
"qwen2_vl"
,
name
=
"qwen2_vl"
,
...
...
src/llamafactory/data/tool_utils.py
View file @
7ea81099
...
@@ -17,7 +17,7 @@ import re
...
@@ -17,7 +17,7 @@ import re
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
datetime
import
datetime
from
datetime
import
datetime
from
typing
import
Any
,
Dict
,
List
,
NamedTuple
,
Tuple
,
Union
from
typing
import
Any
,
Named
Tuple
,
Union
from
typing_extensions
import
override
from
typing_extensions
import
override
...
@@ -60,31 +60,24 @@ QWEN_TOOL_PROMPT = (
...
@@ -60,31 +60,24 @@ QWEN_TOOL_PROMPT = (
@
dataclass
@
dataclass
class
ToolUtils
(
ABC
):
class
ToolUtils
(
ABC
):
"""
"""Base class for tool utilities."""
Base class for tool utilities.
"""
@
staticmethod
@
staticmethod
@
abstractmethod
@
abstractmethod
def
tool_formatter
(
tools
:
List
[
Dict
[
str
,
Any
]])
->
str
:
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
r
"""
r
"""Generate the system message describing all the available tools."""
Generates the system message describing all the available tools.
"""
...
...
@
staticmethod
@
staticmethod
@
abstractmethod
@
abstractmethod
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
r
"""
r
"""Generate the assistant message including all the tool calls."""
Generates the assistant message including all the tool calls.
"""
...
...
@
staticmethod
@
staticmethod
@
abstractmethod
@
abstractmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
r
"""
r
"""Extract all the function calls from the assistant message.
Extracts all the function calls from the assistant message.
It should be an inverse function of `function_formatter`.
It should be an inverse function of `function_formatter`.
"""
"""
...
@@ -92,13 +85,11 @@ class ToolUtils(ABC):
...
@@ -92,13 +85,11 @@ class ToolUtils(ABC):
class
DefaultToolUtils
(
ToolUtils
):
class
DefaultToolUtils
(
ToolUtils
):
r
"""
r
"""Default tool using template."""
Default tool using template.
"""
@
override
@
override
@
staticmethod
@
staticmethod
def
tool_formatter
(
tools
:
L
ist
[
D
ict
[
str
,
Any
]])
->
str
:
def
tool_formatter
(
tools
:
l
ist
[
d
ict
[
str
,
Any
]])
->
str
:
tool_text
=
""
tool_text
=
""
tool_names
=
[]
tool_names
=
[]
for
tool
in
tools
:
for
tool
in
tools
:
...
@@ -132,7 +123,7 @@ class DefaultToolUtils(ToolUtils):
...
@@ -132,7 +123,7 @@ class DefaultToolUtils(ToolUtils):
@
override
@
override
@
staticmethod
@
staticmethod
def
function_formatter
(
functions
:
L
ist
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
l
ist
[
"FunctionCall"
])
->
str
:
function_text
=
""
function_text
=
""
for
name
,
arguments
in
functions
:
for
name
,
arguments
in
functions
:
function_text
+=
f
"Action:
{
name
}
\n
Action Input:
{
arguments
}
\n
"
function_text
+=
f
"Action:
{
name
}
\n
Action Input:
{
arguments
}
\n
"
...
@@ -141,9 +132,9 @@ class DefaultToolUtils(ToolUtils):
...
@@ -141,9 +132,9 @@ class DefaultToolUtils(ToolUtils):
@
override
@
override
@
staticmethod
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
L
ist
[
"FunctionCall"
]]:
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
l
ist
[
"FunctionCall"
]]:
regex
=
re
.
compile
(
r
"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)"
,
re
.
DOTALL
)
regex
=
re
.
compile
(
r
"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)"
,
re
.
DOTALL
)
action_match
:
L
ist
[
T
uple
[
str
,
str
]]
=
re
.
findall
(
regex
,
content
)
action_match
:
l
ist
[
t
uple
[
str
,
str
]]
=
re
.
findall
(
regex
,
content
)
if
not
action_match
:
if
not
action_match
:
return
content
return
content
...
@@ -161,13 +152,11 @@ class DefaultToolUtils(ToolUtils):
...
@@ -161,13 +152,11 @@ class DefaultToolUtils(ToolUtils):
class
GLM4ToolUtils
(
ToolUtils
):
class
GLM4ToolUtils
(
ToolUtils
):
r
"""
r
"""GLM-4 tool using template."""
GLM-4 tool using template.
"""
@
override
@
override
@
staticmethod
@
staticmethod
def
tool_formatter
(
tools
:
L
ist
[
D
ict
[
str
,
Any
]])
->
str
:
def
tool_formatter
(
tools
:
l
ist
[
d
ict
[
str
,
Any
]])
->
str
:
tool_text
=
""
tool_text
=
""
for
tool
in
tools
:
for
tool
in
tools
:
tool_text
+=
"
\n\n
## {name}
\n\n
{body}
\n
在调用上述函数时,请使用 Json 格式表示调用的参数。"
.
format
(
tool_text
+=
"
\n\n
## {name}
\n\n
{body}
\n
在调用上述函数时,请使用 Json 格式表示调用的参数。"
.
format
(
...
@@ -178,7 +167,7 @@ class GLM4ToolUtils(ToolUtils):
...
@@ -178,7 +167,7 @@ class GLM4ToolUtils(ToolUtils):
@
override
@
override
@
staticmethod
@
staticmethod
def
function_formatter
(
functions
:
L
ist
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
l
ist
[
"FunctionCall"
])
->
str
:
if
len
(
functions
)
>
1
:
if
len
(
functions
)
>
1
:
raise
ValueError
(
"GLM-4 does not support parallel functions."
)
raise
ValueError
(
"GLM-4 does not support parallel functions."
)
...
@@ -186,7 +175,7 @@ class GLM4ToolUtils(ToolUtils):
...
@@ -186,7 +175,7 @@ class GLM4ToolUtils(ToolUtils):
@
override
@
override
@
staticmethod
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
L
ist
[
"FunctionCall"
]]:
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
l
ist
[
"FunctionCall"
]]:
if
"
\n
"
not
in
content
:
if
"
\n
"
not
in
content
:
return
content
return
content
...
@@ -200,15 +189,14 @@ class GLM4ToolUtils(ToolUtils):
...
@@ -200,15 +189,14 @@ class GLM4ToolUtils(ToolUtils):
class
Llama3ToolUtils
(
ToolUtils
):
class
Llama3ToolUtils
(
ToolUtils
):
r
"""
r
"""Llama 3.x tool using template with `tools_in_user_message=False`.
Llama 3.x tool using template with `tools_in_user_message=False`.
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
"""
"""
@
override
@
override
@
staticmethod
@
staticmethod
def
tool_formatter
(
tools
:
L
ist
[
D
ict
[
str
,
Any
]])
->
str
:
def
tool_formatter
(
tools
:
l
ist
[
d
ict
[
str
,
Any
]])
->
str
:
date
=
datetime
.
now
().
strftime
(
"%d %b %Y"
)
date
=
datetime
.
now
().
strftime
(
"%d %b %Y"
)
tool_text
=
""
tool_text
=
""
for
tool
in
tools
:
for
tool
in
tools
:
...
@@ -219,7 +207,7 @@ class Llama3ToolUtils(ToolUtils):
...
@@ -219,7 +207,7 @@ class Llama3ToolUtils(ToolUtils):
@
override
@
override
@
staticmethod
@
staticmethod
def
function_formatter
(
functions
:
L
ist
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
l
ist
[
"FunctionCall"
])
->
str
:
if
len
(
functions
)
>
1
:
if
len
(
functions
)
>
1
:
raise
ValueError
(
"Llama-3 does not support parallel functions."
)
raise
ValueError
(
"Llama-3 does not support parallel functions."
)
...
@@ -227,7 +215,7 @@ class Llama3ToolUtils(ToolUtils):
...
@@ -227,7 +215,7 @@ class Llama3ToolUtils(ToolUtils):
@
override
@
override
@
staticmethod
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
L
ist
[
"FunctionCall"
]]:
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
l
ist
[
"FunctionCall"
]]:
try
:
try
:
tool
=
json
.
loads
(
content
.
strip
())
tool
=
json
.
loads
(
content
.
strip
())
except
json
.
JSONDecodeError
:
except
json
.
JSONDecodeError
:
...
@@ -240,13 +228,11 @@ class Llama3ToolUtils(ToolUtils):
...
@@ -240,13 +228,11 @@ class Llama3ToolUtils(ToolUtils):
class
MistralToolUtils
(
ToolUtils
):
class
MistralToolUtils
(
ToolUtils
):
r
"""
r
"""Mistral v0.3 tool using template."""
Mistral v0.3 tool using template.
"""
@
override
@
override
@
staticmethod
@
staticmethod
def
tool_formatter
(
tools
:
L
ist
[
D
ict
[
str
,
Any
]])
->
str
:
def
tool_formatter
(
tools
:
l
ist
[
d
ict
[
str
,
Any
]])
->
str
:
wrapped_tools
=
[]
wrapped_tools
=
[]
for
tool
in
tools
:
for
tool
in
tools
:
wrapped_tools
.
append
({
"type"
:
"function"
,
"function"
:
tool
})
wrapped_tools
.
append
({
"type"
:
"function"
,
"function"
:
tool
})
...
@@ -255,7 +241,7 @@ class MistralToolUtils(ToolUtils):
...
@@ -255,7 +241,7 @@ class MistralToolUtils(ToolUtils):
@
override
@
override
@
staticmethod
@
staticmethod
def
function_formatter
(
functions
:
L
ist
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
l
ist
[
"FunctionCall"
])
->
str
:
function_texts
=
[]
function_texts
=
[]
for
name
,
arguments
in
functions
:
for
name
,
arguments
in
functions
:
function_texts
.
append
(
f
'{{"name": "
{
name
}
", "arguments":
{
arguments
}
}}'
)
function_texts
.
append
(
f
'{{"name": "
{
name
}
", "arguments":
{
arguments
}
}}'
)
...
@@ -264,7 +250,7 @@ class MistralToolUtils(ToolUtils):
...
@@ -264,7 +250,7 @@ class MistralToolUtils(ToolUtils):
@
override
@
override
@
staticmethod
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
L
ist
[
"FunctionCall"
]]:
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
l
ist
[
"FunctionCall"
]]:
try
:
try
:
tools
=
json
.
loads
(
content
.
strip
())
tools
=
json
.
loads
(
content
.
strip
())
except
json
.
JSONDecodeError
:
except
json
.
JSONDecodeError
:
...
@@ -284,13 +270,11 @@ class MistralToolUtils(ToolUtils):
...
@@ -284,13 +270,11 @@ class MistralToolUtils(ToolUtils):
class
QwenToolUtils
(
ToolUtils
):
class
QwenToolUtils
(
ToolUtils
):
r
"""
r
"""Qwen 2.5 tool using template."""
Qwen 2.5 tool using template.
"""
@
override
@
override
@
staticmethod
@
staticmethod
def
tool_formatter
(
tools
:
L
ist
[
D
ict
[
str
,
Any
]])
->
str
:
def
tool_formatter
(
tools
:
l
ist
[
d
ict
[
str
,
Any
]])
->
str
:
tool_text
=
""
tool_text
=
""
for
tool
in
tools
:
for
tool
in
tools
:
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
tool
}
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
tool
}
...
@@ -300,7 +284,7 @@ class QwenToolUtils(ToolUtils):
...
@@ -300,7 +284,7 @@ class QwenToolUtils(ToolUtils):
@
override
@
override
@
staticmethod
@
staticmethod
def
function_formatter
(
functions
:
L
ist
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
l
ist
[
"FunctionCall"
])
->
str
:
function_texts
=
[]
function_texts
=
[]
for
name
,
arguments
in
functions
:
for
name
,
arguments
in
functions
:
function_texts
.
append
(
function_texts
.
append
(
...
@@ -311,9 +295,9 @@ class QwenToolUtils(ToolUtils):
...
@@ -311,9 +295,9 @@ class QwenToolUtils(ToolUtils):
@
override
@
override
@
staticmethod
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
L
ist
[
"FunctionCall"
]]:
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
l
ist
[
"FunctionCall"
]]:
regex
=
re
.
compile
(
r
"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)"
,
re
.
DOTALL
)
regex
=
re
.
compile
(
r
"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)"
,
re
.
DOTALL
)
tool_match
:
L
ist
[
str
]
=
re
.
findall
(
regex
,
content
)
tool_match
:
l
ist
[
str
]
=
re
.
findall
(
regex
,
content
)
if
not
tool_match
:
if
not
tool_match
:
return
content
return
content
...
...
src/llamafactory/eval/evaluator.py
View file @
7ea81099
...
@@ -39,7 +39,7 @@
...
@@ -39,7 +39,7 @@
import
json
import
json
import
os
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -59,7 +59,7 @@ if TYPE_CHECKING:
...
@@ -59,7 +59,7 @@ if TYPE_CHECKING:
class
Evaluator
:
class
Evaluator
:
def
__init__
(
self
,
args
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
)
->
None
:
def
__init__
(
self
,
args
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
)
->
None
:
self
.
model_args
,
self
.
data_args
,
self
.
eval_args
,
finetuning_args
=
get_eval_args
(
args
)
self
.
model_args
,
self
.
data_args
,
self
.
eval_args
,
finetuning_args
=
get_eval_args
(
args
)
self
.
tokenizer
=
load_tokenizer
(
self
.
model_args
)[
"tokenizer"
]
self
.
tokenizer
=
load_tokenizer
(
self
.
model_args
)[
"tokenizer"
]
self
.
tokenizer
.
padding_side
=
"right"
# avoid overflow issue in batched inference for llama2
self
.
tokenizer
.
padding_side
=
"right"
# avoid overflow issue in batched inference for llama2
...
@@ -69,7 +69,7 @@ class Evaluator:
...
@@ -69,7 +69,7 @@ class Evaluator:
self
.
choice_inputs
=
[
self
.
tokenizer
.
encode
(
ch
,
add_special_tokens
=
False
)[
-
1
]
for
ch
in
CHOICES
]
self
.
choice_inputs
=
[
self
.
tokenizer
.
encode
(
ch
,
add_special_tokens
=
False
)[
-
1
]
for
ch
in
CHOICES
]
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
batch_inference
(
self
,
batch_input
:
D
ict
[
str
,
"torch.Tensor"
])
->
L
ist
[
str
]:
def
batch_inference
(
self
,
batch_input
:
d
ict
[
str
,
"torch.Tensor"
])
->
l
ist
[
str
]:
logits
=
self
.
model
(
**
batch_input
).
logits
logits
=
self
.
model
(
**
batch_input
).
logits
lengths
=
torch
.
sum
(
batch_input
[
"attention_mask"
],
dim
=-
1
)
lengths
=
torch
.
sum
(
batch_input
[
"attention_mask"
],
dim
=-
1
)
word_probs
=
torch
.
stack
([
logits
[
i
,
lengths
[
i
]
-
1
]
for
i
in
range
(
len
(
lengths
))],
dim
=
0
)
word_probs
=
torch
.
stack
([
logits
[
i
,
lengths
[
i
]
-
1
]
for
i
in
range
(
len
(
lengths
))],
dim
=
0
)
...
@@ -88,7 +88,7 @@ class Evaluator:
...
@@ -88,7 +88,7 @@ class Evaluator:
)
)
with
open
(
mapping
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
mapping
,
encoding
=
"utf-8"
)
as
f
:
categorys
:
D
ict
[
str
,
D
ict
[
str
,
str
]]
=
json
.
load
(
f
)
categorys
:
d
ict
[
str
,
d
ict
[
str
,
str
]]
=
json
.
load
(
f
)
category_corrects
=
{
subj
:
np
.
array
([],
dtype
=
"bool"
)
for
subj
in
SUBJECTS
}
category_corrects
=
{
subj
:
np
.
array
([],
dtype
=
"bool"
)
for
subj
in
SUBJECTS
}
pbar
=
tqdm
(
categorys
.
keys
(),
desc
=
"Processing subjects"
,
position
=
0
)
pbar
=
tqdm
(
categorys
.
keys
(),
desc
=
"Processing subjects"
,
position
=
0
)
...
@@ -136,7 +136,7 @@ class Evaluator:
...
@@ -136,7 +136,7 @@ class Evaluator:
pbar
.
close
()
pbar
.
close
()
self
.
_save_results
(
category_corrects
,
results
)
self
.
_save_results
(
category_corrects
,
results
)
def
_save_results
(
self
,
category_corrects
:
D
ict
[
str
,
"NDArray"
],
results
:
D
ict
[
str
,
D
ict
[
int
,
str
]])
->
None
:
def
_save_results
(
self
,
category_corrects
:
d
ict
[
str
,
"NDArray"
],
results
:
d
ict
[
str
,
d
ict
[
int
,
str
]])
->
None
:
score_info
=
"
\n
"
.
join
(
score_info
=
"
\n
"
.
join
(
[
[
f
"
{
category_name
:
>
15
}
:
{
100
*
np
.
mean
(
category_correct
):.
2
f
}
"
f
"
{
category_name
:
>
15
}
:
{
100
*
np
.
mean
(
category_correct
):.
2
f
}
"
...
...
src/llamafactory/eval/template.py
View file @
7ea81099
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# limitations under the License.
# limitations under the License.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Sequence
,
Tuple
from
..data
import
Role
from
..data
import
Role
from
..extras.constants
import
CHOICES
from
..extras.constants
import
CHOICES
...
@@ -25,20 +24,19 @@ class EvalTemplate:
...
@@ -25,20 +24,19 @@ class EvalTemplate:
choice
:
str
choice
:
str
answer
:
str
answer
:
str
def
_parse_example
(
self
,
example
:
Dict
[
str
,
str
])
->
Tuple
[
str
,
str
]:
def
_parse_example
(
self
,
example
:
dict
[
str
,
str
])
->
tuple
[
str
,
str
]:
r
"""
r
"""Parse eval example.
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
output: a tuple of (prompt, response)
output: a tuple of (prompt, response)
.
"""
"""
candidates
=
[
self
.
choice
.
format
(
choice
=
ch
,
content
=
example
[
ch
])
for
ch
in
CHOICES
if
ch
in
example
]
candidates
=
[
self
.
choice
.
format
(
choice
=
ch
,
content
=
example
[
ch
])
for
ch
in
CHOICES
if
ch
in
example
]
return
""
.
join
([
example
[
"question"
]]
+
candidates
+
[
self
.
answer
]),
example
[
"answer"
]
return
""
.
join
([
example
[
"question"
]]
+
candidates
+
[
self
.
answer
]),
example
[
"answer"
]
def
format_example
(
def
format_example
(
self
,
target_data
:
Dict
[
str
,
str
],
support_set
:
Sequence
[
Dict
[
str
,
str
]],
subject_name
:
str
self
,
target_data
:
dict
[
str
,
str
],
support_set
:
list
[
dict
[
str
,
str
]],
subject_name
:
str
)
->
List
[
Dict
[
str
,
str
]]:
)
->
list
[
dict
[
str
,
str
]]:
r
"""
r
"""Convert dataset examples to messages."""
Converts dataset examples to messages.
"""
messages
=
[]
messages
=
[]
for
k
in
range
(
len
(
support_set
)):
for
k
in
range
(
len
(
support_set
)):
prompt
,
response
=
self
.
_parse_example
(
support_set
[
k
])
prompt
,
response
=
self
.
_parse_example
(
support_set
[
k
])
...
@@ -52,7 +50,7 @@ class EvalTemplate:
...
@@ -52,7 +50,7 @@ class EvalTemplate:
return
messages
return
messages
eval_templates
:
D
ict
[
str
,
"EvalTemplate"
]
=
{}
eval_templates
:
d
ict
[
str
,
"EvalTemplate"
]
=
{}
def
_register_eval_template
(
name
:
str
,
system
:
str
,
choice
:
str
,
answer
:
str
)
->
None
:
def
_register_eval_template
(
name
:
str
,
system
:
str
,
choice
:
str
,
answer
:
str
)
->
None
:
...
...
src/llamafactory/extras/constants.py
View file @
7ea81099
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
import
os
import
os
from
collections
import
OrderedDict
,
defaultdict
from
collections
import
OrderedDict
,
defaultdict
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Dict
,
Optional
from
typing
import
Optional
from
peft.utils
import
SAFETENSORS_WEIGHTS_NAME
as
SAFE_ADAPTER_WEIGHTS_NAME
from
peft.utils
import
SAFETENSORS_WEIGHTS_NAME
as
SAFE_ADAPTER_WEIGHTS_NAME
from
peft.utils
import
WEIGHTS_NAME
as
ADAPTER_WEIGHTS_NAME
from
peft.utils
import
WEIGHTS_NAME
as
ADAPTER_WEIGHTS_NAME
...
@@ -106,6 +106,7 @@ class AttentionFunction(str, Enum):
...
@@ -106,6 +106,7 @@ class AttentionFunction(str, Enum):
class
EngineName
(
str
,
Enum
):
class
EngineName
(
str
,
Enum
):
HF
=
"huggingface"
HF
=
"huggingface"
VLLM
=
"vllm"
VLLM
=
"vllm"
SGLANG
=
"sglang"
class
DownloadSource
(
str
,
Enum
):
class
DownloadSource
(
str
,
Enum
):
...
@@ -122,7 +123,7 @@ class RopeScaling(str, Enum):
...
@@ -122,7 +123,7 @@ class RopeScaling(str, Enum):
def
register_model_group
(
def
register_model_group
(
models
:
D
ict
[
str
,
D
ict
[
DownloadSource
,
str
]],
models
:
d
ict
[
str
,
d
ict
[
DownloadSource
,
str
]],
template
:
Optional
[
str
]
=
None
,
template
:
Optional
[
str
]
=
None
,
multimodal
:
bool
=
False
,
multimodal
:
bool
=
False
,
)
->
None
:
)
->
None
:
...
@@ -650,11 +651,51 @@ register_model_group(
...
@@ -650,11 +651,51 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"google/gemma-2-27b-it"
,
DownloadSource
.
DEFAULT
:
"google/gemma-2-27b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-2-27b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-2-27b-it"
,
},
},
"Gemma-3-1B"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-1b-pt"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-1b-pt"
,
},
"Gemma-3-1B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-1b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-1b-it"
,
},
},
},
template
=
"gemma"
,
template
=
"gemma"
,
)
)
register_model_group
(
models
=
{
"Gemma-3-4B"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-4b-pt"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-4b-pt"
,
},
"Gemma-3-12B"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-12b-pt"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-12b-pt"
,
},
"Gemma-3-27B"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-27b-pt"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-27b-pt"
,
},
"Gemma-3-4B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-4b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-4b-it"
,
},
"Gemma-3-12B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-12b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-12b-it"
,
},
"Gemma-3-27B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-27b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-27b-it"
,
},
},
template
=
"gemma3"
,
multimodal
=
True
,
)
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"GLM-4-9B"
:
{
"GLM-4-9B"
:
{
...
@@ -768,6 +809,17 @@ register_model_group(
...
@@ -768,6 +809,17 @@ register_model_group(
)
)
register_model_group
(
models
=
{
"Hunyuan-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"tencent/Hunyuan-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Hunyuan-7B-Instruct"
,
},
},
template
=
"hunyuan"
,
)
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"Index-1.9B-Base"
:
{
"Index-1.9B-Base"
:
{
...
@@ -1059,6 +1111,30 @@ register_model_group(
...
@@ -1059,6 +1111,30 @@ register_model_group(
)
)
register_model_group
(
models
=
{
"Llama-4-Scout-17B-16E"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-4-Scout-17B-16E"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Llama-4-Scout-17B-16E"
,
},
"Llama-4-Scout-17B-16E-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Llama-4-Scout-17B-16E-Instruct"
,
},
"Llama-4-Maverick-17B-128E"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-4-Maverick-17B-128E"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Llama-4-Maverick-17B-128E"
,
},
"Llama-4-Maverick-17B-128E-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-4-Maverick-17B-128E-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Llama-4-Maverick-17B-128E-Instruct"
,
},
},
template
=
"llama4"
,
multimodal
=
True
,
)
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"LLaVA-1.5-7B-Chat"
:
{
"LLaVA-1.5-7B-Chat"
:
{
...
@@ -2218,6 +2294,18 @@ register_model_group(
...
@@ -2218,6 +2294,18 @@ register_model_group(
)
)
register_model_group
(
models
=
{
"Qwen2.5-Omni-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-7B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-7B"
,
}
},
template
=
"qwen2_omni"
,
multimodal
=
True
,
)
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"Qwen2-VL-2B"
:
{
"Qwen2-VL-2B"
:
{
...
@@ -2294,6 +2382,10 @@ register_model_group(
...
@@ -2294,6 +2382,10 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-VL-7B-Instruct"
,
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-VL-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-VL-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-VL-7B-Instruct"
,
},
},
"Qwen2.5-VL-32B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-VL-32B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-VL-32B-Instruct"
,
},
"Qwen2.5-VL-72B-Instruct"
:
{
"Qwen2.5-VL-72B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-VL-72B-Instruct"
,
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-VL-72B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-VL-72B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-VL-72B-Instruct"
,
...
...
src/llamafactory/extras/env.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's transformers library.
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
...
@@ -26,7 +26,7 @@ import trl
...
@@ -26,7 +26,7 @@ import trl
from
transformers.utils
import
is_torch_cuda_available
,
is_torch_npu_available
from
transformers.utils
import
is_torch_cuda_available
,
is_torch_npu_available
VERSION
=
"0.9.
2
"
VERSION
=
"0.9.
3.dev0
"
def
print_env
()
->
None
:
def
print_env
()
->
None
:
...
...
src/llamafactory/extras/logging.py
View file @
7ea81099
# Copyright 202
4
Optuna, HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
Optuna, HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's transformers library.
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py
...
@@ -32,9 +32,7 @@ _default_log_level: "logging._Level" = logging.INFO
...
@@ -32,9 +32,7 @@ _default_log_level: "logging._Level" = logging.INFO
class
LoggerHandler
(
logging
.
Handler
):
class
LoggerHandler
(
logging
.
Handler
):
r
"""
r
"""Redirect the logging output to the logging file for LLaMA Board."""
Redirects the logging output to the logging file for LLaMA Board.
"""
def
__init__
(
self
,
output_dir
:
str
)
->
None
:
def
__init__
(
self
,
output_dir
:
str
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -67,9 +65,7 @@ class LoggerHandler(logging.Handler):
...
@@ -67,9 +65,7 @@ class LoggerHandler(logging.Handler):
class
_Logger
(
logging
.
Logger
):
class
_Logger
(
logging
.
Logger
):
r
"""
r
"""A logger that supports rank0 logging."""
A logger that supports rank0 logging.
"""
def
info_rank0
(
self
,
*
args
,
**
kwargs
)
->
None
:
def
info_rank0
(
self
,
*
args
,
**
kwargs
)
->
None
:
self
.
info
(
*
args
,
**
kwargs
)
self
.
info
(
*
args
,
**
kwargs
)
...
@@ -82,9 +78,7 @@ class _Logger(logging.Logger):
...
@@ -82,9 +78,7 @@ class _Logger(logging.Logger):
def
_get_default_logging_level
()
->
"logging._Level"
:
def
_get_default_logging_level
()
->
"logging._Level"
:
r
"""
r
"""Return the default logging level."""
Returns the default logging level.
"""
env_level_str
=
os
.
environ
.
get
(
"LLAMAFACTORY_VERBOSITY"
,
None
)
env_level_str
=
os
.
environ
.
get
(
"LLAMAFACTORY_VERBOSITY"
,
None
)
if
env_level_str
:
if
env_level_str
:
if
env_level_str
.
upper
()
in
logging
.
_nameToLevel
:
if
env_level_str
.
upper
()
in
logging
.
_nameToLevel
:
...
@@ -104,9 +98,7 @@ def _get_library_root_logger() -> "_Logger":
...
@@ -104,9 +98,7 @@ def _get_library_root_logger() -> "_Logger":
def
_configure_library_root_logger
()
->
None
:
def
_configure_library_root_logger
()
->
None
:
r
"""
r
"""Configure root logger using a stdout stream handler with an explicit format."""
Configures root logger using a stdout stream handler with an explicit format.
"""
global
_default_handler
global
_default_handler
with
_thread_lock
:
with
_thread_lock
:
...
@@ -126,9 +118,7 @@ def _configure_library_root_logger() -> None:
...
@@ -126,9 +118,7 @@ def _configure_library_root_logger() -> None:
def
get_logger
(
name
:
Optional
[
str
]
=
None
)
->
"_Logger"
:
def
get_logger
(
name
:
Optional
[
str
]
=
None
)
->
"_Logger"
:
r
"""
r
"""Return a logger with the specified name. It it not supposed to be accessed externally."""
Returns a logger with the specified name. It it not supposed to be accessed externally.
"""
if
name
is
None
:
if
name
is
None
:
name
=
_get_library_name
()
name
=
_get_library_name
()
...
@@ -137,17 +127,13 @@ def get_logger(name: Optional[str] = None) -> "_Logger":
...
@@ -137,17 +127,13 @@ def get_logger(name: Optional[str] = None) -> "_Logger":
def
add_handler
(
handler
:
"logging.Handler"
)
->
None
:
def
add_handler
(
handler
:
"logging.Handler"
)
->
None
:
r
"""
r
"""Add a handler to the root logger."""
Adds a handler to the root logger.
"""
_configure_library_root_logger
()
_configure_library_root_logger
()
_get_library_root_logger
().
addHandler
(
handler
)
_get_library_root_logger
().
addHandler
(
handler
)
def
remove_handler
(
handler
:
logging
.
Handler
)
->
None
:
def
remove_handler
(
handler
:
logging
.
Handler
)
->
None
:
r
"""
r
"""Remove a handler to the root logger."""
Removes a handler to the root logger.
"""
_configure_library_root_logger
()
_configure_library_root_logger
()
_get_library_root_logger
().
removeHandler
(
handler
)
_get_library_root_logger
().
removeHandler
(
handler
)
...
...
src/llamafactory/extras/misc.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's PEFT library.
# This code is inspired by the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
...
@@ -17,7 +17,8 @@
...
@@ -17,7 +17,8 @@
import
gc
import
gc
import
os
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Literal
,
Sequence
,
Tuple
,
Union
import
socket
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -54,9 +55,7 @@ logger = logging.get_logger(__name__)
...
@@ -54,9 +55,7 @@ logger = logging.get_logger(__name__)
class
AverageMeter
:
class
AverageMeter
:
r
"""
r
"""Compute and store the average and current value."""
Computes and stores the average and current value.
"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
reset
()
self
.
reset
()
...
@@ -75,9 +74,7 @@ class AverageMeter:
...
@@ -75,9 +74,7 @@ class AverageMeter:
def
check_version
(
requirement
:
str
,
mandatory
:
bool
=
False
)
->
None
:
def
check_version
(
requirement
:
str
,
mandatory
:
bool
=
False
)
->
None
:
r
"""
r
"""Optionally check the package version."""
Optionally checks the package version.
"""
if
is_env_enabled
(
"DISABLE_VERSION_CHECK"
)
and
not
mandatory
:
if
is_env_enabled
(
"DISABLE_VERSION_CHECK"
)
and
not
mandatory
:
logger
.
warning_rank0_once
(
"Version checking has been disabled, may lead to unexpected behaviors."
)
logger
.
warning_rank0_once
(
"Version checking has been disabled, may lead to unexpected behaviors."
)
return
return
...
@@ -91,22 +88,18 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
...
@@ -91,22 +88,18 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def
check_dependencies
()
->
None
:
def
check_dependencies
()
->
None
:
r
"""
r
"""Check the version of the required packages."""
Checks the version of the required packages.
check_version
(
"transformers>=4.41.2,<=4.51.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0"
)
"""
check_version
(
"datasets>=2.16.0,<=3.4.1"
)
check_version
(
"transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0"
)
check_version
(
"accelerate>=0.34.0,<=1.5.2"
)
check_version
(
"datasets>=2.16.0,<=3.2.0"
)
check_version
(
"peft>=0.14.0,<=0.15.0"
)
check_version
(
"accelerate>=0.34.0,<=1.2.1"
)
check_version
(
"peft>=0.11.1,<=0.12.0"
)
check_version
(
"trl>=0.8.6,<=0.9.6"
)
check_version
(
"trl>=0.8.6,<=0.9.6"
)
if
is_transformers_version_greater_than
(
"4.46.0"
)
and
not
is_transformers_version_greater_than
(
"4.48.1"
):
if
is_transformers_version_greater_than
(
"4.46.0"
)
and
not
is_transformers_version_greater_than
(
"4.48.1"
):
logger
.
warning_rank0_once
(
"There are known bugs in transformers v4.46.0-v4.48.0, please use other versions."
)
logger
.
warning_rank0_once
(
"There are known bugs in transformers v4.46.0-v4.48.0, please use other versions."
)
def
calculate_tps
(
dataset
:
Sequence
[
Dict
[
str
,
Any
]],
metrics
:
Dict
[
str
,
float
],
stage
:
Literal
[
"sft"
,
"rm"
])
->
float
:
def
calculate_tps
(
dataset
:
list
[
dict
[
str
,
Any
]],
metrics
:
dict
[
str
,
float
],
stage
:
Literal
[
"sft"
,
"rm"
])
->
float
:
r
"""
r
"""Calculate effective tokens per second."""
Calculates effective tokens per second.
"""
effective_token_num
=
0
effective_token_num
=
0
for
data
in
dataset
:
for
data
in
dataset
:
if
stage
==
"sft"
:
if
stage
==
"sft"
:
...
@@ -118,10 +111,8 @@ def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float],
...
@@ -118,10 +111,8 @@ def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float],
return
result
/
dist
.
get_world_size
()
if
dist
.
is_initialized
()
else
result
return
result
/
dist
.
get_world_size
()
if
dist
.
is_initialized
()
else
result
def
count_parameters
(
model
:
"torch.nn.Module"
)
->
Tuple
[
int
,
int
]:
def
count_parameters
(
model
:
"torch.nn.Module"
)
->
tuple
[
int
,
int
]:
r
"""
r
"""Return the number of trainable parameters and number of all parameters in the model."""
Returns the number of trainable parameters and number of all parameters in the model.
"""
trainable_params
,
all_param
=
0
,
0
trainable_params
,
all_param
=
0
,
0
for
param
in
model
.
parameters
():
for
param
in
model
.
parameters
():
num_params
=
param
.
numel
()
num_params
=
param
.
numel
()
...
@@ -148,9 +139,7 @@ def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
...
@@ -148,9 +139,7 @@ def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
def
get_current_device
()
->
"torch.device"
:
def
get_current_device
()
->
"torch.device"
:
r
"""
r
"""Get the current available device."""
Gets the current available device.
"""
if
is_torch_xpu_available
():
if
is_torch_xpu_available
():
device
=
"xpu:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
device
=
"xpu:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
elif
is_torch_npu_available
():
elif
is_torch_npu_available
():
...
@@ -166,9 +155,7 @@ def get_current_device() -> "torch.device":
...
@@ -166,9 +155,7 @@ def get_current_device() -> "torch.device":
def
get_device_count
()
->
int
:
def
get_device_count
()
->
int
:
r
"""
r
"""Get the number of available GPU or NPU devices."""
Gets the number of available GPU or NPU devices.
"""
if
is_torch_xpu_available
():
if
is_torch_xpu_available
():
return
torch
.
xpu
.
device_count
()
return
torch
.
xpu
.
device_count
()
elif
is_torch_npu_available
():
elif
is_torch_npu_available
():
...
@@ -180,18 +167,14 @@ def get_device_count() -> int:
...
@@ -180,18 +167,14 @@ def get_device_count() -> int:
def
get_logits_processor
()
->
"LogitsProcessorList"
:
def
get_logits_processor
()
->
"LogitsProcessorList"
:
r
"""
r
"""Get logits processor that removes NaN and Inf logits."""
Gets logits processor that removes NaN and Inf logits.
"""
logits_processor
=
LogitsProcessorList
()
logits_processor
=
LogitsProcessorList
()
logits_processor
.
append
(
InfNanRemoveLogitsProcessor
())
logits_processor
.
append
(
InfNanRemoveLogitsProcessor
())
return
logits_processor
return
logits_processor
def
get_peak_memory
()
->
Tuple
[
int
,
int
]:
def
get_peak_memory
()
->
tuple
[
int
,
int
]:
r
"""
r
"""Get the peak memory usage for the current device (in Bytes)."""
Gets the peak memory usage for the current device (in Bytes).
"""
if
is_torch_npu_available
():
if
is_torch_npu_available
():
return
torch
.
npu
.
max_memory_allocated
(),
torch
.
npu
.
max_memory_reserved
()
return
torch
.
npu
.
max_memory_allocated
(),
torch
.
npu
.
max_memory_reserved
()
elif
is_torch_cuda_available
():
elif
is_torch_cuda_available
():
...
@@ -201,16 +184,12 @@ def get_peak_memory() -> Tuple[int, int]:
...
@@ -201,16 +184,12 @@ def get_peak_memory() -> Tuple[int, int]:
def
has_tokenized_data
(
path
:
"os.PathLike"
)
->
bool
:
def
has_tokenized_data
(
path
:
"os.PathLike"
)
->
bool
:
r
"""
r
"""Check if the path has a tokenized dataset."""
Checks if the path has a tokenized dataset.
"""
return
os
.
path
.
isdir
(
path
)
and
len
(
os
.
listdir
(
path
))
>
0
return
os
.
path
.
isdir
(
path
)
and
len
(
os
.
listdir
(
path
))
>
0
def
infer_optim_dtype
(
model_dtype
:
"torch.dtype"
)
->
"torch.dtype"
:
def
infer_optim_dtype
(
model_dtype
:
"torch.dtype"
)
->
"torch.dtype"
:
r
"""
r
"""Infer the optimal dtype according to the model_dtype and device compatibility."""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
if
_is_bf16_available
and
model_dtype
==
torch
.
bfloat16
:
if
_is_bf16_available
and
model_dtype
==
torch
.
bfloat16
:
return
torch
.
bfloat16
return
torch
.
bfloat16
elif
_is_fp16_available
:
elif
_is_fp16_available
:
...
@@ -220,23 +199,17 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
...
@@ -220,23 +199,17 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
def
is_gpu_or_npu_available
()
->
bool
:
def
is_gpu_or_npu_available
()
->
bool
:
r
"""
r
"""Check if the GPU or NPU is available."""
Checks if the GPU or NPU is available.
"""
return
is_torch_npu_available
()
or
is_torch_cuda_available
()
return
is_torch_npu_available
()
or
is_torch_cuda_available
()
def
is_env_enabled
(
env_var
:
str
,
default
:
str
=
"0"
)
->
bool
:
def
is_env_enabled
(
env_var
:
str
,
default
:
str
=
"0"
)
->
bool
:
r
"""
r
"""Check if the environment variable is enabled."""
Checks if the environment variable is enabled.
"""
return
os
.
getenv
(
env_var
,
default
).
lower
()
in
[
"true"
,
"y"
,
"1"
]
return
os
.
getenv
(
env_var
,
default
).
lower
()
in
[
"true"
,
"y"
,
"1"
]
def
numpify
(
inputs
:
Union
[
"NDArray"
,
"torch.Tensor"
])
->
"NDArray"
:
def
numpify
(
inputs
:
Union
[
"NDArray"
,
"torch.Tensor"
])
->
"NDArray"
:
r
"""
r
"""Cast a torch tensor or a numpy array to a numpy array."""
Casts a torch tensor or a numpy array to a numpy array.
"""
if
isinstance
(
inputs
,
torch
.
Tensor
):
if
isinstance
(
inputs
,
torch
.
Tensor
):
inputs
=
inputs
.
cpu
()
inputs
=
inputs
.
cpu
()
if
inputs
.
dtype
==
torch
.
bfloat16
:
# numpy does not support bfloat16 until 1.21.4
if
inputs
.
dtype
==
torch
.
bfloat16
:
# numpy does not support bfloat16 until 1.21.4
...
@@ -248,17 +221,13 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
...
@@ -248,17 +221,13 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
def
skip_check_imports
()
->
None
:
def
skip_check_imports
()
->
None
:
r
"""
r
"""Avoid flash attention import error in custom model files."""
Avoids flash attention import error in custom model files.
"""
if
not
is_env_enabled
(
"FORCE_CHECK_IMPORTS"
):
if
not
is_env_enabled
(
"FORCE_CHECK_IMPORTS"
):
transformers
.
dynamic_module_utils
.
check_imports
=
get_relative_imports
transformers
.
dynamic_module_utils
.
check_imports
=
get_relative_imports
def
torch_gc
()
->
None
:
def
torch_gc
()
->
None
:
r
"""
r
"""Collect GPU or NPU memory."""
Collects GPU or NPU memory.
"""
gc
.
collect
()
gc
.
collect
()
if
is_torch_xpu_available
():
if
is_torch_xpu_available
():
torch
.
xpu
.
empty_cache
()
torch
.
xpu
.
empty_cache
()
...
@@ -306,3 +275,20 @@ def use_openmind() -> bool:
...
@@ -306,3 +275,20 @@ def use_openmind() -> bool:
def
use_ray
()
->
bool
:
def
use_ray
()
->
bool
:
return
is_env_enabled
(
"USE_RAY"
)
return
is_env_enabled
(
"USE_RAY"
)
def
find_available_port
()
->
int
:
"""Find an available port on the local machine."""
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
bind
((
""
,
0
))
port
=
sock
.
getsockname
()[
1
]
sock
.
close
()
return
port
def
fix_proxy
(
ipv6_enabled
:
bool
)
->
None
:
"""Fix proxy settings for gradio ui."""
os
.
environ
[
"no_proxy"
]
=
"localhost,127.0.0.1,0.0.0.0"
if
ipv6_enabled
:
for
name
in
(
"http_proxy"
,
"https_proxy"
,
"HTTP_PROXY"
,
"HTTPS_PROXY"
):
os
.
environ
.
pop
(
name
,
None
)
src/llamafactory/extras/packages.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's transformers library.
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
...
@@ -97,3 +97,7 @@ def is_uvicorn_available():
...
@@ -97,3 +97,7 @@ def is_uvicorn_available():
def
is_vllm_available
():
def
is_vllm_available
():
return
_is_package_available
(
"vllm"
)
return
_is_package_available
(
"vllm"
)
def
is_sglang_available
():
return
_is_package_available
(
"sglang"
)
src/llamafactory/extras/ploting.py
View file @
7ea81099
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
import
json
import
json
import
math
import
math
import
os
import
os
from
typing
import
Any
,
Dict
,
List
from
typing
import
Any
from
transformers.trainer
import
TRAINER_STATE_NAME
from
transformers.trainer
import
TRAINER_STATE_NAME
...
@@ -31,10 +31,8 @@ if is_matplotlib_available():
...
@@ -31,10 +31,8 @@ if is_matplotlib_available():
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
smooth
(
scalars
:
List
[
float
])
->
List
[
float
]:
def
smooth
(
scalars
:
list
[
float
])
->
list
[
float
]:
r
"""
r
"""EMA implementation according to TensorBoard."""
EMA implementation according to TensorBoard.
"""
if
len
(
scalars
)
==
0
:
if
len
(
scalars
)
==
0
:
return
[]
return
[]
...
@@ -48,10 +46,8 @@ def smooth(scalars: List[float]) -> List[float]:
...
@@ -48,10 +46,8 @@ def smooth(scalars: List[float]) -> List[float]:
return
smoothed
return
smoothed
def
gen_loss_plot
(
trainer_log
:
List
[
Dict
[
str
,
Any
]])
->
"matplotlib.figure.Figure"
:
def
gen_loss_plot
(
trainer_log
:
list
[
dict
[
str
,
Any
]])
->
"matplotlib.figure.Figure"
:
r
"""
r
"""Plot loss curves in LlamaBoard."""
Plots loss curves in LlamaBoard.
"""
plt
.
close
(
"all"
)
plt
.
close
(
"all"
)
plt
.
switch_backend
(
"agg"
)
plt
.
switch_backend
(
"agg"
)
fig
=
plt
.
figure
()
fig
=
plt
.
figure
()
...
@@ -70,10 +66,8 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
...
@@ -70,10 +66,8 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
return
fig
return
fig
def
plot_loss
(
save_dictionary
:
str
,
keys
:
List
[
str
]
=
[
"loss"
])
->
None
:
def
plot_loss
(
save_dictionary
:
str
,
keys
:
list
[
str
]
=
[
"loss"
])
->
None
:
r
"""
r
"""Plot loss curves and saves the image."""
Plots loss curves and saves the image.
"""
plt
.
switch_backend
(
"agg"
)
plt
.
switch_backend
(
"agg"
)
with
open
(
os
.
path
.
join
(
save_dictionary
,
TRAINER_STATE_NAME
),
encoding
=
"utf-8"
)
as
f
:
with
open
(
os
.
path
.
join
(
save_dictionary
,
TRAINER_STATE_NAME
),
encoding
=
"utf-8"
)
as
f
:
data
=
json
.
load
(
f
)
data
=
json
.
load
(
f
)
...
...
src/llamafactory/hparams/data_args.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's transformers library.
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
...
@@ -16,14 +16,12 @@
...
@@ -16,14 +16,12 @@
# limitations under the License.
# limitations under the License.
from
dataclasses
import
asdict
,
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Dict
,
Literal
,
Optional
from
typing
import
Any
,
Literal
,
Optional
@
dataclass
@
dataclass
class
DataArguments
:
class
DataArguments
:
r
"""
r
"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
template
:
Optional
[
str
]
=
field
(
template
:
Optional
[
str
]
=
field
(
default
=
None
,
default
=
None
,
...
@@ -162,5 +160,5 @@ class DataArguments:
...
@@ -162,5 +160,5 @@ class DataArguments:
if
self
.
mask_history
and
self
.
train_on_prompt
:
if
self
.
mask_history
and
self
.
train_on_prompt
:
raise
ValueError
(
"`mask_history` is incompatible with `train_on_prompt`."
)
raise
ValueError
(
"`mask_history` is incompatible with `train_on_prompt`."
)
def
to_dict
(
self
)
->
D
ict
[
str
,
Any
]:
def
to_dict
(
self
)
->
d
ict
[
str
,
Any
]:
return
asdict
(
self
)
return
asdict
(
self
)
src/llamafactory/hparams/evaluation_args.py
View file @
7ea81099
...
@@ -21,9 +21,7 @@ from datasets import DownloadMode
...
@@ -21,9 +21,7 @@ from datasets import DownloadMode
@
dataclass
@
dataclass
class
EvaluationArguments
:
class
EvaluationArguments
:
r
"""
r
"""Arguments pertaining to specify the evaluation parameters."""
Arguments pertaining to specify the evaluation parameters.
"""
task
:
str
=
field
(
task
:
str
=
field
(
metadata
=
{
"help"
:
"Name of the evaluation task."
},
metadata
=
{
"help"
:
"Name of the evaluation task."
},
...
...
src/llamafactory/hparams/finetuning_args.py
View file @
7ea81099
...
@@ -13,14 +13,12 @@
...
@@ -13,14 +13,12 @@
# limitations under the License.
# limitations under the License.
from
dataclasses
import
asdict
,
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
from
typing
import
Any
,
Literal
,
Optional
@
dataclass
@
dataclass
class
FreezeArguments
:
class
FreezeArguments
:
r
"""
r
"""Arguments pertaining to the freeze (partial-parameter) training."""
Arguments pertaining to the freeze (partial-parameter) training.
"""
freeze_trainable_layers
:
int
=
field
(
freeze_trainable_layers
:
int
=
field
(
default
=
2
,
default
=
2
,
...
@@ -56,9 +54,7 @@ class FreezeArguments:
...
@@ -56,9 +54,7 @@ class FreezeArguments:
@
dataclass
@
dataclass
class
LoraArguments
:
class
LoraArguments
:
r
"""
r
"""Arguments pertaining to the LoRA training."""
Arguments pertaining to the LoRA training.
"""
additional_target
:
Optional
[
str
]
=
field
(
additional_target
:
Optional
[
str
]
=
field
(
default
=
None
,
default
=
None
,
...
@@ -128,9 +124,7 @@ class LoraArguments:
...
@@ -128,9 +124,7 @@ class LoraArguments:
@
dataclass
@
dataclass
class
RLHFArguments
:
class
RLHFArguments
:
r
"""
r
"""Arguments pertaining to the PPO, DPO and KTO training."""
Arguments pertaining to the PPO, DPO and KTO training.
"""
pref_beta
:
float
=
field
(
pref_beta
:
float
=
field
(
default
=
0.1
,
default
=
0.1
,
...
@@ -212,9 +206,7 @@ class RLHFArguments:
...
@@ -212,9 +206,7 @@ class RLHFArguments:
@
dataclass
@
dataclass
class
GaloreArguments
:
class
GaloreArguments
:
r
"""
r
"""Arguments pertaining to the GaLore algorithm."""
Arguments pertaining to the GaLore algorithm.
"""
use_galore
:
bool
=
field
(
use_galore
:
bool
=
field
(
default
=
False
,
default
=
False
,
...
@@ -253,9 +245,7 @@ class GaloreArguments:
...
@@ -253,9 +245,7 @@ class GaloreArguments:
@
dataclass
@
dataclass
class
ApolloArguments
:
class
ApolloArguments
:
r
"""
r
"""Arguments pertaining to the APOLLO algorithm."""
Arguments pertaining to the APOLLO algorithm.
"""
use_apollo
:
bool
=
field
(
use_apollo
:
bool
=
field
(
default
=
False
,
default
=
False
,
...
@@ -306,9 +296,7 @@ class ApolloArguments:
...
@@ -306,9 +296,7 @@ class ApolloArguments:
@
dataclass
@
dataclass
class
BAdamArgument
:
class
BAdamArgument
:
r
"""
r
"""Arguments pertaining to the BAdam optimizer."""
Arguments pertaining to the BAdam optimizer.
"""
use_badam
:
bool
=
field
(
use_badam
:
bool
=
field
(
default
=
False
,
default
=
False
,
...
@@ -387,15 +375,21 @@ class SwanLabArguments:
...
@@ -387,15 +375,21 @@ class SwanLabArguments:
default
=
None
,
default
=
None
,
metadata
=
{
"help"
:
"The log directory for SwanLab."
},
metadata
=
{
"help"
:
"The log directory for SwanLab."
},
)
)
swanlab_lark_webhook_url
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The Lark(飞书) webhook URL for SwanLab."
},
)
swanlab_lark_secret
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The Lark(飞书) secret for SwanLab."
},
)
@
dataclass
@
dataclass
class
FinetuningArguments
(
class
FinetuningArguments
(
SwanLabArguments
,
BAdamArgument
,
ApolloArguments
,
GaloreArguments
,
RLHFArguments
,
LoraArguments
,
FreezeArguments
SwanLabArguments
,
BAdamArgument
,
ApolloArguments
,
GaloreArguments
,
RLHFArguments
,
LoraArguments
,
FreezeArguments
):
):
r
"""
r
"""Arguments pertaining to which techniques we are going to fine-tuning with."""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
pure_bf16
:
bool
=
field
(
pure_bf16
:
bool
=
field
(
default
=
False
,
default
=
False
,
...
@@ -452,13 +446,13 @@ class FinetuningArguments(
...
@@ -452,13 +446,13 @@ class FinetuningArguments(
return
[
item
.
strip
()
for
item
in
arg
.
split
(
","
)]
return
[
item
.
strip
()
for
item
in
arg
.
split
(
","
)]
return
arg
return
arg
self
.
freeze_trainable_modules
:
L
ist
[
str
]
=
split_arg
(
self
.
freeze_trainable_modules
)
self
.
freeze_trainable_modules
:
l
ist
[
str
]
=
split_arg
(
self
.
freeze_trainable_modules
)
self
.
freeze_extra_modules
:
Optional
[
L
ist
[
str
]]
=
split_arg
(
self
.
freeze_extra_modules
)
self
.
freeze_extra_modules
:
Optional
[
l
ist
[
str
]]
=
split_arg
(
self
.
freeze_extra_modules
)
self
.
lora_alpha
:
int
=
self
.
lora_alpha
or
self
.
lora_rank
*
2
self
.
lora_alpha
:
int
=
self
.
lora_alpha
or
self
.
lora_rank
*
2
self
.
lora_target
:
L
ist
[
str
]
=
split_arg
(
self
.
lora_target
)
self
.
lora_target
:
l
ist
[
str
]
=
split_arg
(
self
.
lora_target
)
self
.
additional_target
:
Optional
[
L
ist
[
str
]]
=
split_arg
(
self
.
additional_target
)
self
.
additional_target
:
Optional
[
l
ist
[
str
]]
=
split_arg
(
self
.
additional_target
)
self
.
galore_target
:
L
ist
[
str
]
=
split_arg
(
self
.
galore_target
)
self
.
galore_target
:
l
ist
[
str
]
=
split_arg
(
self
.
galore_target
)
self
.
apollo_target
:
L
ist
[
str
]
=
split_arg
(
self
.
apollo_target
)
self
.
apollo_target
:
l
ist
[
str
]
=
split_arg
(
self
.
apollo_target
)
self
.
use_ref_model
=
self
.
stage
==
"dpo"
and
self
.
pref_loss
not
in
[
"orpo"
,
"simpo"
]
self
.
use_ref_model
=
self
.
stage
==
"dpo"
and
self
.
pref_loss
not
in
[
"orpo"
,
"simpo"
]
assert
self
.
finetuning_type
in
[
"lora"
,
"freeze"
,
"full"
],
"Invalid fine-tuning method."
assert
self
.
finetuning_type
in
[
"lora"
,
"freeze"
,
"full"
],
"Invalid fine-tuning method."
...
@@ -499,7 +493,7 @@ class FinetuningArguments(
...
@@ -499,7 +493,7 @@ class FinetuningArguments(
if
self
.
pissa_init
:
if
self
.
pissa_init
:
raise
ValueError
(
"`pissa_init` is only valid for LoRA training."
)
raise
ValueError
(
"`pissa_init` is only valid for LoRA training."
)
def
to_dict
(
self
)
->
D
ict
[
str
,
Any
]:
def
to_dict
(
self
)
->
d
ict
[
str
,
Any
]:
args
=
asdict
(
self
)
args
=
asdict
(
self
)
args
=
{
k
:
f
"<
{
k
.
upper
()
}
>"
if
k
.
endswith
(
"api_key"
)
else
v
for
k
,
v
in
args
.
items
()}
args
=
{
k
:
f
"<
{
k
.
upper
()
}
>"
if
k
.
endswith
(
"api_key"
)
else
v
for
k
,
v
in
args
.
items
()}
return
args
return
args
src/llamafactory/hparams/generating_args.py
View file @
7ea81099
...
@@ -13,16 +13,14 @@
...
@@ -13,16 +13,14 @@
# limitations under the License.
# limitations under the License.
from
dataclasses
import
asdict
,
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
Optional
from
transformers
import
GenerationConfig
from
transformers
import
GenerationConfig
@
dataclass
@
dataclass
class
GeneratingArguments
:
class
GeneratingArguments
:
r
"""
r
"""Arguments pertaining to specify the decoding parameters."""
Arguments pertaining to specify the decoding parameters.
"""
do_sample
:
bool
=
field
(
do_sample
:
bool
=
field
(
default
=
True
,
default
=
True
,
...
@@ -35,7 +33,9 @@ class GeneratingArguments:
...
@@ -35,7 +33,9 @@ class GeneratingArguments:
top_p
:
float
=
field
(
top_p
:
float
=
field
(
default
=
0.7
,
default
=
0.7
,
metadata
=
{
metadata
=
{
"help"
:
"The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
"help"
:
(
"The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
)
},
},
)
)
top_k
:
int
=
field
(
top_k
:
int
=
field
(
...
@@ -71,7 +71,7 @@ class GeneratingArguments:
...
@@ -71,7 +71,7 @@ class GeneratingArguments:
metadata
=
{
"help"
:
"Whether or not to remove special tokens in the decoding."
},
metadata
=
{
"help"
:
"Whether or not to remove special tokens in the decoding."
},
)
)
def
to_dict
(
self
,
obey_generation_config
:
bool
=
False
)
->
D
ict
[
str
,
Any
]:
def
to_dict
(
self
,
obey_generation_config
:
bool
=
False
)
->
d
ict
[
str
,
Any
]:
args
=
asdict
(
self
)
args
=
asdict
(
self
)
if
args
.
get
(
"max_new_tokens"
,
-
1
)
>
0
:
if
args
.
get
(
"max_new_tokens"
,
-
1
)
>
0
:
args
.
pop
(
"max_length"
,
None
)
args
.
pop
(
"max_length"
,
None
)
...
...
src/llamafactory/hparams/model_args.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's transformers library.
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
import
json
import
json
from
dataclasses
import
asdict
,
dataclass
,
field
,
fields
from
dataclasses
import
asdict
,
dataclass
,
field
,
fields
from
typing
import
Any
,
Dict
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
Literal
,
Optional
,
Union
import
torch
import
torch
from
transformers.training_args
import
_convert_str_dict
from
transformers.training_args
import
_convert_str_dict
...
@@ -28,9 +28,7 @@ from ..extras.constants import AttentionFunction, EngineName, RopeScaling
...
@@ -28,9 +28,7 @@ from ..extras.constants import AttentionFunction, EngineName, RopeScaling
@
dataclass
@
dataclass
class
BaseModelArguments
:
class
BaseModelArguments
:
r
"""
r
"""Arguments pertaining to the model."""
Arguments pertaining to the model.
"""
model_name_or_path
:
Optional
[
str
]
=
field
(
model_name_or_path
:
Optional
[
str
]
=
field
(
default
=
None
,
default
=
None
,
...
@@ -184,9 +182,7 @@ class BaseModelArguments:
...
@@ -184,9 +182,7 @@ class BaseModelArguments:
@
dataclass
@
dataclass
class
QuantizationArguments
:
class
QuantizationArguments
:
r
"""
r
"""Arguments pertaining to the quantization method."""
Arguments pertaining to the quantization method.
"""
quantization_method
:
Literal
[
"bitsandbytes"
,
"hqq"
,
"eetq"
]
=
field
(
quantization_method
:
Literal
[
"bitsandbytes"
,
"hqq"
,
"eetq"
]
=
field
(
default
=
"bitsandbytes"
,
default
=
"bitsandbytes"
,
...
@@ -212,9 +208,7 @@ class QuantizationArguments:
...
@@ -212,9 +208,7 @@ class QuantizationArguments:
@
dataclass
@
dataclass
class
ProcessorArguments
:
class
ProcessorArguments
:
r
"""
r
"""Arguments pertaining to the image processor."""
Arguments pertaining to the image processor.
"""
image_max_pixels
:
int
=
field
(
image_max_pixels
:
int
=
field
(
default
=
768
*
768
,
default
=
768
*
768
,
...
@@ -224,6 +218,14 @@ class ProcessorArguments:
...
@@ -224,6 +218,14 @@ class ProcessorArguments:
default
=
32
*
32
,
default
=
32
*
32
,
metadata
=
{
"help"
:
"The minimum number of pixels of image inputs."
},
metadata
=
{
"help"
:
"The minimum number of pixels of image inputs."
},
)
)
image_do_pan_and_scan
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Use pan and scan to process image for gemma3."
},
)
use_audio_in_video
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use audio in video inputs."
},
)
video_max_pixels
:
int
=
field
(
video_max_pixels
:
int
=
field
(
default
=
256
*
256
,
default
=
256
*
256
,
metadata
=
{
"help"
:
"The maximum number of pixels of video inputs."
},
metadata
=
{
"help"
:
"The maximum number of pixels of video inputs."
},
...
@@ -240,13 +242,22 @@ class ProcessorArguments:
...
@@ -240,13 +242,22 @@ class ProcessorArguments:
default
=
128
,
default
=
128
,
metadata
=
{
"help"
:
"The maximum number of sampled frames for video inputs."
},
metadata
=
{
"help"
:
"The maximum number of sampled frames for video inputs."
},
)
)
audio_sampling_rate
:
int
=
field
(
default
=
16000
,
metadata
=
{
"help"
:
"The sampling rate of audio inputs."
},
)
def
__post_init__
(
self
):
if
self
.
image_max_pixels
<
self
.
image_min_pixels
:
raise
ValueError
(
"`image_max_pixels` cannot be smaller than `image_min_pixels`."
)
if
self
.
video_max_pixels
<
self
.
video_min_pixels
:
raise
ValueError
(
"`video_max_pixels` cannot be smaller than `video_min_pixels`."
)
@
dataclass
@
dataclass
class
ExportArguments
:
class
ExportArguments
:
r
"""
r
"""Arguments pertaining to the model export."""
Arguments pertaining to the model export.
"""
export_dir
:
Optional
[
str
]
=
field
(
export_dir
:
Optional
[
str
]
=
field
(
default
=
None
,
default
=
None
,
...
@@ -292,16 +303,14 @@ class ExportArguments:
...
@@ -292,16 +303,14 @@ class ExportArguments:
@
dataclass
@
dataclass
class
VllmArguments
:
class
VllmArguments
:
r
"""
r
"""Arguments pertaining to the vLLM worker."""
Arguments pertaining to the vLLM worker.
"""
vllm_maxlen
:
int
=
field
(
vllm_maxlen
:
int
=
field
(
default
=
4096
,
default
=
4096
,
metadata
=
{
"help"
:
"Maximum sequence (prompt + response) length of the vLLM engine."
},
metadata
=
{
"help"
:
"Maximum sequence (prompt + response) length of the vLLM engine."
},
)
)
vllm_gpu_util
:
float
=
field
(
vllm_gpu_util
:
float
=
field
(
default
=
0.
9
,
default
=
0.
7
,
metadata
=
{
"help"
:
"The fraction of GPU memory in (0,1) to be used for the vLLM engine."
},
metadata
=
{
"help"
:
"The fraction of GPU memory in (0,1) to be used for the vLLM engine."
},
)
)
vllm_enforce_eager
:
bool
=
field
(
vllm_enforce_eager
:
bool
=
field
(
...
@@ -323,9 +332,36 @@ class VllmArguments:
...
@@ -323,9 +332,36 @@ class VllmArguments:
@
dataclass
@
dataclass
class
ModelArguments
(
VllmArguments
,
ExportArguments
,
ProcessorArguments
,
QuantizationArguments
,
BaseModelArguments
):
class
SGLangArguments
:
r
"""
r
"""Arguments pertaining to the SGLang worker."""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
sglang_maxlen
:
int
=
field
(
default
=
4096
,
metadata
=
{
"help"
:
"Maximum sequence (prompt + response) length of the SGLang engine."
},
)
sglang_mem_fraction
:
float
=
field
(
default
=
0.7
,
metadata
=
{
"help"
:
"The memory fraction (0-1) to be used for the SGLang engine."
},
)
sglang_tp_size
:
int
=
field
(
default
=-
1
,
metadata
=
{
"help"
:
"Tensor parallel size for the SGLang engine."
},
)
sglang_config
:
Optional
[
Union
[
dict
,
str
]]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Config to initialize the SGLang engine. Please use JSON strings."
},
)
def
__post_init__
(
self
):
if
isinstance
(
self
.
sglang_config
,
str
)
and
self
.
sglang_config
.
startswith
(
"{"
):
self
.
sglang_config
=
_convert_str_dict
(
json
.
loads
(
self
.
sglang_config
))
@
dataclass
class
ModelArguments
(
SGLangArguments
,
VllmArguments
,
ExportArguments
,
ProcessorArguments
,
QuantizationArguments
,
BaseModelArguments
):
r
"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
The class on the most right will be displayed first.
The class on the most right will be displayed first.
"""
"""
...
@@ -335,7 +371,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
...
@@ -335,7 +371,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
init
=
False
,
init
=
False
,
metadata
=
{
"help"
:
"Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."
},
metadata
=
{
"help"
:
"Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."
},
)
)
device_map
:
Optional
[
Union
[
str
,
D
ict
[
str
,
Any
]]]
=
field
(
device_map
:
Optional
[
Union
[
str
,
d
ict
[
str
,
Any
]]]
=
field
(
default
=
None
,
default
=
None
,
init
=
False
,
init
=
False
,
metadata
=
{
"help"
:
"Device map for model placement, derived from training stage. Do not specify it."
},
metadata
=
{
"help"
:
"Device map for model placement, derived from training stage. Do not specify it."
},
...
@@ -353,8 +389,10 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
...
@@ -353,8 +389,10 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
def
__post_init__
(
self
):
def
__post_init__
(
self
):
BaseModelArguments
.
__post_init__
(
self
)
BaseModelArguments
.
__post_init__
(
self
)
ProcessorArguments
.
__post_init__
(
self
)
ExportArguments
.
__post_init__
(
self
)
ExportArguments
.
__post_init__
(
self
)
VllmArguments
.
__post_init__
(
self
)
VllmArguments
.
__post_init__
(
self
)
SGLangArguments
.
__post_init__
(
self
)
@
classmethod
@
classmethod
def
copyfrom
(
cls
,
source
:
"Self"
,
**
kwargs
)
->
"Self"
:
def
copyfrom
(
cls
,
source
:
"Self"
,
**
kwargs
)
->
"Self"
:
...
@@ -372,7 +410,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
...
@@ -372,7 +410,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
return
result
return
result
def
to_dict
(
self
)
->
D
ict
[
str
,
Any
]:
def
to_dict
(
self
)
->
d
ict
[
str
,
Any
]:
args
=
asdict
(
self
)
args
=
asdict
(
self
)
args
=
{
k
:
f
"<
{
k
.
upper
()
}
>"
if
k
.
endswith
(
"token"
)
else
v
for
k
,
v
in
args
.
items
()}
args
=
{
k
:
f
"<
{
k
.
upper
()
}
>"
if
k
.
endswith
(
"token"
)
else
v
for
k
,
v
in
args
.
items
()}
return
args
return
args
src/llamafactory/hparams/parser.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's transformers library.
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
...
@@ -19,7 +19,7 @@ import json
...
@@ -19,7 +19,7 @@ import json
import
os
import
os
import
sys
import
sys
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Optional
,
Union
import
torch
import
torch
import
transformers
import
transformers
...
@@ -31,7 +31,7 @@ from transformers.training_args import ParallelMode
...
@@ -31,7 +31,7 @@ from transformers.training_args import ParallelMode
from
transformers.utils
import
is_torch_bf16_gpu_available
,
is_torch_npu_available
from
transformers.utils
import
is_torch_bf16_gpu_available
,
is_torch_npu_available
from
..extras
import
logging
from
..extras
import
logging
from
..extras.constants
import
CHECKPOINT_NAMES
from
..extras.constants
import
CHECKPOINT_NAMES
,
EngineName
from
..extras.misc
import
check_dependencies
,
check_version
,
get_current_device
,
is_env_enabled
from
..extras.misc
import
check_dependencies
,
check_version
,
get_current_device
,
is_env_enabled
from
.data_args
import
DataArguments
from
.data_args
import
DataArguments
from
.evaluation_args
import
EvaluationArguments
from
.evaluation_args
import
EvaluationArguments
...
@@ -47,17 +47,15 @@ check_dependencies()
...
@@ -47,17 +47,15 @@ check_dependencies()
_TRAIN_ARGS
=
[
ModelArguments
,
DataArguments
,
TrainingArguments
,
FinetuningArguments
,
GeneratingArguments
]
_TRAIN_ARGS
=
[
ModelArguments
,
DataArguments
,
TrainingArguments
,
FinetuningArguments
,
GeneratingArguments
]
_TRAIN_CLS
=
T
uple
[
ModelArguments
,
DataArguments
,
TrainingArguments
,
FinetuningArguments
,
GeneratingArguments
]
_TRAIN_CLS
=
t
uple
[
ModelArguments
,
DataArguments
,
TrainingArguments
,
FinetuningArguments
,
GeneratingArguments
]
_INFER_ARGS
=
[
ModelArguments
,
DataArguments
,
FinetuningArguments
,
GeneratingArguments
]
_INFER_ARGS
=
[
ModelArguments
,
DataArguments
,
FinetuningArguments
,
GeneratingArguments
]
_INFER_CLS
=
T
uple
[
ModelArguments
,
DataArguments
,
FinetuningArguments
,
GeneratingArguments
]
_INFER_CLS
=
t
uple
[
ModelArguments
,
DataArguments
,
FinetuningArguments
,
GeneratingArguments
]
_EVAL_ARGS
=
[
ModelArguments
,
DataArguments
,
EvaluationArguments
,
FinetuningArguments
]
_EVAL_ARGS
=
[
ModelArguments
,
DataArguments
,
EvaluationArguments
,
FinetuningArguments
]
_EVAL_CLS
=
T
uple
[
ModelArguments
,
DataArguments
,
EvaluationArguments
,
FinetuningArguments
]
_EVAL_CLS
=
t
uple
[
ModelArguments
,
DataArguments
,
EvaluationArguments
,
FinetuningArguments
]
def
read_args
(
args
:
Optional
[
Union
[
Dict
[
str
,
Any
],
List
[
str
]]]
=
None
)
->
Union
[
Dict
[
str
,
Any
],
List
[
str
]]:
def
read_args
(
args
:
Optional
[
Union
[
dict
[
str
,
Any
],
list
[
str
]]]
=
None
)
->
Union
[
dict
[
str
,
Any
],
list
[
str
]]:
r
"""
r
"""Get arguments from the command line or a config file."""
Gets arguments from the command line or a config file.
"""
if
args
is
not
None
:
if
args
is
not
None
:
return
args
return
args
...
@@ -70,8 +68,8 @@ def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[
...
@@ -70,8 +68,8 @@ def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[
def
_parse_args
(
def
_parse_args
(
parser
:
"HfArgumentParser"
,
args
:
Optional
[
Union
[
D
ict
[
str
,
Any
],
L
ist
[
str
]]]
=
None
,
allow_extra_keys
:
bool
=
False
parser
:
"HfArgumentParser"
,
args
:
Optional
[
Union
[
d
ict
[
str
,
Any
],
l
ist
[
str
]]]
=
None
,
allow_extra_keys
:
bool
=
False
)
->
T
uple
[
Any
]:
)
->
t
uple
[
Any
]:
args
=
read_args
(
args
)
args
=
read_args
(
args
)
if
isinstance
(
args
,
dict
):
if
isinstance
(
args
,
dict
):
return
parser
.
parse_dict
(
args
,
allow_extra_keys
=
allow_extra_keys
)
return
parser
.
parse_dict
(
args
,
allow_extra_keys
=
allow_extra_keys
)
...
@@ -136,9 +134,12 @@ def _check_extra_dependencies(
...
@@ -136,9 +134,12 @@ def _check_extra_dependencies(
if
model_args
.
mixture_of_depths
is
not
None
:
if
model_args
.
mixture_of_depths
is
not
None
:
check_version
(
"mixture-of-depth>=1.1.6"
,
mandatory
=
True
)
check_version
(
"mixture-of-depth>=1.1.6"
,
mandatory
=
True
)
if
model_args
.
infer_backend
==
"vllm"
:
if
model_args
.
infer_backend
==
EngineName
.
VLLM
:
check_version
(
"vllm>=0.4.3,<=0.
7.3
"
)
check_version
(
"vllm>=0.4.3,<=0.
8.2
"
)
check_version
(
"vllm"
,
mandatory
=
True
)
check_version
(
"vllm"
,
mandatory
=
True
)
elif
model_args
.
infer_backend
==
EngineName
.
SGLANG
:
check_version
(
"sglang>=0.4.4"
)
check_version
(
"sglang"
,
mandatory
=
True
)
if
finetuning_args
.
use_galore
:
if
finetuning_args
.
use_galore
:
check_version
(
"galore_torch"
,
mandatory
=
True
)
check_version
(
"galore_torch"
,
mandatory
=
True
)
...
@@ -161,31 +162,31 @@ def _check_extra_dependencies(
...
@@ -161,31 +162,31 @@ def _check_extra_dependencies(
check_version
(
"rouge_chinese"
,
mandatory
=
True
)
check_version
(
"rouge_chinese"
,
mandatory
=
True
)
def
_parse_train_args
(
args
:
Optional
[
Union
[
D
ict
[
str
,
Any
],
L
ist
[
str
]]]
=
None
)
->
_TRAIN_CLS
:
def
_parse_train_args
(
args
:
Optional
[
Union
[
d
ict
[
str
,
Any
],
l
ist
[
str
]]]
=
None
)
->
_TRAIN_CLS
:
parser
=
HfArgumentParser
(
_TRAIN_ARGS
)
parser
=
HfArgumentParser
(
_TRAIN_ARGS
)
allow_extra_keys
=
is_env_enabled
(
"ALLOW_EXTRA_ARGS"
)
allow_extra_keys
=
is_env_enabled
(
"ALLOW_EXTRA_ARGS"
)
return
_parse_args
(
parser
,
args
,
allow_extra_keys
=
allow_extra_keys
)
return
_parse_args
(
parser
,
args
,
allow_extra_keys
=
allow_extra_keys
)
def
_parse_infer_args
(
args
:
Optional
[
Union
[
D
ict
[
str
,
Any
],
L
ist
[
str
]]]
=
None
)
->
_INFER_CLS
:
def
_parse_infer_args
(
args
:
Optional
[
Union
[
d
ict
[
str
,
Any
],
l
ist
[
str
]]]
=
None
)
->
_INFER_CLS
:
parser
=
HfArgumentParser
(
_INFER_ARGS
)
parser
=
HfArgumentParser
(
_INFER_ARGS
)
allow_extra_keys
=
is_env_enabled
(
"ALLOW_EXTRA_ARGS"
)
allow_extra_keys
=
is_env_enabled
(
"ALLOW_EXTRA_ARGS"
)
return
_parse_args
(
parser
,
args
,
allow_extra_keys
=
allow_extra_keys
)
return
_parse_args
(
parser
,
args
,
allow_extra_keys
=
allow_extra_keys
)
def
_parse_eval_args
(
args
:
Optional
[
Union
[
D
ict
[
str
,
Any
],
L
ist
[
str
]]]
=
None
)
->
_EVAL_CLS
:
def
_parse_eval_args
(
args
:
Optional
[
Union
[
d
ict
[
str
,
Any
],
l
ist
[
str
]]]
=
None
)
->
_EVAL_CLS
:
parser
=
HfArgumentParser
(
_EVAL_ARGS
)
parser
=
HfArgumentParser
(
_EVAL_ARGS
)
allow_extra_keys
=
is_env_enabled
(
"ALLOW_EXTRA_ARGS"
)
allow_extra_keys
=
is_env_enabled
(
"ALLOW_EXTRA_ARGS"
)
return
_parse_args
(
parser
,
args
,
allow_extra_keys
=
allow_extra_keys
)
return
_parse_args
(
parser
,
args
,
allow_extra_keys
=
allow_extra_keys
)
def
get_ray_args
(
args
:
Optional
[
Union
[
D
ict
[
str
,
Any
],
L
ist
[
str
]]]
=
None
)
->
RayArguments
:
def
get_ray_args
(
args
:
Optional
[
Union
[
d
ict
[
str
,
Any
],
l
ist
[
str
]]]
=
None
)
->
RayArguments
:
parser
=
HfArgumentParser
(
RayArguments
)
parser
=
HfArgumentParser
(
RayArguments
)
(
ray_args
,)
=
_parse_args
(
parser
,
args
,
allow_extra_keys
=
True
)
(
ray_args
,)
=
_parse_args
(
parser
,
args
,
allow_extra_keys
=
True
)
return
ray_args
return
ray_args
def
get_train_args
(
args
:
Optional
[
Union
[
D
ict
[
str
,
Any
],
L
ist
[
str
]]]
=
None
)
->
_TRAIN_CLS
:
def
get_train_args
(
args
:
Optional
[
Union
[
d
ict
[
str
,
Any
],
l
ist
[
str
]]]
=
None
)
->
_TRAIN_CLS
:
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
=
_parse_train_args
(
args
)
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
=
_parse_train_args
(
args
)
# Setup logging
# Setup logging
...
@@ -364,9 +365,7 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
...
@@ -364,9 +365,7 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
and
training_args
.
resume_from_checkpoint
is
not
None
and
training_args
.
resume_from_checkpoint
is
not
None
):
):
logger
.
warning_rank0
(
logger
.
warning_rank0
(
"Add {} to `adapter_name_or_path` to resume training from checkpoint."
.
format
(
f
"Add
{
training_args
.
resume_from_checkpoint
}
to `adapter_name_or_path` to resume training from checkpoint."
training_args
.
resume_from_checkpoint
)
)
)
# Post-process model arguments
# Post-process model arguments
...
@@ -382,20 +381,17 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
...
@@ -382,20 +381,17 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
# Log on each process the small summary
# Log on each process the small summary
logger
.
info
(
logger
.
info
(
"Process rank: {}, world size: {}, device: {}, distributed training: {}, compute dtype: {}"
.
format
(
f
"Process rank:
{
training_args
.
process_index
}
, "
training_args
.
process_index
,
f
"world size:
{
training_args
.
world_size
}
, device:
{
training_args
.
device
}
, "
training_args
.
world_size
,
f
"distributed training:
{
training_args
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
}
, "
training_args
.
device
,
f
"compute dtype:
{
str
(
model_args
.
compute_dtype
)
}
"
training_args
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
,
str
(
model_args
.
compute_dtype
),
)
)
)
transformers
.
set_seed
(
training_args
.
seed
)
transformers
.
set_seed
(
training_args
.
seed
)
return
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
return
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
def
get_infer_args
(
args
:
Optional
[
Union
[
D
ict
[
str
,
Any
],
L
ist
[
str
]]]
=
None
)
->
_INFER_CLS
:
def
get_infer_args
(
args
:
Optional
[
Union
[
d
ict
[
str
,
Any
],
l
ist
[
str
]]]
=
None
)
->
_INFER_CLS
:
model_args
,
data_args
,
finetuning_args
,
generating_args
=
_parse_infer_args
(
args
)
model_args
,
data_args
,
finetuning_args
,
generating_args
=
_parse_infer_args
(
args
)
_set_transformers_logging
()
_set_transformers_logging
()
...
@@ -426,7 +422,7 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
...
@@ -426,7 +422,7 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
return
model_args
,
data_args
,
finetuning_args
,
generating_args
return
model_args
,
data_args
,
finetuning_args
,
generating_args
def
get_eval_args
(
args
:
Optional
[
Union
[
D
ict
[
str
,
Any
],
L
ist
[
str
]]]
=
None
)
->
_EVAL_CLS
:
def
get_eval_args
(
args
:
Optional
[
Union
[
d
ict
[
str
,
Any
],
l
ist
[
str
]]]
=
None
)
->
_EVAL_CLS
:
model_args
,
data_args
,
eval_args
,
finetuning_args
=
_parse_eval_args
(
args
)
model_args
,
data_args
,
eval_args
,
finetuning_args
=
_parse_eval_args
(
args
)
_set_transformers_logging
()
_set_transformers_logging
()
...
...
src/llamafactory/hparams/training_args.py
View file @
7ea81099
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
json
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
Literal
,
Optional
,
Union
from
typing
import
Literal
,
Optional
,
Union
...
@@ -10,9 +24,7 @@ from ..extras.misc import use_ray
...
@@ -10,9 +24,7 @@ from ..extras.misc import use_ray
@
dataclass
@
dataclass
class
RayArguments
:
class
RayArguments
:
r
"""
r
"""Arguments pertaining to the Ray training."""
Arguments pertaining to the Ray training.
"""
ray_run_name
:
Optional
[
str
]
=
field
(
ray_run_name
:
Optional
[
str
]
=
field
(
default
=
None
,
default
=
None
,
...
@@ -43,9 +55,7 @@ class RayArguments:
...
@@ -43,9 +55,7 @@ class RayArguments:
@
dataclass
@
dataclass
class
TrainingArguments
(
RayArguments
,
Seq2SeqTrainingArguments
):
class
TrainingArguments
(
RayArguments
,
Seq2SeqTrainingArguments
):
r
"""
r
"""Arguments pertaining to the trainer."""
Arguments pertaining to the trainer.
"""
def
__post_init__
(
self
):
def
__post_init__
(
self
):
Seq2SeqTrainingArguments
.
__post_init__
(
self
)
Seq2SeqTrainingArguments
.
__post_init__
(
self
)
...
...
src/llamafactory/model/__init__.py
View file @
7ea81099
...
@@ -20,9 +20,9 @@ from .model_utils.valuehead import load_valuehead_params
...
@@ -20,9 +20,9 @@ from .model_utils.valuehead import load_valuehead_params
__all__
=
[
__all__
=
[
"QuantizationMethod"
,
"QuantizationMethod"
,
"find_all_linear_modules"
,
"load_config"
,
"load_config"
,
"load_model"
,
"load_model"
,
"load_tokenizer"
,
"load_tokenizer"
,
"find_all_linear_modules"
,
"load_valuehead_params"
,
"load_valuehead_params"
,
]
]
src/llamafactory/model/adapter.py
View file @
7ea81099
...
@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING
...
@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING
import
torch
import
torch
from
peft
import
LoraConfig
,
LoraModel
,
PeftModel
,
TaskType
,
get_peft_model
from
peft
import
LoraConfig
,
LoraModel
,
PeftModel
,
TaskType
,
get_peft_model
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.modeling_utils
import
is_fsdp_enabled
from
..extras
import
logging
from
..extras
import
logging
from
.model_utils.misc
import
find_all_linear_modules
,
find_expanded_modules
from
.model_utils.misc
import
find_all_linear_modules
,
find_expanded_modules
...
@@ -81,9 +80,8 @@ def _setup_freeze_tuning(
...
@@ -81,9 +80,8 @@ def _setup_freeze_tuning(
if
finetuning_args
.
use_llama_pro
:
if
finetuning_args
.
use_llama_pro
:
if
num_layers
%
finetuning_args
.
freeze_trainable_layers
!=
0
:
if
num_layers
%
finetuning_args
.
freeze_trainable_layers
!=
0
:
raise
ValueError
(
raise
ValueError
(
"`num_layers` {} should be divisible by `num_layer_trainable` {}."
.
format
(
f
"`num_layers`
{
num_layers
}
should be "
num_layers
,
finetuning_args
.
freeze_trainable_layers
f
"divisible by `num_layer_trainable`
{
finetuning_args
.
freeze_trainable_layers
}
."
)
)
)
stride
=
num_layers
//
finetuning_args
.
freeze_trainable_layers
stride
=
num_layers
//
finetuning_args
.
freeze_trainable_layers
...
@@ -178,7 +176,7 @@ def _setup_lora_tuning(
...
@@ -178,7 +176,7 @@ def _setup_lora_tuning(
}
}
for
adapter
in
adapter_to_merge
:
for
adapter
in
adapter_to_merge
:
model
:
"
LoraModel
"
=
PeftModel
.
from_pretrained
(
model
,
adapter
,
**
init_kwargs
)
model
:
LoraModel
=
PeftModel
.
from_pretrained
(
model
,
adapter
,
**
init_kwargs
)
model
=
model
.
merge_and_unload
()
model
=
model
.
merge_and_unload
()
if
len
(
adapter_to_merge
)
>
0
:
if
len
(
adapter_to_merge
)
>
0
:
...
@@ -263,8 +261,7 @@ def init_adapter(
...
@@ -263,8 +261,7 @@ def init_adapter(
finetuning_args
:
"FinetuningArguments"
,
finetuning_args
:
"FinetuningArguments"
,
is_trainable
:
bool
,
is_trainable
:
bool
,
)
->
"PreTrainedModel"
:
)
->
"PreTrainedModel"
:
r
"""
r
"""Initialize the adapters.
Initializes the adapters.
Support full-parameter, freeze and LoRA training.
Support full-parameter, freeze and LoRA training.
...
@@ -279,14 +276,14 @@ def init_adapter(
...
@@ -279,14 +276,14 @@ def init_adapter(
# cast trainable parameters to float32 if:
# cast trainable parameters to float32 if:
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
# 2. is_trainable and not pure_bf16 and not badam and not zero3
and not fsdp (zero3 or fsdp
already in fp32)
# 2. is_trainable and not pure_bf16 and not badam and not zero3
(zero3
already in fp32)
cast_trainable_params_to_fp32
=
False
cast_trainable_params_to_fp32
=
False
if
not
is_trainable
:
if
not
is_trainable
:
pass
pass
elif
finetuning_args
.
pure_bf16
or
finetuning_args
.
use_badam
:
elif
finetuning_args
.
pure_bf16
or
finetuning_args
.
use_badam
:
logger
.
info_rank0
(
"Pure bf16 / BAdam detected, remaining trainable params in half precision."
)
logger
.
info_rank0
(
"Pure bf16 / BAdam detected, remaining trainable params in half precision."
)
elif
model_args
.
quantization_bit
is
None
and
(
is_deepspeed_zero3_enabled
()
or
is_fsdp_enabled
())
:
elif
model_args
.
quantization_bit
is
None
and
is_deepspeed_zero3_enabled
():
logger
.
info_rank0
(
"
ZeRO3 / FSDP
detected, remaining trainable params in float32."
)
logger
.
info_rank0
(
"
DeepSpeed ZeRO3
detected, remaining trainable params in float32."
)
else
:
else
:
logger
.
info_rank0
(
"Upcasting trainable params to float32."
)
logger
.
info_rank0
(
"Upcasting trainable params to float32."
)
cast_trainable_params_to_fp32
=
True
cast_trainable_params_to_fp32
=
True
...
...
src/llamafactory/model/loader.py
View file @
7ea81099
...
@@ -13,13 +13,15 @@
...
@@ -13,13 +13,15 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
,
TypedDict
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
TypedDict
import
torch
import
torch
from
transformers
import
(
from
transformers
import
(
AutoConfig
,
AutoConfig
,
AutoModelForCausalLM
,
AutoModelForCausalLM
,
AutoModelForImageTextToText
,
AutoModelForSeq2SeqLM
,
AutoModelForSeq2SeqLM
,
AutoModelForTextToWaveform
,
AutoModelForVision2Seq
,
AutoModelForVision2Seq
,
AutoProcessor
,
AutoProcessor
,
AutoTokenizer
,
AutoTokenizer
,
...
@@ -51,9 +53,8 @@ class TokenizerModule(TypedDict):
...
@@ -51,9 +53,8 @@ class TokenizerModule(TypedDict):
processor
:
Optional
[
"ProcessorMixin"
]
processor
:
Optional
[
"ProcessorMixin"
]
def
_get_init_kwargs
(
model_args
:
"ModelArguments"
)
->
Dict
[
str
,
Any
]:
def
_get_init_kwargs
(
model_args
:
"ModelArguments"
)
->
dict
[
str
,
Any
]:
r
"""
r
"""Get arguments to load config/tokenizer/model.
Gets arguments to load config/tokenizer/model.
Note: including inplace operation of model_args.
Note: including inplace operation of model_args.
"""
"""
...
@@ -68,13 +69,11 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
...
@@ -68,13 +69,11 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
def
load_tokenizer
(
model_args
:
"ModelArguments"
)
->
"TokenizerModule"
:
def
load_tokenizer
(
model_args
:
"ModelArguments"
)
->
"TokenizerModule"
:
r
"""
r
"""Load pretrained tokenizer and optionally loads processor.
Loads pretrained tokenizer and optionally loads processor.
Note: including inplace operation of model_args.
Note: including inplace operation of model_args.
"""
"""
init_kwargs
=
_get_init_kwargs
(
model_args
)
init_kwargs
=
_get_init_kwargs
(
model_args
)
config
=
load_config
(
model_args
)
try
:
try
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_args
.
model_name_or_path
,
model_args
.
model_name_or_path
,
...
@@ -96,7 +95,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
...
@@ -96,7 +95,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
patch_tokenizer
(
tokenizer
,
model_args
)
patch_tokenizer
(
tokenizer
,
model_args
)
try
:
try
:
processor
=
AutoProcessor
.
from_pretrained
(
model_args
.
model_name_or_path
,
**
init_kwargs
)
processor
=
AutoProcessor
.
from_pretrained
(
model_args
.
model_name_or_path
,
**
init_kwargs
)
patch_processor
(
processor
,
config
,
tokenizer
,
model_args
)
patch_processor
(
processor
,
tokenizer
,
model_args
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
debug
(
f
"Processor was not found:
{
e
}
."
)
logger
.
debug
(
f
"Processor was not found:
{
e
}
."
)
processor
=
None
processor
=
None
...
@@ -110,9 +109,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
...
@@ -110,9 +109,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
def
load_config
(
model_args
:
"ModelArguments"
)
->
"PretrainedConfig"
:
def
load_config
(
model_args
:
"ModelArguments"
)
->
"PretrainedConfig"
:
r
"""
r
"""Load model config."""
Loads model config.
"""
init_kwargs
=
_get_init_kwargs
(
model_args
)
init_kwargs
=
_get_init_kwargs
(
model_args
)
return
AutoConfig
.
from_pretrained
(
model_args
.
model_name_or_path
,
**
init_kwargs
)
return
AutoConfig
.
from_pretrained
(
model_args
.
model_name_or_path
,
**
init_kwargs
)
...
@@ -124,9 +121,7 @@ def load_model(
...
@@ -124,9 +121,7 @@ def load_model(
is_trainable
:
bool
=
False
,
is_trainable
:
bool
=
False
,
add_valuehead
:
bool
=
False
,
add_valuehead
:
bool
=
False
,
)
->
"PreTrainedModel"
:
)
->
"PreTrainedModel"
:
r
"""
r
"""Load pretrained model."""
Loads pretrained model.
"""
init_kwargs
=
_get_init_kwargs
(
model_args
)
init_kwargs
=
_get_init_kwargs
(
model_args
)
config
=
load_config
(
model_args
)
config
=
load_config
(
model_args
)
patch_config
(
config
,
tokenizer
,
model_args
,
init_kwargs
,
is_trainable
)
patch_config
(
config
,
tokenizer
,
model_args
,
init_kwargs
,
is_trainable
)
...
@@ -147,10 +142,14 @@ def load_model(
...
@@ -147,10 +142,14 @@ def load_model(
if
model_args
.
mixture_of_depths
==
"load"
:
if
model_args
.
mixture_of_depths
==
"load"
:
model
=
load_mod_pretrained_model
(
**
init_kwargs
)
model
=
load_mod_pretrained_model
(
**
init_kwargs
)
else
:
else
:
if
type
(
config
)
in
AutoModelForVision2Seq
.
_model_mapping
.
keys
():
#
assume built-in models
if
type
(
config
)
in
AutoModelForVision2Seq
.
_model_mapping
.
keys
():
#
image-text
load_class
=
AutoModelForVision2Seq
load_class
=
AutoModelForVision2Seq
elif
type
(
config
)
in
AutoModelForSeq2SeqLM
.
_model_mapping
.
keys
():
elif
type
(
config
)
in
AutoModelForImageTextToText
.
_model_mapping
.
keys
():
# image-text
load_class
=
AutoModelForImageTextToText
elif
type
(
config
)
in
AutoModelForSeq2SeqLM
.
_model_mapping
.
keys
():
# audio-text
load_class
=
AutoModelForSeq2SeqLM
load_class
=
AutoModelForSeq2SeqLM
elif
type
(
config
)
in
AutoModelForTextToWaveform
.
_model_mapping
.
keys
():
# audio hack for qwen2_5_omni
load_class
=
AutoModelForTextToWaveform
else
:
else
:
load_class
=
AutoModelForCausalLM
load_class
=
AutoModelForCausalLM
...
@@ -158,6 +157,8 @@ def load_model(
...
@@ -158,6 +157,8 @@ def load_model(
model
=
load_class
.
from_config
(
config
,
trust_remote_code
=
model_args
.
trust_remote_code
)
model
=
load_class
.
from_config
(
config
,
trust_remote_code
=
model_args
.
trust_remote_code
)
else
:
else
:
model
=
load_class
.
from_pretrained
(
**
init_kwargs
)
model
=
load_class
.
from_pretrained
(
**
init_kwargs
)
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"qwen2_5_omni"
:
model
=
model
.
thinker
# use part of Omni model
if
model_args
.
mixture_of_depths
==
"convert"
:
if
model_args
.
mixture_of_depths
==
"convert"
:
model
=
convert_pretrained_model_to_mod
(
model
,
config
,
model_args
)
model
=
convert_pretrained_model_to_mod
(
model
,
config
,
model_args
)
...
@@ -194,8 +195,9 @@ def load_model(
...
@@ -194,8 +195,9 @@ def load_model(
trainable_params
,
all_param
=
count_parameters
(
model
)
trainable_params
,
all_param
=
count_parameters
(
model
)
if
is_trainable
:
if
is_trainable
:
param_stats
=
"trainable params: {:,} || all params: {:,} || trainable%: {:.4f}"
.
format
(
param_stats
=
(
trainable_params
,
all_param
,
100
*
trainable_params
/
all_param
f
"trainable params:
{
trainable_params
:,
}
|| "
f
"all params:
{
all_param
:,
}
|| trainable%:
{
100
*
trainable_params
/
all_param
:.
4
f
}
"
)
)
else
:
else
:
param_stats
=
f
"all params:
{
all_param
:,
}
"
param_stats
=
f
"all params:
{
all_param
:,
}
"
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment