Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
0722acf1
Commit
0722acf1
authored
Jun 04, 2025
by
chenych
Browse files
Update 0604
parent
c4ba4563
Changes
68
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
690 additions
and
323 deletions
+690
-323
src/llamafactory/chat/sglang_engine.py
src/llamafactory/chat/sglang_engine.py
+15
-1
src/llamafactory/chat/vllm_engine.py
src/llamafactory/chat/vllm_engine.py
+0
-1
src/llamafactory/cli.py
src/llamafactory/cli.py
+1
-1
src/llamafactory/data/converter.py
src/llamafactory/data/converter.py
+21
-6
src/llamafactory/data/data_utils.py
src/llamafactory/data/data_utils.py
+28
-27
src/llamafactory/data/loader.py
src/llamafactory/data/loader.py
+12
-7
src/llamafactory/data/mm_plugin.py
src/llamafactory/data/mm_plugin.py
+101
-174
src/llamafactory/data/parser.py
src/llamafactory/data/parser.py
+1
-6
src/llamafactory/data/template.py
src/llamafactory/data/template.py
+210
-26
src/llamafactory/data/tool_utils.py
src/llamafactory/data/tool_utils.py
+26
-39
src/llamafactory/extras/constants.py
src/llamafactory/extras/constants.py
+212
-12
src/llamafactory/extras/env.py
src/llamafactory/extras/env.py
+6
-0
src/llamafactory/extras/misc.py
src/llamafactory/extras/misc.py
+13
-6
src/llamafactory/hparams/data_args.py
src/llamafactory/hparams/data_args.py
+16
-0
src/llamafactory/hparams/generating_args.py
src/llamafactory/hparams/generating_args.py
+1
-5
src/llamafactory/hparams/model_args.py
src/llamafactory/hparams/model_args.py
+10
-4
src/llamafactory/hparams/parser.py
src/llamafactory/hparams/parser.py
+2
-2
src/llamafactory/hparams/training_args.py
src/llamafactory/hparams/training_args.py
+1
-0
src/llamafactory/model/model_utils/attention.py
src/llamafactory/model/model_utils/attention.py
+2
-4
src/llamafactory/model/model_utils/liger_kernel.py
src/llamafactory/model/model_utils/liger_kernel.py
+12
-2
No files found.
src/llamafactory/chat/sglang_engine.py
View file @
0722acf1
...
@@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine):
...
@@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine):
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
data_args
)
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
data_args
)
self
.
template
.
mm_plugin
.
expand_mm_tokens
=
False
# for sglang generate
self
.
template
.
mm_plugin
.
expand_mm_tokens
=
False
# for sglang generate
self
.
generating_args
=
generating_args
.
to_dict
()
self
.
generating_args
=
generating_args
.
to_dict
()
if
model_args
.
adapter_name_or_path
is
not
None
:
self
.
lora_request
=
True
else
:
self
.
lora_request
=
False
launch_cmd
=
[
launch_cmd
=
[
"python3 -m sglang.launch_server"
,
"python3 -m sglang.launch_server"
,
...
@@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine):
...
@@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine):
f
"--download-dir
{
model_args
.
cache_dir
}
"
,
f
"--download-dir
{
model_args
.
cache_dir
}
"
,
"--log-level error"
,
"--log-level error"
,
]
]
if
self
.
lora_request
:
launch_cmd
.
extend
(
[
"--max-loras-per-batch 1"
,
f
"--lora-backend
{
model_args
.
sglang_lora_backend
}
"
,
f
"--lora-paths lora0=
{
model_args
.
adapter_name_or_path
[
0
]
}
"
,
"--disable-radix-cache"
,
]
)
launch_cmd
=
" "
.
join
(
launch_cmd
)
launch_cmd
=
" "
.
join
(
launch_cmd
)
logger
.
info_rank0
(
f
"Starting SGLang server with command:
{
launch_cmd
}
"
)
logger
.
info_rank0
(
f
"Starting SGLang server with command:
{
launch_cmd
}
"
)
try
:
try
:
...
@@ -147,7 +160,6 @@ class SGLangEngine(BaseEngine):
...
@@ -147,7 +160,6 @@ class SGLangEngine(BaseEngine):
messages
,
images
or
[],
videos
or
[],
audios
or
[],
self
.
processor
messages
,
images
or
[],
videos
or
[],
audios
or
[],
self
.
processor
)
)
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
self
.
generating_args
[
"default_system"
]
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_length
=
len
(
prompt_ids
)
prompt_length
=
len
(
prompt_ids
)
...
@@ -200,6 +212,8 @@ class SGLangEngine(BaseEngine):
...
@@ -200,6 +212,8 @@ class SGLangEngine(BaseEngine):
"sampling_params"
:
sampling_params
,
"sampling_params"
:
sampling_params
,
"stream"
:
True
,
"stream"
:
True
,
}
}
if
self
.
lora_request
:
json_data
[
"lora_request"
]
=
[
"lora0"
]
response
=
requests
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
json_data
,
stream
=
True
)
response
=
requests
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
json_data
,
stream
=
True
)
if
response
.
status_code
!=
200
:
if
response
.
status_code
!=
200
:
raise
RuntimeError
(
f
"SGLang server error:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
raise
RuntimeError
(
f
"SGLang server error:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
...
...
src/llamafactory/chat/vllm_engine.py
View file @
0722acf1
...
@@ -124,7 +124,6 @@ class VllmEngine(BaseEngine):
...
@@ -124,7 +124,6 @@ class VllmEngine(BaseEngine):
messages
,
images
or
[],
videos
or
[],
audios
or
[],
self
.
processor
messages
,
images
or
[],
videos
or
[],
audios
or
[],
self
.
processor
)
)
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
self
.
generating_args
[
"default_system"
]
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_length
=
len
(
prompt_ids
)
prompt_length
=
len
(
prompt_ids
)
...
...
src/llamafactory/cli.py
View file @
0722acf1
...
@@ -73,7 +73,7 @@ def main():
...
@@ -73,7 +73,7 @@ def main():
"help"
:
partial
(
print
,
USAGE
),
"help"
:
partial
(
print
,
USAGE
),
}
}
command
=
sys
.
argv
.
pop
(
1
)
if
len
(
sys
.
argv
)
>
=
1
else
"help"
command
=
sys
.
argv
.
pop
(
1
)
if
len
(
sys
.
argv
)
>
1
else
"help"
if
command
==
"train"
and
(
is_env_enabled
(
"FORCE_TORCHRUN"
)
or
(
get_device_count
()
>
1
and
not
use_ray
())):
if
command
==
"train"
and
(
is_env_enabled
(
"FORCE_TORCHRUN"
)
or
(
get_device_count
()
>
1
and
not
use_ray
())):
# launch distributed training
# launch distributed training
nnodes
=
os
.
getenv
(
"NNODES"
,
"1"
)
nnodes
=
os
.
getenv
(
"NNODES"
,
"1"
)
...
...
src/llamafactory/data/converter.py
View file @
0722acf1
...
@@ -51,12 +51,27 @@ class DatasetConverter:
...
@@ -51,12 +51,27 @@ class DatasetConverter:
else
:
else
:
medias
=
medias
[:]
medias
=
medias
[:]
if
self
.
dataset_attr
.
load_from
in
[
"script"
,
"file"
]
and
isinstance
(
medias
[
0
],
str
):
if
self
.
dataset_attr
.
load_from
in
[
"script"
,
"file"
]:
for
i
in
range
(
len
(
medias
)):
if
isinstance
(
medias
[
0
],
str
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
self
.
data_args
.
media_dir
,
medias
[
i
])):
for
i
in
range
(
len
(
medias
)):
medias
[
i
]
=
os
.
path
.
join
(
self
.
data_args
.
media_dir
,
medias
[
i
])
media_path
=
os
.
path
.
join
(
self
.
data_args
.
media_dir
,
medias
[
i
])
else
:
if
os
.
path
.
isfile
(
media_path
):
logger
.
warning_rank0_once
(
f
"Media
{
medias
[
i
]
}
does not exist in `media_dir`. Use original path."
)
medias
[
i
]
=
media_path
else
:
logger
.
warning_rank0_once
(
f
"Media
{
medias
[
i
]
}
does not exist in `media_dir`. Use original path."
)
elif
isinstance
(
medias
[
0
],
list
):
# for processed video frames
# medias is a list of lists, e.g., [[frame1.jpg, frame2.jpg], [frame3.jpg, frame4.jpg]]
for
i
in
range
(
len
(
medias
)):
for
j
in
range
(
len
(
medias
[
i
])):
media_path
=
os
.
path
.
join
(
self
.
data_args
.
media_dir
,
medias
[
i
][
j
])
if
os
.
path
.
isfile
(
media_path
):
medias
[
i
][
j
]
=
media_path
else
:
logger
.
warning_rank0_once
(
f
"Media
{
medias
[
i
][
j
]
}
does not exist in `media_dir`. Use original path."
)
return
medias
return
medias
...
...
src/llamafactory/data/data_utils.py
View file @
0722acf1
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
json
import
json
from
enum
import
Enum
,
unique
from
enum
import
Enum
,
unique
from
typing
import
TYPE_CHECKING
,
Optional
,
TypedDict
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
TypedDict
,
Union
import
fsspec
import
fsspec
from
datasets
import
DatasetDict
,
concatenate_datasets
,
interleave_datasets
from
datasets
import
DatasetDict
,
concatenate_datasets
,
interleave_datasets
...
@@ -142,48 +142,49 @@ def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModu
...
@@ -142,48 +142,49 @@ def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModu
return
dataset_module
return
dataset_module
def
setup_fs
(
path
,
anon
=
False
)
:
def
setup_fs
(
path
:
str
,
anon
:
bool
=
False
)
->
"fsspec.AbstractFileSystem"
:
"""Set up a filesystem object based on the path protocol."""
r
"""Set up a filesystem object based on the path protocol."""
storage_options
=
{
"anon"
:
anon
}
if
anon
else
{}
storage_options
=
{
"anon"
:
anon
}
if
anon
else
{}
if
path
.
startswith
(
"s3://"
):
if
path
.
startswith
(
"s3://"
):
fs
=
fsspec
.
filesystem
(
"s3"
,
**
storage_options
)
fs
=
fsspec
.
filesystem
(
"s3"
,
**
storage_options
)
elif
path
.
startswith
((
"gs://"
,
"gcs://"
)):
elif
path
.
startswith
((
"gs://"
,
"gcs://"
)):
fs
=
fsspec
.
filesystem
(
"gcs"
,
**
storage_options
)
fs
=
fsspec
.
filesystem
(
"gcs"
,
**
storage_options
)
else
:
else
:
raise
ValueError
(
f
"Unsupported protocol in path:
{
path
}
. Use 's3://' or 'gs://'"
)
raise
ValueError
(
f
"Unsupported protocol in path:
{
path
}
. Use 's3://' or 'gs://'."
)
if
not
fs
.
exists
(
path
):
raise
ValueError
(
f
"Path does not exist:
{
path
}
."
)
return
fs
return
fs
def
read_cloud_json
(
cloud_path
):
def
_read_json_with_fs
(
fs
:
"fsspec.AbstractFileSystem"
,
path
:
str
)
->
list
[
Any
]:
"""Read a JSON/JSONL file from cloud storage (S3 or GCS).
r
"""Helper function to read JSON/JSONL files using fsspec."""
with
fs
.
open
(
path
,
"r"
)
as
f
:
if
path
.
endswith
(
".jsonl"
):
return
[
json
.
loads
(
line
)
for
line
in
f
if
line
.
strip
()]
else
:
return
json
.
load
(
f
)
def
read_cloud_json
(
cloud_path
:
str
)
->
list
[
Any
]:
r
"""Read a JSON/JSONL file from cloud storage (S3 or GCS).
Args:
Args:
cloud_path
: str
cloud_path: str
Cloud path in the format:
Cloud path in the format:
- 's3://bucket-name/file.json' for AWS S3
- 's3://bucket-name/file.json' for AWS S3
- 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage
- 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage
lines : bool, default=True
If True, read the file as JSON Lines format (one JSON object per line)
"""
"""
try
:
try
:
# Try with anonymous access first
fs
=
setup_fs
(
cloud_path
,
anon
=
True
)
# try with anonymous access first
fs
=
setup_fs
(
cloud_path
,
anon
=
True
)
return
_read_json_with_fs
(
fs
,
cloud_path
,
lines
=
cloud_path
.
endswith
(
".jsonl"
))
except
Exception
:
except
Exception
:
# Try again with credentials
fs
=
setup_fs
(
cloud_path
)
# try again with credentials
fs
=
setup_fs
(
cloud_path
)
return
_read_json_with_fs
(
fs
,
cloud_path
,
lines
=
cloud_path
.
endswith
(
".jsonl"
))
def
_read_json_with_fs
(
fs
,
path
,
lines
=
True
):
# filter out non-JSON files
"""Helper function to read JSON/JSONL files using fsspec."""
files
=
[
x
[
"Key"
]
for
x
in
fs
.
listdir
(
cloud_path
)]
if
fs
.
isdir
(
cloud_path
)
else
[
cloud_path
]
with
fs
.
open
(
path
,
"r"
)
as
f
:
files
=
filter
(
lambda
file
:
file
.
endswith
(
".json"
)
or
file
.
endswith
(
".jsonl"
),
files
)
if
lines
:
if
not
files
:
# Read JSONL (JSON Lines) format - one JSON object per line
raise
ValueError
(
f
"No JSON/JSONL files found in the specified path:
{
cloud_path
}
."
)
data
=
[
json
.
loads
(
line
)
for
line
in
f
if
line
.
strip
()]
else
:
# Read regular JSON format
data
=
json
.
load
(
f
)
return
data
return
sum
([
_read_json_with_fs
(
fs
,
file
)
for
file
in
files
],
[])
src/llamafactory/data/loader.py
View file @
0722acf1
...
@@ -168,7 +168,7 @@ def _get_merged_dataset(
...
@@ -168,7 +168,7 @@ def _get_merged_dataset(
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
merge
:
bool
=
Tru
e
,
return_dict
:
bool
=
Fals
e
,
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
dict
[
str
,
"Dataset"
]]]:
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
dict
[
str
,
"Dataset"
]]]:
r
"""Return the merged datasets in the standard format."""
r
"""Return the merged datasets in the standard format."""
if
dataset_names
is
None
:
if
dataset_names
is
None
:
...
@@ -181,10 +181,10 @@ def _get_merged_dataset(
...
@@ -181,10 +181,10 @@ def _get_merged_dataset(
datasets
[
dataset_name
]
=
_load_single_dataset
(
dataset_attr
,
model_args
,
data_args
,
training_args
)
datasets
[
dataset_name
]
=
_load_single_dataset
(
dataset_attr
,
model_args
,
data_args
,
training_args
)
if
merge
:
if
return_dict
:
return
merge_dataset
(
list
(
datasets
.
values
()),
data_args
,
seed
=
training_args
.
seed
)
else
:
return
datasets
return
datasets
else
:
return
merge_dataset
(
list
(
datasets
.
values
()),
data_args
,
seed
=
training_args
.
seed
)
def
_get_dataset_processor
(
def
_get_dataset_processor
(
...
@@ -300,13 +300,18 @@ def get_dataset(
...
@@ -300,13 +300,18 @@ def get_dataset(
raise
ValueError
(
"Turn off `streaming` when saving dataset to disk."
)
raise
ValueError
(
"Turn off `streaming` when saving dataset to disk."
)
# Load and preprocess dataset
# Load and preprocess dataset
with
training_args
.
main_process_first
(
desc
=
"load dataset"
):
with
training_args
.
main_process_first
(
desc
=
"load dataset"
,
local
=
(
not
data_args
.
data_shared_file_system
)
):
dataset
=
_get_merged_dataset
(
data_args
.
dataset
,
model_args
,
data_args
,
training_args
,
stage
)
dataset
=
_get_merged_dataset
(
data_args
.
dataset
,
model_args
,
data_args
,
training_args
,
stage
)
eval_dataset
=
_get_merged_dataset
(
eval_dataset
=
_get_merged_dataset
(
data_args
.
eval_dataset
,
model_args
,
data_args
,
training_args
,
stage
,
merge
=
training_args
.
do_predict
data_args
.
eval_dataset
,
model_args
,
data_args
,
training_args
,
stage
,
return_dict
=
data_args
.
eval_on_each_dataset
,
)
)
with
training_args
.
main_process_first
(
desc
=
"pre-process dataset"
):
with
training_args
.
main_process_first
(
desc
=
"pre-process dataset"
,
local
=
(
not
data_args
.
data_shared_file_system
)
):
dataset
=
_get_preprocessed_dataset
(
dataset
=
_get_preprocessed_dataset
(
dataset
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
False
dataset
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
False
)
)
...
...
src/llamafactory/data/mm_plugin.py
View file @
0722acf1
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
inspect
import
inspect
import
math
import
math
import
os
import
re
import
re
from
copy
import
deepcopy
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
@@ -25,7 +26,7 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
...
@@ -25,7 +26,7 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
transformers.image_utils
import
get_image_size
,
to_numpy_array
from
transformers.image_utils
import
get_image_size
,
is_valid_image
,
to_numpy_array
from
typing_extensions
import
override
from
typing_extensions
import
override
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IGNORE_INDEX
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IGNORE_INDEX
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
...
@@ -57,7 +58,10 @@ if is_transformers_version_greater_than("4.45.0"):
...
@@ -57,7 +58,10 @@ if is_transformers_version_greater_than("4.45.0"):
)
)
if
is_transformers_version_greater_than
(
"4.49.0"
):
if
is_transformers_version_greater_than
(
"4.52.0"
):
from
transformers.image_utils
import
make_flat_list_of_images
from
transformers.video_utils
import
make_batched_videos
elif
is_transformers_version_greater_than
(
"4.49.0"
):
from
transformers.image_utils
import
make_batched_videos
,
make_flat_list_of_images
from
transformers.image_utils
import
make_batched_videos
,
make_flat_list_of_images
...
@@ -73,7 +77,7 @@ if TYPE_CHECKING:
...
@@ -73,7 +77,7 @@ if TYPE_CHECKING:
bytes
:
Optional
[
bytes
]
bytes
:
Optional
[
bytes
]
ImageInput
=
Union
[
str
,
bytes
,
EncodedImage
,
BinaryIO
,
ImageObject
]
ImageInput
=
Union
[
str
,
bytes
,
EncodedImage
,
BinaryIO
,
ImageObject
]
VideoInput
=
Union
[
str
,
BinaryIO
]
VideoInput
=
Union
[
str
,
BinaryIO
,
list
[
list
[
ImageInput
]]
]
AudioInput
=
Union
[
str
,
BinaryIO
,
NDArray
]
AudioInput
=
Union
[
str
,
BinaryIO
,
NDArray
]
class
MMProcessor
(
ProcessorMixin
):
class
MMProcessor
(
ProcessorMixin
):
...
@@ -131,6 +135,11 @@ def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> lis
...
@@ -131,6 +135,11 @@ def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> lis
return
batch_images
return
batch_images
def
_check_video_is_nested_images
(
video
:
"VideoInput"
)
->
bool
:
r
"""Check if the video is nested images."""
return
isinstance
(
video
,
list
)
and
all
(
isinstance
(
frame
,
(
str
,
BinaryIO
,
dict
))
for
frame
in
video
)
@
dataclass
@
dataclass
class
MMPluginMixin
:
class
MMPluginMixin
:
image_token
:
Optional
[
str
]
image_token
:
Optional
[
str
]
...
@@ -167,16 +176,45 @@ class MMPluginMixin:
...
@@ -167,16 +176,45 @@ class MMPluginMixin:
)
)
if
self
.
image_token
is
not
None
and
processor
is
None
:
if
self
.
image_token
is
not
None
and
processor
is
None
:
raise
ValueError
(
"Processor was not found, please check and update your
processor config
."
)
raise
ValueError
(
"Processor was not found, please check and update your
model file
."
)
if
self
.
image_token
is
not
None
and
image_processor
is
None
:
if
self
.
image_token
is
not
None
and
image_processor
is
None
:
raise
ValueError
(
"Image processor was not found, please check and update your
processor config
."
)
raise
ValueError
(
"Image processor was not found, please check and update your
model file
."
)
if
self
.
video_token
is
not
None
and
video_processor
is
None
:
if
self
.
video_token
is
not
None
and
video_processor
is
None
:
raise
ValueError
(
"Video processor was not found, please check and update your
processor config
."
)
raise
ValueError
(
"Video processor was not found, please check and update your
model file
."
)
if
self
.
audio_token
is
not
None
and
feature_extractor
is
None
:
if
self
.
audio_token
is
not
None
and
feature_extractor
is
None
:
raise
ValueError
(
"Audio feature extractor was not found, please check and update your processor config."
)
raise
ValueError
(
"Audio feature extractor was not found, please check and update your model file."
)
def
_validate_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
):
r
"""Validate if the number of images, videos and audios match the number of placeholders in messages."""
num_image_tokens
,
num_video_tokens
,
num_audio_tokens
=
0
,
0
,
0
for
message
in
messages
:
num_image_tokens
+=
message
[
"content"
].
count
(
IMAGE_PLACEHOLDER
)
num_video_tokens
+=
message
[
"content"
].
count
(
VIDEO_PLACEHOLDER
)
num_audio_tokens
+=
message
[
"content"
].
count
(
AUDIO_PLACEHOLDER
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens in
{
messages
}
."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens in
{
messages
}
."
)
if
len
(
audios
)
!=
num_audio_tokens
:
raise
ValueError
(
f
"The number of audios does not match the number of
{
AUDIO_PLACEHOLDER
}
tokens in
{
messages
}
."
)
def
_preprocess_image
(
def
_preprocess_image
(
self
,
image
:
"ImageObject"
,
image_max_pixels
:
int
,
image_min_pixels
:
int
,
**
kwargs
self
,
image
:
"ImageObject"
,
image_max_pixels
:
int
,
image_min_pixels
:
int
,
**
kwargs
...
@@ -234,14 +272,20 @@ class MMPluginMixin:
...
@@ -234,14 +272,20 @@ class MMPluginMixin:
r
"""Regularizes videos to avoid error. Including reading, resizing and converting."""
r
"""Regularizes videos to avoid error. Including reading, resizing and converting."""
results
=
[]
results
=
[]
for
video
in
videos
:
for
video
in
videos
:
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
sample_indices
=
self
.
_get_video_sample_indices
(
video_stream
,
**
kwargs
)
frames
:
list
[
ImageObject
]
=
[]
frames
:
list
[
ImageObject
]
=
[]
container
.
seek
(
0
)
if
_check_video_is_nested_images
(
video
):
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
for
frame
in
video
:
if
frame_idx
in
sample_indices
:
if
not
is_valid_image
(
frame
)
and
not
isinstance
(
frame
,
dict
)
and
not
os
.
path
.
exists
(
frame
):
frames
.
append
(
frame
.
to_image
())
raise
ValueError
(
"Invalid image found in video frames."
)
frames
=
video
else
:
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
sample_indices
=
self
.
_get_video_sample_indices
(
video_stream
,
**
kwargs
)
container
.
seek
(
0
)
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
if
frame_idx
in
sample_indices
:
frames
.
append
(
frame
.
to_image
())
frames
=
self
.
_regularize_images
(
frames
,
**
kwargs
)[
"images"
]
frames
=
self
.
_regularize_images
(
frames
,
**
kwargs
)[
"images"
]
results
.
append
(
frames
)
results
.
append
(
frames
)
...
@@ -420,6 +464,7 @@ class Gemma3Plugin(BasePlugin):
...
@@ -420,6 +464,7 @@ class Gemma3Plugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
=
0
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
boi_token
:
str
=
getattr
(
processor
,
"boi_token"
)
boi_token
:
str
=
getattr
(
processor
,
"boi_token"
)
...
@@ -446,9 +491,6 @@ class Gemma3Plugin(BasePlugin):
...
@@ -446,9 +491,6 @@ class Gemma3Plugin(BasePlugin):
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
image_str
)
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
image_str
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
return
messages
@
override
@
override
...
@@ -495,14 +537,14 @@ class InternVLPlugin(BasePlugin):
...
@@ -495,14 +537,14 @@ class InternVLPlugin(BasePlugin):
mm_inputs
=
{}
mm_inputs
=
{}
image_video_patches
=
[]
image_video_patches
=
[]
if
len
(
images
)
!=
0
and
isinstance
(
images
[
0
],
str
)
:
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
=
self
.
_regularize_images
(
images
,
images
,
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
1024
*
1024
),
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
1024
*
1024
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)[
"images"
]
)[
"images"
]
if
len
(
videos
)
!=
0
and
isinstance
(
videos
[
0
],
str
)
:
if
len
(
videos
)
!=
0
:
videos
=
self
.
_regularize_videos
(
videos
=
self
.
_regularize_videos
(
videos
,
videos
,
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
...
@@ -566,8 +608,8 @@ class InternVLPlugin(BasePlugin):
...
@@ -566,8 +608,8 @@ class InternVLPlugin(BasePlugin):
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
list
[
dict
[
str
,
str
]]:
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_video_tokens
=
0
num_image_tokens
,
num_video_tokens
=
0
,
0
image_seqlen
=
getattr
(
processor
,
"image_seq_length"
)
if
self
.
expand_mm_tokens
else
1
image_seqlen
=
getattr
(
processor
,
"image_seq_length"
)
if
self
.
expand_mm_tokens
else
1
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
...
@@ -579,9 +621,6 @@ class InternVLPlugin(BasePlugin):
...
@@ -579,9 +621,6 @@ class InternVLPlugin(BasePlugin):
for
message
in
messages
:
for
message
in
messages
:
content
=
message
[
"content"
]
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
content
=
content
.
replace
(
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
IMAGE_PLACEHOLDER
,
f
"<img>
{
'<IMG_CONTEXT>'
*
image_seqlen
*
image_pixel_patch_list
[
num_image_tokens
]
}
</img>"
,
f
"<img>
{
'<IMG_CONTEXT>'
*
image_seqlen
*
image_pixel_patch_list
[
num_image_tokens
]
}
</img>"
,
...
@@ -590,9 +629,6 @@ class InternVLPlugin(BasePlugin):
...
@@ -590,9 +629,6 @@ class InternVLPlugin(BasePlugin):
num_image_tokens
+=
1
num_image_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
while
VIDEO_PLACEHOLDER
in
content
:
if
num_video_tokens
>=
len
(
videos
):
raise
ValueError
(
f
"`len(videos)` is less than the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
current_patch_index
=
video_patch_indices
[
num_video_tokens
-
1
]
if
num_video_tokens
>
0
else
0
current_patch_index
=
video_patch_indices
[
num_video_tokens
-
1
]
if
num_video_tokens
>
0
else
0
end_patch_index
=
video_patch_indices
[
num_video_tokens
]
end_patch_index
=
video_patch_indices
[
num_video_tokens
]
num_patches
=
list
(
video_num_patches
[
current_patch_index
:
end_patch_index
])
num_patches
=
list
(
video_num_patches
[
current_patch_index
:
end_patch_index
])
...
@@ -605,12 +641,6 @@ class InternVLPlugin(BasePlugin):
...
@@ -605,12 +641,6 @@ class InternVLPlugin(BasePlugin):
message
[
"content"
]
=
content
message
[
"content"
]
=
content
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
return
messages
return
messages
@
override
@
override
...
@@ -637,10 +667,13 @@ class KimiVLPlugin(BasePlugin):
...
@@ -637,10 +667,13 @@ class KimiVLPlugin(BasePlugin):
@
override
@
override
def
process_messages
(
self
,
messages
,
images
,
videos
,
audios
,
processor
):
def
process_messages
(
self
,
messages
,
images
,
videos
,
audios
,
processor
):
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
if
self
.
expand_mm_tokens
:
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
image_grid_hws
=
mm_inputs
.
get
(
"image_grid_hws"
,
[])
else
:
image_grid_hws
=
[
None
]
*
len
(
images
)
image_grid_hws
=
mm_inputs
.
get
(
"image_grid_hws"
,
[])
num_image_tokens
=
0
num_image_tokens
=
0
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
merge_length
=
math
.
prod
(
image_processor
.
merge_kernel_size
)
merge_length
=
math
.
prod
(
image_processor
.
merge_kernel_size
)
...
@@ -648,9 +681,6 @@ class KimiVLPlugin(BasePlugin):
...
@@ -648,9 +681,6 @@ class KimiVLPlugin(BasePlugin):
for
message
in
messages
:
for
message
in
messages
:
content
=
message
[
"content"
]
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
image_seqlen
=
image_grid_hws
[
num_image_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
image_seqlen
=
image_grid_hws
[
num_image_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
IMAGE_PLACEHOLDER
,
...
@@ -661,9 +691,6 @@ class KimiVLPlugin(BasePlugin):
...
@@ -661,9 +691,6 @@ class KimiVLPlugin(BasePlugin):
message
[
"content"
]
=
content
message
[
"content"
]
=
content
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
return
messages
...
@@ -679,6 +706,7 @@ class Llama4Plugin(BasePlugin):
...
@@ -679,6 +706,7 @@ class Llama4Plugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
if
self
.
expand_mm_tokens
:
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"pixel_values"
in
mm_inputs
:
if
"pixel_values"
in
mm_inputs
:
...
@@ -701,9 +729,6 @@ class Llama4Plugin(BasePlugin):
...
@@ -701,9 +729,6 @@ class Llama4Plugin(BasePlugin):
for
local_image_index
,
split_part
in
enumerate
(
prompt_splits
):
for
local_image_index
,
split_part
in
enumerate
(
prompt_splits
):
new_content
.
append
(
split_part
)
new_content
.
append
(
split_part
)
if
local_image_index
<
placeholder_count
:
if
local_image_index
<
placeholder_count
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
tokens_for_this_image
=
processor
.
_prompt_split_image
(
tokens_for_this_image
=
processor
.
_prompt_split_image
(
aspect_ratios
[
num_image_tokens
],
num_patches_per_chunk
aspect_ratios
[
num_image_tokens
],
num_patches_per_chunk
)
)
...
@@ -716,9 +741,6 @@ class Llama4Plugin(BasePlugin):
...
@@ -716,9 +741,6 @@ class Llama4Plugin(BasePlugin):
message
[
"content"
]
=
content
message
[
"content"
]
=
content
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
return
messages
@
override
@
override
...
@@ -751,7 +773,7 @@ class LlavaPlugin(BasePlugin):
...
@@ -751,7 +773,7 @@ class LlavaPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
...
@@ -768,17 +790,10 @@ class LlavaPlugin(BasePlugin):
...
@@ -768,17 +790,10 @@ class LlavaPlugin(BasePlugin):
for
message
in
messages
:
for
message
in
messages
:
content
=
message
[
"content"
]
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
return
messages
...
@@ -794,6 +809,7 @@ class LlavaNextPlugin(BasePlugin):
...
@@ -794,6 +809,7 @@ class LlavaNextPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
=
0
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
if
self
.
expand_mm_tokens
:
...
@@ -805,9 +821,6 @@ class LlavaNextPlugin(BasePlugin):
...
@@ -805,9 +821,6 @@ class LlavaNextPlugin(BasePlugin):
for
message
in
messages
:
for
message
in
messages
:
content
=
message
[
"content"
]
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
self
.
expand_mm_tokens
:
if
self
.
expand_mm_tokens
:
orig_height
,
orig_width
=
next
(
image_sizes
)
orig_height
,
orig_width
=
next
(
image_sizes
)
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
...
@@ -821,9 +834,6 @@ class LlavaNextPlugin(BasePlugin):
...
@@ -821,9 +834,6 @@ class LlavaNextPlugin(BasePlugin):
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
return
messages
...
@@ -839,7 +849,7 @@ class LlavaNextVideoPlugin(BasePlugin):
...
@@ -839,7 +849,7 @@ class LlavaNextVideoPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
...
@@ -850,9 +860,6 @@ class LlavaNextVideoPlugin(BasePlugin):
...
@@ -850,9 +860,6 @@ class LlavaNextVideoPlugin(BasePlugin):
for
message
in
messages
:
for
message
in
messages
:
content
=
message
[
"content"
]
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
self
.
expand_mm_tokens
:
if
self
.
expand_mm_tokens
:
orig_height
,
orig_width
=
next
(
image_sizes
)
orig_height
,
orig_width
=
next
(
image_sizes
)
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
...
@@ -862,7 +869,6 @@ class LlavaNextVideoPlugin(BasePlugin):
...
@@ -862,7 +869,6 @@ class LlavaNextVideoPlugin(BasePlugin):
image_seqlen
=
1
image_seqlen
=
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
...
@@ -879,20 +885,10 @@ class LlavaNextVideoPlugin(BasePlugin):
...
@@ -879,20 +885,10 @@ class LlavaNextVideoPlugin(BasePlugin):
for
message
in
messages
:
for
message
in
messages
:
content
=
message
[
"content"
]
content
=
message
[
"content"
]
while
VIDEO_PLACEHOLDER
in
content
:
while
VIDEO_PLACEHOLDER
in
content
:
if
num_video_tokens
>=
len
(
videos
):
raise
ValueError
(
f
"`len(videos)` is less than the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{video}}"
*
video_seqlen
,
1
)
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{video}}"
*
video_seqlen
,
1
)
num_video_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{video}}"
,
self
.
video_token
)
message
[
"content"
]
=
content
.
replace
(
"{{video}}"
,
self
.
video_token
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
return
messages
return
messages
...
@@ -978,6 +974,7 @@ class MiniCPMVPlugin(BasePlugin):
...
@@ -978,6 +974,7 @@ class MiniCPMVPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
,
num_audio_tokens
=
0
,
0
,
0
num_image_tokens
,
num_video_tokens
,
num_audio_tokens
=
0
,
0
,
0
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
...
@@ -996,24 +993,15 @@ class MiniCPMVPlugin(BasePlugin):
...
@@ -996,24 +993,15 @@ class MiniCPMVPlugin(BasePlugin):
for
i
,
message
in
enumerate
(
messages
):
for
i
,
message
in
enumerate
(
messages
):
content
=
message
[
"content"
]
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
,
1
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
,
1
)
num_image_tokens
+=
1
num_image_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
while
VIDEO_PLACEHOLDER
in
content
:
if
num_video_tokens
>=
len
(
videos
):
raise
ValueError
(
f
"`len(videos)` is less than the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
video_seqlen
=
len
(
mm_inputs
[
"pixel_values"
][
num_video_tokens
])
if
self
.
expand_mm_tokens
else
1
video_seqlen
=
len
(
mm_inputs
[
"pixel_values"
][
num_video_tokens
])
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{image}}"
*
video_seqlen
,
1
)
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{image}}"
*
video_seqlen
,
1
)
num_video_tokens
+=
1
num_video_tokens
+=
1
while
AUDIO_PLACEHOLDER
in
content
:
while
AUDIO_PLACEHOLDER
in
content
:
if
num_audio_tokens
>=
len
(
audios
):
raise
ValueError
(
f
"`len(audios)` is less than the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
"{{audio}}"
,
1
)
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
"{{audio}}"
,
1
)
num_audio_tokens
+=
1
num_audio_tokens
+=
1
...
@@ -1065,15 +1053,6 @@ class MiniCPMVPlugin(BasePlugin):
...
@@ -1065,15 +1053,6 @@ class MiniCPMVPlugin(BasePlugin):
final_text
+=
text_chunks
[
-
1
]
final_text
+=
text_chunks
[
-
1
]
messages
[
index
][
"content"
]
=
final_text
messages
[
index
][
"content"
]
=
final_text
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
if
len
(
audios
)
!=
num_audio_tokens
:
raise
ValueError
(
f
"The number of audios does not match the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
return
messages
return
messages
@
override
@
override
...
@@ -1157,6 +1136,7 @@ class MllamaPlugin(BasePlugin):
...
@@ -1157,6 +1136,7 @@ class MllamaPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
=
0
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
for
message
in
messages
:
...
@@ -1164,9 +1144,6 @@ class MllamaPlugin(BasePlugin):
...
@@ -1164,9 +1144,6 @@ class MllamaPlugin(BasePlugin):
num_image_tokens
+=
content
.
count
(
IMAGE_PLACEHOLDER
)
num_image_tokens
+=
content
.
count
(
IMAGE_PLACEHOLDER
)
message
[
"content"
]
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
self
.
image_token
)
message
[
"content"
]
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
self
.
image_token
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
return
messages
@
override
@
override
...
@@ -1214,6 +1191,7 @@ class PaliGemmaPlugin(BasePlugin):
...
@@ -1214,6 +1191,7 @@ class PaliGemmaPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
=
0
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
for
message
in
messages
:
...
@@ -1224,9 +1202,6 @@ class PaliGemmaPlugin(BasePlugin):
...
@@ -1224,9 +1202,6 @@ class PaliGemmaPlugin(BasePlugin):
message
[
"content"
]
=
content
message
[
"content"
]
=
content
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
return
messages
@
override
@
override
...
@@ -1281,7 +1256,7 @@ class PixtralPlugin(BasePlugin):
...
@@ -1281,7 +1256,7 @@ class PixtralPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
...
@@ -1291,15 +1266,13 @@ class PixtralPlugin(BasePlugin):
...
@@ -1291,15 +1266,13 @@ class PixtralPlugin(BasePlugin):
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
][
0
])
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
][
0
])
else
:
else
:
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
].
tolist
())
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
].
tolist
())
image_break_token
:
str
=
getattr
(
processor
,
"image_break_token"
)
image_break_token
:
str
=
getattr
(
processor
,
"image_break_token"
)
image_end_token
:
str
=
getattr
(
processor
,
"image_end_token"
)
image_end_token
:
str
=
getattr
(
processor
,
"image_end_token"
)
for
message
in
messages
:
for
message
in
messages
:
content
=
message
[
"content"
]
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
self
.
expand_mm_tokens
:
if
self
.
expand_mm_tokens
:
height
,
width
=
next
(
image_sizes
)
height
,
width
=
next
(
image_sizes
)
num_height_tokens
=
height
//
processor
.
patch_size
num_height_tokens
=
height
//
processor
.
patch_size
...
@@ -1312,13 +1285,9 @@ class PixtralPlugin(BasePlugin):
...
@@ -1312,13 +1285,9 @@ class PixtralPlugin(BasePlugin):
replace_str
=
self
.
image_token
replace_str
=
self
.
image_token
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
replace_str
,
1
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
replace_str
,
1
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
message
[
"content"
]
=
content
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
return
messages
@
override
@
override
...
@@ -1355,9 +1324,9 @@ class Qwen2AudioPlugin(BasePlugin):
...
@@ -1355,9 +1324,9 @@ class Qwen2AudioPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
bos_token
:
str
=
getattr
(
processor
,
"audio_bos_token"
)
bos_token
:
str
=
getattr
(
processor
,
"audio_bos_token"
)
eos_token
:
str
=
getattr
(
processor
,
"audio_eos_token"
)
eos_token
:
str
=
getattr
(
processor
,
"audio_eos_token"
)
num_audio_tokens
=
0
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
([],
[],
audios
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
([],
[],
audios
,
processor
)
...
@@ -1367,9 +1336,6 @@ class Qwen2AudioPlugin(BasePlugin):
...
@@ -1367,9 +1336,6 @@ class Qwen2AudioPlugin(BasePlugin):
for
message
in
messages
:
for
message
in
messages
:
content
=
message
[
"content"
]
content
=
message
[
"content"
]
while
AUDIO_PLACEHOLDER
in
content
:
while
AUDIO_PLACEHOLDER
in
content
:
if
num_audio_tokens
>=
len
(
audios
):
raise
ValueError
(
f
"`len(audios)` is less than the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
if
self
.
expand_mm_tokens
:
if
self
.
expand_mm_tokens
:
audio_length
=
audio_lengths
.
pop
(
0
)
audio_length
=
audio_lengths
.
pop
(
0
)
input_length
=
(
audio_length
-
1
)
//
2
+
1
input_length
=
(
audio_length
-
1
)
//
2
+
1
...
@@ -1380,13 +1346,9 @@ class Qwen2AudioPlugin(BasePlugin):
...
@@ -1380,13 +1346,9 @@ class Qwen2AudioPlugin(BasePlugin):
content
=
content
.
replace
(
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
f
"
{
bos_token
}{
self
.
audio_token
*
audio_seqlen
}{
eos_token
}
"
,
1
AUDIO_PLACEHOLDER
,
f
"
{
bos_token
}{
self
.
audio_token
*
audio_seqlen
}{
eos_token
}
"
,
1
)
)
num_audio_tokens
+=
1
message
[
"content"
]
=
content
message
[
"content"
]
=
content
if
len
(
audios
)
!=
num_audio_tokens
:
raise
ValueError
(
f
"The number of audios does not match the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
return
messages
return
messages
@
override
@
override
...
@@ -1430,24 +1392,33 @@ class Qwen2VLPlugin(BasePlugin):
...
@@ -1430,24 +1392,33 @@ class Qwen2VLPlugin(BasePlugin):
)
->
dict
[
str
,
Union
[
list
[
list
[
"ImageObject"
]],
list
[
float
]]]:
)
->
dict
[
str
,
Union
[
list
[
list
[
"ImageObject"
]],
list
[
float
]]]:
results
,
fps_per_video
=
[],
[]
results
,
fps_per_video
=
[],
[]
for
video
in
videos
:
for
video
in
videos
:
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
sample_indices
=
self
.
_get_video_sample_indices
(
video_stream
,
**
kwargs
)
frames
:
list
[
ImageObject
]
=
[]
frames
:
list
[
ImageObject
]
=
[]
container
.
seek
(
0
)
if
_check_video_is_nested_images
(
video
):
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
for
frame
in
video
:
if
frame_idx
in
sample_indices
:
if
not
is_valid_image
(
frame
)
and
not
isinstance
(
frame
,
dict
)
and
not
os
.
path
.
exists
(
frame
):
frames
.
append
(
frame
.
to_image
())
raise
ValueError
(
"Invalid image found in video frames."
)
frames
=
video
fps_per_video
.
append
(
kwargs
.
get
(
"video_fps"
,
2.0
))
else
:
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
sample_indices
=
self
.
_get_video_sample_indices
(
video_stream
,
**
kwargs
)
container
.
seek
(
0
)
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
if
frame_idx
in
sample_indices
:
frames
.
append
(
frame
.
to_image
())
if
video_stream
.
duration
is
None
:
fps_per_video
.
append
(
kwargs
.
get
(
"video_fps"
,
2.0
))
else
:
fps_per_video
.
append
(
len
(
sample_indices
)
/
float
(
video_stream
.
duration
*
video_stream
.
time_base
))
if
len
(
frames
)
%
2
!=
0
:
# qwen2-vl requires even number of frames
if
len
(
frames
)
%
2
!=
0
:
frames
.
append
(
frames
[
-
1
])
frames
.
append
(
frames
[
-
1
])
frames
=
self
.
_regularize_images
(
frames
,
**
kwargs
)[
"images"
]
frames
=
self
.
_regularize_images
(
frames
,
**
kwargs
)[
"images"
]
results
.
append
(
frames
)
results
.
append
(
frames
)
if
video_stream
.
duration
is
None
:
fps_per_video
.
append
(
2.0
)
else
:
fps_per_video
.
append
(
len
(
sample_indices
)
/
float
(
video_stream
.
duration
*
video_stream
.
time_base
))
return
{
"videos"
:
results
,
"fps_per_video"
:
fps_per_video
}
return
{
"videos"
:
results
,
"fps_per_video"
:
fps_per_video
}
...
@@ -1494,6 +1465,7 @@ class Qwen2VLPlugin(BasePlugin):
...
@@ -1494,6 +1465,7 @@ class Qwen2VLPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
...
@@ -1510,9 +1482,6 @@ class Qwen2VLPlugin(BasePlugin):
...
@@ -1510,9 +1482,6 @@ class Qwen2VLPlugin(BasePlugin):
for
message
in
messages
:
for
message
in
messages
:
content
=
message
[
"content"
]
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
image_seqlen
=
image_grid_thw
[
num_image_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
image_seqlen
=
image_grid_thw
[
num_image_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
f
"<|vision_start|>
{
self
.
image_token
*
image_seqlen
}
<|vision_end|>"
,
1
IMAGE_PLACEHOLDER
,
f
"<|vision_start|>
{
self
.
image_token
*
image_seqlen
}
<|vision_end|>"
,
1
...
@@ -1520,9 +1489,6 @@ class Qwen2VLPlugin(BasePlugin):
...
@@ -1520,9 +1489,6 @@ class Qwen2VLPlugin(BasePlugin):
num_image_tokens
+=
1
num_image_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
while
VIDEO_PLACEHOLDER
in
content
:
if
num_video_tokens
>=
len
(
videos
):
raise
ValueError
(
f
"`len(videos)` is less than the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
video_seqlen
=
video_grid_thw
[
num_video_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
video_seqlen
=
video_grid_thw
[
num_video_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
f
"<|vision_start|>
{
self
.
video_token
*
video_seqlen
}
<|vision_end|>"
,
1
VIDEO_PLACEHOLDER
,
f
"<|vision_start|>
{
self
.
video_token
*
video_seqlen
}
<|vision_end|>"
,
1
...
@@ -1531,12 +1497,6 @@ class Qwen2VLPlugin(BasePlugin):
...
@@ -1531,12 +1497,6 @@ class Qwen2VLPlugin(BasePlugin):
message
[
"content"
]
=
content
message
[
"content"
]
=
content
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
return
messages
return
messages
...
@@ -1602,6 +1562,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
...
@@ -1602,6 +1562,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
processor
:
Optional
[
"MMProcessor"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
,
num_audio_tokens
=
0
,
0
,
0
num_image_tokens
,
num_video_tokens
,
num_audio_tokens
=
0
,
0
,
0
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
...
@@ -1624,9 +1585,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
...
@@ -1624,9 +1585,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
for
message
in
messages
:
for
message
in
messages
:
content
=
message
[
"content"
]
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
image_seqlen
=
image_grid_thw
[
num_image_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
image_seqlen
=
image_grid_thw
[
num_image_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
f
"<|vision_bos|>
{
self
.
image_token
*
image_seqlen
}
<|vision_eos|>"
,
1
IMAGE_PLACEHOLDER
,
f
"<|vision_bos|>
{
self
.
image_token
*
image_seqlen
}
<|vision_eos|>"
,
1
...
@@ -1642,11 +1600,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
...
@@ -1642,11 +1600,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
)
)
while
VIDEO_PLACEHOLDER
in
content
:
while
VIDEO_PLACEHOLDER
in
content
:
if
num_video_tokens
>=
len
(
videos
):
raise
ValueError
(
f
"`len(videos)` is less than the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
if
num_audio_tokens
>=
len
(
audios
):
raise
ValueError
(
f
"`len(audios)` is less than the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
video_pos
=
content
.
find
(
VIDEO_PLACEHOLDER
)
video_pos
=
content
.
find
(
VIDEO_PLACEHOLDER
)
audio_pos
=
content
.
find
(
AUDIO_PLACEHOLDER
,
video_pos
)
audio_pos
=
content
.
find
(
AUDIO_PLACEHOLDER
,
video_pos
)
if
audio_pos
==
-
1
or
audio_pos
<
video_pos
:
if
audio_pos
==
-
1
or
audio_pos
<
video_pos
:
...
@@ -1688,9 +1641,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
...
@@ -1688,9 +1641,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
num_video_tokens
+=
1
num_video_tokens
+=
1
else
:
else
:
while
AUDIO_PLACEHOLDER
in
content
:
while
AUDIO_PLACEHOLDER
in
content
:
if
num_audio_tokens
>=
len
(
audios
):
raise
ValueError
(
f
"`len(audios)` is less than the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
audio_seqlen
=
audio_lengths
[
num_audio_tokens
]
if
self
.
expand_mm_tokens
else
1
audio_seqlen
=
audio_lengths
[
num_audio_tokens
]
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
f
"<|audio_bos|>
{
self
.
audio_token
*
audio_seqlen
}
<|audio_eos|>"
,
1
AUDIO_PLACEHOLDER
,
f
"<|audio_bos|>
{
self
.
audio_token
*
audio_seqlen
}
<|audio_eos|>"
,
1
...
@@ -1698,9 +1648,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
...
@@ -1698,9 +1648,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
num_audio_tokens
+=
1
num_audio_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
while
VIDEO_PLACEHOLDER
in
content
:
if
num_video_tokens
>=
len
(
videos
):
raise
ValueError
(
f
"`len(videos)` is less than the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
video_seqlen
=
(
video_seqlen
=
(
video_grid_thw
[
num_video_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
video_grid_thw
[
num_video_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
)
)
...
@@ -1711,15 +1658,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
...
@@ -1711,15 +1658,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
message
[
"content"
]
=
content
message
[
"content"
]
=
content
if
len
(
audios
)
!=
num_audio_tokens
:
raise
ValueError
(
f
"The number of audios does not match the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
return
messages
return
messages
...
@@ -1735,6 +1673,7 @@ class VideoLlavaPlugin(BasePlugin):
...
@@ -1735,6 +1673,7 @@ class VideoLlavaPlugin(BasePlugin):
processor
:
Optional
[
"MMProcessor"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
num_frames
=
0
num_frames
=
0
...
@@ -1762,28 +1701,16 @@ class VideoLlavaPlugin(BasePlugin):
...
@@ -1762,28 +1701,16 @@ class VideoLlavaPlugin(BasePlugin):
for
message
in
messages
:
for
message
in
messages
:
content
=
message
[
"content"
]
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
images
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
num_image_tokens
+=
1
num_image_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
while
VIDEO_PLACEHOLDER
in
content
:
if
num_video_tokens
>=
len
(
videos
):
raise
ValueError
(
f
"`len(videos)` is less than the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{video}}"
*
video_seqlen
,
1
)
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{video}}"
*
video_seqlen
,
1
)
num_video_tokens
+=
1
num_video_tokens
+=
1
content
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
content
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
message
[
"content"
]
=
content
.
replace
(
"{{video}}"
,
self
.
video_token
)
message
[
"content"
]
=
content
.
replace
(
"{{video}}"
,
self
.
video_token
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
return
messages
return
messages
...
...
src/llamafactory/data/parser.py
View file @
0722acf1
...
@@ -115,12 +115,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li
...
@@ -115,12 +115,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li
dataset_list
:
list
[
DatasetAttr
]
=
[]
dataset_list
:
list
[
DatasetAttr
]
=
[]
for
name
in
dataset_names
:
for
name
in
dataset_names
:
if
dataset_info
is
None
:
# dataset_dir is ONLINE
if
dataset_info
is
None
:
# dataset_dir is ONLINE
if
use_modelscope
():
load_from
=
"ms_hub"
if
use_modelscope
()
else
"om_hub"
if
use_openmind
()
else
"hf_hub"
load_from
=
"ms_hub"
elif
use_openmind
():
load_from
=
"om_hub"
else
:
load_from
=
"hf_hub"
dataset_attr
=
DatasetAttr
(
load_from
,
dataset_name
=
name
)
dataset_attr
=
DatasetAttr
(
load_from
,
dataset_name
=
name
)
dataset_list
.
append
(
dataset_attr
)
dataset_list
.
append
(
dataset_attr
)
continue
continue
...
...
src/llamafactory/data/template.py
View file @
0722acf1
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
re
import
re
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
...
@@ -51,6 +52,7 @@ class Template:
...
@@ -51,6 +52,7 @@ class Template:
efficient_eos
:
bool
efficient_eos
:
bool
replace_eos
:
bool
replace_eos
:
bool
replace_jinja_template
:
bool
replace_jinja_template
:
bool
enable_thinking
:
Optional
[
bool
]
mm_plugin
:
"BasePlugin"
mm_plugin
:
"BasePlugin"
def
encode_oneturn
(
def
encode_oneturn
(
...
@@ -61,7 +63,7 @@ class Template:
...
@@ -61,7 +63,7 @@ class Template:
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
tuple
[
list
[
int
],
list
[
int
]]:
)
->
tuple
[
list
[
int
],
list
[
int
]]:
r
"""Return a single pair of token ids representing prompt and response respectively."""
r
"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
,
remove_thought
=
True
)
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
prompt_ids
=
[]
prompt_ids
=
[]
for
encoded_ids
in
encoded_messages
[:
-
1
]:
for
encoded_ids
in
encoded_messages
[:
-
1
]:
prompt_ids
+=
encoded_ids
prompt_ids
+=
encoded_ids
...
@@ -77,7 +79,7 @@ class Template:
...
@@ -77,7 +79,7 @@ class Template:
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
list
[
tuple
[
list
[
int
],
list
[
int
]]]:
)
->
list
[
tuple
[
list
[
int
],
list
[
int
]]]:
r
"""Return multiple pairs of token ids representing prompts and responses respectively."""
r
"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
,
remove_thought
=
False
)
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
return
[(
encoded_messages
[
i
],
encoded_messages
[
i
+
1
])
for
i
in
range
(
0
,
len
(
encoded_messages
),
2
)]
return
[(
encoded_messages
[
i
],
encoded_messages
[
i
+
1
])
for
i
in
range
(
0
,
len
(
encoded_messages
),
2
)]
def
extract_tool
(
self
,
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
def
extract_tool
(
self
,
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
...
@@ -92,6 +94,19 @@ class Template:
...
@@ -92,6 +94,19 @@ class Template:
return
list
(
stop_token_ids
)
return
list
(
stop_token_ids
)
def
add_thought
(
self
,
content
:
str
=
""
)
->
str
:
r
"""Add empty thought to assistant message."""
return
f
"
{
self
.
thought_words
[
0
]
}
\n\n
{
self
.
thought_words
[
1
]
}
\n\n
"
+
content
def
remove_thought
(
self
,
content
:
str
)
->
str
:
r
"""Remove thought from assistant message."""
pattern
=
re
.
compile
(
f
"
{
re
.
escape
(
self
.
thought_words
[
0
])
}
(.*?)
{
re
.
escape
(
self
.
thought_words
[
1
])
}
"
,
re
.
DOTALL
)
return
re
.
sub
(
pattern
,
""
,
content
).
lstrip
(
"
\n
"
)
def
get_thought_word_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
list
[
int
]:
r
"""Get the token ids of thought words."""
return
tokenizer
.
encode
(
self
.
add_thought
(),
add_special_tokens
=
False
)
def
_convert_elements_to_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
elements
:
"SLOTS"
)
->
list
[
int
]:
def
_convert_elements_to_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
elements
:
"SLOTS"
)
->
list
[
int
]:
r
"""Convert elements to token ids."""
r
"""Convert elements to token ids."""
token_ids
=
[]
token_ids
=
[]
...
@@ -111,18 +126,12 @@ class Template:
...
@@ -111,18 +126,12 @@ class Template:
return
token_ids
return
token_ids
def
_remove_thought
(
self
,
content
:
str
)
->
str
:
r
"""Remove thought from assistant message."""
pattern
=
re
.
compile
(
f
"
{
re
.
escape
(
self
.
thought_words
[
0
])
}
(.*?)
{
re
.
escape
(
self
.
thought_words
[
1
])
}
"
,
re
.
DOTALL
)
return
re
.
sub
(
pattern
,
""
,
content
).
lstrip
(
"
\n
"
)
def
_encode
(
def
_encode
(
self
,
self
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
list
[
dict
[
str
,
str
]],
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
tools
:
Optional
[
str
],
remove_thought
:
bool
,
)
->
list
[
list
[
int
]]:
)
->
list
[
list
[
int
]]:
r
"""Encode formatted inputs to pairs of token ids.
r
"""Encode formatted inputs to pairs of token ids.
...
@@ -140,18 +149,14 @@ class Template:
...
@@ -140,18 +149,14 @@ class Template:
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
elements
+=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))
elements
+=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))
content
=
message
[
"content"
]
if
remove_thought
and
message
[
"role"
]
==
Role
.
ASSISTANT
and
(
i
!=
len
(
messages
)
-
1
):
content
=
self
.
_remove_thought
(
content
)
if
message
[
"role"
]
==
Role
.
USER
:
if
message
[
"role"
]
==
Role
.
USER
:
elements
+=
self
.
format_user
.
apply
(
content
=
content
,
idx
=
str
(
i
//
2
))
elements
+=
self
.
format_user
.
apply
(
content
=
message
[
"
content
"
]
,
idx
=
str
(
i
//
2
))
elif
message
[
"role"
]
==
Role
.
ASSISTANT
:
elif
message
[
"role"
]
==
Role
.
ASSISTANT
:
elements
+=
self
.
format_assistant
.
apply
(
content
=
content
)
elements
+=
self
.
format_assistant
.
apply
(
content
=
message
[
"
content
"
]
)
elif
message
[
"role"
]
==
Role
.
OBSERVATION
:
elif
message
[
"role"
]
==
Role
.
OBSERVATION
:
elements
+=
self
.
format_observation
.
apply
(
content
=
content
)
elements
+=
self
.
format_observation
.
apply
(
content
=
message
[
"
content
"
]
)
elif
message
[
"role"
]
==
Role
.
FUNCTION
:
elif
message
[
"role"
]
==
Role
.
FUNCTION
:
elements
+=
self
.
format_function
.
apply
(
content
=
content
)
elements
+=
self
.
format_function
.
apply
(
content
=
message
[
"
content
"
]
)
else
:
else
:
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
...
@@ -162,6 +167,9 @@ class Template:
...
@@ -162,6 +167,9 @@ class Template:
@
staticmethod
@
staticmethod
def
_add_or_replace_eos_token
(
tokenizer
:
"PreTrainedTokenizer"
,
eos_token
:
str
)
->
None
:
def
_add_or_replace_eos_token
(
tokenizer
:
"PreTrainedTokenizer"
,
eos_token
:
str
)
->
None
:
r
"""Add or replace eos token to the tokenizer."""
r
"""Add or replace eos token to the tokenizer."""
if
tokenizer
.
eos_token
==
eos_token
:
return
is_added
=
tokenizer
.
eos_token_id
is
None
is_added
=
tokenizer
.
eos_token_id
is
None
num_added_tokens
=
tokenizer
.
add_special_tokens
({
"eos_token"
:
eos_token
})
num_added_tokens
=
tokenizer
.
add_special_tokens
({
"eos_token"
:
eos_token
})
...
@@ -328,7 +336,6 @@ class Llama2Template(Template):
...
@@ -328,7 +336,6 @@ class Llama2Template(Template):
messages
:
list
[
dict
[
str
,
str
]],
messages
:
list
[
dict
[
str
,
str
]],
system
:
str
,
system
:
str
,
tools
:
str
,
tools
:
str
,
remove_thought
:
bool
,
)
->
list
[
list
[
int
]]:
)
->
list
[
list
[
int
]]:
system
=
system
or
self
.
default_system
system
=
system
or
self
.
default_system
encoded_messages
=
[]
encoded_messages
=
[]
...
@@ -342,18 +349,14 @@ class Llama2Template(Template):
...
@@ -342,18 +349,14 @@ class Llama2Template(Template):
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
system_text
=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))[
0
]
system_text
=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))[
0
]
content
=
message
[
"content"
]
if
remove_thought
and
message
[
"role"
]
==
Role
.
ASSISTANT
and
(
i
!=
len
(
messages
)
-
1
):
content
=
self
.
_remove_thought
(
content
)
if
message
[
"role"
]
==
Role
.
USER
:
if
message
[
"role"
]
==
Role
.
USER
:
elements
+=
self
.
format_user
.
apply
(
content
=
system_text
+
content
)
elements
+=
self
.
format_user
.
apply
(
content
=
system_text
+
message
[
"
content
"
]
)
elif
message
[
"role"
]
==
Role
.
ASSISTANT
:
elif
message
[
"role"
]
==
Role
.
ASSISTANT
:
elements
+=
self
.
format_assistant
.
apply
(
content
=
content
)
elements
+=
self
.
format_assistant
.
apply
(
content
=
message
[
"
content
"
]
)
elif
message
[
"role"
]
==
Role
.
OBSERVATION
:
elif
message
[
"role"
]
==
Role
.
OBSERVATION
:
elements
+=
self
.
format_observation
.
apply
(
content
=
content
)
elements
+=
self
.
format_observation
.
apply
(
content
=
message
[
"
content
"
]
)
elif
message
[
"role"
]
==
Role
.
FUNCTION
:
elif
message
[
"role"
]
==
Role
.
FUNCTION
:
elements
+=
self
.
format_function
.
apply
(
content
=
content
)
elements
+=
self
.
format_function
.
apply
(
content
=
message
[
"
content
"
]
)
else
:
else
:
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
...
@@ -392,6 +395,64 @@ class Llama2Template(Template):
...
@@ -392,6 +395,64 @@ class Llama2Template(Template):
return
jinja_template
return
jinja_template
@
dataclass
class
ReasoningTemplate
(
Template
):
r
"""A template that add thought to assistant message."""
@
override
def
encode_oneturn
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
tuple
[
list
[
int
],
list
[
int
]]:
messages
=
deepcopy
(
messages
)
for
i
in
range
(
1
,
len
(
messages
)
-
2
,
2
):
messages
[
i
][
"content"
]
=
self
.
remove_thought
(
messages
[
i
][
"content"
])
if
self
.
enable_thinking
is
False
:
# remove all cot
messages
[
-
1
][
"content"
]
=
self
.
remove_thought
(
messages
[
-
1
][
"content"
])
prompt_ids
,
response_ids
=
super
().
encode_oneturn
(
tokenizer
,
messages
,
system
,
tools
)
if
(
self
.
thought_words
[
0
]
not
in
messages
[
-
1
][
"content"
]
and
self
.
thought_words
[
1
]
not
in
messages
[
-
1
][
"content"
]
):
# add empty cot
if
not
self
.
enable_thinking
:
# do not compute loss
prompt_ids
+=
self
.
get_thought_word_ids
(
tokenizer
)
else
:
# do compute loss
response_ids
=
self
.
get_thought_word_ids
(
tokenizer
)
+
response_ids
return
prompt_ids
,
response_ids
@
override
def
encode_multiturn
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
list
[
tuple
[
list
[
int
],
list
[
int
]]]:
messages
=
deepcopy
(
messages
)
if
self
.
enable_thinking
is
False
:
# remove all cot
for
i
in
range
(
1
,
len
(
messages
),
2
):
messages
[
i
][
"content"
]
=
self
.
remove_thought
(
messages
[
i
][
"content"
])
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
for
i
in
range
(
0
,
len
(
messages
),
2
):
if
(
self
.
thought_words
[
0
]
not
in
messages
[
i
+
1
][
"content"
]
and
self
.
thought_words
[
1
]
not
in
messages
[
i
+
1
][
"content"
]
):
# add empty cot
if
not
self
.
enable_thinking
:
# do not compute loss
encoded_messages
[
i
]
+=
self
.
get_thought_word_ids
(
tokenizer
)
else
:
# do compute loss
encoded_messages
[
i
+
1
]
=
self
.
get_thought_word_ids
(
tokenizer
)
+
encoded_messages
[
i
+
1
]
return
[(
encoded_messages
[
i
],
encoded_messages
[
i
+
1
])
for
i
in
range
(
0
,
len
(
encoded_messages
),
2
)]
TEMPLATES
:
dict
[
str
,
"Template"
]
=
{}
TEMPLATES
:
dict
[
str
,
"Template"
]
=
{}
...
@@ -410,6 +471,7 @@ def register_template(
...
@@ -410,6 +471,7 @@ def register_template(
efficient_eos
:
bool
=
False
,
efficient_eos
:
bool
=
False
,
replace_eos
:
bool
=
False
,
replace_eos
:
bool
=
False
,
replace_jinja_template
:
bool
=
False
,
replace_jinja_template
:
bool
=
False
,
enable_thinking
:
Optional
[
bool
]
=
True
,
mm_plugin
:
"BasePlugin"
=
get_mm_plugin
(
name
=
"base"
),
mm_plugin
:
"BasePlugin"
=
get_mm_plugin
(
name
=
"base"
),
template_class
:
type
[
"Template"
]
=
Template
,
template_class
:
type
[
"Template"
]
=
Template
,
)
->
None
:
)
->
None
:
...
@@ -456,6 +518,7 @@ def register_template(
...
@@ -456,6 +518,7 @@ def register_template(
efficient_eos
=
efficient_eos
,
efficient_eos
=
efficient_eos
,
replace_eos
=
replace_eos
,
replace_eos
=
replace_eos
,
replace_jinja_template
=
replace_jinja_template
,
replace_jinja_template
=
replace_jinja_template
,
enable_thinking
=
enable_thinking
,
mm_plugin
=
mm_plugin
,
mm_plugin
=
mm_plugin
,
)
)
...
@@ -492,6 +555,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
...
@@ -492,6 +555,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
messages
=
[{
"role"
:
"user"
,
"content"
:
"{{content}}"
},
{
"role"
:
"assistant"
,
"content"
:
"{{content}}"
}]
messages
=
[{
"role"
:
"user"
,
"content"
:
"{{content}}"
},
{
"role"
:
"assistant"
,
"content"
:
"{{content}}"
}]
assistant_slot
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
False
,
tokenize
=
False
)
assistant_slot
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
False
,
tokenize
=
False
)
assistant_slot
=
assistant_slot
[
len
(
prefix
)
+
len
(
user_slot
)
:]
assistant_slot
=
assistant_slot
[
len
(
prefix
)
+
len
(
user_slot
)
:]
template_class
=
ReasoningTemplate
if
"<think>"
in
assistant_slot
else
Template
assistant_slot
=
assistant_slot
.
replace
(
"<think>"
,
""
).
replace
(
"</think>"
,
""
).
lstrip
(
"
\n
"
)
# remove thought tags
assistant_slot
=
assistant_slot
.
replace
(
"<think>"
,
""
).
replace
(
"</think>"
,
""
).
lstrip
(
"
\n
"
)
# remove thought tags
if
len
(
user_slot
)
>
len
(
user_slot_empty_system
):
if
len
(
user_slot
)
>
len
(
user_slot_empty_system
):
...
@@ -501,7 +565,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
...
@@ -501,7 +565,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
else
:
# if defaut_system is empty, user_slot_empty_system will be longer than user_slot
else
:
# if defaut_system is empty, user_slot_empty_system will be longer than user_slot
default_system
=
""
default_system
=
""
return
T
emplate
(
return
t
emplate
_class
(
format_user
=
StringFormatter
(
slots
=
[
user_slot
]),
format_user
=
StringFormatter
(
slots
=
[
user_slot
]),
format_assistant
=
StringFormatter
(
slots
=
[
assistant_slot
]),
format_assistant
=
StringFormatter
(
slots
=
[
assistant_slot
]),
format_system
=
StringFormatter
(
slots
=
[
system_slot
]),
format_system
=
StringFormatter
(
slots
=
[
system_slot
]),
...
@@ -515,6 +579,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
...
@@ -515,6 +579,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
efficient_eos
=
False
,
efficient_eos
=
False
,
replace_eos
=
False
,
replace_eos
=
False
,
replace_jinja_template
=
False
,
replace_jinja_template
=
False
,
enable_thinking
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"base"
),
mm_plugin
=
get_mm_plugin
(
name
=
"base"
),
)
)
...
@@ -543,6 +608,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
...
@@ -543,6 +608,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template
.
format_function
=
FunctionFormatter
(
slots
=
default_slots
,
tool_format
=
data_args
.
tool_format
)
template
.
format_function
=
FunctionFormatter
(
slots
=
default_slots
,
tool_format
=
data_args
.
tool_format
)
template
.
format_tools
=
ToolFormatter
(
tool_format
=
data_args
.
tool_format
)
template
.
format_tools
=
ToolFormatter
(
tool_format
=
data_args
.
tool_format
)
if
data_args
.
default_system
is
not
None
:
logger
.
info_rank0
(
f
"Using default system message:
{
data_args
.
default_system
}
."
)
template
.
default_system
=
data_args
.
default_system
template
.
enable_thinking
=
data_args
.
enable_thinking
template
.
fix_special_tokens
(
tokenizer
)
template
.
fix_special_tokens
(
tokenizer
)
template
.
fix_jinja_template
(
tokenizer
)
template
.
fix_jinja_template
(
tokenizer
)
return
template
return
template
...
@@ -756,6 +826,7 @@ register_template(
...
@@ -756,6 +826,7 @@ register_template(
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
),
),
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
)
...
@@ -774,6 +845,15 @@ register_template(
...
@@ -774,6 +845,15 @@ register_template(
)
)
# copied from deepseek3 template
register_template
(
name
=
"deepseekr1"
,
format_user
=
StringFormatter
(
slots
=
[
"<|User|>{{content}}<|Assistant|>"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
template_class
=
ReasoningTemplate
,
)
register_template
(
register_template
(
name
=
"deepseekcoder"
,
name
=
"deepseekcoder"
,
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n
### Response:"
]),
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n
### Response:"
]),
...
@@ -838,6 +918,7 @@ register_template(
...
@@ -838,6 +918,7 @@ register_template(
),
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
stop_words
=
[
"<end_of_turn>"
],
replace_eos
=
True
,
template_class
=
Llama2Template
,
template_class
=
Llama2Template
,
)
)
...
@@ -853,6 +934,7 @@ register_template(
...
@@ -853,6 +934,7 @@ register_template(
),
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
stop_words
=
[
"<end_of_turn>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
"gemma3"
,
image_token
=
"<image_soft_token>"
),
mm_plugin
=
get_mm_plugin
(
"gemma3"
,
image_token
=
"<image_soft_token>"
),
template_class
=
Llama2Template
,
template_class
=
Llama2Template
,
)
)
...
@@ -872,6 +954,22 @@ register_template(
...
@@ -872,6 +954,22 @@ register_template(
)
)
# copied from glm4 template
register_template
(
name
=
"glmz1"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}"
],
tool_format
=
"glm4"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|observation|>
\n
{{content}}<|assistant|>"
]),
format_tools
=
ToolFormatter
(
tool_format
=
"glm4"
),
format_prefix
=
EmptyFormatter
(
slots
=
[
"[gMASK]<sop>"
]),
stop_words
=
[
"<|user|>"
,
"<|observation|>"
],
efficient_eos
=
True
,
template_class
=
ReasoningTemplate
,
)
register_template
(
register_template
(
name
=
"granite3"
,
name
=
"granite3"
,
format_user
=
StringFormatter
(
format_user
=
StringFormatter
(
...
@@ -973,6 +1071,7 @@ register_template(
...
@@ -973,6 +1071,7 @@ register_template(
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
thought_words
=
(
"◁think▷"
,
"◁/think▷"
),
thought_words
=
(
"◁think▷"
,
"◁/think▷"
),
mm_plugin
=
get_mm_plugin
(
"kimi_vl"
,
image_token
=
"<|media_pad|>"
),
mm_plugin
=
get_mm_plugin
(
"kimi_vl"
,
image_token
=
"<|media_pad|>"
),
template_class
=
ReasoningTemplate
,
)
)
...
@@ -1018,6 +1117,7 @@ register_template(
...
@@ -1018,6 +1117,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
replace_eos
=
True
,
)
)
...
@@ -1037,6 +1137,7 @@ register_template(
...
@@ -1037,6 +1137,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot|>"
,
"<|eom|>"
],
stop_words
=
[
"<|eot|>"
,
"<|eom|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"llama4"
,
image_token
=
"<|image|>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"llama4"
,
image_token
=
"<|image|>"
),
)
)
...
@@ -1066,6 +1167,7 @@ register_template(
...
@@ -1066,6 +1167,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"mllama"
,
image_token
=
"<|image|>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"mllama"
,
image_token
=
"<|image|>"
),
)
)
...
@@ -1079,6 +1181,7 @@ register_template(
...
@@ -1079,6 +1181,7 @@ register_template(
format_system
=
StringFormatter
(
slots
=
[
"<|im_system|>system<|im_middle|>{{content}}<|im_end|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_system|>system<|im_middle|>{{content}}<|im_end|>"
]),
default_system
=
"You are a helpful assistant provided by Moonshot-AI."
,
default_system
=
"You are a helpful assistant provided by Moonshot-AI."
,
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
)
...
@@ -1131,6 +1234,7 @@ register_template(
...
@@ -1131,6 +1234,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
)
)
...
@@ -1163,6 +1267,7 @@ register_template(
...
@@ -1163,6 +1267,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
)
)
...
@@ -1233,6 +1338,42 @@ register_template(
...
@@ -1233,6 +1338,42 @@ register_template(
)
)
# copied from qwen template
register_template
(
name
=
"mimo"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
template_class
=
ReasoningTemplate
,
)
# copied from qwen2vl
register_template
(
name
=
"mimo_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are MiMo, an AI assistant developed by Xiaomi."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_vl"
,
image_token
=
"<|image_pad|>"
,
video_token
=
"<|video_pad|>"
),
template_class
=
ReasoningTemplate
,
)
# copied from chatml template
# copied from chatml template
register_template
(
register_template
(
name
=
"minicpm_v"
,
name
=
"minicpm_v"
,
...
@@ -1363,6 +1504,7 @@ register_template(
...
@@ -1363,6 +1504,7 @@ register_template(
),
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
stop_words
=
[
"<end_of_turn>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
),
template_class
=
Llama2Template
,
template_class
=
Llama2Template
,
)
)
...
@@ -1374,6 +1516,7 @@ register_template(
...
@@ -1374,6 +1516,7 @@ register_template(
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
stop_words
=
[
"<|end|>"
],
stop_words
=
[
"<|end|>"
],
replace_eos
=
True
,
)
)
...
@@ -1384,6 +1527,7 @@ register_template(
...
@@ -1384,6 +1527,7 @@ register_template(
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"<|endoftext|>"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"<|endoftext|>"
}]),
stop_words
=
[
"<|end|>"
],
stop_words
=
[
"<|end|>"
],
replace_eos
=
True
,
)
)
...
@@ -1395,6 +1539,7 @@ register_template(
...
@@ -1395,6 +1539,7 @@ register_template(
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system<|im_sep|>{{content}}<|im_end|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system<|im_sep|>{{content}}<|im_end|>"
]),
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
)
...
@@ -1425,6 +1570,7 @@ register_template(
...
@@ -1425,6 +1570,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
,
default_system
=
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
)
...
@@ -1440,6 +1586,8 @@ register_template(
...
@@ -1440,6 +1586,8 @@ register_template(
),
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
template_class
=
ReasoningTemplate
,
)
)
...
@@ -1451,6 +1599,7 @@ register_template(
...
@@ -1451,6 +1599,7 @@ register_template(
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
default_system
=
"You are a helpful assistant."
,
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_audio"
,
audio_token
=
"<|AUDIO|>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_audio"
,
audio_token
=
"<|AUDIO|>"
),
)
)
...
@@ -1468,6 +1617,7 @@ register_template(
...
@@ -1468,6 +1617,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_omni"
,
audio_token
=
"<|AUDIO|>"
,
image_token
=
"<|IMAGE|>"
,
video_token
=
"<|VIDEO|>"
name
=
"qwen2_omni"
,
audio_token
=
"<|AUDIO|>"
,
image_token
=
"<|IMAGE|>"
,
video_token
=
"<|VIDEO|>"
),
),
...
@@ -1486,6 +1636,7 @@ register_template(
...
@@ -1486,6 +1636,7 @@ register_template(
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_vl"
,
image_token
=
"<|image_pad|>"
,
video_token
=
"<|video_pad|>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_vl"
,
image_token
=
"<|image_pad|>"
,
video_token
=
"<|video_pad|>"
),
)
)
...
@@ -1503,6 +1654,20 @@ register_template(
...
@@ -1503,6 +1654,20 @@ register_template(
)
)
register_template
(
name
=
"seed_coder"
,
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"user
\n
{{content}}"
,
{
"eos_token"
},
{
"bos_token"
},
"assistant
\n
"
]
),
format_system
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"system
\n
{{content}}"
,
{
"eos_token"
}]),
default_system
=
(
"You are an AI programming assistant, utilizing the Seed-Coder model, developed by ByteDance Seed, "
"and you only answer questions related to computer science. For politically sensitive questions, "
"security and privacy issues, and other non-computer science questions, you will refuse to answer.
\n\n
"
),
)
# copied from llama3 template
# copied from llama3 template
register_template
(
register_template
(
name
=
"skywork_o1"
,
name
=
"skywork_o1"
,
...
@@ -1538,6 +1703,25 @@ register_template(
...
@@ -1538,6 +1703,25 @@ register_template(
)
)
register_template
(
name
=
"smollm"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
)
register_template
(
name
=
"smollm2"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
default_system
=
"You are a helpful AI assistant named SmolLM, trained by Hugging Face."
,
)
register_template
(
register_template
(
name
=
"solar"
,
name
=
"solar"
,
format_user
=
StringFormatter
(
slots
=
[
"### User:
\n
{{content}}
\n\n
### Assistant:
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"### User:
\n
{{content}}
\n\n
### Assistant:
\n
"
]),
...
...
src/llamafactory/data/tool_utils.py
View file @
0722acf1
...
@@ -93,6 +93,7 @@ class DefaultToolUtils(ToolUtils):
...
@@ -93,6 +93,7 @@ class DefaultToolUtils(ToolUtils):
tool_text
=
""
tool_text
=
""
tool_names
=
[]
tool_names
=
[]
for
tool
in
tools
:
for
tool
in
tools
:
tool
=
tool
.
get
(
"function"
,
""
)
if
tool
.
get
(
"type"
)
==
"function"
else
tool
param_text
=
""
param_text
=
""
for
name
,
param
in
tool
[
"parameters"
][
"properties"
].
items
():
for
name
,
param
in
tool
[
"parameters"
][
"properties"
].
items
():
required
,
enum
,
items
=
""
,
""
,
""
required
,
enum
,
items
=
""
,
""
,
""
...
@@ -124,11 +125,7 @@ class DefaultToolUtils(ToolUtils):
...
@@ -124,11 +125,7 @@ class DefaultToolUtils(ToolUtils):
@
override
@
override
@
staticmethod
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
function_text
=
""
return
"
\n
"
.
join
([
f
"Action:
{
name
}
\n
Action Input:
{
arguments
}
"
for
name
,
arguments
in
functions
])
for
name
,
arguments
in
functions
:
function_text
+=
f
"Action:
{
name
}
\n
Action Input:
{
arguments
}
\n
"
return
function_text
@
override
@
override
@
staticmethod
@
staticmethod
...
@@ -159,6 +156,7 @@ class GLM4ToolUtils(ToolUtils):
...
@@ -159,6 +156,7 @@ class GLM4ToolUtils(ToolUtils):
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
tool_text
=
""
tool_text
=
""
for
tool
in
tools
:
for
tool
in
tools
:
tool
=
tool
.
get
(
"function"
,
""
)
if
tool
.
get
(
"type"
)
==
"function"
else
tool
tool_text
+=
"
\n\n
## {name}
\n\n
{body}
\n
在调用上述函数时,请使用 Json 格式表示调用的参数。"
.
format
(
tool_text
+=
"
\n\n
## {name}
\n\n
{body}
\n
在调用上述函数时,请使用 Json 格式表示调用的参数。"
.
format
(
name
=
tool
[
"name"
],
body
=
json
.
dumps
(
tool
,
indent
=
4
,
ensure_ascii
=
False
)
name
=
tool
[
"name"
],
body
=
json
.
dumps
(
tool
,
indent
=
4
,
ensure_ascii
=
False
)
)
)
...
@@ -200,7 +198,7 @@ class Llama3ToolUtils(ToolUtils):
...
@@ -200,7 +198,7 @@ class Llama3ToolUtils(ToolUtils):
date
=
datetime
.
now
().
strftime
(
"%d %b %Y"
)
date
=
datetime
.
now
().
strftime
(
"%d %b %Y"
)
tool_text
=
""
tool_text
=
""
for
tool
in
tools
:
for
tool
in
tools
:
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
tool
}
wrapped_tool
=
tool
if
tool
.
get
(
"type"
)
==
"function"
else
{
"type"
:
"function"
,
"function"
:
tool
}
tool_text
+=
json
.
dumps
(
wrapped_tool
,
indent
=
4
,
ensure_ascii
=
False
)
+
"
\n\n
"
tool_text
+=
json
.
dumps
(
wrapped_tool
,
indent
=
4
,
ensure_ascii
=
False
)
+
"
\n\n
"
return
LLAMA3_TOOL_PROMPT
.
format
(
date
=
date
,
tool_text
=
tool_text
)
return
LLAMA3_TOOL_PROMPT
.
format
(
date
=
date
,
tool_text
=
tool_text
)
...
@@ -208,24 +206,23 @@ class Llama3ToolUtils(ToolUtils):
...
@@ -208,24 +206,23 @@ class Llama3ToolUtils(ToolUtils):
@
override
@
override
@
staticmethod
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
if
len
(
functions
)
>
1
:
function_objects
=
[{
"name"
:
name
,
"parameters"
:
json
.
loads
(
arguments
)}
for
name
,
arguments
in
functions
]
raise
ValueError
(
"Llama-3 does not support parallel functions."
)
return
json
.
dumps
(
function_objects
[
0
]
if
len
(
function_objects
)
==
1
else
function_objects
,
ensure_ascii
=
False
)
return
f
'{{"name": "
{
functions
[
0
].
name
}
", "parameters":
{
functions
[
0
].
arguments
}
}}'
@
override
@
override
@
staticmethod
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
try
:
try
:
tool
=
json
.
loads
(
content
.
strip
())
tool
s
=
json
.
loads
(
content
.
strip
())
except
json
.
JSONDecodeError
:
except
json
.
JSONDecodeError
:
return
content
return
content
if
"name"
not
in
tool
or
"parameters"
not
in
tool
:
tools
=
[
tools
]
if
not
isinstance
(
tools
,
list
)
else
tools
try
:
return
[
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"parameters"
],
ensure_ascii
=
False
))
for
tool
in
tools
]
except
KeyError
:
return
content
return
content
return
[
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"parameters"
],
ensure_ascii
=
False
))]
class
MistralToolUtils
(
ToolUtils
):
class
MistralToolUtils
(
ToolUtils
):
r
"""Mistral v0.3 tool using template."""
r
"""Mistral v0.3 tool using template."""
...
@@ -235,18 +232,16 @@ class MistralToolUtils(ToolUtils):
...
@@ -235,18 +232,16 @@ class MistralToolUtils(ToolUtils):
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
wrapped_tools
=
[]
wrapped_tools
=
[]
for
tool
in
tools
:
for
tool
in
tools
:
wrapped_tools
.
append
({
"type"
:
"function"
,
"function"
:
tool
})
wrapped_tools
.
append
(
tool
if
tool
.
get
(
"type"
)
==
"function"
else
{
"type"
:
"function"
,
"function"
:
tool
})
return
"[AVAILABLE_TOOLS] "
+
json
.
dumps
(
wrapped_tools
,
ensure_ascii
=
False
)
+
"[/AVAILABLE_TOOLS]"
return
"[AVAILABLE_TOOLS] "
+
json
.
dumps
(
wrapped_tools
,
ensure_ascii
=
False
)
+
"[/AVAILABLE_TOOLS]"
@
override
@
override
@
staticmethod
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
function_texts
=
[]
return
json
.
dumps
(
for
name
,
arguments
in
functions
:
[{
"name"
:
name
,
"arguments"
:
json
.
loads
(
arguments
)}
for
name
,
arguments
in
functions
],
ensure_ascii
=
False
function_texts
.
append
(
f
'{{"name": "
{
name
}
", "arguments":
{
arguments
}
}}'
)
)
return
"["
+
", "
.
join
(
function_texts
)
+
"]"
@
override
@
override
@
staticmethod
@
staticmethod
...
@@ -256,17 +251,11 @@ class MistralToolUtils(ToolUtils):
...
@@ -256,17 +251,11 @@ class MistralToolUtils(ToolUtils):
except
json
.
JSONDecodeError
:
except
json
.
JSONDecodeError
:
return
content
return
content
if
not
isinstance
(
tools
,
list
):
tools
=
[
tools
]
if
not
isinstance
(
tools
,
list
)
else
tools
tools
=
[
tools
]
try
:
return
[
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"arguments"
],
ensure_ascii
=
False
))
for
tool
in
tools
]
results
=
[]
except
KeyError
:
for
tool
in
tools
:
return
content
if
"name"
not
in
tool
or
"arguments"
not
in
tool
:
return
content
results
.
append
(
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"arguments"
],
ensure_ascii
=
False
)))
return
results
class
QwenToolUtils
(
ToolUtils
):
class
QwenToolUtils
(
ToolUtils
):
...
@@ -277,7 +266,7 @@ class QwenToolUtils(ToolUtils):
...
@@ -277,7 +266,7 @@ class QwenToolUtils(ToolUtils):
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
tool_text
=
""
tool_text
=
""
for
tool
in
tools
:
for
tool
in
tools
:
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
tool
}
wrapped_tool
=
tool
if
tool
.
get
(
"type"
)
==
"function"
else
{
"type"
:
"function"
,
"function"
:
tool
}
tool_text
+=
"
\n
"
+
json
.
dumps
(
wrapped_tool
,
ensure_ascii
=
False
)
tool_text
+=
"
\n
"
+
json
.
dumps
(
wrapped_tool
,
ensure_ascii
=
False
)
return
QWEN_TOOL_PROMPT
.
format
(
tool_text
=
tool_text
)
return
QWEN_TOOL_PROMPT
.
format
(
tool_text
=
tool_text
)
...
@@ -285,13 +274,11 @@ class QwenToolUtils(ToolUtils):
...
@@ -285,13 +274,11 @@ class QwenToolUtils(ToolUtils):
@
override
@
override
@
staticmethod
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
function_texts
=
[]
function_texts
=
[
for
name
,
arguments
in
functions
:
json
.
dumps
({
"name"
:
name
,
"arguments"
:
json
.
loads
(
arguments
)},
ensure_ascii
=
False
)
function_texts
.
append
(
for
name
,
arguments
in
functions
"<tool_call>
\n
"
+
f
'{{"name": "
{
name
}
", "arguments":
{
arguments
}
}}'
+
"
\n
</tool_call>"
]
)
return
"
\n
"
.
join
([
f
"<tool_call>
\n
{
text
}
\n
</tool_call>"
for
text
in
function_texts
])
return
"
\n
"
.
join
(
function_texts
)
@
override
@
override
@
staticmethod
@
staticmethod
...
...
src/llamafactory/extras/constants.py
View file @
0722acf1
...
@@ -513,7 +513,7 @@ register_model_group(
...
@@ -513,7 +513,7 @@ register_model_group(
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"DeepSeek-V2-236B-
Chat-0628
"
:
{
"DeepSeek-V2-236B-
0628-Chat
"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2-Chat-0628"
,
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2-Chat-0628"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2-Chat-0628"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2-Chat-0628"
,
},
},
...
@@ -521,7 +521,7 @@ register_model_group(
...
@@ -521,7 +521,7 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2.5"
,
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2.5"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2.5"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2.5"
,
},
},
"DeepSeek-V2.5-236B-
Chat-1210
"
:
{
"DeepSeek-V2.5-236B-
1210-Chat
"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2.5-1210"
,
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2.5-1210"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2.5-1210"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2.5-1210"
,
},
},
...
@@ -533,6 +533,17 @@ register_model_group(
...
@@ -533,6 +533,17 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V3"
,
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V3"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V3"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V3"
,
},
},
"DeepSeek-V3-671B-0324-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V3-0324"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V3-0324"
,
},
},
template
=
"deepseek3"
,
)
register_model_group
(
models
=
{
"DeepSeek-R1-1.5B-Distill"
:
{
"DeepSeek-R1-1.5B-Distill"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
,
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
,
...
@@ -545,6 +556,10 @@ register_model_group(
...
@@ -545,6 +556,10 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
,
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
,
},
},
"DeepSeek-R1-8B-0528-Distill"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-0528-Qwen3-8B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-0528-Qwen3-8B"
,
},
"DeepSeek-R1-14B-Distill"
:
{
"DeepSeek-R1-14B-Distill"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
,
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
,
...
@@ -565,8 +580,12 @@ register_model_group(
...
@@ -565,8 +580,12 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1"
,
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1"
,
},
},
"DeepSeek-R1-671B-0528-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-0528"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-0528"
,
},
},
},
template
=
"deepseek
3
"
,
template
=
"deepseek
r1
"
,
)
)
...
@@ -673,6 +692,10 @@ register_model_group(
...
@@ -673,6 +692,10 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"google/gemma-3-1b-it"
,
DownloadSource
.
DEFAULT
:
"google/gemma-3-1b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-1b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-1b-it"
,
},
},
"MedGemma-27B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/medgemma-27b-text-it"
,
DownloadSource
.
MODELSCOPE
:
"google/medgemma-27b-text-it"
,
},
},
},
template
=
"gemma"
,
template
=
"gemma"
,
)
)
...
@@ -704,6 +727,14 @@ register_model_group(
...
@@ -704,6 +727,14 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"google/gemma-3-27b-it"
,
DownloadSource
.
DEFAULT
:
"google/gemma-3-27b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-27b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-27b-it"
,
},
},
"MedGemma-4B"
:
{
DownloadSource
.
DEFAULT
:
"google/medgemma-4b-pt"
,
DownloadSource
.
MODELSCOPE
:
"google/medgemma-4b-pt"
,
},
"MedGemma-4B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/medgemma-4b-it"
,
DownloadSource
.
MODELSCOPE
:
"google/medgemma-4b-it"
,
},
},
},
template
=
"gemma3"
,
template
=
"gemma3"
,
multimodal
=
True
,
multimodal
=
True
,
...
@@ -737,6 +768,13 @@ register_model_group(
...
@@ -737,6 +768,13 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"THUDM/GLM-4-32B-0414"
,
DownloadSource
.
DEFAULT
:
"THUDM/GLM-4-32B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4-32B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4-32B-0414"
,
},
},
},
template
=
"glm4"
,
)
register_model_group
(
models
=
{
"GLM-Z1-9B-0414-Chat"
:
{
"GLM-Z1-9B-0414-Chat"
:
{
DownloadSource
.
DEFAULT
:
"THUDM/GLM-Z1-9B-0414"
,
DownloadSource
.
DEFAULT
:
"THUDM/GLM-Z1-9B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-Z1-9B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-Z1-9B-0414"
,
...
@@ -746,7 +784,7 @@ register_model_group(
...
@@ -746,7 +784,7 @@ register_model_group(
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-Z1-32B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-Z1-32B-0414"
,
},
},
},
},
template
=
"glm
4
"
,
template
=
"glm
z1
"
,
)
)
...
@@ -869,12 +907,13 @@ register_model_group(
...
@@ -869,12 +907,13 @@ register_model_group(
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"Granite-
3.2-1B-A400M-Base
"
:
{
"Granite-
Vision-3.2-2B
"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-vision-3.2-2b"
,
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-vision-3.2-2b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-vision-3.2-2b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-vision-3.2-2b"
,
},
},
},
},
template
=
"granite3_vision"
,
template
=
"granite3_vision"
,
multimodal
=
True
,
)
)
...
@@ -1398,6 +1437,45 @@ register_model_group(
...
@@ -1398,6 +1437,45 @@ register_model_group(
)
)
register_model_group
(
models
=
{
"MiMo-7B-Base"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-7B-Base"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-7B-Base"
,
},
"MiMo-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-7B-SFT"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-7B-SFT"
,
},
"MiMo-7B-Instruct-RL"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-7B-RL"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-7B-RL"
,
},
"MiMo-7B-RL-ZERO"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-7B-RL-ZERO"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-7B-RL-ZERO"
,
},
},
template
=
"mimo"
,
)
register_model_group
(
models
=
{
"MiMo-7B-VL-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-VL-7B-SFT"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-VL-7B-SFT"
,
},
"MiMo-7B-VL-RL"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-VL-7B-RL"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-VL-7B-RL"
,
},
},
template
=
"mimo_vl"
,
multimodal
=
True
,
)
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"MiniCPM-2B-SFT-Chat"
:
{
"MiniCPM-2B-SFT-Chat"
:
{
...
@@ -2461,6 +2539,38 @@ register_model_group(
...
@@ -2461,6 +2539,38 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-235B-A22B"
,
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-235B-A22B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-235B-A22B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-235B-A22B"
,
},
},
"Qwen3-0.6B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-0.6B-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-0.6B-GPTQ-Int8"
,
},
"Qwen3-1.7B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-1.7B-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-1.7B-GPTQ-Int8"
,
},
"Qwen3-4B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-4B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-4B-AWQ"
,
},
"Qwen3-8B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-8B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-8B-AWQ"
,
},
"Qwen3-14B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-14B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-14B-AWQ"
,
},
"Qwen3-32B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-32B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-32B-AWQ"
,
},
"Qwen3-30B-A3B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-30B-A3B-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-30B-A3B-GPTQ-Int4"
,
},
"Qwen3-235B-A22B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-235B-A22B-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-235B-A22B-GPTQ-Int4"
,
},
},
},
template
=
"qwen3"
,
template
=
"qwen3"
,
)
)
...
@@ -2484,10 +2594,22 @@ register_model_group(
...
@@ -2484,10 +2594,22 @@ register_model_group(
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"Qwen2.5-Omni-3B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-3B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-3B"
,
},
"Qwen2.5-Omni-7B"
:
{
"Qwen2.5-Omni-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-7B"
,
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-7B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-7B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-7B"
,
}
},
"Qwen2.5-Omni-7B-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-7B-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-7B-GPTQ-Int4"
,
},
"Qwen2.5-Omni-7B-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Omni-7B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Omni-7B-AWQ"
,
},
},
},
template
=
"qwen2_omni"
,
template
=
"qwen2_omni"
,
multimodal
=
True
,
multimodal
=
True
,
...
@@ -2598,15 +2720,17 @@ register_model_group(
...
@@ -2598,15 +2720,17 @@ register_model_group(
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"S
OLAR-10.7B-v1.0
"
:
{
"S
eed-Coder-8B-Base
"
:
{
DownloadSource
.
DEFAULT
:
"
upstage/SOLAR-10.7B-v1.0
"
,
DownloadSource
.
DEFAULT
:
"
ByteDance-Seed/Seed-Coder-8B-Base
"
,
},
},
"SOLAR-10.7B-Instruct-v1.0"
:
{
"Seed-Coder-8B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"upstage/SOLAR-10.7B-Instruct-v1.0"
,
DownloadSource
.
DEFAULT
:
"ByteDance-Seed/Seed-Coder-8B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/SOLAR-10.7B-Instruct-v1.0"
,
},
"Seed-Coder-8B-Instruct-Reasoning"
:
{
DownloadSource
.
DEFAULT
:
"ByteDance-Seed/Seed-Coder-8B-Reasoning-bf16"
,
},
},
},
},
template
=
"s
ola
r"
,
template
=
"s
eed_code
r"
,
)
)
...
@@ -2631,6 +2755,82 @@ register_model_group(
...
@@ -2631,6 +2755,82 @@ register_model_group(
)
)
register_model_group
(
models
=
{
"SmolLM-135M"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM-135M"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM-135M"
,
},
"SmolLM-360M"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM-360M"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM-360M"
,
},
"SmolLM-1.7B"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM-1.7B"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM-1.7B"
,
},
"SmolLM-135M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM-135M-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM-135M-Instruct"
,
},
"SmolLM-360M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM-360M-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM-360M-Instruct"
,
},
"SmolLM-1.7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM-1.7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM-1.7B-Instruct"
,
},
},
template
=
"smollm"
,
)
register_model_group
(
models
=
{
"SmolLM2-135M"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM2-135M"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM2-135M"
,
},
"SmolLM2-360M"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM2-360M"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM2-360M"
,
},
"SmolLM2-1.7B"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM2-1.7B"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM2-1.7B"
,
},
"SmolLM2-135M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM2-135M-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM2-135M-Instruct"
,
},
"SmolLM2-360M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM2-360M-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM2-360M-Instruct"
,
},
"SmolLM2-1.7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceTB/SmolLM2-1.7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"HuggingFaceTB/SmolLM2-1.7B-Instruct"
,
},
},
template
=
"smollm2"
,
)
register_model_group
(
models
=
{
"SOLAR-10.7B-v1.0"
:
{
DownloadSource
.
DEFAULT
:
"upstage/SOLAR-10.7B-v1.0"
,
},
"SOLAR-10.7B-Instruct-v1.0"
:
{
DownloadSource
.
DEFAULT
:
"upstage/SOLAR-10.7B-Instruct-v1.0"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/SOLAR-10.7B-Instruct-v1.0"
,
},
},
template
=
"solar"
,
)
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"StarCoder2-3B"
:
{
"StarCoder2-3B"
:
{
...
...
src/llamafactory/extras/env.py
View file @
0722acf1
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
platform
import
platform
import
accelerate
import
accelerate
...
@@ -83,4 +84,9 @@ def print_env() -> None:
...
@@ -83,4 +84,9 @@ def print_env() -> None:
except
Exception
:
except
Exception
:
pass
pass
if
os
.
path
.
exists
(
"data"
):
info
[
"Default data directory"
]
=
"detected"
else
:
info
[
"Default data directory"
]
=
"not detected"
print
(
"
\n
"
+
"
\n
"
.
join
([
f
"-
{
key
}
:
{
value
}
"
for
key
,
value
in
info
.
items
()])
+
"
\n
"
)
print
(
"
\n
"
+
"
\n
"
.
join
([
f
"-
{
key
}
:
{
value
}
"
for
key
,
value
in
info
.
items
()])
+
"
\n
"
)
src/llamafactory/extras/misc.py
View file @
0722acf1
...
@@ -79,20 +79,27 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
...
@@ -79,20 +79,27 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
logger
.
warning_rank0_once
(
"Version checking has been disabled, may lead to unexpected behaviors."
)
logger
.
warning_rank0_once
(
"Version checking has been disabled, may lead to unexpected behaviors."
)
return
return
if
"gptmodel"
in
requirement
or
"autoawq"
in
requirement
:
pip_command
=
f
"pip install
{
requirement
}
--no-build-isolation"
else
:
pip_command
=
f
"pip install
{
requirement
}
"
if
mandatory
:
if
mandatory
:
hint
=
f
"To fix: run `pip
install
{
requirement
}
`."
hint
=
f
"To fix: run `
{
pip
_command
}
`."
else
:
else
:
hint
=
f
"To fix: run `pip
install
{
requirement
}
` or set `DISABLE_VERSION_CHECK=1` to skip this check."
hint
=
f
"To fix: run `
{
pip
_command
}
` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version
(
requirement
,
hint
)
require_version
(
requirement
,
hint
)
def
check_dependencies
()
->
None
:
def
check_dependencies
()
->
None
:
r
"""Check the version of the required packages."""
r
"""Check the version of the required packages."""
check_version
(
"transformers>=4.45.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0"
)
check_version
(
check_version
(
"datasets>=2.16.0,<=3.5.0"
)
"transformers>=4.45.0,<=4.52.4,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0,!=4.52.0"
check_version
(
"accelerate>=0.34.0,<=1.6.0"
)
)
check_version
(
"peft>=0.14.0,<=0.15.1"
)
check_version
(
"datasets>=2.16.0,<=3.6.0"
)
check_version
(
"accelerate>=0.34.0,<=1.7.0"
)
check_version
(
"peft>=0.14.0,<=0.15.2"
)
check_version
(
"trl>=0.8.6,<=0.9.6"
)
check_version
(
"trl>=0.8.6,<=0.9.6"
)
if
is_transformers_version_greater_than
(
"4.46.0"
)
and
not
is_transformers_version_greater_than
(
"4.48.1"
):
if
is_transformers_version_greater_than
(
"4.46.0"
)
and
not
is_transformers_version_greater_than
(
"4.48.1"
):
logger
.
warning_rank0_once
(
"There are known bugs in transformers v4.46.0-v4.48.0, please use other versions."
)
logger
.
warning_rank0_once
(
"There are known bugs in transformers v4.46.0-v4.48.0, please use other versions."
)
...
...
src/llamafactory/hparams/data_args.py
View file @
0722acf1
...
@@ -99,6 +99,10 @@ class DataArguments:
...
@@ -99,6 +99,10 @@ class DataArguments:
default
=
0.0
,
default
=
0.0
,
metadata
=
{
"help"
:
"Size of the validation set, should be an integer or a float in range `[0,1)`."
},
metadata
=
{
"help"
:
"Size of the validation set, should be an integer or a float in range `[0,1)`."
},
)
)
eval_on_each_dataset
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to evaluate on each dataset separately."
},
)
packing
:
Optional
[
bool
]
=
field
(
packing
:
Optional
[
bool
]
=
field
(
default
=
None
,
default
=
None
,
metadata
=
{
"help"
:
"Enable sequences packing in training. Will automatically enable in pre-training."
},
metadata
=
{
"help"
:
"Enable sequences packing in training. Will automatically enable in pre-training."
},
...
@@ -111,6 +115,14 @@ class DataArguments:
...
@@ -111,6 +115,14 @@ class DataArguments:
default
=
None
,
default
=
None
,
metadata
=
{
"help"
:
"Tool format to use for constructing function calling examples."
},
metadata
=
{
"help"
:
"Tool format to use for constructing function calling examples."
},
)
)
default_system
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Override the default system message in the template."
},
)
enable_thinking
:
Optional
[
bool
]
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to enable thinking mode for reasoning models."
},
)
tokenized_path
:
Optional
[
str
]
=
field
(
tokenized_path
:
Optional
[
str
]
=
field
(
default
=
None
,
default
=
None
,
metadata
=
{
metadata
=
{
...
@@ -121,6 +133,10 @@ class DataArguments:
...
@@ -121,6 +133,10 @@ class DataArguments:
)
)
},
},
)
)
data_shared_file_system
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use a shared file system for the datasets."
},
)
def
__post_init__
(
self
):
def
__post_init__
(
self
):
def
split_arg
(
arg
):
def
split_arg
(
arg
):
...
...
src/llamafactory/hparams/generating_args.py
View file @
0722acf1
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
from
dataclasses
import
asdict
,
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Optional
from
typing
import
Any
from
transformers
import
GenerationConfig
from
transformers
import
GenerationConfig
...
@@ -62,10 +62,6 @@ class GeneratingArguments:
...
@@ -62,10 +62,6 @@ class GeneratingArguments:
default
=
1.0
,
default
=
1.0
,
metadata
=
{
"help"
:
"Exponential penalty to the length that is used with beam-based generation."
},
metadata
=
{
"help"
:
"Exponential penalty to the length that is used with beam-based generation."
},
)
)
default_system
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Default system message to use in chat completion."
},
)
skip_special_tokens
:
bool
=
field
(
skip_special_tokens
:
bool
=
field
(
default
=
True
,
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to remove special tokens in the decoding."
},
metadata
=
{
"help"
:
"Whether or not to remove special tokens in the decoding."
},
...
...
src/llamafactory/hparams/model_args.py
View file @
0722acf1
...
@@ -235,10 +235,6 @@ class ProcessorArguments:
...
@@ -235,10 +235,6 @@ class ProcessorArguments:
default
=
False
,
default
=
False
,
metadata
=
{
"help"
:
"Whether to crop the image to patches for internvl."
},
metadata
=
{
"help"
:
"Whether to crop the image to patches for internvl."
},
)
)
use_audio_in_video
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use audio in video inputs."
},
)
video_max_pixels
:
int
=
field
(
video_max_pixels
:
int
=
field
(
default
=
256
*
256
,
default
=
256
*
256
,
metadata
=
{
"help"
:
"The maximum number of pixels of video inputs."
},
metadata
=
{
"help"
:
"The maximum number of pixels of video inputs."
},
...
@@ -255,6 +251,10 @@ class ProcessorArguments:
...
@@ -255,6 +251,10 @@ class ProcessorArguments:
default
=
128
,
default
=
128
,
metadata
=
{
"help"
:
"The maximum number of sampled frames for video inputs."
},
metadata
=
{
"help"
:
"The maximum number of sampled frames for video inputs."
},
)
)
use_audio_in_video
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use audio in video inputs."
},
)
audio_sampling_rate
:
int
=
field
(
audio_sampling_rate
:
int
=
field
(
default
=
16000
,
default
=
16000
,
metadata
=
{
"help"
:
"The sampling rate of audio inputs."
},
metadata
=
{
"help"
:
"The sampling rate of audio inputs."
},
...
@@ -364,6 +364,12 @@ class SGLangArguments:
...
@@ -364,6 +364,12 @@ class SGLangArguments:
default
=
None
,
default
=
None
,
metadata
=
{
"help"
:
"Config to initialize the SGLang engine. Please use JSON strings."
},
metadata
=
{
"help"
:
"Config to initialize the SGLang engine. Please use JSON strings."
},
)
)
sglang_lora_backend
:
Literal
[
"triton"
,
"flashinfer"
]
=
field
(
default
=
"triton"
,
metadata
=
{
"help"
:
"The backend of running GEMM kernels for Lora modules. Recommend using the Triton LoRA backend for better performance and stability."
},
)
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
isinstance
(
self
.
sglang_config
,
str
)
and
self
.
sglang_config
.
startswith
(
"{"
):
if
isinstance
(
self
.
sglang_config
,
str
)
and
self
.
sglang_config
.
startswith
(
"{"
):
...
...
src/llamafactory/hparams/parser.py
View file @
0722acf1
...
@@ -148,10 +148,10 @@ def _check_extra_dependencies(
...
@@ -148,10 +148,10 @@ def _check_extra_dependencies(
check_version
(
"mixture-of-depth>=1.1.6"
,
mandatory
=
True
)
check_version
(
"mixture-of-depth>=1.1.6"
,
mandatory
=
True
)
if
model_args
.
infer_backend
==
EngineName
.
VLLM
:
if
model_args
.
infer_backend
==
EngineName
.
VLLM
:
check_version
(
"vllm>=0.4.3,<=0.8.
4
"
)
check_version
(
"vllm>=0.4.3,<=0.8.
6
"
)
check_version
(
"vllm"
,
mandatory
=
True
)
check_version
(
"vllm"
,
mandatory
=
True
)
elif
model_args
.
infer_backend
==
EngineName
.
SGLANG
:
elif
model_args
.
infer_backend
==
EngineName
.
SGLANG
:
check_version
(
"sglang>=0.4.
4
"
)
check_version
(
"sglang>=0.4.
5
"
)
check_version
(
"sglang"
,
mandatory
=
True
)
check_version
(
"sglang"
,
mandatory
=
True
)
if
finetuning_args
.
use_galore
:
if
finetuning_args
.
use_galore
:
...
...
src/llamafactory/hparams/training_args.py
View file @
0722acf1
...
@@ -64,6 +64,7 @@ class RayArguments:
...
@@ -64,6 +64,7 @@ class RayArguments:
raise
ValueError
(
raise
ValueError
(
f
"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got
{
self
.
ray_storage_filesystem
}
"
f
"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got
{
self
.
ray_storage_filesystem
}
"
)
)
import
pyarrow.fs
as
fs
import
pyarrow.fs
as
fs
if
self
.
ray_storage_filesystem
==
"s3"
:
if
self
.
ray_storage_filesystem
==
"s3"
:
...
...
src/llamafactory/model/model_utils/attention.py
View file @
0722acf1
...
@@ -29,10 +29,8 @@ if TYPE_CHECKING:
...
@@ -29,10 +29,8 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
configure_attn_implementation
(
def
configure_attn_implementation
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
)
->
None
:
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
if
getattr
(
config
,
"model_type"
,
None
)
==
"gemma2"
:
)
->
None
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"gemma2"
and
is_trainable
:
if
model_args
.
flash_attn
==
AttentionFunction
.
AUTO
or
model_args
.
flash_attn
==
AttentionFunction
.
FA2
:
if
model_args
.
flash_attn
==
AttentionFunction
.
AUTO
or
model_args
.
flash_attn
==
AttentionFunction
.
FA2
:
if
is_flash_attn_2_available
():
if
is_flash_attn_2_available
():
if
model_args
.
flash_attn
!=
AttentionFunction
.
FA2
:
if
model_args
.
flash_attn
!=
AttentionFunction
.
FA2
:
...
...
src/llamafactory/model/model_utils/liger_kernel.py
View file @
0722acf1
...
@@ -45,16 +45,24 @@ def apply_liger_kernel(
...
@@ -45,16 +45,24 @@ def apply_liger_kernel(
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3
as
apply_liger_kernel
elif
model_type
==
"gemma3_text"
:
elif
model_type
==
"gemma3_text"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3_text
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3_text
as
apply_liger_kernel
elif
model_type
==
"paligemma"
:
elif
model_type
==
"glm4"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_paligemma
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_glm4
as
apply_liger_kernel
elif
model_type
==
"granite"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_granite
as
apply_liger_kernel
elif
model_type
==
"llama"
:
elif
model_type
==
"llama"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_llama
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_llama
as
apply_liger_kernel
elif
model_type
==
"llava"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_llava
as
apply_liger_kernel
elif
model_type
==
"mistral"
:
elif
model_type
==
"mistral"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_mistral
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_mistral
as
apply_liger_kernel
elif
model_type
==
"mixtral"
:
elif
model_type
==
"mixtral"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_mixtral
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_mixtral
as
apply_liger_kernel
elif
model_type
==
"mllama"
:
elif
model_type
==
"mllama"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_mllama
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_mllama
as
apply_liger_kernel
elif
model_type
==
"olmo2"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_olmo2
as
apply_liger_kernel
elif
model_type
==
"paligemma"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_paligemma
as
apply_liger_kernel
elif
model_type
==
"phi3"
:
elif
model_type
==
"phi3"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_phi3
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_phi3
as
apply_liger_kernel
elif
model_type
==
"qwen2"
:
elif
model_type
==
"qwen2"
:
...
@@ -63,6 +71,8 @@ def apply_liger_kernel(
...
@@ -63,6 +71,8 @@ def apply_liger_kernel(
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen2_vl
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen2_vl
as
apply_liger_kernel
elif
model_type
==
"qwen2_5_vl"
:
elif
model_type
==
"qwen2_5_vl"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen2_5_vl
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen2_5_vl
as
apply_liger_kernel
elif
model_type
==
"qwen3"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen3
as
apply_liger_kernel
else
:
else
:
logger
.
warning_rank0
(
"Current model does not support liger kernel."
)
logger
.
warning_rank0
(
"Current model does not support liger kernel."
)
return
return
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment