Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
24534501
Commit
24534501
authored
May 21, 2025
by
mashun1
Browse files
parallel_tool
parent
c4ba4563
Changes
63
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
491 additions
and
293 deletions
+491
-293
src/llamafactory/chat/sglang_engine.py
src/llamafactory/chat/sglang_engine.py
+15
-1
src/llamafactory/chat/vllm_engine.py
src/llamafactory/chat/vllm_engine.py
+0
-1
src/llamafactory/cli.py
src/llamafactory/cli.py
+1
-1
src/llamafactory/data/data_utils.py
src/llamafactory/data/data_utils.py
+13
-2
src/llamafactory/data/loader.py
src/llamafactory/data/loader.py
+10
-5
src/llamafactory/data/mm_plugin.py
src/llamafactory/data/mm_plugin.py
+57
-151
src/llamafactory/data/parser.py
src/llamafactory/data/parser.py
+1
-6
src/llamafactory/data/template.py
src/llamafactory/data/template.py
+172
-26
src/llamafactory/data/tool_utils.py
src/llamafactory/data/tool_utils.py
+26
-39
src/llamafactory/extras/constants.py
src/llamafactory/extras/constants.py
+112
-10
src/llamafactory/extras/misc.py
src/llamafactory/extras/misc.py
+13
-6
src/llamafactory/hparams/data_args.py
src/llamafactory/hparams/data_args.py
+12
-0
src/llamafactory/hparams/generating_args.py
src/llamafactory/hparams/generating_args.py
+1
-5
src/llamafactory/hparams/model_args.py
src/llamafactory/hparams/model_args.py
+10
-4
src/llamafactory/hparams/parser.py
src/llamafactory/hparams/parser.py
+2
-2
src/llamafactory/hparams/training_args.py
src/llamafactory/hparams/training_args.py
+1
-0
src/llamafactory/model/model_utils/attention.py
src/llamafactory/model/model_utils/attention.py
+2
-4
src/llamafactory/model/model_utils/liger_kernel.py
src/llamafactory/model/model_utils/liger_kernel.py
+12
-2
src/llamafactory/model/model_utils/moe.py
src/llamafactory/model/model_utils/moe.py
+25
-23
src/llamafactory/model/model_utils/quantization.py
src/llamafactory/model/model_utils/quantization.py
+6
-5
No files found.
src/llamafactory/chat/sglang_engine.py
View file @
24534501
...
@@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine):
...
@@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine):
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
data_args
)
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
data_args
)
self
.
template
.
mm_plugin
.
expand_mm_tokens
=
False
# for sglang generate
self
.
template
.
mm_plugin
.
expand_mm_tokens
=
False
# for sglang generate
self
.
generating_args
=
generating_args
.
to_dict
()
self
.
generating_args
=
generating_args
.
to_dict
()
if
model_args
.
adapter_name_or_path
is
not
None
:
self
.
lora_request
=
True
else
:
self
.
lora_request
=
False
launch_cmd
=
[
launch_cmd
=
[
"python3 -m sglang.launch_server"
,
"python3 -m sglang.launch_server"
,
...
@@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine):
...
@@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine):
f
"--download-dir
{
model_args
.
cache_dir
}
"
,
f
"--download-dir
{
model_args
.
cache_dir
}
"
,
"--log-level error"
,
"--log-level error"
,
]
]
if
self
.
lora_request
:
launch_cmd
.
extend
(
[
"--max-loras-per-batch 1"
,
f
"--lora-backend
{
model_args
.
sglang_lora_backend
}
"
,
f
"--lora-paths lora0=
{
model_args
.
adapter_name_or_path
[
0
]
}
"
,
"--disable-radix-cache"
,
]
)
launch_cmd
=
" "
.
join
(
launch_cmd
)
launch_cmd
=
" "
.
join
(
launch_cmd
)
logger
.
info_rank0
(
f
"Starting SGLang server with command:
{
launch_cmd
}
"
)
logger
.
info_rank0
(
f
"Starting SGLang server with command:
{
launch_cmd
}
"
)
try
:
try
:
...
@@ -147,7 +160,6 @@ class SGLangEngine(BaseEngine):
...
@@ -147,7 +160,6 @@ class SGLangEngine(BaseEngine):
messages
,
images
or
[],
videos
or
[],
audios
or
[],
self
.
processor
messages
,
images
or
[],
videos
or
[],
audios
or
[],
self
.
processor
)
)
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
self
.
generating_args
[
"default_system"
]
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_length
=
len
(
prompt_ids
)
prompt_length
=
len
(
prompt_ids
)
...
@@ -200,6 +212,8 @@ class SGLangEngine(BaseEngine):
...
@@ -200,6 +212,8 @@ class SGLangEngine(BaseEngine):
"sampling_params"
:
sampling_params
,
"sampling_params"
:
sampling_params
,
"stream"
:
True
,
"stream"
:
True
,
}
}
if
self
.
lora_request
:
json_data
[
"lora_request"
]
=
[
"lora0"
]
response
=
requests
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
json_data
,
stream
=
True
)
response
=
requests
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
json_data
,
stream
=
True
)
if
response
.
status_code
!=
200
:
if
response
.
status_code
!=
200
:
raise
RuntimeError
(
f
"SGLang server error:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
raise
RuntimeError
(
f
"SGLang server error:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
...
...
src/llamafactory/chat/vllm_engine.py
View file @
24534501
...
@@ -124,7 +124,6 @@ class VllmEngine(BaseEngine):
...
@@ -124,7 +124,6 @@ class VllmEngine(BaseEngine):
messages
,
images
or
[],
videos
or
[],
audios
or
[],
self
.
processor
messages
,
images
or
[],
videos
or
[],
audios
or
[],
self
.
processor
)
)
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
self
.
generating_args
[
"default_system"
]
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_length
=
len
(
prompt_ids
)
prompt_length
=
len
(
prompt_ids
)
...
...
src/llamafactory/cli.py
View file @
24534501
...
@@ -73,7 +73,7 @@ def main():
...
@@ -73,7 +73,7 @@ def main():
"help"
:
partial
(
print
,
USAGE
),
"help"
:
partial
(
print
,
USAGE
),
}
}
command
=
sys
.
argv
.
pop
(
1
)
if
len
(
sys
.
argv
)
>
=
1
else
"help"
command
=
sys
.
argv
.
pop
(
1
)
if
len
(
sys
.
argv
)
>
1
else
"help"
if
command
==
"train"
and
(
is_env_enabled
(
"FORCE_TORCHRUN"
)
or
(
get_device_count
()
>
1
and
not
use_ray
())):
if
command
==
"train"
and
(
is_env_enabled
(
"FORCE_TORCHRUN"
)
or
(
get_device_count
()
>
1
and
not
use_ray
())):
# launch distributed training
# launch distributed training
nnodes
=
os
.
getenv
(
"NNODES"
,
"1"
)
nnodes
=
os
.
getenv
(
"NNODES"
,
"1"
)
...
...
src/llamafactory/data/data_utils.py
View file @
24534501
...
@@ -169,11 +169,22 @@ def read_cloud_json(cloud_path):
...
@@ -169,11 +169,22 @@ def read_cloud_json(cloud_path):
try
:
try
:
# Try with anonymous access first
# Try with anonymous access first
fs
=
setup_fs
(
cloud_path
,
anon
=
True
)
fs
=
setup_fs
(
cloud_path
,
anon
=
True
)
return
_read_json_with_fs
(
fs
,
cloud_path
,
lines
=
cloud_path
.
endswith
(
".jsonl"
))
except
Exception
:
except
Exception
:
# Try again with credentials
# Try again with credentials
fs
=
setup_fs
(
cloud_path
)
fs
=
setup_fs
(
cloud_path
)
return
_read_json_with_fs
(
fs
,
cloud_path
,
lines
=
cloud_path
.
endswith
(
".jsonl"
))
if
fs
.
isdir
(
cloud_path
):
files
=
[
x
[
"Key"
]
for
x
in
fs
.
listdir
(
cloud_path
)]
else
:
files
=
[
cloud_path
]
# filter out non-JSON files
files
=
[
file
for
file
in
files
if
file
.
endswith
(
".json"
)
or
file
.
endswith
(
".jsonl"
)]
if
not
files
:
raise
ValueError
(
f
"No JSON/JSONL files found in the specified path:
{
cloud_path
}
"
)
data
=
[]
for
file
in
files
:
data
.
extend
(
_read_json_with_fs
(
fs
,
file
,
lines
=
file
.
endswith
(
".jsonl"
)))
return
data
def
_read_json_with_fs
(
fs
,
path
,
lines
=
True
):
def
_read_json_with_fs
(
fs
,
path
,
lines
=
True
):
...
...
src/llamafactory/data/loader.py
View file @
24534501
...
@@ -168,7 +168,7 @@ def _get_merged_dataset(
...
@@ -168,7 +168,7 @@ def _get_merged_dataset(
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
merge
:
bool
=
Tru
e
,
return_dict
:
bool
=
Fals
e
,
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
dict
[
str
,
"Dataset"
]]]:
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
dict
[
str
,
"Dataset"
]]]:
r
"""Return the merged datasets in the standard format."""
r
"""Return the merged datasets in the standard format."""
if
dataset_names
is
None
:
if
dataset_names
is
None
:
...
@@ -181,10 +181,10 @@ def _get_merged_dataset(
...
@@ -181,10 +181,10 @@ def _get_merged_dataset(
datasets
[
dataset_name
]
=
_load_single_dataset
(
dataset_attr
,
model_args
,
data_args
,
training_args
)
datasets
[
dataset_name
]
=
_load_single_dataset
(
dataset_attr
,
model_args
,
data_args
,
training_args
)
if
merge
:
if
return_dict
:
return
merge_dataset
(
list
(
datasets
.
values
()),
data_args
,
seed
=
training_args
.
seed
)
else
:
return
datasets
return
datasets
else
:
return
merge_dataset
(
list
(
datasets
.
values
()),
data_args
,
seed
=
training_args
.
seed
)
def
_get_dataset_processor
(
def
_get_dataset_processor
(
...
@@ -303,7 +303,12 @@ def get_dataset(
...
@@ -303,7 +303,12 @@ def get_dataset(
with
training_args
.
main_process_first
(
desc
=
"load dataset"
):
with
training_args
.
main_process_first
(
desc
=
"load dataset"
):
dataset
=
_get_merged_dataset
(
data_args
.
dataset
,
model_args
,
data_args
,
training_args
,
stage
)
dataset
=
_get_merged_dataset
(
data_args
.
dataset
,
model_args
,
data_args
,
training_args
,
stage
)
eval_dataset
=
_get_merged_dataset
(
eval_dataset
=
_get_merged_dataset
(
data_args
.
eval_dataset
,
model_args
,
data_args
,
training_args
,
stage
,
merge
=
training_args
.
do_predict
data_args
.
eval_dataset
,
model_args
,
data_args
,
training_args
,
stage
,
return_dict
=
data_args
.
eval_on_each_dataset
,
)
)
with
training_args
.
main_process_first
(
desc
=
"pre-process dataset"
):
with
training_args
.
main_process_first
(
desc
=
"pre-process dataset"
):
...
...
src/llamafactory/data/mm_plugin.py
View file @
24534501
This diff is collapsed.
Click to expand it.
src/llamafactory/data/parser.py
View file @
24534501
...
@@ -115,12 +115,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li
...
@@ -115,12 +115,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li
dataset_list
:
list
[
DatasetAttr
]
=
[]
dataset_list
:
list
[
DatasetAttr
]
=
[]
for
name
in
dataset_names
:
for
name
in
dataset_names
:
if
dataset_info
is
None
:
# dataset_dir is ONLINE
if
dataset_info
is
None
:
# dataset_dir is ONLINE
if
use_modelscope
():
load_from
=
"ms_hub"
if
use_modelscope
()
else
"om_hub"
if
use_openmind
()
else
"hf_hub"
load_from
=
"ms_hub"
elif
use_openmind
():
load_from
=
"om_hub"
else
:
load_from
=
"hf_hub"
dataset_attr
=
DatasetAttr
(
load_from
,
dataset_name
=
name
)
dataset_attr
=
DatasetAttr
(
load_from
,
dataset_name
=
name
)
dataset_list
.
append
(
dataset_attr
)
dataset_list
.
append
(
dataset_attr
)
continue
continue
...
...
src/llamafactory/data/template.py
View file @
24534501
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
re
import
re
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
...
@@ -51,6 +52,7 @@ class Template:
...
@@ -51,6 +52,7 @@ class Template:
efficient_eos
:
bool
efficient_eos
:
bool
replace_eos
:
bool
replace_eos
:
bool
replace_jinja_template
:
bool
replace_jinja_template
:
bool
enable_thinking
:
Optional
[
bool
]
mm_plugin
:
"BasePlugin"
mm_plugin
:
"BasePlugin"
def
encode_oneturn
(
def
encode_oneturn
(
...
@@ -61,7 +63,7 @@ class Template:
...
@@ -61,7 +63,7 @@ class Template:
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
tuple
[
list
[
int
],
list
[
int
]]:
)
->
tuple
[
list
[
int
],
list
[
int
]]:
r
"""Return a single pair of token ids representing prompt and response respectively."""
r
"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
,
remove_thought
=
True
)
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
prompt_ids
=
[]
prompt_ids
=
[]
for
encoded_ids
in
encoded_messages
[:
-
1
]:
for
encoded_ids
in
encoded_messages
[:
-
1
]:
prompt_ids
+=
encoded_ids
prompt_ids
+=
encoded_ids
...
@@ -77,7 +79,7 @@ class Template:
...
@@ -77,7 +79,7 @@ class Template:
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
list
[
tuple
[
list
[
int
],
list
[
int
]]]:
)
->
list
[
tuple
[
list
[
int
],
list
[
int
]]]:
r
"""Return multiple pairs of token ids representing prompts and responses respectively."""
r
"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
,
remove_thought
=
False
)
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
return
[(
encoded_messages
[
i
],
encoded_messages
[
i
+
1
])
for
i
in
range
(
0
,
len
(
encoded_messages
),
2
)]
return
[(
encoded_messages
[
i
],
encoded_messages
[
i
+
1
])
for
i
in
range
(
0
,
len
(
encoded_messages
),
2
)]
def
extract_tool
(
self
,
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
def
extract_tool
(
self
,
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
...
@@ -92,6 +94,19 @@ class Template:
...
@@ -92,6 +94,19 @@ class Template:
return
list
(
stop_token_ids
)
return
list
(
stop_token_ids
)
def
add_thought
(
self
,
content
:
str
=
""
)
->
str
:
r
"""Add empty thought to assistant message."""
return
f
"
{
self
.
thought_words
[
0
]
}
\n\n
{
self
.
thought_words
[
1
]
}
\n\n
"
+
content
def
remove_thought
(
self
,
content
:
str
)
->
str
:
r
"""Remove thought from assistant message."""
pattern
=
re
.
compile
(
f
"
{
re
.
escape
(
self
.
thought_words
[
0
])
}
(.*?)
{
re
.
escape
(
self
.
thought_words
[
1
])
}
"
,
re
.
DOTALL
)
return
re
.
sub
(
pattern
,
""
,
content
).
lstrip
(
"
\n
"
)
def
get_thought_word_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
list
[
int
]:
r
"""Get the token ids of thought words."""
return
tokenizer
.
encode
(
self
.
add_thought
(),
add_special_tokens
=
False
)
def
_convert_elements_to_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
elements
:
"SLOTS"
)
->
list
[
int
]:
def
_convert_elements_to_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
elements
:
"SLOTS"
)
->
list
[
int
]:
r
"""Convert elements to token ids."""
r
"""Convert elements to token ids."""
token_ids
=
[]
token_ids
=
[]
...
@@ -111,18 +126,12 @@ class Template:
...
@@ -111,18 +126,12 @@ class Template:
return
token_ids
return
token_ids
def
_remove_thought
(
self
,
content
:
str
)
->
str
:
r
"""Remove thought from assistant message."""
pattern
=
re
.
compile
(
f
"
{
re
.
escape
(
self
.
thought_words
[
0
])
}
(.*?)
{
re
.
escape
(
self
.
thought_words
[
1
])
}
"
,
re
.
DOTALL
)
return
re
.
sub
(
pattern
,
""
,
content
).
lstrip
(
"
\n
"
)
def
_encode
(
def
_encode
(
self
,
self
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
list
[
dict
[
str
,
str
]],
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
tools
:
Optional
[
str
],
remove_thought
:
bool
,
)
->
list
[
list
[
int
]]:
)
->
list
[
list
[
int
]]:
r
"""Encode formatted inputs to pairs of token ids.
r
"""Encode formatted inputs to pairs of token ids.
...
@@ -140,18 +149,14 @@ class Template:
...
@@ -140,18 +149,14 @@ class Template:
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
elements
+=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))
elements
+=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))
content
=
message
[
"content"
]
if
remove_thought
and
message
[
"role"
]
==
Role
.
ASSISTANT
and
(
i
!=
len
(
messages
)
-
1
):
content
=
self
.
_remove_thought
(
content
)
if
message
[
"role"
]
==
Role
.
USER
:
if
message
[
"role"
]
==
Role
.
USER
:
elements
+=
self
.
format_user
.
apply
(
content
=
content
,
idx
=
str
(
i
//
2
))
elements
+=
self
.
format_user
.
apply
(
content
=
message
[
"
content
"
]
,
idx
=
str
(
i
//
2
))
elif
message
[
"role"
]
==
Role
.
ASSISTANT
:
elif
message
[
"role"
]
==
Role
.
ASSISTANT
:
elements
+=
self
.
format_assistant
.
apply
(
content
=
content
)
elements
+=
self
.
format_assistant
.
apply
(
content
=
message
[
"
content
"
]
)
elif
message
[
"role"
]
==
Role
.
OBSERVATION
:
elif
message
[
"role"
]
==
Role
.
OBSERVATION
:
elements
+=
self
.
format_observation
.
apply
(
content
=
content
)
elements
+=
self
.
format_observation
.
apply
(
content
=
message
[
"
content
"
]
)
elif
message
[
"role"
]
==
Role
.
FUNCTION
:
elif
message
[
"role"
]
==
Role
.
FUNCTION
:
elements
+=
self
.
format_function
.
apply
(
content
=
content
)
elements
+=
self
.
format_function
.
apply
(
content
=
message
[
"
content
"
]
)
else
:
else
:
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
...
@@ -162,6 +167,9 @@ class Template:
...
@@ -162,6 +167,9 @@ class Template:
@
staticmethod
@
staticmethod
def
_add_or_replace_eos_token
(
tokenizer
:
"PreTrainedTokenizer"
,
eos_token
:
str
)
->
None
:
def
_add_or_replace_eos_token
(
tokenizer
:
"PreTrainedTokenizer"
,
eos_token
:
str
)
->
None
:
r
"""Add or replace eos token to the tokenizer."""
r
"""Add or replace eos token to the tokenizer."""
if
tokenizer
.
eos_token
==
eos_token
:
return
is_added
=
tokenizer
.
eos_token_id
is
None
is_added
=
tokenizer
.
eos_token_id
is
None
num_added_tokens
=
tokenizer
.
add_special_tokens
({
"eos_token"
:
eos_token
})
num_added_tokens
=
tokenizer
.
add_special_tokens
({
"eos_token"
:
eos_token
})
...
@@ -328,7 +336,6 @@ class Llama2Template(Template):
...
@@ -328,7 +336,6 @@ class Llama2Template(Template):
messages
:
list
[
dict
[
str
,
str
]],
messages
:
list
[
dict
[
str
,
str
]],
system
:
str
,
system
:
str
,
tools
:
str
,
tools
:
str
,
remove_thought
:
bool
,
)
->
list
[
list
[
int
]]:
)
->
list
[
list
[
int
]]:
system
=
system
or
self
.
default_system
system
=
system
or
self
.
default_system
encoded_messages
=
[]
encoded_messages
=
[]
...
@@ -342,18 +349,14 @@ class Llama2Template(Template):
...
@@ -342,18 +349,14 @@ class Llama2Template(Template):
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
system_text
=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))[
0
]
system_text
=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))[
0
]
content
=
message
[
"content"
]
if
remove_thought
and
message
[
"role"
]
==
Role
.
ASSISTANT
and
(
i
!=
len
(
messages
)
-
1
):
content
=
self
.
_remove_thought
(
content
)
if
message
[
"role"
]
==
Role
.
USER
:
if
message
[
"role"
]
==
Role
.
USER
:
elements
+=
self
.
format_user
.
apply
(
content
=
system_text
+
content
)
elements
+=
self
.
format_user
.
apply
(
content
=
system_text
+
message
[
"
content
"
]
)
elif
message
[
"role"
]
==
Role
.
ASSISTANT
:
elif
message
[
"role"
]
==
Role
.
ASSISTANT
:
elements
+=
self
.
format_assistant
.
apply
(
content
=
content
)
elements
+=
self
.
format_assistant
.
apply
(
content
=
message
[
"
content
"
]
)
elif
message
[
"role"
]
==
Role
.
OBSERVATION
:
elif
message
[
"role"
]
==
Role
.
OBSERVATION
:
elements
+=
self
.
format_observation
.
apply
(
content
=
content
)
elements
+=
self
.
format_observation
.
apply
(
content
=
message
[
"
content
"
]
)
elif
message
[
"role"
]
==
Role
.
FUNCTION
:
elif
message
[
"role"
]
==
Role
.
FUNCTION
:
elements
+=
self
.
format_function
.
apply
(
content
=
content
)
elements
+=
self
.
format_function
.
apply
(
content
=
message
[
"
content
"
]
)
else
:
else
:
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
...
@@ -392,6 +395,64 @@ class Llama2Template(Template):
...
@@ -392,6 +395,64 @@ class Llama2Template(Template):
return
jinja_template
return
jinja_template
@
dataclass
class
ReasoningTemplate
(
Template
):
r
"""A template that add thought to assistant message."""
@
override
def
encode_oneturn
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
tuple
[
list
[
int
],
list
[
int
]]:
messages
=
deepcopy
(
messages
)
for
i
in
range
(
1
,
len
(
messages
)
-
2
,
2
):
messages
[
i
][
"content"
]
=
self
.
remove_thought
(
messages
[
i
][
"content"
])
if
self
.
enable_thinking
is
False
:
# remove all cot
messages
[
-
1
][
"content"
]
=
self
.
remove_thought
(
messages
[
-
1
][
"content"
])
prompt_ids
,
response_ids
=
super
().
encode_oneturn
(
tokenizer
,
messages
,
system
,
tools
)
if
(
self
.
thought_words
[
0
]
not
in
messages
[
-
1
][
"content"
]
and
self
.
thought_words
[
1
]
not
in
messages
[
-
1
][
"content"
]
):
# add empty cot
if
not
self
.
enable_thinking
:
# do not compute loss
prompt_ids
+=
self
.
get_thought_word_ids
(
tokenizer
)
else
:
# do compute loss
response_ids
=
self
.
get_thought_word_ids
(
tokenizer
)
+
response_ids
return
prompt_ids
,
response_ids
@
override
def
encode_multiturn
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
list
[
tuple
[
list
[
int
],
list
[
int
]]]:
messages
=
deepcopy
(
messages
)
if
self
.
enable_thinking
is
False
:
# remove all cot
for
i
in
range
(
1
,
len
(
messages
),
2
):
messages
[
i
][
"content"
]
=
self
.
remove_thought
(
messages
[
i
][
"content"
])
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
for
i
in
range
(
0
,
len
(
messages
),
2
):
if
(
self
.
thought_words
[
0
]
not
in
messages
[
i
+
1
][
"content"
]
and
self
.
thought_words
[
1
]
not
in
messages
[
i
+
1
][
"content"
]
):
# add empty cot
if
not
self
.
enable_thinking
:
# do not compute loss
encoded_messages
[
i
]
+=
self
.
get_thought_word_ids
(
tokenizer
)
else
:
# do compute loss
encoded_messages
[
i
+
1
]
=
self
.
get_thought_word_ids
(
tokenizer
)
+
encoded_messages
[
i
+
1
]
return
[(
encoded_messages
[
i
],
encoded_messages
[
i
+
1
])
for
i
in
range
(
0
,
len
(
encoded_messages
),
2
)]
TEMPLATES
:
dict
[
str
,
"Template"
]
=
{}
TEMPLATES
:
dict
[
str
,
"Template"
]
=
{}
...
@@ -410,6 +471,7 @@ def register_template(
...
@@ -410,6 +471,7 @@ def register_template(
efficient_eos
:
bool
=
False
,
efficient_eos
:
bool
=
False
,
replace_eos
:
bool
=
False
,
replace_eos
:
bool
=
False
,
replace_jinja_template
:
bool
=
False
,
replace_jinja_template
:
bool
=
False
,
enable_thinking
:
Optional
[
bool
]
=
True
,
mm_plugin
:
"BasePlugin"
=
get_mm_plugin
(
name
=
"base"
),
mm_plugin
:
"BasePlugin"
=
get_mm_plugin
(
name
=
"base"
),
template_class
:
type
[
"Template"
]
=
Template
,
template_class
:
type
[
"Template"
]
=
Template
,
)
->
None
:
)
->
None
:
...
@@ -456,6 +518,7 @@ def register_template(
...
@@ -456,6 +518,7 @@ def register_template(
efficient_eos
=
efficient_eos
,
efficient_eos
=
efficient_eos
,
replace_eos
=
replace_eos
,
replace_eos
=
replace_eos
,
replace_jinja_template
=
replace_jinja_template
,
replace_jinja_template
=
replace_jinja_template
,
enable_thinking
=
enable_thinking
,
mm_plugin
=
mm_plugin
,
mm_plugin
=
mm_plugin
,
)
)
...
@@ -492,6 +555,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
...
@@ -492,6 +555,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
messages
=
[{
"role"
:
"user"
,
"content"
:
"{{content}}"
},
{
"role"
:
"assistant"
,
"content"
:
"{{content}}"
}]
messages
=
[{
"role"
:
"user"
,
"content"
:
"{{content}}"
},
{
"role"
:
"assistant"
,
"content"
:
"{{content}}"
}]
assistant_slot
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
False
,
tokenize
=
False
)
assistant_slot
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
False
,
tokenize
=
False
)
assistant_slot
=
assistant_slot
[
len
(
prefix
)
+
len
(
user_slot
)
:]
assistant_slot
=
assistant_slot
[
len
(
prefix
)
+
len
(
user_slot
)
:]
template_class
=
ReasoningTemplate
if
"<think>"
in
assistant_slot
else
Template
assistant_slot
=
assistant_slot
.
replace
(
"<think>"
,
""
).
replace
(
"</think>"
,
""
).
lstrip
(
"
\n
"
)
# remove thought tags
assistant_slot
=
assistant_slot
.
replace
(
"<think>"
,
""
).
replace
(
"</think>"
,
""
).
lstrip
(
"
\n
"
)
# remove thought tags
if
len
(
user_slot
)
>
len
(
user_slot_empty_system
):
if
len
(
user_slot
)
>
len
(
user_slot_empty_system
):
...
@@ -501,7 +565,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
...
@@ -501,7 +565,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
else
:
# if defaut_system is empty, user_slot_empty_system will be longer than user_slot
else
:
# if defaut_system is empty, user_slot_empty_system will be longer than user_slot
default_system
=
""
default_system
=
""
return
T
emplate
(
return
t
emplate
_class
(
format_user
=
StringFormatter
(
slots
=
[
user_slot
]),
format_user
=
StringFormatter
(
slots
=
[
user_slot
]),
format_assistant
=
StringFormatter
(
slots
=
[
assistant_slot
]),
format_assistant
=
StringFormatter
(
slots
=
[
assistant_slot
]),
format_system
=
StringFormatter
(
slots
=
[
system_slot
]),
format_system
=
StringFormatter
(
slots
=
[
system_slot
]),
...
@@ -515,6 +579,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
...
@@ -515,6 +579,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
efficient_eos
=
False
,
efficient_eos
=
False
,
replace_eos
=
False
,
replace_eos
=
False
,
replace_jinja_template
=
False
,
replace_jinja_template
=
False
,
enable_thinking
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"base"
),
mm_plugin
=
get_mm_plugin
(
name
=
"base"
),
)
)
...
@@ -543,6 +608,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
...
@@ -543,6 +608,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template
.
format_function
=
FunctionFormatter
(
slots
=
default_slots
,
tool_format
=
data_args
.
tool_format
)
template
.
format_function
=
FunctionFormatter
(
slots
=
default_slots
,
tool_format
=
data_args
.
tool_format
)
template
.
format_tools
=
ToolFormatter
(
tool_format
=
data_args
.
tool_format
)
template
.
format_tools
=
ToolFormatter
(
tool_format
=
data_args
.
tool_format
)
if
data_args
.
default_system
is
not
None
:
logger
.
info_rank0
(
f
"Using default system message:
{
data_args
.
default_system
}
."
)
template
.
default_system
=
data_args
.
default_system
template
.
enable_thinking
=
data_args
.
enable_thinking
template
.
fix_special_tokens
(
tokenizer
)
template
.
fix_special_tokens
(
tokenizer
)
template
.
fix_jinja_template
(
tokenizer
)
template
.
fix_jinja_template
(
tokenizer
)
return
template
return
template
...
@@ -756,6 +826,7 @@ register_template(
...
@@ -756,6 +826,7 @@ register_template(
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
),
),
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
)
...
@@ -774,6 +845,15 @@ register_template(
...
@@ -774,6 +845,15 @@ register_template(
)
)
# copied from deepseek3 template
register_template
(
name
=
"deepseekr1"
,
format_user
=
StringFormatter
(
slots
=
[
"<|User|>{{content}}<|Assistant|>"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
template_class
=
ReasoningTemplate
,
)
register_template
(
register_template
(
name
=
"deepseekcoder"
,
name
=
"deepseekcoder"
,
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n
### Response:"
]),
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n
### Response:"
]),
...
@@ -838,6 +918,7 @@ register_template(
...
@@ -838,6 +918,7 @@ register_template(
),
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
stop_words
=
[
"<end_of_turn>"
],
replace_eos
=
True
,
template_class
=
Llama2Template
,
template_class
=
Llama2Template
,
)
)
...
@@ -853,6 +934,7 @@ register_template(
...
@@ -853,6 +934,7 @@ register_template(
),
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
stop_words
=
[
"<end_of_turn>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
"gemma3"
,
image_token
=
"<image_soft_token>"
),
mm_plugin
=
get_mm_plugin
(
"gemma3"
,
image_token
=
"<image_soft_token>"
),
template_class
=
Llama2Template
,
template_class
=
Llama2Template
,
)
)
...
@@ -872,6 +954,22 @@ register_template(
...
@@ -872,6 +954,22 @@ register_template(
)
)
# copied from glm4 template
register_template
(
name
=
"glmz1"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}"
],
tool_format
=
"glm4"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|observation|>
\n
{{content}}<|assistant|>"
]),
format_tools
=
ToolFormatter
(
tool_format
=
"glm4"
),
format_prefix
=
EmptyFormatter
(
slots
=
[
"[gMASK]<sop>"
]),
stop_words
=
[
"<|user|>"
,
"<|observation|>"
],
efficient_eos
=
True
,
template_class
=
ReasoningTemplate
,
)
register_template
(
register_template
(
name
=
"granite3"
,
name
=
"granite3"
,
format_user
=
StringFormatter
(
format_user
=
StringFormatter
(
...
@@ -1018,6 +1116,7 @@ register_template(
...
@@ -1018,6 +1116,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
replace_eos
=
True
,
)
)
...
@@ -1037,6 +1136,7 @@ register_template(
...
@@ -1037,6 +1136,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot|>"
,
"<|eom|>"
],
stop_words
=
[
"<|eot|>"
,
"<|eom|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"llama4"
,
image_token
=
"<|image|>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"llama4"
,
image_token
=
"<|image|>"
),
)
)
...
@@ -1066,6 +1166,7 @@ register_template(
...
@@ -1066,6 +1166,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"mllama"
,
image_token
=
"<|image|>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"mllama"
,
image_token
=
"<|image|>"
),
)
)
...
@@ -1079,6 +1180,7 @@ register_template(
...
@@ -1079,6 +1180,7 @@ register_template(
format_system
=
StringFormatter
(
slots
=
[
"<|im_system|>system<|im_middle|>{{content}}<|im_end|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_system|>system<|im_middle|>{{content}}<|im_end|>"
]),
default_system
=
"You are a helpful assistant provided by Moonshot-AI."
,
default_system
=
"You are a helpful assistant provided by Moonshot-AI."
,
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
)
...
@@ -1131,6 +1233,7 @@ register_template(
...
@@ -1131,6 +1233,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
)
)
...
@@ -1163,6 +1266,7 @@ register_template(
...
@@ -1163,6 +1266,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
)
)
...
@@ -1233,6 +1337,24 @@ register_template(
...
@@ -1233,6 +1337,24 @@ register_template(
)
)
# copied from qwen template
register_template
(
name
=
"mimo"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
template_class
=
ReasoningTemplate
,
)
# copied from chatml template
# copied from chatml template
register_template
(
register_template
(
name
=
"minicpm_v"
,
name
=
"minicpm_v"
,
...
@@ -1363,6 +1485,7 @@ register_template(
...
@@ -1363,6 +1485,7 @@ register_template(
),
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
stop_words
=
[
"<end_of_turn>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
),
template_class
=
Llama2Template
,
template_class
=
Llama2Template
,
)
)
...
@@ -1374,6 +1497,7 @@ register_template(
...
@@ -1374,6 +1497,7 @@ register_template(
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
stop_words
=
[
"<|end|>"
],
stop_words
=
[
"<|end|>"
],
replace_eos
=
True
,
)
)
...
@@ -1384,6 +1508,7 @@ register_template(
...
@@ -1384,6 +1508,7 @@ register_template(
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"<|endoftext|>"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"<|endoftext|>"
}]),
stop_words
=
[
"<|end|>"
],
stop_words
=
[
"<|end|>"
],
replace_eos
=
True
,
)
)
...
@@ -1395,6 +1520,7 @@ register_template(
...
@@ -1395,6 +1520,7 @@ register_template(
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system<|im_sep|>{{content}}<|im_end|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system<|im_sep|>{{content}}<|im_end|>"
]),
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
)
...
@@ -1425,6 +1551,7 @@ register_template(
...
@@ -1425,6 +1551,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
,
default_system
=
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
)
...
@@ -1440,6 +1567,8 @@ register_template(
...
@@ -1440,6 +1567,8 @@ register_template(
),
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
template_class
=
ReasoningTemplate
,
)
)
...
@@ -1451,6 +1580,7 @@ register_template(
...
@@ -1451,6 +1580,7 @@ register_template(
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
default_system
=
"You are a helpful assistant."
,
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_audio"
,
audio_token
=
"<|AUDIO|>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_audio"
,
audio_token
=
"<|AUDIO|>"
),
)
)
...
@@ -1468,6 +1598,7 @@ register_template(
...
@@ -1468,6 +1598,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_omni"
,
audio_token
=
"<|AUDIO|>"
,
image_token
=
"<|IMAGE|>"
,
video_token
=
"<|VIDEO|>"
name
=
"qwen2_omni"
,
audio_token
=
"<|AUDIO|>"
,
image_token
=
"<|IMAGE|>"
,
video_token
=
"<|VIDEO|>"
),
),
...
@@ -1486,6 +1617,7 @@ register_template(
...
@@ -1486,6 +1617,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_vl"
,
image_token
=
"<|image_pad|>"
,
video_token
=
"<|video_pad|>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_vl"
,
image_token
=
"<|image_pad|>"
,
video_token
=
"<|video_pad|>"
),
)
)
...
@@ -1503,6 +1635,20 @@ register_template(
...
@@ -1503,6 +1635,20 @@ register_template(
)
)
register_template
(
name
=
"seed_coder"
,
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"user
\n
{{content}}"
,
{
"eos_token"
},
{
"bos_token"
},
"assistant
\n
"
]
),
format_system
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"system
\n
{{content}}"
,
{
"eos_token"
}]),
default_system
=
(
"You are an AI programming assistant, utilizing the Seed-Coder model, developed by ByteDance Seed, "
"and you only answer questions related to computer science. For politically sensitive questions, "
"security and privacy issues, and other non-computer science questions, you will refuse to answer.
\n\n
"
),
)
# copied from llama3 template
# copied from llama3 template
register_template
(
register_template
(
name
=
"skywork_o1"
,
name
=
"skywork_o1"
,
...
...
src/llamafactory/data/tool_utils.py
View file @
24534501
...
@@ -93,6 +93,7 @@ class DefaultToolUtils(ToolUtils):
...
@@ -93,6 +93,7 @@ class DefaultToolUtils(ToolUtils):
tool_text
=
""
tool_text
=
""
tool_names
=
[]
tool_names
=
[]
for
tool
in
tools
:
for
tool
in
tools
:
tool
=
tool
.
get
(
"function"
,
""
)
if
tool
.
get
(
"type"
)
==
"function"
else
tool
param_text
=
""
param_text
=
""
for
name
,
param
in
tool
[
"parameters"
][
"properties"
].
items
():
for
name
,
param
in
tool
[
"parameters"
][
"properties"
].
items
():
required
,
enum
,
items
=
""
,
""
,
""
required
,
enum
,
items
=
""
,
""
,
""
...
@@ -124,11 +125,7 @@ class DefaultToolUtils(ToolUtils):
...
@@ -124,11 +125,7 @@ class DefaultToolUtils(ToolUtils):
@
override
@
override
@
staticmethod
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
function_text
=
""
return
"
\n
"
.
join
([
f
"Action:
{
name
}
\n
Action Input:
{
arguments
}
"
for
name
,
arguments
in
functions
])
for
name
,
arguments
in
functions
:
function_text
+=
f
"Action:
{
name
}
\n
Action Input:
{
arguments
}
\n
"
return
function_text
@
override
@
override
@
staticmethod
@
staticmethod
...
@@ -159,6 +156,7 @@ class GLM4ToolUtils(ToolUtils):
...
@@ -159,6 +156,7 @@ class GLM4ToolUtils(ToolUtils):
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
tool_text
=
""
tool_text
=
""
for
tool
in
tools
:
for
tool
in
tools
:
tool
=
tool
.
get
(
"function"
,
""
)
if
tool
.
get
(
"type"
)
==
"function"
else
tool
tool_text
+=
"
\n\n
## {name}
\n\n
{body}
\n
在调用上述函数时,请使用 Json 格式表示调用的参数。"
.
format
(
tool_text
+=
"
\n\n
## {name}
\n\n
{body}
\n
在调用上述函数时,请使用 Json 格式表示调用的参数。"
.
format
(
name
=
tool
[
"name"
],
body
=
json
.
dumps
(
tool
,
indent
=
4
,
ensure_ascii
=
False
)
name
=
tool
[
"name"
],
body
=
json
.
dumps
(
tool
,
indent
=
4
,
ensure_ascii
=
False
)
)
)
...
@@ -200,7 +198,7 @@ class Llama3ToolUtils(ToolUtils):
...
@@ -200,7 +198,7 @@ class Llama3ToolUtils(ToolUtils):
date
=
datetime
.
now
().
strftime
(
"%d %b %Y"
)
date
=
datetime
.
now
().
strftime
(
"%d %b %Y"
)
tool_text
=
""
tool_text
=
""
for
tool
in
tools
:
for
tool
in
tools
:
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
tool
}
wrapped_tool
=
tool
if
tool
.
get
(
"type"
)
==
"function"
else
{
"type"
:
"function"
,
"function"
:
tool
}
tool_text
+=
json
.
dumps
(
wrapped_tool
,
indent
=
4
,
ensure_ascii
=
False
)
+
"
\n\n
"
tool_text
+=
json
.
dumps
(
wrapped_tool
,
indent
=
4
,
ensure_ascii
=
False
)
+
"
\n\n
"
return
LLAMA3_TOOL_PROMPT
.
format
(
date
=
date
,
tool_text
=
tool_text
)
return
LLAMA3_TOOL_PROMPT
.
format
(
date
=
date
,
tool_text
=
tool_text
)
...
@@ -208,24 +206,23 @@ class Llama3ToolUtils(ToolUtils):
...
@@ -208,24 +206,23 @@ class Llama3ToolUtils(ToolUtils):
@
override
@
override
@
staticmethod
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
if
len
(
functions
)
>
1
:
function_objects
=
[{
"name"
:
name
,
"parameters"
:
json
.
loads
(
arguments
)}
for
name
,
arguments
in
functions
]
raise
ValueError
(
"Llama-3 does not support parallel functions."
)
return
json
.
dumps
(
function_objects
[
0
]
if
len
(
function_objects
)
==
1
else
function_objects
,
ensure_ascii
=
False
)
return
f
'{{"name": "
{
functions
[
0
].
name
}
", "parameters":
{
functions
[
0
].
arguments
}
}}'
@
override
@
override
@
staticmethod
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
try
:
try
:
tool
=
json
.
loads
(
content
.
strip
())
tool
s
=
json
.
loads
(
content
.
strip
())
except
json
.
JSONDecodeError
:
except
json
.
JSONDecodeError
:
return
content
return
content
if
"name"
not
in
tool
or
"parameters"
not
in
tool
:
tools
=
[
tools
]
if
not
isinstance
(
tools
,
list
)
else
tools
try
:
return
[
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"parameters"
],
ensure_ascii
=
False
))
for
tool
in
tools
]
except
KeyError
:
return
content
return
content
return
[
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"parameters"
],
ensure_ascii
=
False
))]
class
MistralToolUtils
(
ToolUtils
):
class
MistralToolUtils
(
ToolUtils
):
r
"""Mistral v0.3 tool using template."""
r
"""Mistral v0.3 tool using template."""
...
@@ -235,18 +232,16 @@ class MistralToolUtils(ToolUtils):
...
@@ -235,18 +232,16 @@ class MistralToolUtils(ToolUtils):
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
wrapped_tools
=
[]
wrapped_tools
=
[]
for
tool
in
tools
:
for
tool
in
tools
:
wrapped_tools
.
append
({
"type"
:
"function"
,
"function"
:
tool
})
wrapped_tools
.
append
(
tool
if
tool
.
get
(
"type"
)
==
"function"
else
{
"type"
:
"function"
,
"function"
:
tool
})
return
"[AVAILABLE_TOOLS] "
+
json
.
dumps
(
wrapped_tools
,
ensure_ascii
=
False
)
+
"[/AVAILABLE_TOOLS]"
return
"[AVAILABLE_TOOLS] "
+
json
.
dumps
(
wrapped_tools
,
ensure_ascii
=
False
)
+
"[/AVAILABLE_TOOLS]"
@
override
@
override
@
staticmethod
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
function_texts
=
[]
return
json
.
dumps
(
for
name
,
arguments
in
functions
:
[{
"name"
:
name
,
"arguments"
:
json
.
loads
(
arguments
)}
for
name
,
arguments
in
functions
],
ensure_ascii
=
False
function_texts
.
append
(
f
'{{"name": "
{
name
}
", "arguments":
{
arguments
}
}}'
)
)
return
"["
+
", "
.
join
(
function_texts
)
+
"]"
@
override
@
override
@
staticmethod
@
staticmethod
...
@@ -256,17 +251,11 @@ class MistralToolUtils(ToolUtils):
...
@@ -256,17 +251,11 @@ class MistralToolUtils(ToolUtils):
except
json
.
JSONDecodeError
:
except
json
.
JSONDecodeError
:
return
content
return
content
if
not
isinstance
(
tools
,
list
):
tools
=
[
tools
]
if
not
isinstance
(
tools
,
list
)
else
tools
tools
=
[
tools
]
try
:
return
[
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"arguments"
],
ensure_ascii
=
False
))
for
tool
in
tools
]
results
=
[]
except
KeyError
:
for
tool
in
tools
:
return
content
if
"name"
not
in
tool
or
"arguments"
not
in
tool
:
return
content
results
.
append
(
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"arguments"
],
ensure_ascii
=
False
)))
return
results
class
QwenToolUtils
(
ToolUtils
):
class
QwenToolUtils
(
ToolUtils
):
...
@@ -277,7 +266,7 @@ class QwenToolUtils(ToolUtils):
...
@@ -277,7 +266,7 @@ class QwenToolUtils(ToolUtils):
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
tool_text
=
""
tool_text
=
""
for
tool
in
tools
:
for
tool
in
tools
:
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
tool
}
wrapped_tool
=
tool
if
tool
.
get
(
"type"
)
==
"function"
else
{
"type"
:
"function"
,
"function"
:
tool
}
tool_text
+=
"
\n
"
+
json
.
dumps
(
wrapped_tool
,
ensure_ascii
=
False
)
tool_text
+=
"
\n
"
+
json
.
dumps
(
wrapped_tool
,
ensure_ascii
=
False
)
return
QWEN_TOOL_PROMPT
.
format
(
tool_text
=
tool_text
)
return
QWEN_TOOL_PROMPT
.
format
(
tool_text
=
tool_text
)
...
@@ -285,13 +274,11 @@ class QwenToolUtils(ToolUtils):
...
@@ -285,13 +274,11 @@ class QwenToolUtils(ToolUtils):
@
override
@
override
@
staticmethod
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
function_texts
=
[]
function_texts
=
[
for
name
,
arguments
in
functions
:
json
.
dumps
({
"name"
:
name
,
"arguments"
:
json
.
loads
(
arguments
)},
ensure_ascii
=
False
)
function_texts
.
append
(
for
name
,
arguments
in
functions
"<tool_call>
\n
"
+
f
'{{"name": "
{
name
}
", "arguments":
{
arguments
}
}}'
+
"
\n
</tool_call>"
]
)
return
"
\n
"
.
join
([
f
"<tool_call>
\n
{
text
}
\n
</tool_call>"
for
text
in
function_texts
])
return
"
\n
"
.
join
(
function_texts
)
@
override
@
override
@
staticmethod
@
staticmethod
...
...
src/llamafactory/extras/constants.py
View file @
24534501
...
@@ -533,6 +533,17 @@ register_model_group(
...
@@ -533,6 +533,17 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V3"
,
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V3"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V3"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V3"
,
},
},
"DeepSeek-V3-671B-0324-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V3-0324"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V3-0324"
,
},
},
template
=
"deepseek3"
,
)
register_model_group
(
models
=
{
"DeepSeek-R1-1.5B-Distill"
:
{
"DeepSeek-R1-1.5B-Distill"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
,
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
,
...
@@ -566,7 +577,7 @@ register_model_group(
...
@@ -566,7 +577,7 @@ register_model_group(
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1"
,
},
},
},
},
template
=
"deepseek
3
"
,
template
=
"deepseek
r1
"
,
)
)
...
@@ -737,6 +748,13 @@ register_model_group(
...
@@ -737,6 +748,13 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"THUDM/GLM-4-32B-0414"
,
DownloadSource
.
DEFAULT
:
"THUDM/GLM-4-32B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4-32B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4-32B-0414"
,
},
},
},
template
=
"glm4"
,
)
register_model_group
(
models
=
{
"GLM-Z1-9B-0414-Chat"
:
{
"GLM-Z1-9B-0414-Chat"
:
{
DownloadSource
.
DEFAULT
:
"THUDM/GLM-Z1-9B-0414"
,
DownloadSource
.
DEFAULT
:
"THUDM/GLM-Z1-9B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-Z1-9B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-Z1-9B-0414"
,
...
@@ -746,7 +764,7 @@ register_model_group(
...
@@ -746,7 +764,7 @@ register_model_group(
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-Z1-32B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-Z1-32B-0414"
,
},
},
},
},
template
=
"glm
4
"
,
template
=
"glm
z1
"
,
)
)
...
@@ -869,12 +887,13 @@ register_model_group(
...
@@ -869,12 +887,13 @@ register_model_group(
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"Granite-
3.2-1B-A400M-Base
"
:
{
"Granite-
Vision-3.2-2B
"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-vision-3.2-2b"
,
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-vision-3.2-2b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-vision-3.2-2b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-vision-3.2-2b"
,
},
},
},
},
template
=
"granite3_vision"
,
template
=
"granite3_vision"
,
multimodal
=
True
,
)
)
...
@@ -1398,6 +1417,29 @@ register_model_group(
...
@@ -1398,6 +1417,29 @@ register_model_group(
)
)
register_model_group
(
models
=
{
"MiMo-7B-Base"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-7B-Base"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-7B-Base"
,
},
"MiMo-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-7B-SFT"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-7B-SFT"
,
},
"MiMo-7B-Instruct-RL"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-7B-RL"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-7B-RL"
,
},
"MiMo-7B-RL-ZERO"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-7B-RL-ZERO"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-7B-RL-ZERO"
,
},
},
template
=
"mimo"
,
)
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"MiniCPM-2B-SFT-Chat"
:
{
"MiniCPM-2B-SFT-Chat"
:
{
...
@@ -2461,6 +2503,38 @@ register_model_group(
...
@@ -2461,6 +2503,38 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-235B-A22B"
,
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-235B-A22B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-235B-A22B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-235B-A22B"
,
},
},
"Qwen3-0.6B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-0.6B-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-0.6B-GPTQ-Int8"
,
},
"Qwen3-1.7B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-1.7B-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-1.7B-GPTQ-Int8"
,
},
"Qwen3-4B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-4B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-4B-AWQ"
,
},
"Qwen3-8B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-8B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-8B-AWQ"
,
},
"Qwen3-14B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-14B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-14B-AWQ"
,
},
"Qwen3-32B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-32B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-32B-AWQ"
,
},
"Qwen3-30B-A3B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-30B-A3B-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-30B-A3B-GPTQ-Int4"
,
},
"Qwen3-235B-A22B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-235B-A22B-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-235B-A22B-GPTQ-Int4"
,
},
},
},
template
=
"qwen3"
,
template
=
"qwen3"
,
)
)
...
@@ -2484,10 +2558,22 @@ register_model_group(
...
@@ -2484,10 +2558,22 @@ register_model_group(
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"Qwen2.5-Omni-3B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-3B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-3B"
,
},
"Qwen2.5-Omni-7B"
:
{
"Qwen2.5-Omni-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-7B"
,
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-7B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-7B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-7B"
,
}
},
"Qwen2.5-Omni-7B-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-7B-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-7B-GPTQ-Int4"
,
},
"Qwen2.5-Omni-7B-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-7B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-7B-AWQ"
,
},
},
},
template
=
"qwen2_omni"
,
template
=
"qwen2_omni"
,
multimodal
=
True
,
multimodal
=
True
,
...
@@ -2598,15 +2684,17 @@ register_model_group(
...
@@ -2598,15 +2684,17 @@ register_model_group(
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"S
OLAR-10.7B-v1.0
"
:
{
"S
eed-Coder-8B-Base
"
:
{
DownloadSource
.
DEFAULT
:
"
upstage/SOLAR-10.7B-v1.0
"
,
DownloadSource
.
DEFAULT
:
"
ByteDance-Seed/Seed-Coder-8B-Base
"
,
},
},
"SOLAR-10.7B-Instruct-v1.0"
:
{
"Seed-Coder-8B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"upstage/SOLAR-10.7B-Instruct-v1.0"
,
DownloadSource
.
DEFAULT
:
"ByteDance-Seed/Seed-Coder-8B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/SOLAR-10.7B-Instruct-v1.0"
,
},
"Seed-Coder-8B-Instruct-Reasoning"
:
{
DownloadSource
.
DEFAULT
:
"ByteDance-Seed/Seed-Coder-8B-Reasoning-bf16"
,
},
},
},
},
template
=
"s
ola
r"
,
template
=
"s
eed_code
r"
,
)
)
...
@@ -2631,6 +2719,20 @@ register_model_group(
...
@@ -2631,6 +2719,20 @@ register_model_group(
)
)
register_model_group
(
models
=
{
"SOLAR-10.7B-v1.0"
:
{
DownloadSource
.
DEFAULT
:
"upstage/SOLAR-10.7B-v1.0"
,
},
"SOLAR-10.7B-Instruct-v1.0"
:
{
DownloadSource
.
DEFAULT
:
"upstage/SOLAR-10.7B-Instruct-v1.0"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/SOLAR-10.7B-Instruct-v1.0"
,
},
},
template
=
"solar"
,
)
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"StarCoder2-3B"
:
{
"StarCoder2-3B"
:
{
...
...
src/llamafactory/extras/misc.py
View file @
24534501
...
@@ -79,20 +79,27 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
...
@@ -79,20 +79,27 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
logger
.
warning_rank0_once
(
"Version checking has been disabled, may lead to unexpected behaviors."
)
logger
.
warning_rank0_once
(
"Version checking has been disabled, may lead to unexpected behaviors."
)
return
return
if
"gptmodel"
in
requirement
or
"autoawq"
in
requirement
:
pip_command
=
f
"pip install
{
requirement
}
--no-build-isolation"
else
:
pip_command
=
f
"pip install
{
requirement
}
"
if
mandatory
:
if
mandatory
:
hint
=
f
"To fix: run `pip
install
{
requirement
}
`."
hint
=
f
"To fix: run `
{
pip
_command
}
`."
else
:
else
:
hint
=
f
"To fix: run `pip
install
{
requirement
}
` or set `DISABLE_VERSION_CHECK=1` to skip this check."
hint
=
f
"To fix: run `
{
pip
_command
}
` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version
(
requirement
,
hint
)
require_version
(
requirement
,
hint
)
def
check_dependencies
()
->
None
:
def
check_dependencies
()
->
None
:
r
"""Check the version of the required packages."""
r
"""Check the version of the required packages."""
check_version
(
"transformers>=4.45.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0"
)
check_version
(
check_version
(
"datasets>=2.16.0,<=3.5.0"
)
"transformers>=4.45.0,<=4.52.1,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0,!=4.52.0"
check_version
(
"accelerate>=0.34.0,<=1.6.0"
)
)
check_version
(
"peft>=0.14.0,<=0.15.1"
)
check_version
(
"datasets>=2.16.0,<=3.6.0"
)
check_version
(
"accelerate>=0.34.0,<=1.7.0"
)
check_version
(
"peft>=0.14.0,<=0.15.2"
)
check_version
(
"trl>=0.8.6,<=0.9.6"
)
check_version
(
"trl>=0.8.6,<=0.9.6"
)
if
is_transformers_version_greater_than
(
"4.46.0"
)
and
not
is_transformers_version_greater_than
(
"4.48.1"
):
if
is_transformers_version_greater_than
(
"4.46.0"
)
and
not
is_transformers_version_greater_than
(
"4.48.1"
):
logger
.
warning_rank0_once
(
"There are known bugs in transformers v4.46.0-v4.48.0, please use other versions."
)
logger
.
warning_rank0_once
(
"There are known bugs in transformers v4.46.0-v4.48.0, please use other versions."
)
...
...
src/llamafactory/hparams/data_args.py
View file @
24534501
...
@@ -99,6 +99,10 @@ class DataArguments:
...
@@ -99,6 +99,10 @@ class DataArguments:
default
=
0.0
,
default
=
0.0
,
metadata
=
{
"help"
:
"Size of the validation set, should be an integer or a float in range `[0,1)`."
},
metadata
=
{
"help"
:
"Size of the validation set, should be an integer or a float in range `[0,1)`."
},
)
)
eval_on_each_dataset
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to evaluate on each dataset separately."
},
)
packing
:
Optional
[
bool
]
=
field
(
packing
:
Optional
[
bool
]
=
field
(
default
=
None
,
default
=
None
,
metadata
=
{
"help"
:
"Enable sequences packing in training. Will automatically enable in pre-training."
},
metadata
=
{
"help"
:
"Enable sequences packing in training. Will automatically enable in pre-training."
},
...
@@ -111,6 +115,14 @@ class DataArguments:
...
@@ -111,6 +115,14 @@ class DataArguments:
default
=
None
,
default
=
None
,
metadata
=
{
"help"
:
"Tool format to use for constructing function calling examples."
},
metadata
=
{
"help"
:
"Tool format to use for constructing function calling examples."
},
)
)
default_system
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Override the default system message in the template."
},
)
enable_thinking
:
Optional
[
bool
]
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to enable thinking mode for reasoning models."
},
)
tokenized_path
:
Optional
[
str
]
=
field
(
tokenized_path
:
Optional
[
str
]
=
field
(
default
=
None
,
default
=
None
,
metadata
=
{
metadata
=
{
...
...
src/llamafactory/hparams/generating_args.py
View file @
24534501
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
from
dataclasses
import
asdict
,
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Optional
from
typing
import
Any
from
transformers
import
GenerationConfig
from
transformers
import
GenerationConfig
...
@@ -62,10 +62,6 @@ class GeneratingArguments:
...
@@ -62,10 +62,6 @@ class GeneratingArguments:
default
=
1.0
,
default
=
1.0
,
metadata
=
{
"help"
:
"Exponential penalty to the length that is used with beam-based generation."
},
metadata
=
{
"help"
:
"Exponential penalty to the length that is used with beam-based generation."
},
)
)
default_system
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Default system message to use in chat completion."
},
)
skip_special_tokens
:
bool
=
field
(
skip_special_tokens
:
bool
=
field
(
default
=
True
,
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to remove special tokens in the decoding."
},
metadata
=
{
"help"
:
"Whether or not to remove special tokens in the decoding."
},
...
...
src/llamafactory/hparams/model_args.py
View file @
24534501
...
@@ -235,10 +235,6 @@ class ProcessorArguments:
...
@@ -235,10 +235,6 @@ class ProcessorArguments:
default
=
False
,
default
=
False
,
metadata
=
{
"help"
:
"Whether to crop the image to patches for internvl."
},
metadata
=
{
"help"
:
"Whether to crop the image to patches for internvl."
},
)
)
use_audio_in_video
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use audio in video inputs."
},
)
video_max_pixels
:
int
=
field
(
video_max_pixels
:
int
=
field
(
default
=
256
*
256
,
default
=
256
*
256
,
metadata
=
{
"help"
:
"The maximum number of pixels of video inputs."
},
metadata
=
{
"help"
:
"The maximum number of pixels of video inputs."
},
...
@@ -255,6 +251,10 @@ class ProcessorArguments:
...
@@ -255,6 +251,10 @@ class ProcessorArguments:
default
=
128
,
default
=
128
,
metadata
=
{
"help"
:
"The maximum number of sampled frames for video inputs."
},
metadata
=
{
"help"
:
"The maximum number of sampled frames for video inputs."
},
)
)
use_audio_in_video
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use audio in video inputs."
},
)
audio_sampling_rate
:
int
=
field
(
audio_sampling_rate
:
int
=
field
(
default
=
16000
,
default
=
16000
,
metadata
=
{
"help"
:
"The sampling rate of audio inputs."
},
metadata
=
{
"help"
:
"The sampling rate of audio inputs."
},
...
@@ -364,6 +364,12 @@ class SGLangArguments:
...
@@ -364,6 +364,12 @@ class SGLangArguments:
default
=
None
,
default
=
None
,
metadata
=
{
"help"
:
"Config to initialize the SGLang engine. Please use JSON strings."
},
metadata
=
{
"help"
:
"Config to initialize the SGLang engine. Please use JSON strings."
},
)
)
sglang_lora_backend
:
Literal
[
"triton"
,
"flashinfer"
]
=
field
(
default
=
"triton"
,
metadata
=
{
"help"
:
"The backend of running GEMM kernels for Lora modules. Recommend using the Triton LoRA backend for better performance and stability."
},
)
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
isinstance
(
self
.
sglang_config
,
str
)
and
self
.
sglang_config
.
startswith
(
"{"
):
if
isinstance
(
self
.
sglang_config
,
str
)
and
self
.
sglang_config
.
startswith
(
"{"
):
...
...
src/llamafactory/hparams/parser.py
View file @
24534501
...
@@ -148,10 +148,10 @@ def _check_extra_dependencies(
...
@@ -148,10 +148,10 @@ def _check_extra_dependencies(
check_version
(
"mixture-of-depth>=1.1.6"
,
mandatory
=
True
)
check_version
(
"mixture-of-depth>=1.1.6"
,
mandatory
=
True
)
if
model_args
.
infer_backend
==
EngineName
.
VLLM
:
if
model_args
.
infer_backend
==
EngineName
.
VLLM
:
check_version
(
"vllm>=0.4.3,<=0.8.
4
"
)
check_version
(
"vllm>=0.4.3,<=0.8.
6
"
)
check_version
(
"vllm"
,
mandatory
=
True
)
check_version
(
"vllm"
,
mandatory
=
True
)
elif
model_args
.
infer_backend
==
EngineName
.
SGLANG
:
elif
model_args
.
infer_backend
==
EngineName
.
SGLANG
:
check_version
(
"sglang>=0.4.
4
"
)
check_version
(
"sglang>=0.4.
5
"
)
check_version
(
"sglang"
,
mandatory
=
True
)
check_version
(
"sglang"
,
mandatory
=
True
)
if
finetuning_args
.
use_galore
:
if
finetuning_args
.
use_galore
:
...
...
src/llamafactory/hparams/training_args.py
View file @
24534501
...
@@ -64,6 +64,7 @@ class RayArguments:
...
@@ -64,6 +64,7 @@ class RayArguments:
raise
ValueError
(
raise
ValueError
(
f
"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got
{
self
.
ray_storage_filesystem
}
"
f
"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got
{
self
.
ray_storage_filesystem
}
"
)
)
import
pyarrow.fs
as
fs
import
pyarrow.fs
as
fs
if
self
.
ray_storage_filesystem
==
"s3"
:
if
self
.
ray_storage_filesystem
==
"s3"
:
...
...
src/llamafactory/model/model_utils/attention.py
View file @
24534501
...
@@ -29,10 +29,8 @@ if TYPE_CHECKING:
...
@@ -29,10 +29,8 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
configure_attn_implementation
(
def
configure_attn_implementation
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
)
->
None
:
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
if
getattr
(
config
,
"model_type"
,
None
)
==
"gemma2"
:
)
->
None
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"gemma2"
and
is_trainable
:
if
model_args
.
flash_attn
==
AttentionFunction
.
AUTO
or
model_args
.
flash_attn
==
AttentionFunction
.
FA2
:
if
model_args
.
flash_attn
==
AttentionFunction
.
AUTO
or
model_args
.
flash_attn
==
AttentionFunction
.
FA2
:
if
is_flash_attn_2_available
():
if
is_flash_attn_2_available
():
if
model_args
.
flash_attn
!=
AttentionFunction
.
FA2
:
if
model_args
.
flash_attn
!=
AttentionFunction
.
FA2
:
...
...
src/llamafactory/model/model_utils/liger_kernel.py
View file @
24534501
...
@@ -45,16 +45,24 @@ def apply_liger_kernel(
...
@@ -45,16 +45,24 @@ def apply_liger_kernel(
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3
as
apply_liger_kernel
elif
model_type
==
"gemma3_text"
:
elif
model_type
==
"gemma3_text"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3_text
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3_text
as
apply_liger_kernel
elif
model_type
==
"paligemma"
:
elif
model_type
==
"glm4"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_paligemma
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_glm4
as
apply_liger_kernel
elif
model_type
==
"granite"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_granite
as
apply_liger_kernel
elif
model_type
==
"llama"
:
elif
model_type
==
"llama"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_llama
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_llama
as
apply_liger_kernel
elif
model_type
==
"llava"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_llava
as
apply_liger_kernel
elif
model_type
==
"mistral"
:
elif
model_type
==
"mistral"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_mistral
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_mistral
as
apply_liger_kernel
elif
model_type
==
"mixtral"
:
elif
model_type
==
"mixtral"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_mixtral
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_mixtral
as
apply_liger_kernel
elif
model_type
==
"mllama"
:
elif
model_type
==
"mllama"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_mllama
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_mllama
as
apply_liger_kernel
elif
model_type
==
"olmo2"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_olmo2
as
apply_liger_kernel
elif
model_type
==
"paligemma"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_paligemma
as
apply_liger_kernel
elif
model_type
==
"phi3"
:
elif
model_type
==
"phi3"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_phi3
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_phi3
as
apply_liger_kernel
elif
model_type
==
"qwen2"
:
elif
model_type
==
"qwen2"
:
...
@@ -63,6 +71,8 @@ def apply_liger_kernel(
...
@@ -63,6 +71,8 @@ def apply_liger_kernel(
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen2_vl
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen2_vl
as
apply_liger_kernel
elif
model_type
==
"qwen2_5_vl"
:
elif
model_type
==
"qwen2_5_vl"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen2_5_vl
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen2_5_vl
as
apply_liger_kernel
elif
model_type
==
"qwen3"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen3
as
apply_liger_kernel
else
:
else
:
logger
.
warning_rank0
(
"Current model does not support liger kernel."
)
logger
.
warning_rank0
(
"Current model does not support liger kernel."
)
return
return
...
...
src/llamafactory/model/model_utils/moe.py
View file @
24534501
...
@@ -99,27 +99,29 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
...
@@ -99,27 +99,29 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
def
configure_moe
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
def
configure_moe
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
if
not
is_trainable
or
not
model_args
.
moe_aux_loss_coef
:
return
model_type
=
getattr
(
config
,
"model_type"
,
None
)
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
model_args
.
moe_aux_loss_coef
is
not
None
:
if
model_type
in
[
if
model_type
in
[
"dbrx"
,
"dbrx"
,
"granitemoe"
,
"granitemoe"
,
"jamba"
,
"jamba"
,
"jetmoe"
,
"jetmoe"
,
"llama4"
,
"llama4"
,
"mixtral"
,
"mixtral"
,
"olmoe"
,
"olmoe"
,
"phimoe"
,
"phimoe"
,
"qwen2_moe"
,
"qwen2_moe"
,
"qwen3_moe"
,
"qwen3_moe"
,
]:
]:
setattr
(
config
,
"output_router_logits"
,
True
)
setattr
(
config
,
"output_router_logits"
,
is_trainable
)
if
model_type
in
[
"granitemoe"
,
"jamba"
,
"llama4"
,
"mixtral"
,
"olmoe"
,
"phimoe"
,
"qwen2_moe"
,
"qwen3_moe"
]:
if
model_type
in
[
"granitemoe"
,
"jamba"
,
"llama4"
,
"mixtral"
,
"olmoe"
,
"phimoe"
,
"qwen2_moe"
,
"qwen3_moe"
]:
setattr
(
config
,
"router_aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
setattr
(
config
,
"router_aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
elif
model_type
==
"deepseek"
:
elif
model_type
==
"deepseek"
:
setattr
(
config
,
"aux_loss_alpha"
,
model_args
.
moe_aux_loss_coef
)
setattr
(
config
,
"aux_loss_alpha"
,
model_args
.
moe_aux_loss_coef
)
elif
model_type
==
"jetmoe"
:
elif
model_type
==
"jetmoe"
:
setattr
(
config
,
"aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
setattr
(
config
,
"aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
src/llamafactory/model/model_utils/quantization.py
View file @
24534501
...
@@ -97,7 +97,7 @@ def configure_quantization(
...
@@ -97,7 +97,7 @@ def configure_quantization(
quant_method
=
quantization_config
.
get
(
"quant_method"
,
""
)
quant_method
=
quantization_config
.
get
(
"quant_method"
,
""
)
if
quant_method
==
QuantizationMethod
.
GPTQ
:
if
quant_method
==
QuantizationMethod
.
GPTQ
:
check_version
(
"
auto_gptq>=0.5
.0"
,
mandatory
=
True
)
check_version
(
"
gptqmodel>=2.0
.0"
,
mandatory
=
True
)
quantization_config
.
pop
(
"disable_exllama"
,
None
)
# remove deprecated args
quantization_config
.
pop
(
"disable_exllama"
,
None
)
# remove deprecated args
quantization_config
[
"use_exllama"
]
=
False
# disable exllama
quantization_config
[
"use_exllama"
]
=
False
# disable exllama
...
@@ -111,12 +111,12 @@ def configure_quantization(
...
@@ -111,12 +111,12 @@ def configure_quantization(
quant_bits
=
quantization_config
.
get
(
"bits"
,
"?"
)
quant_bits
=
quantization_config
.
get
(
"bits"
,
"?"
)
logger
.
info_rank0
(
f
"Loading
{
quant_bits
}
-bit
{
quant_method
.
upper
()
}
-quantized model."
)
logger
.
info_rank0
(
f
"Loading
{
quant_bits
}
-bit
{
quant_method
.
upper
()
}
-quantized model."
)
elif
model_args
.
export_quantization_bit
is
not
None
:
#
auto-gptq
elif
model_args
.
export_quantization_bit
is
not
None
:
#
gptqmodel
if
model_args
.
export_quantization_bit
not
in
[
8
,
4
,
3
,
2
]:
if
model_args
.
export_quantization_bit
not
in
[
8
,
4
,
3
,
2
]:
raise
ValueError
(
"AutoGPTQ only accepts 2/3/4/8-bit quantization."
)
raise
ValueError
(
"AutoGPTQ only accepts 2/3/4/8-bit quantization."
)
check_version
(
"optimum>=1.
17
.0"
,
mandatory
=
True
)
check_version
(
"optimum>=1.
24
.0"
,
mandatory
=
True
)
check_version
(
"
auto_gptq>=0.5
.0"
,
mandatory
=
True
)
check_version
(
"
gptqmodel>=2.0
.0"
,
mandatory
=
True
)
from
accelerate.utils
import
get_max_memory
from
accelerate.utils
import
get_max_memory
if
getattr
(
config
,
"model_type"
,
None
)
==
"chatglm"
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"chatglm"
:
...
@@ -142,7 +142,8 @@ def configure_quantization(
...
@@ -142,7 +142,8 @@ def configure_quantization(
)
)
init_kwargs
[
"device_map"
]
=
"auto"
init_kwargs
[
"device_map"
]
=
"auto"
init_kwargs
[
"max_memory"
]
=
get_max_memory
()
init_kwargs
[
"max_memory"
]
=
get_max_memory
()
logger
.
info_rank0
(
f
"Quantizing model to
{
model_args
.
export_quantization_bit
}
bit with AutoGPTQ."
)
model_args
.
compute_dtype
=
torch
.
float16
# force fp16 for gptqmodel
logger
.
info_rank0
(
f
"Quantizing model to
{
model_args
.
export_quantization_bit
}
bit with GPTQModel."
)
elif
model_args
.
quantization_bit
is
not
None
:
# on-the-fly
elif
model_args
.
quantization_bit
is
not
None
:
# on-the-fly
if
model_args
.
quantization_method
==
QuantizationMethod
.
BNB
:
if
model_args
.
quantization_method
==
QuantizationMethod
.
BNB
:
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment