Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
0722acf1
Commit
0722acf1
authored
Jun 04, 2025
by
chenych
Browse files
Update 0604
parent
c4ba4563
Changes
68
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
690 additions
and
323 deletions
+690
-323
src/llamafactory/chat/sglang_engine.py
src/llamafactory/chat/sglang_engine.py
+15
-1
src/llamafactory/chat/vllm_engine.py
src/llamafactory/chat/vllm_engine.py
+0
-1
src/llamafactory/cli.py
src/llamafactory/cli.py
+1
-1
src/llamafactory/data/converter.py
src/llamafactory/data/converter.py
+21
-6
src/llamafactory/data/data_utils.py
src/llamafactory/data/data_utils.py
+28
-27
src/llamafactory/data/loader.py
src/llamafactory/data/loader.py
+12
-7
src/llamafactory/data/mm_plugin.py
src/llamafactory/data/mm_plugin.py
+101
-174
src/llamafactory/data/parser.py
src/llamafactory/data/parser.py
+1
-6
src/llamafactory/data/template.py
src/llamafactory/data/template.py
+210
-26
src/llamafactory/data/tool_utils.py
src/llamafactory/data/tool_utils.py
+26
-39
src/llamafactory/extras/constants.py
src/llamafactory/extras/constants.py
+212
-12
src/llamafactory/extras/env.py
src/llamafactory/extras/env.py
+6
-0
src/llamafactory/extras/misc.py
src/llamafactory/extras/misc.py
+13
-6
src/llamafactory/hparams/data_args.py
src/llamafactory/hparams/data_args.py
+16
-0
src/llamafactory/hparams/generating_args.py
src/llamafactory/hparams/generating_args.py
+1
-5
src/llamafactory/hparams/model_args.py
src/llamafactory/hparams/model_args.py
+10
-4
src/llamafactory/hparams/parser.py
src/llamafactory/hparams/parser.py
+2
-2
src/llamafactory/hparams/training_args.py
src/llamafactory/hparams/training_args.py
+1
-0
src/llamafactory/model/model_utils/attention.py
src/llamafactory/model/model_utils/attention.py
+2
-4
src/llamafactory/model/model_utils/liger_kernel.py
src/llamafactory/model/model_utils/liger_kernel.py
+12
-2
No files found.
src/llamafactory/chat/sglang_engine.py
View file @
0722acf1
...
...
@@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine):
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
data_args
)
self
.
template
.
mm_plugin
.
expand_mm_tokens
=
False
# for sglang generate
self
.
generating_args
=
generating_args
.
to_dict
()
if
model_args
.
adapter_name_or_path
is
not
None
:
self
.
lora_request
=
True
else
:
self
.
lora_request
=
False
launch_cmd
=
[
"python3 -m sglang.launch_server"
,
...
...
@@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine):
f
"--download-dir
{
model_args
.
cache_dir
}
"
,
"--log-level error"
,
]
if
self
.
lora_request
:
launch_cmd
.
extend
(
[
"--max-loras-per-batch 1"
,
f
"--lora-backend
{
model_args
.
sglang_lora_backend
}
"
,
f
"--lora-paths lora0=
{
model_args
.
adapter_name_or_path
[
0
]
}
"
,
"--disable-radix-cache"
,
]
)
launch_cmd
=
" "
.
join
(
launch_cmd
)
logger
.
info_rank0
(
f
"Starting SGLang server with command:
{
launch_cmd
}
"
)
try
:
...
...
@@ -147,7 +160,6 @@ class SGLangEngine(BaseEngine):
messages
,
images
or
[],
videos
or
[],
audios
or
[],
self
.
processor
)
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
self
.
generating_args
[
"default_system"
]
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_length
=
len
(
prompt_ids
)
...
...
@@ -200,6 +212,8 @@ class SGLangEngine(BaseEngine):
"sampling_params"
:
sampling_params
,
"stream"
:
True
,
}
if
self
.
lora_request
:
json_data
[
"lora_request"
]
=
[
"lora0"
]
response
=
requests
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
json_data
,
stream
=
True
)
if
response
.
status_code
!=
200
:
raise
RuntimeError
(
f
"SGLang server error:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
...
...
src/llamafactory/chat/vllm_engine.py
View file @
0722acf1
...
...
@@ -124,7 +124,6 @@ class VllmEngine(BaseEngine):
messages
,
images
or
[],
videos
or
[],
audios
or
[],
self
.
processor
)
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
self
.
generating_args
[
"default_system"
]
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_length
=
len
(
prompt_ids
)
...
...
src/llamafactory/cli.py
View file @
0722acf1
...
...
@@ -73,7 +73,7 @@ def main():
"help"
:
partial
(
print
,
USAGE
),
}
command
=
sys
.
argv
.
pop
(
1
)
if
len
(
sys
.
argv
)
>
=
1
else
"help"
command
=
sys
.
argv
.
pop
(
1
)
if
len
(
sys
.
argv
)
>
1
else
"help"
if
command
==
"train"
and
(
is_env_enabled
(
"FORCE_TORCHRUN"
)
or
(
get_device_count
()
>
1
and
not
use_ray
())):
# launch distributed training
nnodes
=
os
.
getenv
(
"NNODES"
,
"1"
)
...
...
src/llamafactory/data/converter.py
View file @
0722acf1
...
...
@@ -51,12 +51,27 @@ class DatasetConverter:
else
:
medias
=
medias
[:]
if
self
.
dataset_attr
.
load_from
in
[
"script"
,
"file"
]
and
isinstance
(
medias
[
0
],
str
):
for
i
in
range
(
len
(
medias
)):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
self
.
data_args
.
media_dir
,
medias
[
i
])):
medias
[
i
]
=
os
.
path
.
join
(
self
.
data_args
.
media_dir
,
medias
[
i
])
else
:
logger
.
warning_rank0_once
(
f
"Media
{
medias
[
i
]
}
does not exist in `media_dir`. Use original path."
)
if
self
.
dataset_attr
.
load_from
in
[
"script"
,
"file"
]:
if
isinstance
(
medias
[
0
],
str
):
for
i
in
range
(
len
(
medias
)):
media_path
=
os
.
path
.
join
(
self
.
data_args
.
media_dir
,
medias
[
i
])
if
os
.
path
.
isfile
(
media_path
):
medias
[
i
]
=
media_path
else
:
logger
.
warning_rank0_once
(
f
"Media
{
medias
[
i
]
}
does not exist in `media_dir`. Use original path."
)
elif
isinstance
(
medias
[
0
],
list
):
# for processed video frames
# medias is a list of lists, e.g., [[frame1.jpg, frame2.jpg], [frame3.jpg, frame4.jpg]]
for
i
in
range
(
len
(
medias
)):
for
j
in
range
(
len
(
medias
[
i
])):
media_path
=
os
.
path
.
join
(
self
.
data_args
.
media_dir
,
medias
[
i
][
j
])
if
os
.
path
.
isfile
(
media_path
):
medias
[
i
][
j
]
=
media_path
else
:
logger
.
warning_rank0_once
(
f
"Media
{
medias
[
i
][
j
]
}
does not exist in `media_dir`. Use original path."
)
return
medias
...
...
src/llamafactory/data/data_utils.py
View file @
0722acf1
...
...
@@ -14,7 +14,7 @@
import
json
from
enum
import
Enum
,
unique
from
typing
import
TYPE_CHECKING
,
Optional
,
TypedDict
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
TypedDict
,
Union
import
fsspec
from
datasets
import
DatasetDict
,
concatenate_datasets
,
interleave_datasets
...
...
@@ -142,48 +142,49 @@ def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModu
return
dataset_module
def
setup_fs
(
path
,
anon
=
False
)
:
"""Set up a filesystem object based on the path protocol."""
def
setup_fs
(
path
:
str
,
anon
:
bool
=
False
)
->
"fsspec.AbstractFileSystem"
:
r
"""Set up a filesystem object based on the path protocol."""
storage_options
=
{
"anon"
:
anon
}
if
anon
else
{}
if
path
.
startswith
(
"s3://"
):
fs
=
fsspec
.
filesystem
(
"s3"
,
**
storage_options
)
elif
path
.
startswith
((
"gs://"
,
"gcs://"
)):
fs
=
fsspec
.
filesystem
(
"gcs"
,
**
storage_options
)
else
:
raise
ValueError
(
f
"Unsupported protocol in path:
{
path
}
. Use 's3://' or 'gs://'"
)
raise
ValueError
(
f
"Unsupported protocol in path:
{
path
}
. Use 's3://' or 'gs://'."
)
if
not
fs
.
exists
(
path
):
raise
ValueError
(
f
"Path does not exist:
{
path
}
."
)
return
fs
def
read_cloud_json
(
cloud_path
):
"""Read a JSON/JSONL file from cloud storage (S3 or GCS).
def
_read_json_with_fs
(
fs
:
"fsspec.AbstractFileSystem"
,
path
:
str
)
->
list
[
Any
]:
r
"""Helper function to read JSON/JSONL files using fsspec."""
with
fs
.
open
(
path
,
"r"
)
as
f
:
if
path
.
endswith
(
".jsonl"
):
return
[
json
.
loads
(
line
)
for
line
in
f
if
line
.
strip
()]
else
:
return
json
.
load
(
f
)
def
read_cloud_json
(
cloud_path
:
str
)
->
list
[
Any
]:
r
"""Read a JSON/JSONL file from cloud storage (S3 or GCS).
Args:
cloud_path
: str
cloud_path: str
Cloud path in the format:
- 's3://bucket-name/file.json' for AWS S3
- 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage
lines : bool, default=True
If True, read the file as JSON Lines format (one JSON object per line)
"""
try
:
# Try with anonymous access first
fs
=
setup_fs
(
cloud_path
,
anon
=
True
)
return
_read_json_with_fs
(
fs
,
cloud_path
,
lines
=
cloud_path
.
endswith
(
".jsonl"
))
fs
=
setup_fs
(
cloud_path
,
anon
=
True
)
# try with anonymous access first
except
Exception
:
# Try again with credentials
fs
=
setup_fs
(
cloud_path
)
return
_read_json_with_fs
(
fs
,
cloud_path
,
lines
=
cloud_path
.
endswith
(
".jsonl"
))
fs
=
setup_fs
(
cloud_path
)
# try again with credentials
def
_read_json_with_fs
(
fs
,
path
,
lines
=
True
):
"""Helper function to read JSON/JSONL files using fsspec."""
with
fs
.
open
(
path
,
"r"
)
as
f
:
if
lines
:
# Read JSONL (JSON Lines) format - one JSON object per line
data
=
[
json
.
loads
(
line
)
for
line
in
f
if
line
.
strip
()]
else
:
# Read regular JSON format
data
=
json
.
load
(
f
)
# filter out non-JSON files
files
=
[
x
[
"Key"
]
for
x
in
fs
.
listdir
(
cloud_path
)]
if
fs
.
isdir
(
cloud_path
)
else
[
cloud_path
]
files
=
filter
(
lambda
file
:
file
.
endswith
(
".json"
)
or
file
.
endswith
(
".jsonl"
),
files
)
if
not
files
:
raise
ValueError
(
f
"No JSON/JSONL files found in the specified path:
{
cloud_path
}
."
)
return
data
return
sum
([
_read_json_with_fs
(
fs
,
file
)
for
file
in
files
],
[])
src/llamafactory/data/loader.py
View file @
0722acf1
...
...
@@ -168,7 +168,7 @@ def _get_merged_dataset(
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
merge
:
bool
=
Tru
e
,
return_dict
:
bool
=
Fals
e
,
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
dict
[
str
,
"Dataset"
]]]:
r
"""Return the merged datasets in the standard format."""
if
dataset_names
is
None
:
...
...
@@ -181,10 +181,10 @@ def _get_merged_dataset(
datasets
[
dataset_name
]
=
_load_single_dataset
(
dataset_attr
,
model_args
,
data_args
,
training_args
)
if
merge
:
return
merge_dataset
(
list
(
datasets
.
values
()),
data_args
,
seed
=
training_args
.
seed
)
else
:
if
return_dict
:
return
datasets
else
:
return
merge_dataset
(
list
(
datasets
.
values
()),
data_args
,
seed
=
training_args
.
seed
)
def
_get_dataset_processor
(
...
...
@@ -300,13 +300,18 @@ def get_dataset(
raise
ValueError
(
"Turn off `streaming` when saving dataset to disk."
)
# Load and preprocess dataset
with
training_args
.
main_process_first
(
desc
=
"load dataset"
):
with
training_args
.
main_process_first
(
desc
=
"load dataset"
,
local
=
(
not
data_args
.
data_shared_file_system
)
):
dataset
=
_get_merged_dataset
(
data_args
.
dataset
,
model_args
,
data_args
,
training_args
,
stage
)
eval_dataset
=
_get_merged_dataset
(
data_args
.
eval_dataset
,
model_args
,
data_args
,
training_args
,
stage
,
merge
=
training_args
.
do_predict
data_args
.
eval_dataset
,
model_args
,
data_args
,
training_args
,
stage
,
return_dict
=
data_args
.
eval_on_each_dataset
,
)
with
training_args
.
main_process_first
(
desc
=
"pre-process dataset"
):
with
training_args
.
main_process_first
(
desc
=
"pre-process dataset"
,
local
=
(
not
data_args
.
data_shared_file_system
)
):
dataset
=
_get_preprocessed_dataset
(
dataset
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
False
)
...
...
src/llamafactory/data/mm_plugin.py
View file @
0722acf1
...
...
@@ -17,6 +17,7 @@
import
inspect
import
math
import
os
import
re
from
copy
import
deepcopy
from
dataclasses
import
dataclass
...
...
@@ -25,7 +26,7 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
import
numpy
as
np
import
torch
from
transformers.image_utils
import
get_image_size
,
to_numpy_array
from
transformers.image_utils
import
get_image_size
,
is_valid_image
,
to_numpy_array
from
typing_extensions
import
override
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IGNORE_INDEX
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
...
...
@@ -57,7 +58,10 @@ if is_transformers_version_greater_than("4.45.0"):
)
if
is_transformers_version_greater_than
(
"4.49.0"
):
if
is_transformers_version_greater_than
(
"4.52.0"
):
from
transformers.image_utils
import
make_flat_list_of_images
from
transformers.video_utils
import
make_batched_videos
elif
is_transformers_version_greater_than
(
"4.49.0"
):
from
transformers.image_utils
import
make_batched_videos
,
make_flat_list_of_images
...
...
@@ -73,7 +77,7 @@ if TYPE_CHECKING:
bytes
:
Optional
[
bytes
]
ImageInput
=
Union
[
str
,
bytes
,
EncodedImage
,
BinaryIO
,
ImageObject
]
VideoInput
=
Union
[
str
,
BinaryIO
]
VideoInput
=
Union
[
str
,
BinaryIO
,
list
[
list
[
ImageInput
]]
]
AudioInput
=
Union
[
str
,
BinaryIO
,
NDArray
]
class
MMProcessor
(
ProcessorMixin
):
...
...
@@ -131,6 +135,11 @@ def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> lis
return
batch_images
def
_check_video_is_nested_images
(
video
:
"VideoInput"
)
->
bool
:
r
"""Check if the video is nested images."""
return
isinstance
(
video
,
list
)
and
all
(
isinstance
(
frame
,
(
str
,
BinaryIO
,
dict
))
for
frame
in
video
)
@
dataclass
class
MMPluginMixin
:
image_token
:
Optional
[
str
]
...
...
@@ -167,16 +176,45 @@ class MMPluginMixin:
)
if
self
.
image_token
is
not
None
and
processor
is
None
:
raise
ValueError
(
"Processor was not found, please check and update your
processor config
."
)
raise
ValueError
(
"Processor was not found, please check and update your
model file
."
)
if
self
.
image_token
is
not
None
and
image_processor
is
None
:
raise
ValueError
(
"Image processor was not found, please check and update your
processor config
."
)
raise
ValueError
(
"Image processor was not found, please check and update your
model file
."
)
if
self
.
video_token
is
not
None
and
video_processor
is
None
:
raise
ValueError
(
"Video processor was not found, please check and update your
processor config
."
)
raise
ValueError
(
"Video processor was not found, please check and update your
model file
."
)
if
self
.
audio_token
is
not
None
and
feature_extractor
is
None
:
raise
ValueError
(
"Audio feature extractor was not found, please check and update your processor config."
)
raise
ValueError
(
"Audio feature extractor was not found, please check and update your model file."
)
def
_validate_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
):
r
"""Validate if the number of images, videos and audios match the number of placeholders in messages."""
num_image_tokens
,
num_video_tokens
,
num_audio_tokens
=
0
,
0
,
0
for
message
in
messages
:
num_image_tokens
+=
message
[
"content"
].
count
(
IMAGE_PLACEHOLDER
)
num_video_tokens
+=
message
[
"content"
].
count
(
VIDEO_PLACEHOLDER
)
num_audio_tokens
+=
message
[
"content"
].
count
(
AUDIO_PLACEHOLDER
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens in
{
messages
}
."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens in
{
messages
}
."
)
if
len
(
audios
)
!=
num_audio_tokens
:
raise
ValueError
(
f
"The number of audios does not match the number of
{
AUDIO_PLACEHOLDER
}
tokens in
{
messages
}
."
)
def
_preprocess_image
(
self
,
image
:
"ImageObject"
,
image_max_pixels
:
int
,
image_min_pixels
:
int
,
**
kwargs
...
...
@@ -234,14 +272,20 @@ class MMPluginMixin:
r
"""Regularizes videos to avoid error. Including reading, resizing and converting."""
results
=
[]
for
video
in
videos
:
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
sample_indices
=
self
.
_get_video_sample_indices
(
video_stream
,
**
kwargs
)
frames
:
list
[
ImageObject
]
=
[]
container
.
seek
(
0
)
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
if
frame_idx
in
sample_indices
:
frames
.
append
(
frame
.
to_image
())
if
_check_video_is_nested_images
(
video
):
for
frame
in
video
:
if
not
is_valid_image
(
frame
)
and
not
isinstance
(
frame
,
dict
)
and
not
os
.
path
.
exists
(
frame
):
raise
ValueError
(
"Invalid image found in video frames."
)
frames
=
video
else
:
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
sample_indices
=
self
.
_get_video_sample_indices
(
video_stream
,
**
kwargs
)
container
.
seek
(
0
)
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
if
frame_idx
in
sample_indices
:
frames
.
append
(
frame
.
to_image
())
frames
=
self
.
_regularize_images
(
frames
,
**
kwargs
)[
"images"
]
results
.
append
(
frames
)
...
...
@@ -420,6 +464,7 @@ class Gemma3Plugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
boi_token
:
str
=
getattr
(
processor
,
"boi_token"
)
...
...
@@ -446,9 +491,6 @@ class Gemma3Plugin(BasePlugin):
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
image_str
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
@
override
...
...
@@ -495,14 +537,14 @@ class InternVLPlugin(BasePlugin):
mm_inputs
=
{}
image_video_patches
=
[]
if
len
(
images
)
!=
0
and
isinstance
(
images
[
0
],
str
)
:
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
,
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
1024
*
1024
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)[
"images"
]
if
len
(
videos
)
!=
0
and
isinstance
(
videos
[
0
],
str
)
:
if
len
(
videos
)
!=
0
:
videos
=
self
.
_regularize_videos
(
videos
,
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
...
...
@@ -566,8 +608,8 @@ class InternVLPlugin(BasePlugin):
processor
:
Optional
[
"ProcessorMixin"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
num_video_tokens
=
0
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
image_seqlen
=
getattr
(
processor
,
"image_seq_length"
)
if
self
.
expand_mm_tokens
else
1
messages
=
deepcopy
(
messages
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
...
...
@@ -579,9 +621,6 @@ class InternVLPlugin(BasePlugin):
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
f
"<img>
{
'<IMG_CONTEXT>'
*
image_seqlen
*
image_pixel_patch_list
[
num_image_tokens
]
}
</img>"
,
...
...
@@ -590,9 +629,6 @@ class InternVLPlugin(BasePlugin):
num_image_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
if
num_video_tokens
>=
len
(
videos
):
raise
ValueError
(
f
"`len(videos)` is less than the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
current_patch_index
=
video_patch_indices
[
num_video_tokens
-
1
]
if
num_video_tokens
>
0
else
0
end_patch_index
=
video_patch_indices
[
num_video_tokens
]
num_patches
=
list
(
video_num_patches
[
current_patch_index
:
end_patch_index
])
...
...
@@ -605,12 +641,6 @@ class InternVLPlugin(BasePlugin):
message
[
"content"
]
=
content
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
return
messages
@
override
...
...
@@ -637,10 +667,13 @@ class KimiVLPlugin(BasePlugin):
@
override
def
process_messages
(
self
,
messages
,
images
,
videos
,
audios
,
processor
):
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
image_grid_hws
=
mm_inputs
.
get
(
"image_grid_hws"
,
[])
else
:
image_grid_hws
=
[
None
]
*
len
(
images
)
image_grid_hws
=
mm_inputs
.
get
(
"image_grid_hws"
,
[])
num_image_tokens
=
0
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
merge_length
=
math
.
prod
(
image_processor
.
merge_kernel_size
)
...
...
@@ -648,9 +681,6 @@ class KimiVLPlugin(BasePlugin):
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
image_seqlen
=
image_grid_hws
[
num_image_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
...
...
@@ -661,9 +691,6 @@ class KimiVLPlugin(BasePlugin):
message
[
"content"
]
=
content
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
...
...
@@ -679,6 +706,7 @@ class Llama4Plugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"pixel_values"
in
mm_inputs
:
...
...
@@ -701,9 +729,6 @@ class Llama4Plugin(BasePlugin):
for
local_image_index
,
split_part
in
enumerate
(
prompt_splits
):
new_content
.
append
(
split_part
)
if
local_image_index
<
placeholder_count
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
tokens_for_this_image
=
processor
.
_prompt_split_image
(
aspect_ratios
[
num_image_tokens
],
num_patches_per_chunk
)
...
...
@@ -716,9 +741,6 @@ class Llama4Plugin(BasePlugin):
message
[
"content"
]
=
content
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
@
override
...
...
@@ -751,7 +773,7 @@ class LlavaPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
...
...
@@ -768,17 +790,10 @@ class LlavaPlugin(BasePlugin):
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
...
...
@@ -794,6 +809,7 @@ class LlavaNextPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
...
...
@@ -805,9 +821,6 @@ class LlavaNextPlugin(BasePlugin):
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
self
.
expand_mm_tokens
:
orig_height
,
orig_width
=
next
(
image_sizes
)
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
...
...
@@ -821,9 +834,6 @@ class LlavaNextPlugin(BasePlugin):
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
...
...
@@ -839,7 +849,7 @@ class LlavaNextVideoPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
...
...
@@ -850,9 +860,6 @@ class LlavaNextVideoPlugin(BasePlugin):
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
self
.
expand_mm_tokens
:
orig_height
,
orig_width
=
next
(
image_sizes
)
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
...
...
@@ -862,7 +869,6 @@ class LlavaNextVideoPlugin(BasePlugin):
image_seqlen
=
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
...
...
@@ -879,20 +885,10 @@ class LlavaNextVideoPlugin(BasePlugin):
for
message
in
messages
:
content
=
message
[
"content"
]
while
VIDEO_PLACEHOLDER
in
content
:
if
num_video_tokens
>=
len
(
videos
):
raise
ValueError
(
f
"`len(videos)` is less than the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{video}}"
*
video_seqlen
,
1
)
num_video_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{video}}"
,
self
.
video_token
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
return
messages
...
...
@@ -978,6 +974,7 @@ class MiniCPMVPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
,
num_audio_tokens
=
0
,
0
,
0
messages
=
deepcopy
(
messages
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
...
...
@@ -996,24 +993,15 @@ class MiniCPMVPlugin(BasePlugin):
for
i
,
message
in
enumerate
(
messages
):
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
,
1
)
num_image_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
if
num_video_tokens
>=
len
(
videos
):
raise
ValueError
(
f
"`len(videos)` is less than the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
video_seqlen
=
len
(
mm_inputs
[
"pixel_values"
][
num_video_tokens
])
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{image}}"
*
video_seqlen
,
1
)
num_video_tokens
+=
1
while
AUDIO_PLACEHOLDER
in
content
:
if
num_audio_tokens
>=
len
(
audios
):
raise
ValueError
(
f
"`len(audios)` is less than the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
"{{audio}}"
,
1
)
num_audio_tokens
+=
1
...
...
@@ -1065,15 +1053,6 @@ class MiniCPMVPlugin(BasePlugin):
final_text
+=
text_chunks
[
-
1
]
messages
[
index
][
"content"
]
=
final_text
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
if
len
(
audios
)
!=
num_audio_tokens
:
raise
ValueError
(
f
"The number of audios does not match the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
return
messages
@
override
...
...
@@ -1157,6 +1136,7 @@ class MllamaPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
...
...
@@ -1164,9 +1144,6 @@ class MllamaPlugin(BasePlugin):
num_image_tokens
+=
content
.
count
(
IMAGE_PLACEHOLDER
)
message
[
"content"
]
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
self
.
image_token
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
@
override
...
...
@@ -1214,6 +1191,7 @@ class PaliGemmaPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
...
...
@@ -1224,9 +1202,6 @@ class PaliGemmaPlugin(BasePlugin):
message
[
"content"
]
=
content
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
@
override
...
...
@@ -1281,7 +1256,7 @@ class PixtralPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
...
...
@@ -1291,15 +1266,13 @@ class PixtralPlugin(BasePlugin):
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
][
0
])
else
:
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
].
tolist
())
image_break_token
:
str
=
getattr
(
processor
,
"image_break_token"
)
image_end_token
:
str
=
getattr
(
processor
,
"image_end_token"
)
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
self
.
expand_mm_tokens
:
height
,
width
=
next
(
image_sizes
)
num_height_tokens
=
height
//
processor
.
patch_size
...
...
@@ -1312,13 +1285,9 @@ class PixtralPlugin(BasePlugin):
replace_str
=
self
.
image_token
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
replace_str
,
1
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
@
override
...
...
@@ -1355,9 +1324,9 @@ class Qwen2AudioPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
bos_token
:
str
=
getattr
(
processor
,
"audio_bos_token"
)
eos_token
:
str
=
getattr
(
processor
,
"audio_eos_token"
)
num_audio_tokens
=
0
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
([],
[],
audios
,
processor
)
...
...
@@ -1367,9 +1336,6 @@ class Qwen2AudioPlugin(BasePlugin):
for
message
in
messages
:
content
=
message
[
"content"
]
while
AUDIO_PLACEHOLDER
in
content
:
if
num_audio_tokens
>=
len
(
audios
):
raise
ValueError
(
f
"`len(audios)` is less than the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
if
self
.
expand_mm_tokens
:
audio_length
=
audio_lengths
.
pop
(
0
)
input_length
=
(
audio_length
-
1
)
//
2
+
1
...
...
@@ -1380,13 +1346,9 @@ class Qwen2AudioPlugin(BasePlugin):
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
f
"
{
bos_token
}{
self
.
audio_token
*
audio_seqlen
}{
eos_token
}
"
,
1
)
num_audio_tokens
+=
1
message
[
"content"
]
=
content
if
len
(
audios
)
!=
num_audio_tokens
:
raise
ValueError
(
f
"The number of audios does not match the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
return
messages
@
override
...
...
@@ -1430,24 +1392,33 @@ class Qwen2VLPlugin(BasePlugin):
)
->
dict
[
str
,
Union
[
list
[
list
[
"ImageObject"
]],
list
[
float
]]]:
results
,
fps_per_video
=
[],
[]
for
video
in
videos
:
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
sample_indices
=
self
.
_get_video_sample_indices
(
video_stream
,
**
kwargs
)
frames
:
list
[
ImageObject
]
=
[]
container
.
seek
(
0
)
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
if
frame_idx
in
sample_indices
:
frames
.
append
(
frame
.
to_image
())
if
_check_video_is_nested_images
(
video
):
for
frame
in
video
:
if
not
is_valid_image
(
frame
)
and
not
isinstance
(
frame
,
dict
)
and
not
os
.
path
.
exists
(
frame
):
raise
ValueError
(
"Invalid image found in video frames."
)
frames
=
video
fps_per_video
.
append
(
kwargs
.
get
(
"video_fps"
,
2.0
))
else
:
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
sample_indices
=
self
.
_get_video_sample_indices
(
video_stream
,
**
kwargs
)
container
.
seek
(
0
)
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
if
frame_idx
in
sample_indices
:
frames
.
append
(
frame
.
to_image
())
if
video_stream
.
duration
is
None
:
fps_per_video
.
append
(
kwargs
.
get
(
"video_fps"
,
2.0
))
else
:
fps_per_video
.
append
(
len
(
sample_indices
)
/
float
(
video_stream
.
duration
*
video_stream
.
time_base
))
if
len
(
frames
)
%
2
!=
0
:
# qwen2-vl requires even number of frames
if
len
(
frames
)
%
2
!=
0
:
frames
.
append
(
frames
[
-
1
])
frames
=
self
.
_regularize_images
(
frames
,
**
kwargs
)[
"images"
]
results
.
append
(
frames
)
if
video_stream
.
duration
is
None
:
fps_per_video
.
append
(
2.0
)
else
:
fps_per_video
.
append
(
len
(
sample_indices
)
/
float
(
video_stream
.
duration
*
video_stream
.
time_base
))
return
{
"videos"
:
results
,
"fps_per_video"
:
fps_per_video
}
...
...
@@ -1494,6 +1465,7 @@ class Qwen2VLPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
...
...
@@ -1510,9 +1482,6 @@ class Qwen2VLPlugin(BasePlugin):
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
image_seqlen
=
image_grid_thw
[
num_image_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
f
"<|vision_start|>
{
self
.
image_token
*
image_seqlen
}
<|vision_end|>"
,
1
...
...
@@ -1520,9 +1489,6 @@ class Qwen2VLPlugin(BasePlugin):
num_image_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
if
num_video_tokens
>=
len
(
videos
):
raise
ValueError
(
f
"`len(videos)` is less than the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
video_seqlen
=
video_grid_thw
[
num_video_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
f
"<|vision_start|>
{
self
.
video_token
*
video_seqlen
}
<|vision_end|>"
,
1
...
...
@@ -1531,12 +1497,6 @@ class Qwen2VLPlugin(BasePlugin):
message
[
"content"
]
=
content
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
return
messages
...
...
@@ -1602,6 +1562,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
,
num_audio_tokens
=
0
,
0
,
0
messages
=
deepcopy
(
messages
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
...
...
@@ -1624,9 +1585,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
image_seqlen
=
image_grid_thw
[
num_image_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
f
"<|vision_bos|>
{
self
.
image_token
*
image_seqlen
}
<|vision_eos|>"
,
1
...
...
@@ -1642,11 +1600,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
)
while
VIDEO_PLACEHOLDER
in
content
:
if
num_video_tokens
>=
len
(
videos
):
raise
ValueError
(
f
"`len(videos)` is less than the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
if
num_audio_tokens
>=
len
(
audios
):
raise
ValueError
(
f
"`len(audios)` is less than the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
video_pos
=
content
.
find
(
VIDEO_PLACEHOLDER
)
audio_pos
=
content
.
find
(
AUDIO_PLACEHOLDER
,
video_pos
)
if
audio_pos
==
-
1
or
audio_pos
<
video_pos
:
...
...
@@ -1688,9 +1641,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
num_video_tokens
+=
1
else
:
while
AUDIO_PLACEHOLDER
in
content
:
if
num_audio_tokens
>=
len
(
audios
):
raise
ValueError
(
f
"`len(audios)` is less than the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
audio_seqlen
=
audio_lengths
[
num_audio_tokens
]
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
f
"<|audio_bos|>
{
self
.
audio_token
*
audio_seqlen
}
<|audio_eos|>"
,
1
...
...
@@ -1698,9 +1648,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
num_audio_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
if
num_video_tokens
>=
len
(
videos
):
raise
ValueError
(
f
"`len(videos)` is less than the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
video_seqlen
=
(
video_grid_thw
[
num_video_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
)
...
...
@@ -1711,15 +1658,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
message
[
"content"
]
=
content
if
len
(
audios
)
!=
num_audio_tokens
:
raise
ValueError
(
f
"The number of audios does not match the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
return
messages
...
...
@@ -1735,6 +1673,7 @@ class VideoLlavaPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
num_frames
=
0
...
...
@@ -1762,28 +1701,16 @@ class VideoLlavaPlugin(BasePlugin):
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
num_image_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
if
num_video_tokens
>=
len
(
videos
):
raise
ValueError
(
f
"`len(videos)` is less than the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{video}}"
*
video_seqlen
,
1
)
num_video_tokens
+=
1
content
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
message
[
"content"
]
=
content
.
replace
(
"{{video}}"
,
self
.
video_token
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
return
messages
...
...
src/llamafactory/data/parser.py
View file @
0722acf1
...
...
@@ -115,12 +115,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li
dataset_list
:
list
[
DatasetAttr
]
=
[]
for
name
in
dataset_names
:
if
dataset_info
is
None
:
# dataset_dir is ONLINE
if
use_modelscope
():
load_from
=
"ms_hub"
elif
use_openmind
():
load_from
=
"om_hub"
else
:
load_from
=
"hf_hub"
load_from
=
"ms_hub"
if
use_modelscope
()
else
"om_hub"
if
use_openmind
()
else
"hf_hub"
dataset_attr
=
DatasetAttr
(
load_from
,
dataset_name
=
name
)
dataset_list
.
append
(
dataset_attr
)
continue
...
...
src/llamafactory/data/template.py
View file @
0722acf1
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
re
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
...
...
@@ -51,6 +52,7 @@ class Template:
efficient_eos
:
bool
replace_eos
:
bool
replace_jinja_template
:
bool
enable_thinking
:
Optional
[
bool
]
mm_plugin
:
"BasePlugin"
def
encode_oneturn
(
...
...
@@ -61,7 +63,7 @@ class Template:
tools
:
Optional
[
str
]
=
None
,
)
->
tuple
[
list
[
int
],
list
[
int
]]:
r
"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
,
remove_thought
=
True
)
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
prompt_ids
=
[]
for
encoded_ids
in
encoded_messages
[:
-
1
]:
prompt_ids
+=
encoded_ids
...
...
@@ -77,7 +79,7 @@ class Template:
tools
:
Optional
[
str
]
=
None
,
)
->
list
[
tuple
[
list
[
int
],
list
[
int
]]]:
r
"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
,
remove_thought
=
False
)
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
return
[(
encoded_messages
[
i
],
encoded_messages
[
i
+
1
])
for
i
in
range
(
0
,
len
(
encoded_messages
),
2
)]
def
extract_tool
(
self
,
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
...
...
@@ -92,6 +94,19 @@ class Template:
return
list
(
stop_token_ids
)
def
add_thought
(
self
,
content
:
str
=
""
)
->
str
:
r
"""Add empty thought to assistant message."""
return
f
"
{
self
.
thought_words
[
0
]
}
\n\n
{
self
.
thought_words
[
1
]
}
\n\n
"
+
content
def
remove_thought
(
self
,
content
:
str
)
->
str
:
r
"""Remove thought from assistant message."""
pattern
=
re
.
compile
(
f
"
{
re
.
escape
(
self
.
thought_words
[
0
])
}
(.*?)
{
re
.
escape
(
self
.
thought_words
[
1
])
}
"
,
re
.
DOTALL
)
return
re
.
sub
(
pattern
,
""
,
content
).
lstrip
(
"
\n
"
)
def
get_thought_word_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
list
[
int
]:
r
"""Get the token ids of thought words."""
return
tokenizer
.
encode
(
self
.
add_thought
(),
add_special_tokens
=
False
)
def
_convert_elements_to_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
elements
:
"SLOTS"
)
->
list
[
int
]:
r
"""Convert elements to token ids."""
token_ids
=
[]
...
...
@@ -111,18 +126,12 @@ class Template:
return
token_ids
def
_remove_thought
(
self
,
content
:
str
)
->
str
:
r
"""Remove thought from assistant message."""
pattern
=
re
.
compile
(
f
"
{
re
.
escape
(
self
.
thought_words
[
0
])
}
(.*?)
{
re
.
escape
(
self
.
thought_words
[
1
])
}
"
,
re
.
DOTALL
)
return
re
.
sub
(
pattern
,
""
,
content
).
lstrip
(
"
\n
"
)
def
_encode
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
remove_thought
:
bool
,
)
->
list
[
list
[
int
]]:
r
"""Encode formatted inputs to pairs of token ids.
...
...
@@ -140,18 +149,14 @@ class Template:
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
elements
+=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))
content
=
message
[
"content"
]
if
remove_thought
and
message
[
"role"
]
==
Role
.
ASSISTANT
and
(
i
!=
len
(
messages
)
-
1
):
content
=
self
.
_remove_thought
(
content
)
if
message
[
"role"
]
==
Role
.
USER
:
elements
+=
self
.
format_user
.
apply
(
content
=
content
,
idx
=
str
(
i
//
2
))
elements
+=
self
.
format_user
.
apply
(
content
=
message
[
"
content
"
]
,
idx
=
str
(
i
//
2
))
elif
message
[
"role"
]
==
Role
.
ASSISTANT
:
elements
+=
self
.
format_assistant
.
apply
(
content
=
content
)
elements
+=
self
.
format_assistant
.
apply
(
content
=
message
[
"
content
"
]
)
elif
message
[
"role"
]
==
Role
.
OBSERVATION
:
elements
+=
self
.
format_observation
.
apply
(
content
=
content
)
elements
+=
self
.
format_observation
.
apply
(
content
=
message
[
"
content
"
]
)
elif
message
[
"role"
]
==
Role
.
FUNCTION
:
elements
+=
self
.
format_function
.
apply
(
content
=
content
)
elements
+=
self
.
format_function
.
apply
(
content
=
message
[
"
content
"
]
)
else
:
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
...
...
@@ -162,6 +167,9 @@ class Template:
@
staticmethod
def
_add_or_replace_eos_token
(
tokenizer
:
"PreTrainedTokenizer"
,
eos_token
:
str
)
->
None
:
r
"""Add or replace eos token to the tokenizer."""
if
tokenizer
.
eos_token
==
eos_token
:
return
is_added
=
tokenizer
.
eos_token_id
is
None
num_added_tokens
=
tokenizer
.
add_special_tokens
({
"eos_token"
:
eos_token
})
...
...
@@ -328,7 +336,6 @@ class Llama2Template(Template):
messages
:
list
[
dict
[
str
,
str
]],
system
:
str
,
tools
:
str
,
remove_thought
:
bool
,
)
->
list
[
list
[
int
]]:
system
=
system
or
self
.
default_system
encoded_messages
=
[]
...
...
@@ -342,18 +349,14 @@ class Llama2Template(Template):
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
system_text
=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))[
0
]
content
=
message
[
"content"
]
if
remove_thought
and
message
[
"role"
]
==
Role
.
ASSISTANT
and
(
i
!=
len
(
messages
)
-
1
):
content
=
self
.
_remove_thought
(
content
)
if
message
[
"role"
]
==
Role
.
USER
:
elements
+=
self
.
format_user
.
apply
(
content
=
system_text
+
content
)
elements
+=
self
.
format_user
.
apply
(
content
=
system_text
+
message
[
"
content
"
]
)
elif
message
[
"role"
]
==
Role
.
ASSISTANT
:
elements
+=
self
.
format_assistant
.
apply
(
content
=
content
)
elements
+=
self
.
format_assistant
.
apply
(
content
=
message
[
"
content
"
]
)
elif
message
[
"role"
]
==
Role
.
OBSERVATION
:
elements
+=
self
.
format_observation
.
apply
(
content
=
content
)
elements
+=
self
.
format_observation
.
apply
(
content
=
message
[
"
content
"
]
)
elif
message
[
"role"
]
==
Role
.
FUNCTION
:
elements
+=
self
.
format_function
.
apply
(
content
=
content
)
elements
+=
self
.
format_function
.
apply
(
content
=
message
[
"
content
"
]
)
else
:
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
...
...
@@ -392,6 +395,64 @@ class Llama2Template(Template):
return
jinja_template
@
dataclass
class
ReasoningTemplate
(
Template
):
r
"""A template that add thought to assistant message."""
@
override
def
encode_oneturn
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
tuple
[
list
[
int
],
list
[
int
]]:
messages
=
deepcopy
(
messages
)
for
i
in
range
(
1
,
len
(
messages
)
-
2
,
2
):
messages
[
i
][
"content"
]
=
self
.
remove_thought
(
messages
[
i
][
"content"
])
if
self
.
enable_thinking
is
False
:
# remove all cot
messages
[
-
1
][
"content"
]
=
self
.
remove_thought
(
messages
[
-
1
][
"content"
])
prompt_ids
,
response_ids
=
super
().
encode_oneturn
(
tokenizer
,
messages
,
system
,
tools
)
if
(
self
.
thought_words
[
0
]
not
in
messages
[
-
1
][
"content"
]
and
self
.
thought_words
[
1
]
not
in
messages
[
-
1
][
"content"
]
):
# add empty cot
if
not
self
.
enable_thinking
:
# do not compute loss
prompt_ids
+=
self
.
get_thought_word_ids
(
tokenizer
)
else
:
# do compute loss
response_ids
=
self
.
get_thought_word_ids
(
tokenizer
)
+
response_ids
return
prompt_ids
,
response_ids
@
override
def
encode_multiturn
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
list
[
tuple
[
list
[
int
],
list
[
int
]]]:
messages
=
deepcopy
(
messages
)
if
self
.
enable_thinking
is
False
:
# remove all cot
for
i
in
range
(
1
,
len
(
messages
),
2
):
messages
[
i
][
"content"
]
=
self
.
remove_thought
(
messages
[
i
][
"content"
])
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
for
i
in
range
(
0
,
len
(
messages
),
2
):
if
(
self
.
thought_words
[
0
]
not
in
messages
[
i
+
1
][
"content"
]
and
self
.
thought_words
[
1
]
not
in
messages
[
i
+
1
][
"content"
]
):
# add empty cot
if
not
self
.
enable_thinking
:
# do not compute loss
encoded_messages
[
i
]
+=
self
.
get_thought_word_ids
(
tokenizer
)
else
:
# do compute loss
encoded_messages
[
i
+
1
]
=
self
.
get_thought_word_ids
(
tokenizer
)
+
encoded_messages
[
i
+
1
]
return
[(
encoded_messages
[
i
],
encoded_messages
[
i
+
1
])
for
i
in
range
(
0
,
len
(
encoded_messages
),
2
)]
TEMPLATES
:
dict
[
str
,
"Template"
]
=
{}
...
...
@@ -410,6 +471,7 @@ def register_template(
efficient_eos
:
bool
=
False
,
replace_eos
:
bool
=
False
,
replace_jinja_template
:
bool
=
False
,
enable_thinking
:
Optional
[
bool
]
=
True
,
mm_plugin
:
"BasePlugin"
=
get_mm_plugin
(
name
=
"base"
),
template_class
:
type
[
"Template"
]
=
Template
,
)
->
None
:
...
...
@@ -456,6 +518,7 @@ def register_template(
efficient_eos
=
efficient_eos
,
replace_eos
=
replace_eos
,
replace_jinja_template
=
replace_jinja_template
,
enable_thinking
=
enable_thinking
,
mm_plugin
=
mm_plugin
,
)
...
...
@@ -492,6 +555,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
messages
=
[{
"role"
:
"user"
,
"content"
:
"{{content}}"
},
{
"role"
:
"assistant"
,
"content"
:
"{{content}}"
}]
assistant_slot
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
False
,
tokenize
=
False
)
assistant_slot
=
assistant_slot
[
len
(
prefix
)
+
len
(
user_slot
)
:]
template_class
=
ReasoningTemplate
if
"<think>"
in
assistant_slot
else
Template
assistant_slot
=
assistant_slot
.
replace
(
"<think>"
,
""
).
replace
(
"</think>"
,
""
).
lstrip
(
"
\n
"
)
# remove thought tags
if
len
(
user_slot
)
>
len
(
user_slot_empty_system
):
...
...
@@ -501,7 +565,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
else
:
# if defaut_system is empty, user_slot_empty_system will be longer than user_slot
default_system
=
""
return
T
emplate
(
return
t
emplate
_class
(
format_user
=
StringFormatter
(
slots
=
[
user_slot
]),
format_assistant
=
StringFormatter
(
slots
=
[
assistant_slot
]),
format_system
=
StringFormatter
(
slots
=
[
system_slot
]),
...
...
@@ -515,6 +579,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
efficient_eos
=
False
,
replace_eos
=
False
,
replace_jinja_template
=
False
,
enable_thinking
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"base"
),
)
...
...
@@ -543,6 +608,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template
.
format_function
=
FunctionFormatter
(
slots
=
default_slots
,
tool_format
=
data_args
.
tool_format
)
template
.
format_tools
=
ToolFormatter
(
tool_format
=
data_args
.
tool_format
)
if
data_args
.
default_system
is
not
None
:
logger
.
info_rank0
(
f
"Using default system message:
{
data_args
.
default_system
}
."
)
template
.
default_system
=
data_args
.
default_system
template
.
enable_thinking
=
data_args
.
enable_thinking
template
.
fix_special_tokens
(
tokenizer
)
template
.
fix_jinja_template
(
tokenizer
)
return
template
...
...
@@ -756,6 +826,7 @@ register_template(
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
...
...
@@ -774,6 +845,15 @@ register_template(
)
# copied from deepseek3 template
register_template
(
name
=
"deepseekr1"
,
format_user
=
StringFormatter
(
slots
=
[
"<|User|>{{content}}<|Assistant|>"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
template_class
=
ReasoningTemplate
,
)
register_template
(
name
=
"deepseekcoder"
,
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n
### Response:"
]),
...
...
@@ -838,6 +918,7 @@ register_template(
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
replace_eos
=
True
,
template_class
=
Llama2Template
,
)
...
...
@@ -853,6 +934,7 @@ register_template(
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
"gemma3"
,
image_token
=
"<image_soft_token>"
),
template_class
=
Llama2Template
,
)
...
...
@@ -872,6 +954,22 @@ register_template(
)
# copied from glm4 template
register_template
(
name
=
"glmz1"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}"
],
tool_format
=
"glm4"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|observation|>
\n
{{content}}<|assistant|>"
]),
format_tools
=
ToolFormatter
(
tool_format
=
"glm4"
),
format_prefix
=
EmptyFormatter
(
slots
=
[
"[gMASK]<sop>"
]),
stop_words
=
[
"<|user|>"
,
"<|observation|>"
],
efficient_eos
=
True
,
template_class
=
ReasoningTemplate
,
)
register_template
(
name
=
"granite3"
,
format_user
=
StringFormatter
(
...
...
@@ -973,6 +1071,7 @@ register_template(
stop_words
=
[
"<|im_end|>"
],
thought_words
=
(
"◁think▷"
,
"◁/think▷"
),
mm_plugin
=
get_mm_plugin
(
"kimi_vl"
,
image_token
=
"<|media_pad|>"
),
template_class
=
ReasoningTemplate
,
)
...
...
@@ -1018,6 +1117,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
replace_eos
=
True
,
)
...
...
@@ -1037,6 +1137,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot|>"
,
"<|eom|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"llama4"
,
image_token
=
"<|image|>"
),
)
...
...
@@ -1066,6 +1167,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"mllama"
,
image_token
=
"<|image|>"
),
)
...
...
@@ -1079,6 +1181,7 @@ register_template(
format_system
=
StringFormatter
(
slots
=
[
"<|im_system|>system<|im_middle|>{{content}}<|im_end|>"
]),
default_system
=
"You are a helpful assistant provided by Moonshot-AI."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
...
...
@@ -1131,6 +1234,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
)
...
...
@@ -1163,6 +1267,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
)
...
...
@@ -1233,6 +1338,42 @@ register_template(
)
# copied from qwen template
register_template
(
name
=
"mimo"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
template_class
=
ReasoningTemplate
,
)
# copied from qwen2vl
register_template
(
name
=
"mimo_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are MiMo, an AI assistant developed by Xiaomi."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_vl"
,
image_token
=
"<|image_pad|>"
,
video_token
=
"<|video_pad|>"
),
template_class
=
ReasoningTemplate
,
)
# copied from chatml template
register_template
(
name
=
"minicpm_v"
,
...
...
@@ -1363,6 +1504,7 @@ register_template(
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
),
template_class
=
Llama2Template
,
)
...
...
@@ -1374,6 +1516,7 @@ register_template(
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
stop_words
=
[
"<|end|>"
],
replace_eos
=
True
,
)
...
...
@@ -1384,6 +1527,7 @@ register_template(
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"<|endoftext|>"
}]),
stop_words
=
[
"<|end|>"
],
replace_eos
=
True
,
)
...
...
@@ -1395,6 +1539,7 @@ register_template(
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system<|im_sep|>{{content}}<|im_end|>"
]),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
...
...
@@ -1425,6 +1570,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
...
...
@@ -1440,6 +1586,8 @@ register_template(
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
template_class
=
ReasoningTemplate
,
)
...
...
@@ -1451,6 +1599,7 @@ register_template(
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_audio"
,
audio_token
=
"<|AUDIO|>"
),
)
...
...
@@ -1468,6 +1617,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_omni"
,
audio_token
=
"<|AUDIO|>"
,
image_token
=
"<|IMAGE|>"
,
video_token
=
"<|VIDEO|>"
),
...
...
@@ -1486,6 +1636,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_vl"
,
image_token
=
"<|image_pad|>"
,
video_token
=
"<|video_pad|>"
),
)
...
...
@@ -1503,6 +1654,20 @@ register_template(
)
register_template
(
name
=
"seed_coder"
,
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"user
\n
{{content}}"
,
{
"eos_token"
},
{
"bos_token"
},
"assistant
\n
"
]
),
format_system
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"system
\n
{{content}}"
,
{
"eos_token"
}]),
default_system
=
(
"You are an AI programming assistant, utilizing the Seed-Coder model, developed by ByteDance Seed, "
"and you only answer questions related to computer science. For politically sensitive questions, "
"security and privacy issues, and other non-computer science questions, you will refuse to answer.
\n\n
"
),
)
# copied from llama3 template
register_template
(
name
=
"skywork_o1"
,
...
...
@@ -1538,6 +1703,25 @@ register_template(
)
register_template
(
name
=
"smollm"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
)
register_template
(
name
=
"smollm2"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
default_system
=
"You are a helpful AI assistant named SmolLM, trained by Hugging Face."
,
)
register_template
(
name
=
"solar"
,
format_user
=
StringFormatter
(
slots
=
[
"### User:
\n
{{content}}
\n\n
### Assistant:
\n
"
]),
...
...
src/llamafactory/data/tool_utils.py
View file @
0722acf1
...
...
@@ -93,6 +93,7 @@ class DefaultToolUtils(ToolUtils):
tool_text
=
""
tool_names
=
[]
for
tool
in
tools
:
tool
=
tool
.
get
(
"function"
,
""
)
if
tool
.
get
(
"type"
)
==
"function"
else
tool
param_text
=
""
for
name
,
param
in
tool
[
"parameters"
][
"properties"
].
items
():
required
,
enum
,
items
=
""
,
""
,
""
...
...
@@ -124,11 +125,7 @@ class DefaultToolUtils(ToolUtils):
@
override
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
function_text
=
""
for
name
,
arguments
in
functions
:
function_text
+=
f
"Action:
{
name
}
\n
Action Input:
{
arguments
}
\n
"
return
function_text
return
"
\n
"
.
join
([
f
"Action:
{
name
}
\n
Action Input:
{
arguments
}
"
for
name
,
arguments
in
functions
])
@
override
@
staticmethod
...
...
@@ -159,6 +156,7 @@ class GLM4ToolUtils(ToolUtils):
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
tool_text
=
""
for
tool
in
tools
:
tool
=
tool
.
get
(
"function"
,
""
)
if
tool
.
get
(
"type"
)
==
"function"
else
tool
tool_text
+=
"
\n\n
## {name}
\n\n
{body}
\n
在调用上述函数时,请使用 Json 格式表示调用的参数。"
.
format
(
name
=
tool
[
"name"
],
body
=
json
.
dumps
(
tool
,
indent
=
4
,
ensure_ascii
=
False
)
)
...
...
@@ -200,7 +198,7 @@ class Llama3ToolUtils(ToolUtils):
date
=
datetime
.
now
().
strftime
(
"%d %b %Y"
)
tool_text
=
""
for
tool
in
tools
:
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
tool
}
wrapped_tool
=
tool
if
tool
.
get
(
"type"
)
==
"function"
else
{
"type"
:
"function"
,
"function"
:
tool
}
tool_text
+=
json
.
dumps
(
wrapped_tool
,
indent
=
4
,
ensure_ascii
=
False
)
+
"
\n\n
"
return
LLAMA3_TOOL_PROMPT
.
format
(
date
=
date
,
tool_text
=
tool_text
)
...
...
@@ -208,24 +206,23 @@ class Llama3ToolUtils(ToolUtils):
@
override
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
if
len
(
functions
)
>
1
:
raise
ValueError
(
"Llama-3 does not support parallel functions."
)
return
f
'{{"name": "
{
functions
[
0
].
name
}
", "parameters":
{
functions
[
0
].
arguments
}
}}'
function_objects
=
[{
"name"
:
name
,
"parameters"
:
json
.
loads
(
arguments
)}
for
name
,
arguments
in
functions
]
return
json
.
dumps
(
function_objects
[
0
]
if
len
(
function_objects
)
==
1
else
function_objects
,
ensure_ascii
=
False
)
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
try
:
tool
=
json
.
loads
(
content
.
strip
())
tool
s
=
json
.
loads
(
content
.
strip
())
except
json
.
JSONDecodeError
:
return
content
if
"name"
not
in
tool
or
"parameters"
not
in
tool
:
tools
=
[
tools
]
if
not
isinstance
(
tools
,
list
)
else
tools
try
:
return
[
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"parameters"
],
ensure_ascii
=
False
))
for
tool
in
tools
]
except
KeyError
:
return
content
return
[
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"parameters"
],
ensure_ascii
=
False
))]
class
MistralToolUtils
(
ToolUtils
):
r
"""Mistral v0.3 tool using template."""
...
...
@@ -235,18 +232,16 @@ class MistralToolUtils(ToolUtils):
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
wrapped_tools
=
[]
for
tool
in
tools
:
wrapped_tools
.
append
({
"type"
:
"function"
,
"function"
:
tool
})
wrapped_tools
.
append
(
tool
if
tool
.
get
(
"type"
)
==
"function"
else
{
"type"
:
"function"
,
"function"
:
tool
})
return
"[AVAILABLE_TOOLS] "
+
json
.
dumps
(
wrapped_tools
,
ensure_ascii
=
False
)
+
"[/AVAILABLE_TOOLS]"
@
override
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
function_texts
=
[]
for
name
,
arguments
in
functions
:
function_texts
.
append
(
f
'{{"name": "
{
name
}
", "arguments":
{
arguments
}
}}'
)
return
"["
+
", "
.
join
(
function_texts
)
+
"]"
return
json
.
dumps
(
[{
"name"
:
name
,
"arguments"
:
json
.
loads
(
arguments
)}
for
name
,
arguments
in
functions
],
ensure_ascii
=
False
)
@
override
@
staticmethod
...
...
@@ -256,17 +251,11 @@ class MistralToolUtils(ToolUtils):
except
json
.
JSONDecodeError
:
return
content
if
not
isinstance
(
tools
,
list
):
tools
=
[
tools
]
results
=
[]
for
tool
in
tools
:
if
"name"
not
in
tool
or
"arguments"
not
in
tool
:
return
content
results
.
append
(
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"arguments"
],
ensure_ascii
=
False
)))
return
results
tools
=
[
tools
]
if
not
isinstance
(
tools
,
list
)
else
tools
try
:
return
[
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"arguments"
],
ensure_ascii
=
False
))
for
tool
in
tools
]
except
KeyError
:
return
content
class
QwenToolUtils
(
ToolUtils
):
...
...
@@ -277,7 +266,7 @@ class QwenToolUtils(ToolUtils):
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
tool_text
=
""
for
tool
in
tools
:
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
tool
}
wrapped_tool
=
tool
if
tool
.
get
(
"type"
)
==
"function"
else
{
"type"
:
"function"
,
"function"
:
tool
}
tool_text
+=
"
\n
"
+
json
.
dumps
(
wrapped_tool
,
ensure_ascii
=
False
)
return
QWEN_TOOL_PROMPT
.
format
(
tool_text
=
tool_text
)
...
...
@@ -285,13 +274,11 @@ class QwenToolUtils(ToolUtils):
@
override
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
function_texts
=
[]
for
name
,
arguments
in
functions
:
function_texts
.
append
(
"<tool_call>
\n
"
+
f
'{{"name": "
{
name
}
", "arguments":
{
arguments
}
}}'
+
"
\n
</tool_call>"
)
return
"
\n
"
.
join
(
function_texts
)
function_texts
=
[
json
.
dumps
({
"name"
:
name
,
"arguments"
:
json
.
loads
(
arguments
)},
ensure_ascii
=
False
)
for
name
,
arguments
in
functions
]
return
"
\n
"
.
join
([
f
"<tool_call>
\n
{
text
}
\n
</tool_call>"
for
text
in
function_texts
])
@
override
@
staticmethod
...
...
src/llamafactory/extras/constants.py
View file @
0722acf1
...
...
@@ -513,7 +513,7 @@ register_model_group(
register_model_group
(
models
=
{
"DeepSeek-V2-236B-
Chat-0628
"
:
{
"DeepSeek-V2-236B-
0628-Chat
"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2-Chat-0628"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2-Chat-0628"
,
},
...
...
@@ -521,7 +521,7 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2.5"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2.5"
,
},
"DeepSeek-V2.5-236B-
Chat-1210
"
:
{
"DeepSeek-V2.5-236B-
1210-Chat
"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2.5-1210"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2.5-1210"
,
},
...
...
@@ -533,6 +533,17 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V3"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V3"
,
},
"DeepSeek-V3-671B-0324-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V3-0324"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V3-0324"
,
},
},
template
=
"deepseek3"
,
)
register_model_group
(
models
=
{
"DeepSeek-R1-1.5B-Distill"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
,
...
...
@@ -545,6 +556,10 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
,
},
"DeepSeek-R1-8B-0528-Distill"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-0528-Qwen3-8B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-0528-Qwen3-8B"
,
},
"DeepSeek-R1-14B-Distill"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
,
...
...
@@ -565,8 +580,12 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1"
,
},
"DeepSeek-R1-671B-0528-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-0528"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-0528"
,
},
},
template
=
"deepseek
3
"
,
template
=
"deepseek
r1
"
,
)
...
...
@@ -673,6 +692,10 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"google/gemma-3-1b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-1b-it"
,
},
"MedGemma-27B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/medgemma-27b-text-it"
,
DownloadSource
.
MODELSCOPE
:
"google/medgemma-27b-text-it"
,
},
},
template
=
"gemma"
,
)
...
...
@@ -704,6 +727,14 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"google/gemma-3-27b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-27b-it"
,
},
"MedGemma-4B"
:
{
DownloadSource
.
DEFAULT
:
"google/medgemma-4b-pt"
,
DownloadSource
.
MODELSCOPE
:
"google/medgemma-4b-pt"
,
},
"MedGemma-4B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/medgemma-4b-it"
,
DownloadSource
.
MODELSCOPE
:
"google/medgemma-4b-it"
,
},
},
template
=
"gemma3"
,
multimodal
=
True
,
...
...
@@ -737,6 +768,13 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"THUDM/GLM-4-32B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4-32B-0414"
,
},
},
template
=
"glm4"
,
)
register_model_group
(
models
=
{
"GLM-Z1-9B-0414-Chat"
:
{
DownloadSource
.
DEFAULT
:
"THUDM/GLM-Z1-9B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-Z1-9B-0414"
,
...
...
@@ -746,7 +784,7 @@ register_model_group(
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-Z1-32B-0414"
,
},
},
template
=
"glm
4
"
,
template
=
"glm
z1
"
,
)
...
...
@@ -869,12 +907,13 @@ register_model_group(
register_model_group
(
models
=
{
"Granite-
3.2-1B-A400M-Base
"
:
{
"Granite-
Vision-3.2-2B
"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-vision-3.2-2b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-vision-3.2-2b"
,
},
},
template
=
"granite3_vision"
,
multimodal
=
True
,
)
...
...
@@ -1398,6 +1437,45 @@ register_model_group(
)
register_model_group
(
models
=
{
"MiMo-7B-Base"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-7B-Base"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-7B-Base"
,
},
"MiMo-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-7B-SFT"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-7B-SFT"
,
},
"MiMo-7B-Instruct-RL"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-7B-RL"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-7B-RL"
,
},
"MiMo-7B-RL-ZERO"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-7B-RL-ZERO"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-7B-RL-ZERO"
,
},
},
template
=
"mimo"
,
)
register_model_group
(
models
=
{
"MiMo-7B-VL-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-VL-7B-SFT"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-VL-7B-SFT"
,
},
"MiMo-7B-VL-RL"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-VL-7B-RL"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-VL-7B-RL"
,
},
},
template
=
"mimo_vl"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"MiniCPM-2B-SFT-Chat"
:
{
...
...
@@ -2461,6 +2539,38 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-235B-A22B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-235B-A22B"
,
},
"Qwen3-0.6B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-0.6B-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-0.6B-GPTQ-Int8"
,
},
"Qwen3-1.7B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-1.7B-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-1.7B-GPTQ-Int8"
,
},
"Qwen3-4B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-4B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-4B-AWQ"
,
},
"Qwen3-8B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-8B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-8B-AWQ"
,
},
"Qwen3-14B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-14B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-14B-AWQ"
,
},
"Qwen3-32B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-32B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-32B-AWQ"
,
},
"Qwen3-30B-A3B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-30B-A3B-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-30B-A3B-GPTQ-Int4"
,
},
"Qwen3-235B-A22B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-235B-A22B-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-235B-A22B-GPTQ-Int4"
,
},
},
template
=
"qwen3"
,
)
...
...
@@ -2484,10 +2594,22 @@ register_model_group(
register_model_group
(
models
=
{
"Qwen2.5-Omni-3B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-3B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-3B"
,
},
"Qwen2.5-Omni-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-7B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-7B"
,
}
},
"Qwen2.5-Omni-7B-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-7B-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-7B-GPTQ-Int4"
,
},
"Qwen2.5-Omni-7B-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-7B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-7B-AWQ"
,
},
},
template
=
"qwen2_omni"
,
multimodal
=
True
,
...
...
@@ -2598,15 +2720,17 @@ register_model_group(
register_model_group
(
models
=
{
"S
OLAR-10.7B-v1.0
"
:
{
DownloadSource
.
DEFAULT
:
"
upstage/SOLAR-10.7B-v1.0
"
,
"S
eed-Coder-8B-Base
"
:
{
DownloadSource
.
DEFAULT
:
"
ByteDance-Seed/Seed-Coder-8B-Base
"
,
},
"SOLAR-10.7B-Instruct-v1.0"
:
{
DownloadSource
.
DEFAULT
:
"upstage/SOLAR-10.7B-Instruct-v1.0"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/SOLAR-10.7B-Instruct-v1.0"
,
"Seed-Coder-8B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ByteDance-Seed/Seed-Coder-8B-Instruct"
,
},
"Seed-Coder-8B-Instruct-Reasoning"
:
{
DownloadSource
.
DEFAULT
:
"ByteDance-Seed/Seed-Coder-8B-Reasoning-bf16"
,
},
},
template
=
"s
ola
r"
,
template
=
"s
eed_code
r"
,
)
...
...
@@ -2631,6 +2755,82 @@ register_model_group(
)
register_model_group
(
models
=
{
"SmolLM-135M"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM-135M"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM-135M"
,
},
"SmolLM-360M"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM-360M"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM-360M"
,
},
"SmolLM-1.7B"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM-1.7B"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM-1.7B"
,
},
"SmolLM-135M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM-135M-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM-135M-Instruct"
,
},
"SmolLM-360M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM-360M-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM-360M-Instruct"
,
},
"SmolLM-1.7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM-1.7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM-1.7B-Instruct"
,
},
},
template
=
"smollm"
,
)
register_model_group
(
models
=
{
"SmolLM2-135M"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM2-135M"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM2-135M"
,
},
"SmolLM2-360M"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM2-360M"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM2-360M"
,
},
"SmolLM2-1.7B"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM2-1.7B"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM2-1.7B"
,
},
"SmolLM2-135M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM2-135M-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM2-135M-Instruct"
,
},
"SmolLM2-360M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM2-360M-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM2-360M-Instruct"
,
},
"SmolLM2-1.7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM2-1.7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM2-1.7B-Instruct"
,
},
},
template
=
"smollm2"
,
)
register_model_group
(
models
=
{
"SOLAR-10.7B-v1.0"
:
{
DownloadSource
.
DEFAULT
:
"upstage/SOLAR-10.7B-v1.0"
,
},
"SOLAR-10.7B-Instruct-v1.0"
:
{
DownloadSource
.
DEFAULT
:
"upstage/SOLAR-10.7B-Instruct-v1.0"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/SOLAR-10.7B-Instruct-v1.0"
,
},
},
template
=
"solar"
,
)
register_model_group
(
models
=
{
"StarCoder2-3B"
:
{
...
...
src/llamafactory/extras/env.py
View file @
0722acf1
...
...
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
platform
import
accelerate
...
...
@@ -83,4 +84,9 @@ def print_env() -> None:
except
Exception
:
pass
if
os
.
path
.
exists
(
"data"
):
info
[
"Default data directory"
]
=
"detected"
else
:
info
[
"Default data directory"
]
=
"not detected"
print
(
"
\n
"
+
"
\n
"
.
join
([
f
"-
{
key
}
:
{
value
}
"
for
key
,
value
in
info
.
items
()])
+
"
\n
"
)
src/llamafactory/extras/misc.py
View file @
0722acf1
...
...
@@ -79,20 +79,27 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
logger
.
warning_rank0_once
(
"Version checking has been disabled, may lead to unexpected behaviors."
)
return
if
"gptmodel"
in
requirement
or
"autoawq"
in
requirement
:
pip_command
=
f
"pip install
{
requirement
}
--no-build-isolation"
else
:
pip_command
=
f
"pip install
{
requirement
}
"
if
mandatory
:
hint
=
f
"To fix: run `pip
install
{
requirement
}
`."
hint
=
f
"To fix: run `
{
pip
_command
}
`."
else
:
hint
=
f
"To fix: run `pip
install
{
requirement
}
` or set `DISABLE_VERSION_CHECK=1` to skip this check."
hint
=
f
"To fix: run `
{
pip
_command
}
` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version
(
requirement
,
hint
)
def
check_dependencies
()
->
None
:
r
"""Check the version of the required packages."""
check_version
(
"transformers>=4.45.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0"
)
check_version
(
"datasets>=2.16.0,<=3.5.0"
)
check_version
(
"accelerate>=0.34.0,<=1.6.0"
)
check_version
(
"peft>=0.14.0,<=0.15.1"
)
check_version
(
"transformers>=4.45.0,<=4.52.4,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0,!=4.52.0"
)
check_version
(
"datasets>=2.16.0,<=3.6.0"
)
check_version
(
"accelerate>=0.34.0,<=1.7.0"
)
check_version
(
"peft>=0.14.0,<=0.15.2"
)
check_version
(
"trl>=0.8.6,<=0.9.6"
)
if
is_transformers_version_greater_than
(
"4.46.0"
)
and
not
is_transformers_version_greater_than
(
"4.48.1"
):
logger
.
warning_rank0_once
(
"There are known bugs in transformers v4.46.0-v4.48.0, please use other versions."
)
...
...
src/llamafactory/hparams/data_args.py
View file @
0722acf1
...
...
@@ -99,6 +99,10 @@ class DataArguments:
default
=
0.0
,
metadata
=
{
"help"
:
"Size of the validation set, should be an integer or a float in range `[0,1)`."
},
)
eval_on_each_dataset
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to evaluate on each dataset separately."
},
)
packing
:
Optional
[
bool
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Enable sequences packing in training. Will automatically enable in pre-training."
},
...
...
@@ -111,6 +115,14 @@ class DataArguments:
default
=
None
,
metadata
=
{
"help"
:
"Tool format to use for constructing function calling examples."
},
)
default_system
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Override the default system message in the template."
},
)
enable_thinking
:
Optional
[
bool
]
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to enable thinking mode for reasoning models."
},
)
tokenized_path
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
...
...
@@ -121,6 +133,10 @@ class DataArguments:
)
},
)
data_shared_file_system
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use a shared file system for the datasets."
},
)
def
__post_init__
(
self
):
def
split_arg
(
arg
):
...
...
src/llamafactory/hparams/generating_args.py
View file @
0722acf1
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Optional
from
typing
import
Any
from
transformers
import
GenerationConfig
...
...
@@ -62,10 +62,6 @@ class GeneratingArguments:
default
=
1.0
,
metadata
=
{
"help"
:
"Exponential penalty to the length that is used with beam-based generation."
},
)
default_system
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Default system message to use in chat completion."
},
)
skip_special_tokens
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to remove special tokens in the decoding."
},
...
...
src/llamafactory/hparams/model_args.py
View file @
0722acf1
...
...
@@ -235,10 +235,6 @@ class ProcessorArguments:
default
=
False
,
metadata
=
{
"help"
:
"Whether to crop the image to patches for internvl."
},
)
use_audio_in_video
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use audio in video inputs."
},
)
video_max_pixels
:
int
=
field
(
default
=
256
*
256
,
metadata
=
{
"help"
:
"The maximum number of pixels of video inputs."
},
...
...
@@ -255,6 +251,10 @@ class ProcessorArguments:
default
=
128
,
metadata
=
{
"help"
:
"The maximum number of sampled frames for video inputs."
},
)
use_audio_in_video
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use audio in video inputs."
},
)
audio_sampling_rate
:
int
=
field
(
default
=
16000
,
metadata
=
{
"help"
:
"The sampling rate of audio inputs."
},
...
...
@@ -364,6 +364,12 @@ class SGLangArguments:
default
=
None
,
metadata
=
{
"help"
:
"Config to initialize the SGLang engine. Please use JSON strings."
},
)
sglang_lora_backend
:
Literal
[
"triton"
,
"flashinfer"
]
=
field
(
default
=
"triton"
,
metadata
=
{
"help"
:
"The backend of running GEMM kernels for Lora modules. Recommend using the Triton LoRA backend for better performance and stability."
},
)
def
__post_init__
(
self
):
if
isinstance
(
self
.
sglang_config
,
str
)
and
self
.
sglang_config
.
startswith
(
"{"
):
...
...
src/llamafactory/hparams/parser.py
View file @
0722acf1
...
...
@@ -148,10 +148,10 @@ def _check_extra_dependencies(
check_version
(
"mixture-of-depth>=1.1.6"
,
mandatory
=
True
)
if
model_args
.
infer_backend
==
EngineName
.
VLLM
:
check_version
(
"vllm>=0.4.3,<=0.8.
4
"
)
check_version
(
"vllm>=0.4.3,<=0.8.
6
"
)
check_version
(
"vllm"
,
mandatory
=
True
)
elif
model_args
.
infer_backend
==
EngineName
.
SGLANG
:
check_version
(
"sglang>=0.4.
4
"
)
check_version
(
"sglang>=0.4.
5
"
)
check_version
(
"sglang"
,
mandatory
=
True
)
if
finetuning_args
.
use_galore
:
...
...
src/llamafactory/hparams/training_args.py
View file @
0722acf1
...
...
@@ -64,6 +64,7 @@ class RayArguments:
raise
ValueError
(
f
"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got
{
self
.
ray_storage_filesystem
}
"
)
import
pyarrow.fs
as
fs
if
self
.
ray_storage_filesystem
==
"s3"
:
...
...
src/llamafactory/model/model_utils/attention.py
View file @
0722acf1
...
...
@@ -29,10 +29,8 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
def
configure_attn_implementation
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"gemma2"
and
is_trainable
:
def
configure_attn_implementation
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
)
->
None
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"gemma2"
:
if
model_args
.
flash_attn
==
AttentionFunction
.
AUTO
or
model_args
.
flash_attn
==
AttentionFunction
.
FA2
:
if
is_flash_attn_2_available
():
if
model_args
.
flash_attn
!=
AttentionFunction
.
FA2
:
...
...
src/llamafactory/model/model_utils/liger_kernel.py
View file @
0722acf1
...
...
@@ -45,16 +45,24 @@ def apply_liger_kernel(
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3
as
apply_liger_kernel
elif
model_type
==
"gemma3_text"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3_text
as
apply_liger_kernel
elif
model_type
==
"paligemma"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_paligemma
as
apply_liger_kernel
elif
model_type
==
"glm4"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_glm4
as
apply_liger_kernel
elif
model_type
==
"granite"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_granite
as
apply_liger_kernel
elif
model_type
==
"llama"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_llama
as
apply_liger_kernel
elif
model_type
==
"llava"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_llava
as
apply_liger_kernel
elif
model_type
==
"mistral"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_mistral
as
apply_liger_kernel
elif
model_type
==
"mixtral"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_mixtral
as
apply_liger_kernel
elif
model_type
==
"mllama"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_mllama
as
apply_liger_kernel
elif
model_type
==
"olmo2"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_olmo2
as
apply_liger_kernel
elif
model_type
==
"paligemma"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_paligemma
as
apply_liger_kernel
elif
model_type
==
"phi3"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_phi3
as
apply_liger_kernel
elif
model_type
==
"qwen2"
:
...
...
@@ -63,6 +71,8 @@ def apply_liger_kernel(
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen2_vl
as
apply_liger_kernel
elif
model_type
==
"qwen2_5_vl"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen2_5_vl
as
apply_liger_kernel
elif
model_type
==
"qwen3"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen3
as
apply_liger_kernel
else
:
logger
.
warning_rank0
(
"Current model does not support liger kernel."
)
return
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment