Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
0722acf1
Commit
0722acf1
authored
Jun 04, 2025
by
chenych
Browse files
Update 0604
parent
c4ba4563
Changes
68
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
690 additions
and
323 deletions
+690
-323
src/llamafactory/chat/sglang_engine.py
src/llamafactory/chat/sglang_engine.py
+15
-1
src/llamafactory/chat/vllm_engine.py
src/llamafactory/chat/vllm_engine.py
+0
-1
src/llamafactory/cli.py
src/llamafactory/cli.py
+1
-1
src/llamafactory/data/converter.py
src/llamafactory/data/converter.py
+21
-6
src/llamafactory/data/data_utils.py
src/llamafactory/data/data_utils.py
+28
-27
src/llamafactory/data/loader.py
src/llamafactory/data/loader.py
+12
-7
src/llamafactory/data/mm_plugin.py
src/llamafactory/data/mm_plugin.py
+101
-174
src/llamafactory/data/parser.py
src/llamafactory/data/parser.py
+1
-6
src/llamafactory/data/template.py
src/llamafactory/data/template.py
+210
-26
src/llamafactory/data/tool_utils.py
src/llamafactory/data/tool_utils.py
+26
-39
src/llamafactory/extras/constants.py
src/llamafactory/extras/constants.py
+212
-12
src/llamafactory/extras/env.py
src/llamafactory/extras/env.py
+6
-0
src/llamafactory/extras/misc.py
src/llamafactory/extras/misc.py
+13
-6
src/llamafactory/hparams/data_args.py
src/llamafactory/hparams/data_args.py
+16
-0
src/llamafactory/hparams/generating_args.py
src/llamafactory/hparams/generating_args.py
+1
-5
src/llamafactory/hparams/model_args.py
src/llamafactory/hparams/model_args.py
+10
-4
src/llamafactory/hparams/parser.py
src/llamafactory/hparams/parser.py
+2
-2
src/llamafactory/hparams/training_args.py
src/llamafactory/hparams/training_args.py
+1
-0
src/llamafactory/model/model_utils/attention.py
src/llamafactory/model/model_utils/attention.py
+2
-4
src/llamafactory/model/model_utils/liger_kernel.py
src/llamafactory/model/model_utils/liger_kernel.py
+12
-2
No files found.
src/llamafactory/chat/sglang_engine.py
View file @
0722acf1
...
...
@@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine):
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
data_args
)
self
.
template
.
mm_plugin
.
expand_mm_tokens
=
False
# for sglang generate
self
.
generating_args
=
generating_args
.
to_dict
()
if
model_args
.
adapter_name_or_path
is
not
None
:
self
.
lora_request
=
True
else
:
self
.
lora_request
=
False
launch_cmd
=
[
"python3 -m sglang.launch_server"
,
...
...
@@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine):
f
"--download-dir
{
model_args
.
cache_dir
}
"
,
"--log-level error"
,
]
if
self
.
lora_request
:
launch_cmd
.
extend
(
[
"--max-loras-per-batch 1"
,
f
"--lora-backend
{
model_args
.
sglang_lora_backend
}
"
,
f
"--lora-paths lora0=
{
model_args
.
adapter_name_or_path
[
0
]
}
"
,
"--disable-radix-cache"
,
]
)
launch_cmd
=
" "
.
join
(
launch_cmd
)
logger
.
info_rank0
(
f
"Starting SGLang server with command:
{
launch_cmd
}
"
)
try
:
...
...
@@ -147,7 +160,6 @@ class SGLangEngine(BaseEngine):
messages
,
images
or
[],
videos
or
[],
audios
or
[],
self
.
processor
)
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
self
.
generating_args
[
"default_system"
]
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_length
=
len
(
prompt_ids
)
...
...
@@ -200,6 +212,8 @@ class SGLangEngine(BaseEngine):
"sampling_params"
:
sampling_params
,
"stream"
:
True
,
}
if
self
.
lora_request
:
json_data
[
"lora_request"
]
=
[
"lora0"
]
response
=
requests
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
json_data
,
stream
=
True
)
if
response
.
status_code
!=
200
:
raise
RuntimeError
(
f
"SGLang server error:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
...
...
src/llamafactory/chat/vllm_engine.py
View file @
0722acf1
...
...
@@ -124,7 +124,6 @@ class VllmEngine(BaseEngine):
messages
,
images
or
[],
videos
or
[],
audios
or
[],
self
.
processor
)
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
self
.
generating_args
[
"default_system"
]
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_length
=
len
(
prompt_ids
)
...
...
src/llamafactory/cli.py
View file @
0722acf1
...
...
@@ -73,7 +73,7 @@ def main():
"help"
:
partial
(
print
,
USAGE
),
}
command
=
sys
.
argv
.
pop
(
1
)
if
len
(
sys
.
argv
)
>
=
1
else
"help"
command
=
sys
.
argv
.
pop
(
1
)
if
len
(
sys
.
argv
)
>
1
else
"help"
if
command
==
"train"
and
(
is_env_enabled
(
"FORCE_TORCHRUN"
)
or
(
get_device_count
()
>
1
and
not
use_ray
())):
# launch distributed training
nnodes
=
os
.
getenv
(
"NNODES"
,
"1"
)
...
...
src/llamafactory/data/converter.py
View file @
0722acf1
...
...
@@ -51,12 +51,27 @@ class DatasetConverter:
else
:
medias
=
medias
[:]
if
self
.
dataset_attr
.
load_from
in
[
"script"
,
"file"
]
and
isinstance
(
medias
[
0
],
str
):
for
i
in
range
(
len
(
medias
)):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
self
.
data_args
.
media_dir
,
medias
[
i
])):
medias
[
i
]
=
os
.
path
.
join
(
self
.
data_args
.
media_dir
,
medias
[
i
])
else
:
logger
.
warning_rank0_once
(
f
"Media
{
medias
[
i
]
}
does not exist in `media_dir`. Use original path."
)
if
self
.
dataset_attr
.
load_from
in
[
"script"
,
"file"
]:
if
isinstance
(
medias
[
0
],
str
):
for
i
in
range
(
len
(
medias
)):
media_path
=
os
.
path
.
join
(
self
.
data_args
.
media_dir
,
medias
[
i
])
if
os
.
path
.
isfile
(
media_path
):
medias
[
i
]
=
media_path
else
:
logger
.
warning_rank0_once
(
f
"Media
{
medias
[
i
]
}
does not exist in `media_dir`. Use original path."
)
elif
isinstance
(
medias
[
0
],
list
):
# for processed video frames
# medias is a list of lists, e.g., [[frame1.jpg, frame2.jpg], [frame3.jpg, frame4.jpg]]
for
i
in
range
(
len
(
medias
)):
for
j
in
range
(
len
(
medias
[
i
])):
media_path
=
os
.
path
.
join
(
self
.
data_args
.
media_dir
,
medias
[
i
][
j
])
if
os
.
path
.
isfile
(
media_path
):
medias
[
i
][
j
]
=
media_path
else
:
logger
.
warning_rank0_once
(
f
"Media
{
medias
[
i
][
j
]
}
does not exist in `media_dir`. Use original path."
)
return
medias
...
...
src/llamafactory/data/data_utils.py
View file @
0722acf1
...
...
@@ -14,7 +14,7 @@
import
json
from
enum
import
Enum
,
unique
from
typing
import
TYPE_CHECKING
,
Optional
,
TypedDict
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
TypedDict
,
Union
import
fsspec
from
datasets
import
DatasetDict
,
concatenate_datasets
,
interleave_datasets
...
...
@@ -142,48 +142,49 @@ def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModu
return
dataset_module
def
setup_fs
(
path
,
anon
=
False
)
:
"""Set up a filesystem object based on the path protocol."""
def
setup_fs
(
path
:
str
,
anon
:
bool
=
False
)
->
"fsspec.AbstractFileSystem"
:
r
"""Set up a filesystem object based on the path protocol."""
storage_options
=
{
"anon"
:
anon
}
if
anon
else
{}
if
path
.
startswith
(
"s3://"
):
fs
=
fsspec
.
filesystem
(
"s3"
,
**
storage_options
)
elif
path
.
startswith
((
"gs://"
,
"gcs://"
)):
fs
=
fsspec
.
filesystem
(
"gcs"
,
**
storage_options
)
else
:
raise
ValueError
(
f
"Unsupported protocol in path:
{
path
}
. Use 's3://' or 'gs://'"
)
raise
ValueError
(
f
"Unsupported protocol in path:
{
path
}
. Use 's3://' or 'gs://'."
)
if
not
fs
.
exists
(
path
):
raise
ValueError
(
f
"Path does not exist:
{
path
}
."
)
return
fs
def
read_cloud_json
(
cloud_path
):
"""Read a JSON/JSONL file from cloud storage (S3 or GCS).
def
_read_json_with_fs
(
fs
:
"fsspec.AbstractFileSystem"
,
path
:
str
)
->
list
[
Any
]:
r
"""Helper function to read JSON/JSONL files using fsspec."""
with
fs
.
open
(
path
,
"r"
)
as
f
:
if
path
.
endswith
(
".jsonl"
):
return
[
json
.
loads
(
line
)
for
line
in
f
if
line
.
strip
()]
else
:
return
json
.
load
(
f
)
def
read_cloud_json
(
cloud_path
:
str
)
->
list
[
Any
]:
r
"""Read a JSON/JSONL file from cloud storage (S3 or GCS).
Args:
cloud_path
: str
cloud_path: str
Cloud path in the format:
- 's3://bucket-name/file.json' for AWS S3
- 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage
lines : bool, default=True
If True, read the file as JSON Lines format (one JSON object per line)
"""
try
:
# Try with anonymous access first
fs
=
setup_fs
(
cloud_path
,
anon
=
True
)
return
_read_json_with_fs
(
fs
,
cloud_path
,
lines
=
cloud_path
.
endswith
(
".jsonl"
))
fs
=
setup_fs
(
cloud_path
,
anon
=
True
)
# try with anonymous access first
except
Exception
:
# Try again with credentials
fs
=
setup_fs
(
cloud_path
)
return
_read_json_with_fs
(
fs
,
cloud_path
,
lines
=
cloud_path
.
endswith
(
".jsonl"
))
fs
=
setup_fs
(
cloud_path
)
# try again with credentials
def
_read_json_with_fs
(
fs
,
path
,
lines
=
True
):
"""Helper function to read JSON/JSONL files using fsspec."""
with
fs
.
open
(
path
,
"r"
)
as
f
:
if
lines
:
# Read JSONL (JSON Lines) format - one JSON object per line
data
=
[
json
.
loads
(
line
)
for
line
in
f
if
line
.
strip
()]
else
:
# Read regular JSON format
data
=
json
.
load
(
f
)
# filter out non-JSON files
files
=
[
x
[
"Key"
]
for
x
in
fs
.
listdir
(
cloud_path
)]
if
fs
.
isdir
(
cloud_path
)
else
[
cloud_path
]
files
=
filter
(
lambda
file
:
file
.
endswith
(
".json"
)
or
file
.
endswith
(
".jsonl"
),
files
)
if
not
files
:
raise
ValueError
(
f
"No JSON/JSONL files found in the specified path:
{
cloud_path
}
."
)
return
data
return
sum
([
_read_json_with_fs
(
fs
,
file
)
for
file
in
files
],
[])
src/llamafactory/data/loader.py
View file @
0722acf1
...
...
@@ -168,7 +168,7 @@ def _get_merged_dataset(
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
merge
:
bool
=
Tru
e
,
return_dict
:
bool
=
Fals
e
,
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
dict
[
str
,
"Dataset"
]]]:
r
"""Return the merged datasets in the standard format."""
if
dataset_names
is
None
:
...
...
@@ -181,10 +181,10 @@ def _get_merged_dataset(
datasets
[
dataset_name
]
=
_load_single_dataset
(
dataset_attr
,
model_args
,
data_args
,
training_args
)
if
merge
:
return
merge_dataset
(
list
(
datasets
.
values
()),
data_args
,
seed
=
training_args
.
seed
)
else
:
if
return_dict
:
return
datasets
else
:
return
merge_dataset
(
list
(
datasets
.
values
()),
data_args
,
seed
=
training_args
.
seed
)
def
_get_dataset_processor
(
...
...
@@ -300,13 +300,18 @@ def get_dataset(
raise
ValueError
(
"Turn off `streaming` when saving dataset to disk."
)
# Load and preprocess dataset
with
training_args
.
main_process_first
(
desc
=
"load dataset"
):
with
training_args
.
main_process_first
(
desc
=
"load dataset"
,
local
=
(
not
data_args
.
data_shared_file_system
)
):
dataset
=
_get_merged_dataset
(
data_args
.
dataset
,
model_args
,
data_args
,
training_args
,
stage
)
eval_dataset
=
_get_merged_dataset
(
data_args
.
eval_dataset
,
model_args
,
data_args
,
training_args
,
stage
,
merge
=
training_args
.
do_predict
data_args
.
eval_dataset
,
model_args
,
data_args
,
training_args
,
stage
,
return_dict
=
data_args
.
eval_on_each_dataset
,
)
with
training_args
.
main_process_first
(
desc
=
"pre-process dataset"
):
with
training_args
.
main_process_first
(
desc
=
"pre-process dataset"
,
local
=
(
not
data_args
.
data_shared_file_system
)
):
dataset
=
_get_preprocessed_dataset
(
dataset
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
False
)
...
...
src/llamafactory/data/mm_plugin.py
View file @
0722acf1
This diff is collapsed.
Click to expand it.
src/llamafactory/data/parser.py
View file @
0722acf1
...
...
@@ -115,12 +115,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li
dataset_list
:
list
[
DatasetAttr
]
=
[]
for
name
in
dataset_names
:
if
dataset_info
is
None
:
# dataset_dir is ONLINE
if
use_modelscope
():
load_from
=
"ms_hub"
elif
use_openmind
():
load_from
=
"om_hub"
else
:
load_from
=
"hf_hub"
load_from
=
"ms_hub"
if
use_modelscope
()
else
"om_hub"
if
use_openmind
()
else
"hf_hub"
dataset_attr
=
DatasetAttr
(
load_from
,
dataset_name
=
name
)
dataset_list
.
append
(
dataset_attr
)
continue
...
...
src/llamafactory/data/template.py
View file @
0722acf1
This diff is collapsed.
Click to expand it.
src/llamafactory/data/tool_utils.py
View file @
0722acf1
...
...
@@ -93,6 +93,7 @@ class DefaultToolUtils(ToolUtils):
tool_text
=
""
tool_names
=
[]
for
tool
in
tools
:
tool
=
tool
.
get
(
"function"
,
""
)
if
tool
.
get
(
"type"
)
==
"function"
else
tool
param_text
=
""
for
name
,
param
in
tool
[
"parameters"
][
"properties"
].
items
():
required
,
enum
,
items
=
""
,
""
,
""
...
...
@@ -124,11 +125,7 @@ class DefaultToolUtils(ToolUtils):
@
override
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
function_text
=
""
for
name
,
arguments
in
functions
:
function_text
+=
f
"Action:
{
name
}
\n
Action Input:
{
arguments
}
\n
"
return
function_text
return
"
\n
"
.
join
([
f
"Action:
{
name
}
\n
Action Input:
{
arguments
}
"
for
name
,
arguments
in
functions
])
@
override
@
staticmethod
...
...
@@ -159,6 +156,7 @@ class GLM4ToolUtils(ToolUtils):
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
tool_text
=
""
for
tool
in
tools
:
tool
=
tool
.
get
(
"function"
,
""
)
if
tool
.
get
(
"type"
)
==
"function"
else
tool
tool_text
+=
"
\n\n
## {name}
\n\n
{body}
\n
在调用上述函数时,请使用 Json 格式表示调用的参数。"
.
format
(
name
=
tool
[
"name"
],
body
=
json
.
dumps
(
tool
,
indent
=
4
,
ensure_ascii
=
False
)
)
...
...
@@ -200,7 +198,7 @@ class Llama3ToolUtils(ToolUtils):
date
=
datetime
.
now
().
strftime
(
"%d %b %Y"
)
tool_text
=
""
for
tool
in
tools
:
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
tool
}
wrapped_tool
=
tool
if
tool
.
get
(
"type"
)
==
"function"
else
{
"type"
:
"function"
,
"function"
:
tool
}
tool_text
+=
json
.
dumps
(
wrapped_tool
,
indent
=
4
,
ensure_ascii
=
False
)
+
"
\n\n
"
return
LLAMA3_TOOL_PROMPT
.
format
(
date
=
date
,
tool_text
=
tool_text
)
...
...
@@ -208,24 +206,23 @@ class Llama3ToolUtils(ToolUtils):
@
override
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
if
len
(
functions
)
>
1
:
raise
ValueError
(
"Llama-3 does not support parallel functions."
)
return
f
'{{"name": "
{
functions
[
0
].
name
}
", "parameters":
{
functions
[
0
].
arguments
}
}}'
function_objects
=
[{
"name"
:
name
,
"parameters"
:
json
.
loads
(
arguments
)}
for
name
,
arguments
in
functions
]
return
json
.
dumps
(
function_objects
[
0
]
if
len
(
function_objects
)
==
1
else
function_objects
,
ensure_ascii
=
False
)
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
try
:
tool
=
json
.
loads
(
content
.
strip
())
tool
s
=
json
.
loads
(
content
.
strip
())
except
json
.
JSONDecodeError
:
return
content
if
"name"
not
in
tool
or
"parameters"
not
in
tool
:
tools
=
[
tools
]
if
not
isinstance
(
tools
,
list
)
else
tools
try
:
return
[
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"parameters"
],
ensure_ascii
=
False
))
for
tool
in
tools
]
except
KeyError
:
return
content
return
[
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"parameters"
],
ensure_ascii
=
False
))]
class
MistralToolUtils
(
ToolUtils
):
r
"""Mistral v0.3 tool using template."""
...
...
@@ -235,18 +232,16 @@ class MistralToolUtils(ToolUtils):
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
wrapped_tools
=
[]
for
tool
in
tools
:
wrapped_tools
.
append
({
"type"
:
"function"
,
"function"
:
tool
})
wrapped_tools
.
append
(
tool
if
tool
.
get
(
"type"
)
==
"function"
else
{
"type"
:
"function"
,
"function"
:
tool
})
return
"[AVAILABLE_TOOLS] "
+
json
.
dumps
(
wrapped_tools
,
ensure_ascii
=
False
)
+
"[/AVAILABLE_TOOLS]"
@
override
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
function_texts
=
[]
for
name
,
arguments
in
functions
:
function_texts
.
append
(
f
'{{"name": "
{
name
}
", "arguments":
{
arguments
}
}}'
)
return
"["
+
", "
.
join
(
function_texts
)
+
"]"
return
json
.
dumps
(
[{
"name"
:
name
,
"arguments"
:
json
.
loads
(
arguments
)}
for
name
,
arguments
in
functions
],
ensure_ascii
=
False
)
@
override
@
staticmethod
...
...
@@ -256,17 +251,11 @@ class MistralToolUtils(ToolUtils):
except
json
.
JSONDecodeError
:
return
content
if
not
isinstance
(
tools
,
list
):
tools
=
[
tools
]
results
=
[]
for
tool
in
tools
:
if
"name"
not
in
tool
or
"arguments"
not
in
tool
:
return
content
results
.
append
(
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"arguments"
],
ensure_ascii
=
False
)))
return
results
tools
=
[
tools
]
if
not
isinstance
(
tools
,
list
)
else
tools
try
:
return
[
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"arguments"
],
ensure_ascii
=
False
))
for
tool
in
tools
]
except
KeyError
:
return
content
class
QwenToolUtils
(
ToolUtils
):
...
...
@@ -277,7 +266,7 @@ class QwenToolUtils(ToolUtils):
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
tool_text
=
""
for
tool
in
tools
:
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
tool
}
wrapped_tool
=
tool
if
tool
.
get
(
"type"
)
==
"function"
else
{
"type"
:
"function"
,
"function"
:
tool
}
tool_text
+=
"
\n
"
+
json
.
dumps
(
wrapped_tool
,
ensure_ascii
=
False
)
return
QWEN_TOOL_PROMPT
.
format
(
tool_text
=
tool_text
)
...
...
@@ -285,13 +274,11 @@ class QwenToolUtils(ToolUtils):
@
override
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
function_texts
=
[]
for
name
,
arguments
in
functions
:
function_texts
.
append
(
"<tool_call>
\n
"
+
f
'{{"name": "
{
name
}
", "arguments":
{
arguments
}
}}'
+
"
\n
</tool_call>"
)
return
"
\n
"
.
join
(
function_texts
)
function_texts
=
[
json
.
dumps
({
"name"
:
name
,
"arguments"
:
json
.
loads
(
arguments
)},
ensure_ascii
=
False
)
for
name
,
arguments
in
functions
]
return
"
\n
"
.
join
([
f
"<tool_call>
\n
{
text
}
\n
</tool_call>"
for
text
in
function_texts
])
@
override
@
staticmethod
...
...
src/llamafactory/extras/constants.py
View file @
0722acf1
...
...
@@ -513,7 +513,7 @@ register_model_group(
register_model_group
(
models
=
{
"DeepSeek-V2-236B-
Chat-0628
"
:
{
"DeepSeek-V2-236B-
0628-Chat
"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2-Chat-0628"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2-Chat-0628"
,
},
...
...
@@ -521,7 +521,7 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2.5"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2.5"
,
},
"DeepSeek-V2.5-236B-
Chat-1210
"
:
{
"DeepSeek-V2.5-236B-
1210-Chat
"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2.5-1210"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2.5-1210"
,
},
...
...
@@ -533,6 +533,17 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V3"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V3"
,
},
"DeepSeek-V3-671B-0324-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V3-0324"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V3-0324"
,
},
},
template
=
"deepseek3"
,
)
register_model_group
(
models
=
{
"DeepSeek-R1-1.5B-Distill"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
,
...
...
@@ -545,6 +556,10 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
,
},
"DeepSeek-R1-8B-0528-Distill"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-0528-Qwen3-8B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-0528-Qwen3-8B"
,
},
"DeepSeek-R1-14B-Distill"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
,
...
...
@@ -565,8 +580,12 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1"
,
},
"DeepSeek-R1-671B-0528-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-0528"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-0528"
,
},
},
template
=
"deepseek
3
"
,
template
=
"deepseek
r1
"
,
)
...
...
@@ -673,6 +692,10 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"google/gemma-3-1b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-1b-it"
,
},
"MedGemma-27B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/medgemma-27b-text-it"
,
DownloadSource
.
MODELSCOPE
:
"google/medgemma-27b-text-it"
,
},
},
template
=
"gemma"
,
)
...
...
@@ -704,6 +727,14 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"google/gemma-3-27b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-27b-it"
,
},
"MedGemma-4B"
:
{
DownloadSource
.
DEFAULT
:
"google/medgemma-4b-pt"
,
DownloadSource
.
MODELSCOPE
:
"google/medgemma-4b-pt"
,
},
"MedGemma-4B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/medgemma-4b-it"
,
DownloadSource
.
MODELSCOPE
:
"google/medgemma-4b-it"
,
},
},
template
=
"gemma3"
,
multimodal
=
True
,
...
...
@@ -737,6 +768,13 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"THUDM/GLM-4-32B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4-32B-0414"
,
},
},
template
=
"glm4"
,
)
register_model_group
(
models
=
{
"GLM-Z1-9B-0414-Chat"
:
{
DownloadSource
.
DEFAULT
:
"THUDM/GLM-Z1-9B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-Z1-9B-0414"
,
...
...
@@ -746,7 +784,7 @@ register_model_group(
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-Z1-32B-0414"
,
},
},
template
=
"glm
4
"
,
template
=
"glm
z1
"
,
)
...
...
@@ -869,12 +907,13 @@ register_model_group(
register_model_group
(
models
=
{
"Granite-
3.2-1B-A400M-Base
"
:
{
"Granite-
Vision-3.2-2B
"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-vision-3.2-2b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-vision-3.2-2b"
,
},
},
template
=
"granite3_vision"
,
multimodal
=
True
,
)
...
...
@@ -1398,6 +1437,45 @@ register_model_group(
)
register_model_group
(
models
=
{
"MiMo-7B-Base"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-7B-Base"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-7B-Base"
,
},
"MiMo-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-7B-SFT"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-7B-SFT"
,
},
"MiMo-7B-Instruct-RL"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-7B-RL"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-7B-RL"
,
},
"MiMo-7B-RL-ZERO"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-7B-RL-ZERO"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-7B-RL-ZERO"
,
},
},
template
=
"mimo"
,
)
register_model_group
(
models
=
{
"MiMo-7B-VL-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-VL-7B-SFT"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-VL-7B-SFT"
,
},
"MiMo-7B-VL-RL"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-VL-7B-RL"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-VL-7B-RL"
,
},
},
template
=
"mimo_vl"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"MiniCPM-2B-SFT-Chat"
:
{
...
...
@@ -2461,6 +2539,38 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-235B-A22B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-235B-A22B"
,
},
"Qwen3-0.6B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-0.6B-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-0.6B-GPTQ-Int8"
,
},
"Qwen3-1.7B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-1.7B-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-1.7B-GPTQ-Int8"
,
},
"Qwen3-4B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-4B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-4B-AWQ"
,
},
"Qwen3-8B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-8B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-8B-AWQ"
,
},
"Qwen3-14B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-14B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-14B-AWQ"
,
},
"Qwen3-32B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-32B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-32B-AWQ"
,
},
"Qwen3-30B-A3B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-30B-A3B-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-30B-A3B-GPTQ-Int4"
,
},
"Qwen3-235B-A22B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-235B-A22B-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-235B-A22B-GPTQ-Int4"
,
},
},
template
=
"qwen3"
,
)
...
...
@@ -2484,10 +2594,22 @@ register_model_group(
register_model_group
(
models
=
{
"Qwen2.5-Omni-3B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-3B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-3B"
,
},
"Qwen2.5-Omni-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-7B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-7B"
,
}
},
"Qwen2.5-Omni-7B-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-7B-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-7B-GPTQ-Int4"
,
},
"Qwen2.5-Omni-7B-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-7B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-7B-AWQ"
,
},
},
template
=
"qwen2_omni"
,
multimodal
=
True
,
...
...
@@ -2598,15 +2720,17 @@ register_model_group(
register_model_group
(
models
=
{
"S
OLAR-10.7B-v1.0
"
:
{
DownloadSource
.
DEFAULT
:
"
upstage/SOLAR-10.7B-v1.0
"
,
"S
eed-Coder-8B-Base
"
:
{
DownloadSource
.
DEFAULT
:
"
ByteDance-Seed/Seed-Coder-8B-Base
"
,
},
"SOLAR-10.7B-Instruct-v1.0"
:
{
DownloadSource
.
DEFAULT
:
"upstage/SOLAR-10.7B-Instruct-v1.0"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/SOLAR-10.7B-Instruct-v1.0"
,
"Seed-Coder-8B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ByteDance-Seed/Seed-Coder-8B-Instruct"
,
},
"Seed-Coder-8B-Instruct-Reasoning"
:
{
DownloadSource
.
DEFAULT
:
"ByteDance-Seed/Seed-Coder-8B-Reasoning-bf16"
,
},
},
template
=
"s
ola
r"
,
template
=
"s
eed_code
r"
,
)
...
...
@@ -2631,6 +2755,82 @@ register_model_group(
)
register_model_group
(
models
=
{
"SmolLM-135M"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM-135M"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM-135M"
,
},
"SmolLM-360M"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM-360M"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM-360M"
,
},
"SmolLM-1.7B"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM-1.7B"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM-1.7B"
,
},
"SmolLM-135M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM-135M-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM-135M-Instruct"
,
},
"SmolLM-360M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM-360M-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM-360M-Instruct"
,
},
"SmolLM-1.7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM-1.7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM-1.7B-Instruct"
,
},
},
template
=
"smollm"
,
)
register_model_group
(
models
=
{
"SmolLM2-135M"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM2-135M"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM2-135M"
,
},
"SmolLM2-360M"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM2-360M"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM2-360M"
,
},
"SmolLM2-1.7B"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM2-1.7B"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM2-1.7B"
,
},
"SmolLM2-135M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM2-135M-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM2-135M-Instruct"
,
},
"SmolLM2-360M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM2-360M-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM2-360M-Instruct"
,
},
"SmolLM2-1.7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM2-1.7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM2-1.7B-Instruct"
,
},
},
template
=
"smollm2"
,
)
register_model_group
(
models
=
{
"SOLAR-10.7B-v1.0"
:
{
DownloadSource
.
DEFAULT
:
"upstage/SOLAR-10.7B-v1.0"
,
},
"SOLAR-10.7B-Instruct-v1.0"
:
{
DownloadSource
.
DEFAULT
:
"upstage/SOLAR-10.7B-Instruct-v1.0"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/SOLAR-10.7B-Instruct-v1.0"
,
},
},
template
=
"solar"
,
)
register_model_group
(
models
=
{
"StarCoder2-3B"
:
{
...
...
src/llamafactory/extras/env.py
View file @
0722acf1
...
...
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
platform
import
accelerate
...
...
@@ -83,4 +84,9 @@ def print_env() -> None:
except
Exception
:
pass
if
os
.
path
.
exists
(
"data"
):
info
[
"Default data directory"
]
=
"detected"
else
:
info
[
"Default data directory"
]
=
"not detected"
print
(
"
\n
"
+
"
\n
"
.
join
([
f
"-
{
key
}
:
{
value
}
"
for
key
,
value
in
info
.
items
()])
+
"
\n
"
)
src/llamafactory/extras/misc.py
View file @
0722acf1
...
...
@@ -79,20 +79,27 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
logger
.
warning_rank0_once
(
"Version checking has been disabled, may lead to unexpected behaviors."
)
return
if
"gptmodel"
in
requirement
or
"autoawq"
in
requirement
:
pip_command
=
f
"pip install
{
requirement
}
--no-build-isolation"
else
:
pip_command
=
f
"pip install
{
requirement
}
"
if
mandatory
:
hint
=
f
"To fix: run `pip
install
{
requirement
}
`."
hint
=
f
"To fix: run `
{
pip
_command
}
`."
else
:
hint
=
f
"To fix: run `pip
install
{
requirement
}
` or set `DISABLE_VERSION_CHECK=1` to skip this check."
hint
=
f
"To fix: run `
{
pip
_command
}
` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version
(
requirement
,
hint
)
def
check_dependencies
()
->
None
:
r
"""Check the version of the required packages."""
check_version
(
"transformers>=4.45.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0"
)
check_version
(
"datasets>=2.16.0,<=3.5.0"
)
check_version
(
"accelerate>=0.34.0,<=1.6.0"
)
check_version
(
"peft>=0.14.0,<=0.15.1"
)
check_version
(
"transformers>=4.45.0,<=4.52.4,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0,!=4.52.0"
)
check_version
(
"datasets>=2.16.0,<=3.6.0"
)
check_version
(
"accelerate>=0.34.0,<=1.7.0"
)
check_version
(
"peft>=0.14.0,<=0.15.2"
)
check_version
(
"trl>=0.8.6,<=0.9.6"
)
if
is_transformers_version_greater_than
(
"4.46.0"
)
and
not
is_transformers_version_greater_than
(
"4.48.1"
):
logger
.
warning_rank0_once
(
"There are known bugs in transformers v4.46.0-v4.48.0, please use other versions."
)
...
...
src/llamafactory/hparams/data_args.py
View file @
0722acf1
...
...
@@ -99,6 +99,10 @@ class DataArguments:
default
=
0.0
,
metadata
=
{
"help"
:
"Size of the validation set, should be an integer or a float in range `[0,1)`."
},
)
eval_on_each_dataset
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to evaluate on each dataset separately."
},
)
packing
:
Optional
[
bool
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Enable sequences packing in training. Will automatically enable in pre-training."
},
...
...
@@ -111,6 +115,14 @@ class DataArguments:
default
=
None
,
metadata
=
{
"help"
:
"Tool format to use for constructing function calling examples."
},
)
default_system
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Override the default system message in the template."
},
)
enable_thinking
:
Optional
[
bool
]
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to enable thinking mode for reasoning models."
},
)
tokenized_path
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
...
...
@@ -121,6 +133,10 @@ class DataArguments:
)
},
)
data_shared_file_system
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use a shared file system for the datasets."
},
)
def
__post_init__
(
self
):
def
split_arg
(
arg
):
...
...
src/llamafactory/hparams/generating_args.py
View file @
0722acf1
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Optional
from
typing
import
Any
from
transformers
import
GenerationConfig
...
...
@@ -62,10 +62,6 @@ class GeneratingArguments:
default
=
1.0
,
metadata
=
{
"help"
:
"Exponential penalty to the length that is used with beam-based generation."
},
)
default_system
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Default system message to use in chat completion."
},
)
skip_special_tokens
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to remove special tokens in the decoding."
},
...
...
src/llamafactory/hparams/model_args.py
View file @
0722acf1
...
...
@@ -235,10 +235,6 @@ class ProcessorArguments:
default
=
False
,
metadata
=
{
"help"
:
"Whether to crop the image to patches for internvl."
},
)
use_audio_in_video
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use audio in video inputs."
},
)
video_max_pixels
:
int
=
field
(
default
=
256
*
256
,
metadata
=
{
"help"
:
"The maximum number of pixels of video inputs."
},
...
...
@@ -255,6 +251,10 @@ class ProcessorArguments:
default
=
128
,
metadata
=
{
"help"
:
"The maximum number of sampled frames for video inputs."
},
)
use_audio_in_video
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use audio in video inputs."
},
)
audio_sampling_rate
:
int
=
field
(
default
=
16000
,
metadata
=
{
"help"
:
"The sampling rate of audio inputs."
},
...
...
@@ -364,6 +364,12 @@ class SGLangArguments:
default
=
None
,
metadata
=
{
"help"
:
"Config to initialize the SGLang engine. Please use JSON strings."
},
)
sglang_lora_backend
:
Literal
[
"triton"
,
"flashinfer"
]
=
field
(
default
=
"triton"
,
metadata
=
{
"help"
:
"The backend of running GEMM kernels for Lora modules. Recommend using the Triton LoRA backend for better performance and stability."
},
)
def
__post_init__
(
self
):
if
isinstance
(
self
.
sglang_config
,
str
)
and
self
.
sglang_config
.
startswith
(
"{"
):
...
...
src/llamafactory/hparams/parser.py
View file @
0722acf1
...
...
@@ -148,10 +148,10 @@ def _check_extra_dependencies(
check_version
(
"mixture-of-depth>=1.1.6"
,
mandatory
=
True
)
if
model_args
.
infer_backend
==
EngineName
.
VLLM
:
check_version
(
"vllm>=0.4.3,<=0.8.
4
"
)
check_version
(
"vllm>=0.4.3,<=0.8.
6
"
)
check_version
(
"vllm"
,
mandatory
=
True
)
elif
model_args
.
infer_backend
==
EngineName
.
SGLANG
:
check_version
(
"sglang>=0.4.
4
"
)
check_version
(
"sglang>=0.4.
5
"
)
check_version
(
"sglang"
,
mandatory
=
True
)
if
finetuning_args
.
use_galore
:
...
...
src/llamafactory/hparams/training_args.py
View file @
0722acf1
...
...
@@ -64,6 +64,7 @@ class RayArguments:
raise
ValueError
(
f
"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got
{
self
.
ray_storage_filesystem
}
"
)
import
pyarrow.fs
as
fs
if
self
.
ray_storage_filesystem
==
"s3"
:
...
...
src/llamafactory/model/model_utils/attention.py
View file @
0722acf1
...
...
@@ -29,10 +29,8 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
def
configure_attn_implementation
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"gemma2"
and
is_trainable
:
def
configure_attn_implementation
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
)
->
None
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"gemma2"
:
if
model_args
.
flash_attn
==
AttentionFunction
.
AUTO
or
model_args
.
flash_attn
==
AttentionFunction
.
FA2
:
if
is_flash_attn_2_available
():
if
model_args
.
flash_attn
!=
AttentionFunction
.
FA2
:
...
...
src/llamafactory/model/model_utils/liger_kernel.py
View file @
0722acf1
...
...
@@ -45,16 +45,24 @@ def apply_liger_kernel(
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3
as
apply_liger_kernel
elif
model_type
==
"gemma3_text"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3_text
as
apply_liger_kernel
elif
model_type
==
"paligemma"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_paligemma
as
apply_liger_kernel
elif
model_type
==
"glm4"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_glm4
as
apply_liger_kernel
elif
model_type
==
"granite"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_granite
as
apply_liger_kernel
elif
model_type
==
"llama"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_llama
as
apply_liger_kernel
elif
model_type
==
"llava"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_llava
as
apply_liger_kernel
elif
model_type
==
"mistral"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_mistral
as
apply_liger_kernel
elif
model_type
==
"mixtral"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_mixtral
as
apply_liger_kernel
elif
model_type
==
"mllama"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_mllama
as
apply_liger_kernel
elif
model_type
==
"olmo2"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_olmo2
as
apply_liger_kernel
elif
model_type
==
"paligemma"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_paligemma
as
apply_liger_kernel
elif
model_type
==
"phi3"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_phi3
as
apply_liger_kernel
elif
model_type
==
"qwen2"
:
...
...
@@ -63,6 +71,8 @@ def apply_liger_kernel(
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen2_vl
as
apply_liger_kernel
elif
model_type
==
"qwen2_5_vl"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen2_5_vl
as
apply_liger_kernel
elif
model_type
==
"qwen3"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen3
as
apply_liger_kernel
else
:
logger
.
warning_rank0
(
"Current model does not support liger kernel."
)
return
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment