Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
317a82e2
Commit
317a82e2
authored
Mar 07, 2025
by
chenych
Browse files
Add QWQ-32B
parent
37b0ad9f
Changes
255
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1754 additions
and
1221 deletions
+1754
-1221
src/llamafactory/data/formatter.py
src/llamafactory/data/formatter.py
+16
-13
src/llamafactory/data/loader.py
src/llamafactory/data/loader.py
+96
-23
src/llamafactory/data/mm_plugin.py
src/llamafactory/data/mm_plugin.py
+482
-153
src/llamafactory/data/parser.py
src/llamafactory/data/parser.py
+24
-32
src/llamafactory/data/preprocess.py
src/llamafactory/data/preprocess.py
+0
-111
src/llamafactory/data/processor/__init__.py
src/llamafactory/data/processor/__init__.py
+17
-0
src/llamafactory/data/processor/feedback.py
src/llamafactory/data/processor/feedback.py
+129
-0
src/llamafactory/data/processor/pairwise.py
src/llamafactory/data/processor/pairwise.py
+118
-0
src/llamafactory/data/processor/pretrain.py
src/llamafactory/data/processor/pretrain.py
+57
-0
src/llamafactory/data/processor/processor_utils.py
src/llamafactory/data/processor/processor_utils.py
+37
-2
src/llamafactory/data/processor/supervised.py
src/llamafactory/data/processor/supervised.py
+200
-0
src/llamafactory/data/processor/unsupervised.py
src/llamafactory/data/processor/unsupervised.py
+91
-0
src/llamafactory/data/processors/__init__.py
src/llamafactory/data/processors/__init__.py
+0
-0
src/llamafactory/data/processors/feedback.py
src/llamafactory/data/processors/feedback.py
+0
-130
src/llamafactory/data/processors/mm_utils.py
src/llamafactory/data/processors/mm_utils.py
+0
-27
src/llamafactory/data/processors/pairwise.py
src/llamafactory/data/processors/pairwise.py
+0
-119
src/llamafactory/data/processors/pretrain.py
src/llamafactory/data/processors/pretrain.py
+0
-59
src/llamafactory/data/processors/supervised.py
src/llamafactory/data/processors/supervised.py
+0
-219
src/llamafactory/data/processors/unsupervised.py
src/llamafactory/data/processors/unsupervised.py
+0
-104
src/llamafactory/data/template.py
src/llamafactory/data/template.py
+487
-229
No files found.
src/llamafactory/data/formatter.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -86,19 +86,25 @@ class StringFormatter(Formatter):
...
@@ -86,19 +86,25 @@ class StringFormatter(Formatter):
elif
isinstance
(
slot
,
(
dict
,
set
)):
elif
isinstance
(
slot
,
(
dict
,
set
)):
elements
.
append
(
slot
)
elements
.
append
(
slot
)
else
:
else
:
raise
RuntimeError
(
f
"Input must be string, set[str] or dict[str, str], got
{
type
(
slot
)
}
"
)
raise
RuntimeError
(
f
"Input must be string, set[str] or dict[str, str], got
{
type
(
slot
)
}
.
"
)
return
elements
return
elements
@
dataclass
@
dataclass
class
FunctionFormatter
(
Formatter
):
class
FunctionFormatter
(
String
Formatter
):
def
__post_init__
(
self
):
def
__post_init__
(
self
):
super
().
__post_init__
()
self
.
tool_utils
=
get_tool_utils
(
self
.
tool_format
)
self
.
tool_utils
=
get_tool_utils
(
self
.
tool_format
)
@
override
@
override
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
content
=
kwargs
.
pop
(
"content"
)
content
:
str
=
kwargs
.
pop
(
"content"
)
regex
=
re
.
compile
(
r
"<think>(.*)</think>"
,
re
.
DOTALL
)
thought
=
re
.
search
(
regex
,
content
)
if
thought
:
content
=
content
.
replace
(
thought
.
group
(
0
),
""
)
functions
:
List
[
"FunctionCall"
]
=
[]
functions
:
List
[
"FunctionCall"
]
=
[]
try
:
try
:
tool_calls
=
json
.
loads
(
content
)
tool_calls
=
json
.
loads
(
content
)
...
@@ -111,16 +117,13 @@ class FunctionFormatter(Formatter):
...
@@ -111,16 +117,13 @@ class FunctionFormatter(Formatter):
)
)
except
json
.
JSONDecodeError
:
except
json
.
JSONDecodeError
:
raise
RuntimeError
(
f
"Invalid JSON format in function message:
{
str
([
content
])
}
"
)
# flat string
raise
RuntimeError
(
f
"Invalid JSON format in function message:
{
str
([
content
])
}
.
"
)
# flat string
elements
=
[]
function_str
=
self
.
tool_utils
.
function_formatter
(
functions
)
for
slot
in
self
.
slots
:
if
thought
:
if
slot
==
"{{content}}"
:
function_str
=
thought
.
group
(
1
)
+
function_str
elements
+=
self
.
tool_utils
.
function_formatter
(
functions
)
else
:
elements
.
append
(
slot
)
return
elements
return
super
().
apply
(
content
=
function_str
)
@
dataclass
@
dataclass
...
@@ -135,7 +138,7 @@ class ToolFormatter(Formatter):
...
@@ -135,7 +138,7 @@ class ToolFormatter(Formatter):
tools
=
json
.
loads
(
content
)
tools
=
json
.
loads
(
content
)
return
[
self
.
tool_utils
.
tool_formatter
(
tools
)
if
len
(
tools
)
!=
0
else
""
]
return
[
self
.
tool_utils
.
tool_formatter
(
tools
)
if
len
(
tools
)
!=
0
else
""
]
except
json
.
JSONDecodeError
:
except
json
.
JSONDecodeError
:
raise
RuntimeError
(
f
"Invalid JSON format in tool description:
{
str
([
content
])
}
"
)
# flat string
raise
RuntimeError
(
f
"Invalid JSON format in tool description:
{
str
([
content
])
}
.
"
)
# flat string
@
override
@
override
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
...
...
src/llamafactory/data/loader.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -22,10 +22,17 @@ from datasets import DatasetDict, load_dataset, load_from_disk
...
@@ -22,10 +22,17 @@ from datasets import DatasetDict, load_dataset, load_from_disk
from
..extras
import
logging
from
..extras
import
logging
from
..extras.constants
import
FILEEXT2TYPE
from
..extras.constants
import
FILEEXT2TYPE
from
..extras.misc
import
check_version
,
has_tokenized_data
from
..extras.misc
import
check_version
,
has_tokenized_data
from
.
align
er
import
align_dataset
from
.
convert
er
import
align_dataset
from
.data_utils
import
merge_dataset
,
split_dataset
from
.data_utils
import
merge_dataset
,
split_dataset
from
.parser
import
get_dataset_list
from
.parser
import
get_dataset_list
from
.preprocess
import
get_preprocess_and_print_func
from
.processor
import
(
FeedbackDatasetProcessor
,
PackedSupervisedDatasetProcessor
,
PairwiseDatasetProcessor
,
PretrainDatasetProcessor
,
SupervisedDatasetProcessor
,
UnsupervisedDatasetProcessor
,
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -35,6 +42,7 @@ if TYPE_CHECKING:
...
@@ -35,6 +42,7 @@ if TYPE_CHECKING:
from
..hparams
import
DataArguments
,
ModelArguments
from
..hparams
import
DataArguments
,
ModelArguments
from
.data_utils
import
DatasetModule
from
.data_utils
import
DatasetModule
from
.parser
import
DatasetAttr
from
.parser
import
DatasetAttr
from
.processor
import
DatasetProcessor
from
.template
import
Template
from
.template
import
Template
...
@@ -156,21 +164,67 @@ def _get_merged_dataset(
...
@@ -156,21 +164,67 @@ def _get_merged_dataset(
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]:
merge
:
bool
=
True
,
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
Dict
[
str
,
"Dataset"
]]]:
r
"""
r
"""
G
ets the merged datasets in the standard format.
R
et
urn
s the merged datasets in the standard format.
"""
"""
if
dataset_names
is
None
:
if
dataset_names
is
None
:
return
None
return
None
datasets
=
[]
datasets
=
{}
for
dataset_
attr
in
get_dataset_list
(
dataset_names
,
data_args
.
dataset_dir
):
for
dataset_
name
,
dataset_attr
in
zip
(
dataset_names
,
get_dataset_list
(
dataset_names
,
data_args
.
dataset_dir
)
)
:
if
(
stage
==
"rm"
and
dataset_attr
.
ranking
is
False
)
or
(
stage
!=
"rm"
and
dataset_attr
.
ranking
is
True
):
if
(
stage
==
"rm"
and
dataset_attr
.
ranking
is
False
)
or
(
stage
!=
"rm"
and
dataset_attr
.
ranking
is
True
):
raise
ValueError
(
"The dataset is not applicable in the current training stage."
)
raise
ValueError
(
"The dataset is not applicable in the current training stage."
)
datasets
.
append
(
_load_single_dataset
(
dataset_attr
,
model_args
,
data_args
,
training_args
))
datasets
[
dataset_name
]
=
_load_single_dataset
(
dataset_attr
,
model_args
,
data_args
,
training_args
)
if
merge
:
return
merge_dataset
(
list
(
datasets
.
values
()),
data_args
,
seed
=
training_args
.
seed
)
else
:
return
datasets
def
_get_dataset_processor
(
data_args
:
"DataArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
do_generate
:
bool
=
False
,
)
->
"DatasetProcessor"
:
r
"""
Returns the corresponding dataset processor.
"""
if
stage
==
"pt"
:
dataset_processor_class
=
PretrainDatasetProcessor
elif
stage
==
"sft"
and
not
do_generate
:
if
data_args
.
packing
:
if
data_args
.
neat_packing
:
# hack datasets to have int32 attention mask
from
datasets.arrow_writer
import
OptimizedTypedSequence
,
TypedSequence
def
__init__
(
self
,
data
,
**
kwargs
):
return
TypedSequence
.
__init__
(
self
,
data
,
type
=
kwargs
.
pop
(
"type"
,
None
),
try_type
=
kwargs
.
pop
(
"try_type"
,
None
),
optimized_int_type
=
kwargs
.
pop
(
"optimized_int_type"
,
None
),
)
OptimizedTypedSequence
.
__init__
=
__init__
dataset_processor_class
=
PackedSupervisedDatasetProcessor
else
:
dataset_processor_class
=
SupervisedDatasetProcessor
elif
stage
==
"rm"
:
dataset_processor_class
=
PairwiseDatasetProcessor
elif
stage
==
"kto"
:
dataset_processor_class
=
FeedbackDatasetProcessor
else
:
dataset_processor_class
=
UnsupervisedDatasetProcessor
return
merge_
dataset
(
datasets
,
data_args
,
seed
=
training_args
.
seed
)
return
dataset
_processor_class
(
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
data_args
=
data_args
)
def
_get_preprocessed_dataset
(
def
_get_preprocessed_dataset
(
...
@@ -189,7 +243,7 @@ def _get_preprocessed_dataset(
...
@@ -189,7 +243,7 @@ def _get_preprocessed_dataset(
if
dataset
is
None
:
if
dataset
is
None
:
return
None
return
None
preprocess_func
,
print_function
=
get_
preprocess_and_print_func
(
dataset_processor
=
_
get_
dataset_processor
(
data_args
,
stage
,
template
,
tokenizer
,
processor
,
do_generate
=
(
training_args
.
predict_with_generate
and
is_eval
)
data_args
,
stage
,
template
,
tokenizer
,
processor
,
do_generate
=
(
training_args
.
predict_with_generate
and
is_eval
)
)
)
column_names
=
list
(
next
(
iter
(
dataset
)).
keys
())
column_names
=
list
(
next
(
iter
(
dataset
)).
keys
())
...
@@ -202,7 +256,7 @@ def _get_preprocessed_dataset(
...
@@ -202,7 +256,7 @@ def _get_preprocessed_dataset(
)
)
dataset
=
dataset
.
map
(
dataset
=
dataset
.
map
(
preprocess_
func
,
dataset_processor
.
preprocess_
dataset
,
batched
=
True
,
batched
=
True
,
batch_size
=
data_args
.
preprocessing_batch_size
,
batch_size
=
data_args
.
preprocessing_batch_size
,
remove_columns
=
column_names
,
remove_columns
=
column_names
,
...
@@ -212,7 +266,7 @@ def _get_preprocessed_dataset(
...
@@ -212,7 +266,7 @@ def _get_preprocessed_dataset(
if
training_args
.
should_log
:
if
training_args
.
should_log
:
try
:
try
:
print
(
"eval example:"
if
is_eval
else
"training example:"
)
print
(
"eval example:"
if
is_eval
else
"training example:"
)
print_function
(
next
(
iter
(
dataset
)))
dataset_processor
.
print_data_example
(
next
(
iter
(
dataset
)))
except
StopIteration
:
except
StopIteration
:
if
stage
==
"pt"
:
if
stage
==
"pt"
:
raise
RuntimeError
(
"Cannot find sufficient samples, consider increasing dataset size."
)
raise
RuntimeError
(
"Cannot find sufficient samples, consider increasing dataset size."
)
...
@@ -234,7 +288,7 @@ def get_dataset(
...
@@ -234,7 +288,7 @@ def get_dataset(
r
"""
r
"""
Gets the train dataset and optionally gets the evaluation dataset.
Gets the train dataset and optionally gets the evaluation dataset.
"""
"""
# Load tokenized dataset
# Load tokenized dataset
if path exists
if
data_args
.
tokenized_path
is
not
None
:
if
data_args
.
tokenized_path
is
not
None
:
if
has_tokenized_data
(
data_args
.
tokenized_path
):
if
has_tokenized_data
(
data_args
.
tokenized_path
):
logger
.
warning_rank0
(
"Loading dataset from disk will ignore other data arguments."
)
logger
.
warning_rank0
(
"Loading dataset from disk will ignore other data arguments."
)
...
@@ -249,7 +303,7 @@ def get_dataset(
...
@@ -249,7 +303,7 @@ def get_dataset(
if
"validation"
in
tokenized_data
:
if
"validation"
in
tokenized_data
:
dataset_module
[
"eval_dataset"
]
=
tokenized_data
[
"validation"
]
dataset_module
[
"eval_dataset"
]
=
tokenized_data
[
"validation"
]
else
:
#
D
ataset
else
:
#
single d
ataset
dataset_module
[
"train_dataset"
]
=
tokenized_data
dataset_module
[
"train_dataset"
]
=
tokenized_data
if
data_args
.
streaming
:
if
data_args
.
streaming
:
...
@@ -263,15 +317,23 @@ def get_dataset(
...
@@ -263,15 +317,23 @@ def get_dataset(
# Load and preprocess dataset
# Load and preprocess dataset
with
training_args
.
main_process_first
(
desc
=
"load dataset"
):
with
training_args
.
main_process_first
(
desc
=
"load dataset"
):
dataset
=
_get_merged_dataset
(
data_args
.
dataset
,
model_args
,
data_args
,
training_args
,
stage
)
dataset
=
_get_merged_dataset
(
data_args
.
dataset
,
model_args
,
data_args
,
training_args
,
stage
)
eval_dataset
=
_get_merged_dataset
(
data_args
.
eval_dataset
,
model_args
,
data_args
,
training_args
,
stage
)
eval_dataset
=
_get_merged_dataset
(
data_args
.
eval_dataset
,
model_args
,
data_args
,
training_args
,
stage
,
merge
=
training_args
.
do_predict
)
with
training_args
.
main_process_first
(
desc
=
"pre-process dataset"
):
with
training_args
.
main_process_first
(
desc
=
"pre-process dataset"
):
dataset
=
_get_preprocessed_dataset
(
dataset
=
_get_preprocessed_dataset
(
dataset
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
False
dataset
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
False
)
)
eval_dataset
=
_get_preprocessed_dataset
(
if
isinstance
(
eval_dataset
,
dict
):
eval_dataset
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
True
for
eval_name
,
eval_data
in
eval_dataset
.
items
():
)
eval_dataset
[
eval_name
]
=
_get_preprocessed_dataset
(
eval_data
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
True
)
else
:
eval_dataset
=
_get_preprocessed_dataset
(
eval_dataset
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
True
)
if
data_args
.
val_size
>
1e-6
:
if
data_args
.
val_size
>
1e-6
:
dataset_dict
=
split_dataset
(
dataset
,
data_args
,
seed
=
training_args
.
seed
)
dataset_dict
=
split_dataset
(
dataset
,
data_args
,
seed
=
training_args
.
seed
)
...
@@ -284,17 +346,20 @@ def get_dataset(
...
@@ -284,17 +346,20 @@ def get_dataset(
dataset_dict
[
"train"
]
=
dataset
dataset_dict
[
"train"
]
=
dataset
if
eval_dataset
is
not
None
:
if
eval_dataset
is
not
None
:
if
data_args
.
streaming
:
if
isinstance
(
eval_dataset
,
dict
):
eval_dataset
=
eval_dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
training_args
.
seed
)
dataset_dict
.
update
({
f
"validation_
{
name
}
"
:
data
for
name
,
data
in
eval_dataset
.
items
()})
else
:
if
data_args
.
streaming
:
eval_dataset
=
eval_dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
training_args
.
seed
)
dataset_dict
[
"validation"
]
=
eval_dataset
dataset_dict
[
"validation"
]
=
eval_dataset
dataset_dict
=
DatasetDict
(
dataset_dict
)
dataset_dict
=
DatasetDict
(
dataset_dict
)
if
data_args
.
tokenized_path
is
not
None
:
if
data_args
.
tokenized_path
is
not
None
:
# save tokenized dataset to disk and exit
if
training_args
.
should_save
:
if
training_args
.
should_save
:
dataset_dict
.
save_to_disk
(
data_args
.
tokenized_path
)
dataset_dict
.
save_to_disk
(
data_args
.
tokenized_path
)
logger
.
info_rank0
(
f
"Tokenized dataset saved at
{
data_args
.
tokenized_path
}
."
)
logger
.
info_rank0
(
f
"Tokenized dataset
is
saved at
{
data_args
.
tokenized_path
}
."
)
logger
.
info_rank0
(
f
"Please restart the training with `tokenized_path:
{
data_args
.
tokenized_path
}
`."
)
logger
.
info_rank0
(
f
"Please restart the training with `tokenized_path:
{
data_args
.
tokenized_path
}
`."
)
sys
.
exit
(
0
)
sys
.
exit
(
0
)
...
@@ -305,5 +370,13 @@ def get_dataset(
...
@@ -305,5 +370,13 @@ def get_dataset(
if
"validation"
in
dataset_dict
:
if
"validation"
in
dataset_dict
:
dataset_module
[
"eval_dataset"
]
=
dataset_dict
[
"validation"
]
dataset_module
[
"eval_dataset"
]
=
dataset_dict
[
"validation"
]
else
:
eval_dataset
=
{}
for
key
in
dataset_dict
.
keys
():
if
key
.
startswith
(
"validation_"
):
eval_dataset
[
key
[
len
(
"validation_"
)
:]]
=
dataset_dict
[
key
]
if
len
(
eval_dataset
):
dataset_module
[
"eval_dataset"
]
=
eval_dataset
return
dataset_module
return
dataset_module
src/llamafactory/data/mm_plugin.py
View file @
317a82e2
import
inspect
import
math
import
math
import
re
import
re
from
copy
import
deepcopy
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
io
import
BytesIO
from
io
import
BytesIO
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
TypedDict
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
TypedDict
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
transformers.image_utils
import
get_image_size
,
to_numpy_array
from
transformers.image_utils
import
get_image_size
,
to_numpy_array
from
typing_extensions
import
override
from
typing_extensions
import
override
from
..extras.constants
import
IGNORE_INDEX
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IGNORE_INDEX
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.packages
import
is_pillow_available
,
is_pyav_available
,
is_transformers_version_greater_than
from
..extras.packages
import
(
is_librosa_available
,
is_pillow_available
,
is_pyav_available
,
is_transformers_version_greater_than
,
)
if
is_librosa_available
():
import
librosa
if
is_pillow_available
():
if
is_pillow_available
():
...
@@ -31,7 +42,9 @@ if is_transformers_version_greater_than("4.45.0"):
...
@@ -31,7 +42,9 @@ if is_transformers_version_greater_than("4.45.0"):
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
av.stream
import
Stream
from
av.stream
import
Stream
from
numpy.typing
import
NDArray
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
transformers.feature_extraction_sequence_utils
import
SequenceFeatureExtractor
from
transformers.image_processing_utils
import
BaseImageProcessor
from
transformers.image_processing_utils
import
BaseImageProcessor
class
EncodedImage
(
TypedDict
):
class
EncodedImage
(
TypedDict
):
...
@@ -40,6 +53,7 @@ if TYPE_CHECKING:
...
@@ -40,6 +53,7 @@ if TYPE_CHECKING:
ImageInput
=
Union
[
str
,
bytes
,
EncodedImage
,
ImageObject
]
ImageInput
=
Union
[
str
,
bytes
,
EncodedImage
,
ImageObject
]
VideoInput
=
str
VideoInput
=
str
AudioInput
=
Union
[
str
,
NDArray
]
def
_get_paligemma_token_type_ids
(
def
_get_paligemma_token_type_ids
(
...
@@ -59,20 +73,25 @@ def _get_paligemma_token_type_ids(
...
@@ -59,20 +73,25 @@ def _get_paligemma_token_type_ids(
return
batch_token_type_ids
return
batch_token_type_ids
class
BasePlugin
:
@
dataclass
def
__init__
(
self
,
image_token
:
Optional
[
str
],
video_token
:
Optional
[
str
])
->
None
:
class
MMPluginMixin
:
self
.
image_token
=
image_token
image_token
:
Optional
[
str
]
self
.
video_token
=
video_token
video_token
:
Optional
[
str
]
self
.
expand_mm_tokens
=
True
audio_token
:
Optional
[
str
]
expand_mm_tokens
:
bool
=
True
def
_validate_input
(
def
_validate_input
(
self
,
self
,
processor
:
Optional
[
"ProcessorMixin"
],
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
)
->
None
:
)
->
None
:
r
"""
r
"""
Validates if this model accepts the input modalities.
Validates if this model accepts the input modalities.
"""
"""
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
,
None
)
feature_extractor
:
"SequenceFeatureExtractor"
=
getattr
(
processor
,
"feature_extractor"
,
None
)
if
len
(
images
)
!=
0
and
self
.
image_token
is
None
:
if
len
(
images
)
!=
0
and
self
.
image_token
is
None
:
raise
ValueError
(
raise
ValueError
(
"This model does not support image input. Please check whether the correct `template` is used."
"This model does not support image input. Please check whether the correct `template` is used."
...
@@ -83,31 +102,54 @@ class BasePlugin:
...
@@ -83,31 +102,54 @@ class BasePlugin:
"This model does not support video input. Please check whether the correct `template` is used."
"This model does not support video input. Please check whether the correct `template` is used."
)
)
def
_preprocess_image
(
self
,
image
:
"ImageObject"
,
**
kwargs
)
->
"ImageObject"
:
if
len
(
audios
)
!=
0
and
self
.
audio_token
is
None
:
raise
ValueError
(
"This model does not support audio input. Please check whether the correct `template` is used."
)
if
self
.
image_token
is
not
None
and
processor
is
None
:
raise
ValueError
(
"Processor was not found, please check and update your processor config."
)
if
self
.
image_token
is
not
None
and
image_processor
is
None
:
raise
ValueError
(
"Image processor was not found, please check and update your processor config."
)
if
self
.
audio_token
is
not
None
and
feature_extractor
is
None
:
raise
ValueError
(
"Audio feature extractor was not found, please check and update your processor config."
)
def
_preprocess_image
(
self
,
image
:
"ImageObject"
,
image_max_pixels
:
int
,
image_min_pixels
:
int
,
**
kwargs
)
->
"ImageObject"
:
r
"""
r
"""
Pre-processes a single image.
Pre-processes a single image.
"""
"""
image_resolution
:
int
=
kwargs
.
get
(
"image_resolution"
)
if
(
image
.
width
*
image
.
height
)
>
image_max_pixels
:
if
(
image
.
width
*
image
.
height
)
>
image_resolution
:
resize_factor
=
math
.
sqrt
(
image_max_pixels
/
(
image
.
width
*
image
.
height
))
resize_factor
=
math
.
sqrt
(
image_resolution
/
(
image
.
width
*
image
.
height
))
width
,
height
=
int
(
image
.
width
*
resize_factor
),
int
(
image
.
height
*
resize_factor
)
width
,
height
=
int
(
image
.
width
*
resize_factor
),
int
(
image
.
height
*
resize_factor
)
image
=
image
.
resize
((
width
,
height
),
resample
=
Image
.
NEAREST
)
image
=
image
.
resize
((
width
,
height
))
if
(
image
.
width
*
image
.
height
)
<
image_min_pixels
:
resize_factor
=
math
.
sqrt
(
image_min_pixels
/
(
image
.
width
*
image
.
height
))
width
,
height
=
int
(
image
.
width
*
resize_factor
),
int
(
image
.
height
*
resize_factor
)
image
=
image
.
resize
((
width
,
height
))
if
image
.
mode
!=
"RGB"
:
if
image
.
mode
!=
"RGB"
:
image
=
image
.
convert
(
"RGB"
)
image
=
image
.
convert
(
"RGB"
)
return
image
return
image
def
_get_video_sample_frames
(
self
,
video_stream
:
"Stream"
,
**
kwargs
)
->
int
:
def
_get_video_sample_indices
(
self
,
video_stream
:
"Stream"
,
video_fps
:
float
,
video_maxlen
:
int
,
**
kwargs
)
->
List
[
int
]:
r
"""
r
"""
Computes video sample
fram
es according to fps.
Computes video sample
indic
es according to fps.
"""
"""
video_fps
:
float
=
kwargs
.
get
(
"video_fps"
)
video_maxlen
:
int
=
kwargs
.
get
(
"video_maxlen"
)
total_frames
=
video_stream
.
frames
total_frames
=
video_stream
.
frames
sample_frames
=
float
(
video_stream
.
duration
*
video_stream
.
time_base
)
*
video_fps
if
total_frames
==
0
:
# infinite video
return
np
.
linspace
(
0
,
video_maxlen
-
1
,
video_maxlen
).
astype
(
np
.
int32
)
sample_frames
=
math
.
floor
(
float
(
video_stream
.
duration
*
video_stream
.
time_base
)
*
video_fps
)
sample_frames
=
min
(
total_frames
,
video_maxlen
,
sample_frames
)
sample_frames
=
min
(
total_frames
,
video_maxlen
,
sample_frames
)
return
math
.
floor
(
sample_frames
)
return
np
.
linspace
(
0
,
total_frames
-
1
,
sample_frames
).
astype
(
np
.
int32
)
def
_regularize_images
(
self
,
images
:
Sequence
[
"ImageInput"
],
**
kwargs
)
->
List
[
"ImageObject"
]:
def
_regularize_images
(
self
,
images
:
Sequence
[
"ImageInput"
],
**
kwargs
)
->
List
[
"ImageObject"
]:
r
"""
r
"""
...
@@ -126,7 +168,7 @@ class BasePlugin:
...
@@ -126,7 +168,7 @@ class BasePlugin:
image
=
Image
.
open
(
image
[
"path"
])
image
=
Image
.
open
(
image
[
"path"
])
if
not
isinstance
(
image
,
ImageObject
):
if
not
isinstance
(
image
,
ImageObject
):
raise
ValueError
(
f
"Expect input is a list of
I
mages, but got
{
type
(
image
)
}
."
)
raise
ValueError
(
f
"Expect input is a list of
i
mages, but got
{
type
(
image
)
}
."
)
results
.
append
(
self
.
_preprocess_image
(
image
,
**
kwargs
))
results
.
append
(
self
.
_preprocess_image
(
image
,
**
kwargs
))
...
@@ -140,9 +182,7 @@ class BasePlugin:
...
@@ -140,9 +182,7 @@ class BasePlugin:
for
video
in
videos
:
for
video
in
videos
:
container
=
av
.
open
(
video
,
"r"
)
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
total_frames
=
video_stream
.
frames
sample_indices
=
self
.
_get_video_sample_indices
(
video_stream
,
**
kwargs
)
sample_frames
=
self
.
_get_video_sample_frames
(
video_stream
,
**
kwargs
)
sample_indices
=
np
.
linspace
(
0
,
total_frames
-
1
,
sample_frames
).
astype
(
np
.
int32
)
frames
:
List
[
"ImageObject"
]
=
[]
frames
:
List
[
"ImageObject"
]
=
[]
container
.
seek
(
0
)
container
.
seek
(
0
)
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
...
@@ -154,10 +194,27 @@ class BasePlugin:
...
@@ -154,10 +194,27 @@ class BasePlugin:
return
results
return
results
def
_regularize_audios
(
self
,
audios
:
Sequence
[
"AudioInput"
],
sampling_rate
:
float
,
**
kwargs
)
->
List
[
"NDArray"
]:
r
"""
Regularizes audios to avoid error. Including reading and resampling.
"""
results
=
[]
for
audio
in
audios
:
if
isinstance
(
audio
,
str
):
audio
=
librosa
.
load
(
audio
,
sr
=
sampling_rate
)[
0
]
if
not
isinstance
(
audio
,
np
.
ndarray
):
raise
ValueError
(
f
"Expect input is a list of audios, but got
{
type
(
audio
)
}
."
)
results
.
append
(
audio
)
return
results
def
_get_mm_inputs
(
def
_get_mm_inputs
(
self
,
self
,
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
"ProcessorMixin"
,
processor
:
"ProcessorMixin"
,
)
->
Dict
[
str
,
"torch.Tensor"
]:
)
->
Dict
[
str
,
"torch.Tensor"
]:
r
"""
r
"""
...
@@ -172,47 +229,65 @@ class BasePlugin:
...
@@ -172,47 +229,65 @@ class BasePlugin:
It holds num_patches == torch.prod(image_grid_thw)
It holds num_patches == torch.prod(image_grid_thw)
"""
"""
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
,
None
)
video_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"video_processor"
,
image_processor
)
video_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"video_processor"
,
image_processor
)
input_dict
=
{
"images"
:
None
}
# default key
feature_extractor
:
"SequenceFeatureExtractor"
=
getattr
(
processor
,
"feature_extractor"
,
None
)
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
=
self
.
_regularize_images
(
images
,
images
,
image_resolution
=
getattr
(
processor
,
"image_resolution"
,
512
*
512
),
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)
)
input
_dict
[
"images"
]
=
images
mm_
input
s
.
update
(
image_processor
(
images
,
return_tensors
=
"pt"
))
if
len
(
videos
)
!=
0
:
if
len
(
videos
)
!=
0
:
videos
=
self
.
_regularize_videos
(
videos
=
self
.
_regularize_videos
(
videos
,
videos
,
image_resolution
=
getattr
(
processor
,
"video_resolution"
,
128
*
128
),
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
image_min_pixels
=
getattr
(
processor
,
"video_min_pixels"
,
16
*
16
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
64
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)
)
input_dict
[
"videos"
]
=
videos
if
"videos"
in
inspect
.
signature
(
video_processor
.
preprocess
).
parameters
:
# for qwen2_vl and video_llava
mm_inputs
.
update
(
video_processor
(
images
=
None
,
videos
=
videos
,
return_tensors
=
"pt"
))
mm_inputs
=
{}
else
:
# for llava_next_video
if
image_processor
!=
video_processor
:
mm_inputs
.
update
(
video_processor
(
videos
,
return_tensors
=
"pt"
))
if
input_dict
.
get
(
"images"
)
is
not
None
:
mm_inputs
.
update
(
image_processor
(
input_dict
[
"images"
],
return_tensors
=
"pt"
))
if
len
(
audios
)
!=
0
:
if
input_dict
.
get
(
"videos"
)
is
not
None
:
audios
=
self
.
_regularize_audios
(
mm_inputs
.
update
(
video_processor
(
input_dict
[
"videos"
],
return_tensors
=
"pt"
))
audios
,
elif
input_dict
.
get
(
"images"
)
is
not
None
or
input_dict
.
get
(
"videos"
)
is
not
None
:
# same processor (qwen2-vl)
sampling_rate
=
getattr
(
feature_extractor
,
"sampling_rate"
,
16000
),
mm_inputs
.
update
(
image_processor
(
**
input_dict
,
return_tensors
=
"pt"
))
)
mm_inputs
.
update
(
feature_extractor
(
audios
,
sampling_rate
=
getattr
(
feature_extractor
,
"sampling_rate"
,
16000
),
return_attention_mask
=
True
,
padding
=
"max_length"
,
return_tensors
=
"pt"
,
)
)
mm_inputs
[
"feature_attention_mask"
]
=
mm_inputs
.
pop
(
"attention_mask"
)
# prevent conflicts
return
mm_inputs
return
mm_inputs
@
dataclass
class
BasePlugin
(
MMPluginMixin
):
def
process_messages
(
def
process_messages
(
self
,
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
)
->
List
[
Dict
[
str
,
str
]]:
r
"""
r
"""
Pre-processes input messages before tokenization for VLMs.
Pre-processes input messages before tokenization for VLMs.
"""
"""
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
messages
return
messages
def
process_token_ids
(
def
process_token_ids
(
...
@@ -221,21 +296,24 @@ class BasePlugin:
...
@@ -221,21 +296,24 @@ class BasePlugin:
labels
:
Optional
[
List
[
int
]],
labels
:
Optional
[
List
[
int
]],
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Tuple
[
List
[
int
],
Optional
[
List
[
int
]]]:
)
->
Tuple
[
List
[
int
],
Optional
[
List
[
int
]]]:
r
"""
r
"""
Pre-processes token ids after tokenization for VLMs.
Pre-processes token ids after tokenization for VLMs.
"""
"""
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
input_ids
,
labels
return
input_ids
,
labels
def
get_mm_inputs
(
def
get_mm_inputs
(
self
,
self
,
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
...
@@ -247,13 +325,15 @@ class BasePlugin:
...
@@ -247,13 +325,15 @@ class BasePlugin:
videos: a list of video inputs, shape (num_videos,)
videos: a list of video inputs, shape (num_videos,)
imglens: number of images in each sample, shape (batch_size,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
audlens: number of audios in each sample, shape (batch_size,)
batch_ids: token ids of input samples, shape (batch_size, seq_len)
batch_ids: token ids of input samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos
processor: a processor for pre-processing images and videos
"""
"""
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
{}
return
{}
@
dataclass
class
LlavaPlugin
(
BasePlugin
):
class
LlavaPlugin
(
BasePlugin
):
@
override
@
override
def
process_messages
(
def
process_messages
(
...
@@ -261,9 +341,10 @@ class LlavaPlugin(BasePlugin):
...
@@ -261,9 +341,10 @@ class LlavaPlugin(BasePlugin):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
num_image_tokens
=
0
image_seqlen
=
getattr
(
processor
,
"image_seqlen"
)
if
self
.
expand_mm_tokens
else
1
image_seqlen
=
getattr
(
processor
,
"image_seqlen"
)
if
self
.
expand_mm_tokens
else
1
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
...
@@ -285,15 +366,18 @@ class LlavaPlugin(BasePlugin):
...
@@ -285,15 +366,18 @@ class LlavaPlugin(BasePlugin):
self
,
self
,
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
@
dataclass
class
LlavaNextPlugin
(
BasePlugin
):
class
LlavaNextPlugin
(
BasePlugin
):
@
override
@
override
def
process_messages
(
def
process_messages
(
...
@@ -301,16 +385,15 @@ class LlavaNextPlugin(BasePlugin):
...
@@ -301,16 +385,15 @@ class LlavaNextPlugin(BasePlugin):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"image_sizes"
in
mm_inputs
:
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
])
if
"pixel_values"
in
mm_inputs
:
if
"pixel_values"
in
mm_inputs
:
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
].
tolist
())
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
[
"pixel_values"
][
0
][
0
]))
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
[
"pixel_values"
][
0
][
0
]))
for
message
in
messages
:
for
message
in
messages
:
...
@@ -319,7 +402,7 @@ class LlavaNextPlugin(BasePlugin):
...
@@ -319,7 +402,7 @@ class LlavaNextPlugin(BasePlugin):
if
self
.
expand_mm_tokens
:
if
self
.
expand_mm_tokens
:
orig_height
,
orig_width
=
next
(
image_sizes
)
orig_height
,
orig_width
=
next
(
image_sizes
)
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
if
getattr
(
processor
,
"vision_feature_select_strategy"
)
==
"default"
:
if
getattr
(
processor
,
"vision_feature_select_strategy"
,
"default"
)
==
"default"
:
image_seqlen
-=
1
image_seqlen
-=
1
else
:
else
:
image_seqlen
=
1
image_seqlen
=
1
...
@@ -339,15 +422,18 @@ class LlavaNextPlugin(BasePlugin):
...
@@ -339,15 +422,18 @@ class LlavaNextPlugin(BasePlugin):
self
,
self
,
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
@
dataclass
class
LlavaNextVideoPlugin
(
BasePlugin
):
class
LlavaNextVideoPlugin
(
BasePlugin
):
@
override
@
override
def
process_messages
(
def
process_messages
(
...
@@ -355,14 +441,15 @@ class LlavaNextVideoPlugin(BasePlugin):
...
@@ -355,14 +441,15 @@ class LlavaNextVideoPlugin(BasePlugin):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"pixel_values"
in
mm_inputs
:
if
"pixel_values"
in
mm_inputs
:
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
])
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
]
.
tolist
()
)
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
[
"pixel_values"
][
0
][
0
]))
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
[
"pixel_values"
][
0
][
0
]))
for
message
in
messages
:
for
message
in
messages
:
content
=
message
[
"content"
]
content
=
message
[
"content"
]
...
@@ -370,7 +457,7 @@ class LlavaNextVideoPlugin(BasePlugin):
...
@@ -370,7 +457,7 @@ class LlavaNextVideoPlugin(BasePlugin):
if
self
.
expand_mm_tokens
:
if
self
.
expand_mm_tokens
:
orig_height
,
orig_width
=
next
(
image_sizes
)
orig_height
,
orig_width
=
next
(
image_sizes
)
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
if
getattr
(
processor
,
"vision_feature_select_strategy"
)
==
"default"
:
if
getattr
(
processor
,
"vision_feature_select_strategy"
,
"default"
)
==
"default"
:
image_seqlen
-=
1
image_seqlen
-=
1
else
:
else
:
image_seqlen
=
1
image_seqlen
=
1
...
@@ -381,12 +468,15 @@ class LlavaNextVideoPlugin(BasePlugin):
...
@@ -381,12 +468,15 @@ class LlavaNextVideoPlugin(BasePlugin):
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
if
"pixel_values_videos"
in
mm_inputs
:
if
"pixel_values_videos"
in
mm_inputs
:
pixel_values_video
=
to_numpy_array
(
mm_inputs
.
get
(
"pixel_values_videos"
)[
0
])
if
self
.
expand_mm_tokens
:
height
,
width
=
get_image_size
(
pixel_values_video
[
0
])
pixel_values_video
=
to_numpy_array
(
mm_inputs
.
get
(
"pixel_values_videos"
)[
0
])
num_frames
=
pixel_values_video
.
shape
[
0
]
# frame dim is always after batch dim
height
,
width
=
get_image_size
(
pixel_values_video
[
0
])
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
num_frames
=
pixel_values_video
.
shape
[
0
]
# frame dim is always after batch dim
video_seqlen
=
image_seqlen
//
4
*
num_frames
# divide by 4 needed for avg pooling layer
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
video_seqlen
=
video_seqlen
if
self
.
expand_mm_tokens
else
1
video_seqlen
=
image_seqlen
//
4
*
num_frames
# divide by 4 needed for avg pooling layer
else
:
video_seqlen
=
1
for
message
in
messages
:
for
message
in
messages
:
content
=
message
[
"content"
]
content
=
message
[
"content"
]
while
VIDEO_PLACEHOLDER
in
content
:
while
VIDEO_PLACEHOLDER
in
content
:
...
@@ -408,15 +498,18 @@ class LlavaNextVideoPlugin(BasePlugin):
...
@@ -408,15 +498,18 @@ class LlavaNextVideoPlugin(BasePlugin):
self
,
self
,
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
@
dataclass
class
MiniCPMVPlugin
(
BasePlugin
):
class
MiniCPMVPlugin
(
BasePlugin
):
@
override
@
override
def
process_messages
(
def
process_messages
(
...
@@ -424,26 +517,27 @@ class MiniCPMVPlugin(BasePlugin):
...
@@ -424,26 +517,27 @@ class MiniCPMVPlugin(BasePlugin):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
num_image_tokens
,
num_video_tokens
,
num_audio_tokens
=
0
,
0
,
0
num_video_tokens
=
0
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
mm_inputs
=
{}
mm_inputs
=
{}
audio_inputs
=
{}
if
len
(
images
)
!=
0
and
len
(
videos
)
!=
0
:
if
len
(
images
)
!=
0
and
len
(
videos
)
!=
0
:
raise
ValueError
(
"MiniCPM-V model does not support input images and videos at the same time."
)
raise
ValueError
(
"MiniCPM-V model does not support input images and videos at the same time."
)
if
len
(
videos
)
!=
0
:
if
len
(
videos
)
!=
0
:
max_slice_nums
=
2
max_slice_nums
=
2
use_image_id
=
False
use_image_id
=
False
mm_inputs
=
self
.
_get_mm_inputs
([],
videos
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
([],
videos
,
[],
processor
)
else
:
else
:
max_slice_nums
=
image_processor
.
max_slice_nums
max_slice_nums
=
image_processor
.
max_slice_nums
use_image_id
=
image_processor
.
use_image_id
use_image_id
=
image_processor
.
use_image_id
for
message
in
messages
:
for
i
,
message
in
enumerate
(
messages
)
:
content
=
message
[
"content"
]
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
while
IMAGE_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
,
1
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
,
1
)
...
@@ -454,15 +548,24 @@ class MiniCPMVPlugin(BasePlugin):
...
@@ -454,15 +548,24 @@ class MiniCPMVPlugin(BasePlugin):
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{image}}"
*
video_seqlen
,
1
)
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{image}}"
*
video_seqlen
,
1
)
num_video_tokens
+=
1
num_video_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
"(<image>./</image>)"
)
while
AUDIO_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
"{{audio}}"
,
1
)
num_audio_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
"(<image>./</image>)"
).
replace
(
"{{audio}}"
,
"(<audio>./</audio>)"
)
if
num_image_tokens
>
0
:
if
num_image_tokens
>
0
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
[],
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
[],
[],
processor
)
if
num_audio_tokens
>
0
:
audio_inputs
=
self
.
_get_mm_inputs
([],
[],
audios
,
processor
,
ret_phs
=
True
)
if
mm_inputs
:
if
mm_inputs
:
pattern
=
"(<image>./</image>)"
pattern
=
"(<image>./</image>)"
image_sizes
=
mm_inputs
[
"image_sizes"
]
image_sizes
=
mm_inputs
[
"image_sizes"
]
idx
=
0
for
index
,
message
in
enumerate
(
messages
):
for
index
,
message
in
enumerate
(
messages
):
text
=
message
[
"content"
]
text
=
message
[
"content"
]
image_tags
=
re
.
findall
(
pattern
,
text
)
image_tags
=
re
.
findall
(
pattern
,
text
)
...
@@ -473,9 +576,26 @@ class MiniCPMVPlugin(BasePlugin):
...
@@ -473,9 +576,26 @@ class MiniCPMVPlugin(BasePlugin):
final_text
final_text
+
text_chunks
[
i
]
+
text_chunks
[
i
]
+
image_processor
.
get_slice_image_placeholder
(
+
image_processor
.
get_slice_image_placeholder
(
image_sizes
[
0
][
i
],
i
,
max_slice_nums
,
use_image_id
image_sizes
[
0
][
i
dx
],
i
dx
,
max_slice_nums
,
use_image_id
)
)
)
)
idx
+=
1
final_text
+=
text_chunks
[
-
1
]
messages
[
index
][
"content"
]
=
final_text
if
audio_inputs
:
pattern
=
"(<audio>./</audio>)"
idx
=
0
for
index
,
message
in
enumerate
(
messages
):
text
=
message
[
"content"
]
audio_tags
=
re
.
findall
(
pattern
,
text
)
text_chunks
=
text
.
split
(
pattern
)
final_text
=
""
for
i
in
range
(
len
(
audio_tags
)):
audio_placeholder
=
audio_inputs
[
"audio_phs"
][
0
][
idx
]
final_text
=
final_text
+
text_chunks
[
i
]
+
audio_placeholder
idx
+=
1
final_text
+=
text_chunks
[
-
1
]
final_text
+=
text_chunks
[
-
1
]
messages
[
index
][
"content"
]
=
final_text
messages
[
index
][
"content"
]
=
final_text
...
@@ -486,6 +606,9 @@ class MiniCPMVPlugin(BasePlugin):
...
@@ -486,6 +606,9 @@ class MiniCPMVPlugin(BasePlugin):
if
len
(
videos
)
!=
num_video_tokens
:
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
if
len
(
audios
)
!=
num_audio_tokens
:
raise
ValueError
(
f
"The number of audios does not match the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
return
messages
return
messages
@
override
@
override
...
@@ -493,15 +616,18 @@ class MiniCPMVPlugin(BasePlugin):
...
@@ -493,15 +616,18 @@ class MiniCPMVPlugin(BasePlugin):
self
,
self
,
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
"ProcessorMixin"
,
processor
:
"ProcessorMixin"
,
**
kwargs
,
**
kwargs
,
)
->
Dict
[
str
,
"torch.Tensor"
]:
)
->
Dict
[
str
,
"torch.Tensor"
]:
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
feature_extractor
:
"SequenceFeatureExtractor"
=
getattr
(
processor
,
"feature_extractor"
,
None
)
mm_inputs
=
{}
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
=
self
.
_regularize_images
(
images
,
images
,
image_resolution
=
getattr
(
processor
,
"image_resolution"
,
512
*
512
),
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)
)
if
"valid_image_nums_ls"
in
kwargs
:
if
"valid_image_nums_ls"
in
kwargs
:
valid_image_nums_ls
=
kwargs
[
"valid_image_nums_ls"
]
valid_image_nums_ls
=
kwargs
[
"valid_image_nums_ls"
]
...
@@ -521,13 +647,39 @@ class MiniCPMVPlugin(BasePlugin):
...
@@ -521,13 +647,39 @@ class MiniCPMVPlugin(BasePlugin):
if
len
(
videos
)
!=
0
:
if
len
(
videos
)
!=
0
:
videos
=
self
.
_regularize_videos
(
videos
=
self
.
_regularize_videos
(
videos
,
videos
,
image_resolution
=
getattr
(
processor
,
"video_resolution"
,
128
*
128
),
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
image_min_pixels
=
getattr
(
processor
,
"video_min_pixels"
,
16
*
16
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
64
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)
)
video_inputs
=
image_processor
(
videos
,
do_pad
=
True
,
max_slice_nums
=
2
,
return_tensors
=
"pt"
)
video_inputs
=
image_processor
(
videos
,
do_pad
=
True
,
max_slice_nums
=
2
,
return_tensors
=
"pt"
)
mm_inputs
.
update
(
video_inputs
)
mm_inputs
.
update
(
video_inputs
)
if
len
(
audios
)
!=
0
:
audios
=
self
.
_regularize_audios
(
audios
,
sampling_rate
=
getattr
(
feature_extractor
,
"sampling_rate"
,
16000
),
)
if
"valid_audio_nums_ls"
in
kwargs
:
valid_audio_nums_ls
=
kwargs
[
"valid_audio_nums_ls"
]
audios_ls
=
[]
idx
=
0
for
valid_audio_nums
in
valid_audio_nums_ls
:
audios_ls
.
append
(
audios
[
idx
:
idx
+
valid_audio_nums
])
idx
+=
valid_audio_nums
else
:
audios_ls
=
[
audios
]
audio_features
,
audio_feature_lens
,
audio_phs
=
processor
.
audio_feature_extract
(
audios_ls
,
chunk_input
=
True
,
sampling_rate
=
16000
,
)
audio_feature_lens
=
[
torch
.
tensor
(
audio_feature_len
)
for
audio_feature_len
in
audio_feature_lens
]
mm_inputs
.
update
({
"audio_features"
:
audio_features
,
"audio_feature_lens"
:
audio_feature_lens
})
if
kwargs
.
get
(
"ret_phs"
,
False
):
mm_inputs
.
update
({
"audio_phs"
:
audio_phs
})
return
mm_inputs
return
mm_inputs
@
override
@
override
...
@@ -535,15 +687,18 @@ class MiniCPMVPlugin(BasePlugin):
...
@@ -535,15 +687,18 @@ class MiniCPMVPlugin(BasePlugin):
self
,
self
,
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
# image bound
image_bounds_list
=
[]
image_bounds_list
=
[]
valid_image_nums_ls
=
[]
valid_image_nums_ls
=
[]
for
input_ids
in
batch_ids
:
for
i
,
input_ids
in
enumerate
(
batch_ids
)
:
input_ids_
=
torch
.
tensor
(
input_ids
)
input_ids_
=
torch
.
tensor
(
input_ids
)
start_cond
=
(
input_ids_
==
processor
.
tokenizer
.
im_start_id
)
|
(
start_cond
=
(
input_ids_
==
processor
.
tokenizer
.
im_start_id
)
|
(
input_ids_
==
processor
.
tokenizer
.
slice_start_id
input_ids_
==
processor
.
tokenizer
.
slice_start_id
...
@@ -552,21 +707,51 @@ class MiniCPMVPlugin(BasePlugin):
...
@@ -552,21 +707,51 @@ class MiniCPMVPlugin(BasePlugin):
image_start_tokens
=
torch
.
where
(
start_cond
)[
0
]
image_start_tokens
=
torch
.
where
(
start_cond
)[
0
]
image_start_tokens
+=
1
image_start_tokens
+=
1
image_end_tokens
=
torch
.
where
(
end_cond
)[
0
]
image_end_tokens
=
torch
.
where
(
end_cond
)[
0
]
valid_image_nums
=
max
(
len
(
image_start_tokens
),
len
(
image_end_tokens
))
valid_image_nums_ls
.
append
(
imglens
[
i
])
valid_image_nums_ls
.
append
(
valid_image_nums
)
image_bounds
=
torch
.
hstack
(
image_bounds
=
torch
.
hstack
(
[
[
image_start_tokens
[:
valid_image_nums
]
.
unsqueeze
(
-
1
),
image_start_tokens
.
unsqueeze
(
-
1
),
image_end_tokens
[:
valid_image_nums
]
.
unsqueeze
(
-
1
),
image_end_tokens
.
unsqueeze
(
-
1
),
]
]
)
)
image_bounds_list
.
append
(
image_bounds
)
image_bounds_list
.
append
(
image_bounds
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
,
valid_image_nums_ls
=
valid_image_nums_ls
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
[],
processor
,
valid_image_nums_ls
=
valid_image_nums_ls
)
if
"tgt_sizes"
not
in
mm_inputs
:
dummy_data
=
[
torch
.
empty
(
0
)
for
_
in
range
(
len
(
batch_ids
))]
mm_inputs
.
update
({
"tgt_sizes"
:
dummy_data
,
"pixel_values"
:
dummy_data
,
"image_sizes"
:
dummy_data
})
mm_inputs
.
update
({
"image_bound"
:
image_bounds_list
})
mm_inputs
.
update
({
"image_bound"
:
image_bounds_list
})
if
len
(
audios
)
>
0
:
# audio bound
audio_bounds_ls
=
[]
spk_bounds_ls
=
[]
valid_audio_nums_ls
=
[]
for
input_ids
,
audiolen
in
zip
(
batch_ids
,
audlens
):
input_ids_
=
torch
.
tensor
(
input_ids
)
audio_start_idx
=
torch
.
where
(
input_ids_
==
processor
.
tokenizer
.
audio_start_id
)[
0
]
audio_end_idx
=
torch
.
where
(
input_ids_
==
processor
.
tokenizer
.
audio_end_id
)[
0
]
assert
len
(
audio_start_idx
)
==
len
(
audio_end_idx
)
audio_bounds
=
torch
.
hstack
([(
audio_start_idx
+
1
).
unsqueeze
(
-
1
),
audio_end_idx
.
unsqueeze
(
-
1
)])
audio_bounds_ls
.
append
(
audio_bounds
)
valid_audio_nums_ls
.
append
(
audiolen
)
spk_start_idx
=
torch
.
where
(
input_ids_
==
processor
.
tokenizer
.
spk_start_id
)[
0
]
spk_end_idx
=
torch
.
where
(
input_ids_
==
processor
.
tokenizer
.
spk_end_id
)[
0
]
assert
len
(
spk_start_idx
)
==
len
(
spk_end_idx
)
spk_bounds
=
torch
.
hstack
([(
spk_start_idx
+
1
).
unsqueeze
(
-
1
),
spk_end_idx
.
unsqueeze
(
-
1
)])
spk_bounds_ls
.
append
(
spk_bounds
)
audio_inputs
=
self
.
_get_mm_inputs
([],
[],
audios
,
processor
,
valid_audio_nums_ls
=
valid_audio_nums_ls
)
mm_inputs
.
update
(
audio_inputs
)
mm_inputs
.
update
({
"audio_bounds"
:
audio_bounds_ls
,
"spk_bounds"
:
spk_bounds_ls
})
return
mm_inputs
return
mm_inputs
@
dataclass
class
MllamaPlugin
(
BasePlugin
):
class
MllamaPlugin
(
BasePlugin
):
@
override
@
override
def
process_messages
(
def
process_messages
(
...
@@ -574,9 +759,10 @@ class MllamaPlugin(BasePlugin):
...
@@ -574,9 +759,10 @@ class MllamaPlugin(BasePlugin):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
for
message
in
messages
:
...
@@ -594,8 +780,9 @@ class MllamaPlugin(BasePlugin):
...
@@ -594,8 +780,9 @@ class MllamaPlugin(BasePlugin):
self
,
self
,
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
"ProcessorMixin"
,
processor
:
"ProcessorMixin"
,
**
kwargs
,
imglens
:
List
[
int
]
,
)
->
Dict
[
str
,
"torch.Tensor"
]:
)
->
Dict
[
str
,
"torch.Tensor"
]:
r
"""
r
"""
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
...
@@ -609,43 +796,56 @@ class MllamaPlugin(BasePlugin):
...
@@ -609,43 +796,56 @@ class MllamaPlugin(BasePlugin):
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
"""
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
imglens
:
List
[
int
]
=
kwargs
[
"imglens"
]
mm_inputs
=
{}
images
=
self
.
_regularize_images
(
images
,
image_resolution
=
getattr
(
processor
,
"image_resolution"
,
512
*
512
))
if
len
(
images
)
>
0
:
batch_images
=
[]
images
=
self
.
_regularize_images
(
for
image_length
in
imglens
:
images
,
batch_images
.
append
(
images
[:
image_length
])
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
images
=
images
[
image_length
:]
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)
batch_images
=
[]
for
image_length
in
imglens
:
batch_images
.
append
(
images
[:
image_length
])
images
=
images
[
image_length
:]
mm_inputs
.
update
(
image_processor
(
batch_images
,
return_tensors
=
"pt"
))
return
image_processor
(
batch_images
,
return_tensors
=
"pt"
)
return
mm_inputs
@
override
def
get_mm_inputs
(
def
get_mm_inputs
(
self
,
self
,
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
,
imglens
=
imglens
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
,
imglens
)
num_tiles
=
mm_inputs
.
pop
(
"num_tiles"
)
if
mm_inputs
:
image_token_id
=
getattr
(
processor
,
"image_token_id"
)
num_tiles
=
mm_inputs
.
pop
(
"num_tiles"
)
max_image_tiles
=
getattr
(
processor
.
image_processor
,
"max_image_tiles"
)
image_token_id
=
getattr
(
processor
,
"image_token_id"
)
cross_attention_token_mask
=
[
max_image_tiles
=
getattr
(
processor
.
image_processor
,
"max_image_tiles"
)
get_cross_attention_token_mask
(
input_ids
,
image_token_id
)
for
input_ids
in
batch_ids
cross_attention_token_mask
=
[
]
get_cross_attention_token_mask
(
input_ids
,
image_token_id
)
for
input_ids
in
batch_ids
mm_inputs
[
"cross_attention_mask"
]
=
torch
.
from_numpy
(
]
convert_sparse_cross_attention_mask_to_dense
(
mm_inputs
[
"cross_attention_mask"
]
=
torch
.
from_numpy
(
cross_attention_token_mask
,
convert_sparse_cross_attention_mask_to_dense
(
num_tiles
=
num_tiles
,
cross_attention_token_mask
,
max_num_tiles
=
max_image_tiles
,
num_tiles
=
num_tiles
,
length
=
max
(
len
(
input_ids
)
for
input_ids
in
batch_ids
),
max_num_tiles
=
max_image_tiles
,
)
length
=
max
(
len
(
input_ids
)
for
input_ids
in
batch_ids
),
)
# shape: (batch_size, length, max_num_images, max_num_tiles)
)
)
# shape: (batch_size, length, max_num_images, max_num_tiles)
return
mm_inputs
return
mm_inputs
@
dataclass
class
PaliGemmaPlugin
(
BasePlugin
):
class
PaliGemmaPlugin
(
BasePlugin
):
@
override
@
override
def
process_messages
(
def
process_messages
(
...
@@ -653,9 +853,10 @@ class PaliGemmaPlugin(BasePlugin):
...
@@ -653,9 +853,10 @@ class PaliGemmaPlugin(BasePlugin):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
for
message
in
messages
:
...
@@ -678,10 +879,11 @@ class PaliGemmaPlugin(BasePlugin):
...
@@ -678,10 +879,11 @@ class PaliGemmaPlugin(BasePlugin):
labels
:
Optional
[
List
[
int
]],
labels
:
Optional
[
List
[
int
]],
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Tuple
[
List
[
int
],
Optional
[
List
[
int
]]]:
)
->
Tuple
[
List
[
int
],
Optional
[
List
[
int
]]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_images
=
len
(
images
)
num_images
=
len
(
images
)
image_seqlen
=
num_images
*
getattr
(
processor
,
"image_seqlen"
)
if
self
.
expand_mm_tokens
else
0
# skip mm token
image_seqlen
=
num_images
*
getattr
(
processor
,
"image_seqlen"
)
if
self
.
expand_mm_tokens
else
0
# skip mm token
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
self
.
image_token
)
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
self
.
image_token
)
...
@@ -696,18 +898,21 @@ class PaliGemmaPlugin(BasePlugin):
...
@@ -696,18 +898,21 @@ class PaliGemmaPlugin(BasePlugin):
self
,
self
,
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
seqlens
=
[
len
(
input_ids
)
for
input_ids
in
batch_ids
]
seqlens
=
[
len
(
input_ids
)
for
input_ids
in
batch_ids
]
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
mm_inputs
[
"token_type_ids"
]
=
_get_paligemma_token_type_ids
(
imglens
,
seqlens
,
processor
)
mm_inputs
[
"token_type_ids"
]
=
_get_paligemma_token_type_ids
(
imglens
,
seqlens
,
processor
)
return
mm_inputs
return
mm_inputs
@
dataclass
class
PixtralPlugin
(
BasePlugin
):
class
PixtralPlugin
(
BasePlugin
):
@
override
@
override
def
process_messages
(
def
process_messages
(
...
@@ -715,9 +920,10 @@ class PixtralPlugin(BasePlugin):
...
@@ -715,9 +920,10 @@ class PixtralPlugin(BasePlugin):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
patch_size
=
getattr
(
processor
,
"patch_size"
)
patch_size
=
getattr
(
processor
,
"patch_size"
)
image_token
=
getattr
(
processor
,
"image_token"
)
image_token
=
getattr
(
processor
,
"image_token"
)
image_break_token
=
getattr
(
processor
,
"image_break_token"
)
image_break_token
=
getattr
(
processor
,
"image_break_token"
)
...
@@ -725,17 +931,15 @@ class PixtralPlugin(BasePlugin):
...
@@ -725,17 +931,15 @@ class PixtralPlugin(BasePlugin):
num_image_tokens
=
0
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
image_input_sizes
=
mm_inputs
.
get
(
"image_sizes"
,
None
)
if
"pixel_values"
in
mm_inputs
:
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
].
tolist
())
for
message
in
messages
:
for
message
in
messages
:
content
=
message
[
"content"
]
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
while
IMAGE_PLACEHOLDER
in
content
:
if
image_input_sizes
is
None
:
raise
ValueError
(
"Cannot get image input sizes."
)
if
self
.
expand_mm_tokens
:
if
self
.
expand_mm_tokens
:
image_size
=
image_input_sizes
[
0
][
num_image_tokens
]
height
,
width
=
next
(
image_sizes
)
height
,
width
=
image_size
num_height_tokens
=
height
//
patch_size
num_height_tokens
=
height
//
patch_size
num_width_tokens
=
width
//
patch_size
num_width_tokens
=
width
//
patch_size
replace_tokens
=
[[
image_token
]
*
num_width_tokens
+
[
image_break_token
]]
*
num_height_tokens
replace_tokens
=
[[
image_token
]
*
num_width_tokens
+
[
image_break_token
]]
*
num_height_tokens
...
@@ -760,47 +964,105 @@ class PixtralPlugin(BasePlugin):
...
@@ -760,47 +964,105 @@ class PixtralPlugin(BasePlugin):
self
,
self
,
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
mm_inputs
.
get
(
"pixel_values"
):
mm_inputs
[
"pixel_values"
]
=
mm_inputs
[
"pixel_values"
][
0
]
mm_inputs
.
pop
(
"image_sizes"
,
None
)
mm_inputs
.
pop
(
"image_sizes"
,
None
)
return
mm_inputs
return
mm_inputs
class
Qwen2vlPlugin
(
BasePlugin
):
@
dataclass
class
Qwen2AudioPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
bos_token
:
str
=
getattr
(
processor
,
"audio_bos_token"
)
eos_token
:
str
=
getattr
(
processor
,
"audio_eos_token"
)
messages
=
deepcopy
(
messages
)
mm_inputs
=
self
.
_get_mm_inputs
([],
[],
audios
,
processor
)
if
"feature_attention_mask"
in
mm_inputs
:
audio_lengths
=
mm_inputs
[
"feature_attention_mask"
].
sum
(
-
1
).
tolist
()
num_audio_tokens
=
0
for
message
in
messages
:
content
=
message
[
"content"
]
while
AUDIO_PLACEHOLDER
in
content
:
if
self
.
expand_mm_tokens
:
audio_length
=
audio_lengths
.
pop
(
0
)
input_length
=
(
audio_length
-
1
)
//
2
+
1
audio_seqlen
=
(
input_length
-
2
)
//
2
+
1
else
:
audio_seqlen
=
1
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
f
"
{
bos_token
}{
self
.
audio_token
*
audio_seqlen
}{
eos_token
}
"
,
1
)
num_audio_tokens
+=
1
message
[
"content"
]
=
content
if
len
(
audios
)
!=
num_audio_tokens
:
raise
ValueError
(
f
"The number of audios does not match the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
return
messages
@
override
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
@
dataclass
class
Qwen2VLPlugin
(
BasePlugin
):
@
override
@
override
def
_preprocess_image
(
self
,
image
:
"ImageObject"
,
**
kwargs
)
->
"ImageObject"
:
def
_preprocess_image
(
self
,
image
:
"ImageObject"
,
**
kwargs
)
->
"ImageObject"
:
image
=
super
().
_preprocess_image
(
image
,
**
kwargs
)
image
=
super
().
_preprocess_image
(
image
,
**
kwargs
)
if
min
(
image
.
width
,
image
.
height
)
<
28
:
if
min
(
image
.
width
,
image
.
height
)
<
28
:
width
,
height
=
max
(
image
.
width
,
28
),
max
(
image
.
height
,
28
)
width
,
height
=
max
(
image
.
width
,
28
),
max
(
image
.
height
,
28
)
image
=
image
.
resize
((
width
,
height
)
,
resample
=
Image
.
NEAREST
)
image
=
image
.
resize
((
width
,
height
))
if
image
.
width
/
image
.
height
>
200
:
if
image
.
width
/
image
.
height
>
200
:
width
,
height
=
image
.
height
*
180
,
image
.
height
width
,
height
=
image
.
height
*
180
,
image
.
height
image
=
image
.
resize
((
width
,
height
)
,
resample
=
Image
.
NEAREST
)
image
=
image
.
resize
((
width
,
height
))
if
image
.
height
/
image
.
width
>
200
:
if
image
.
height
/
image
.
width
>
200
:
width
,
height
=
image
.
width
,
image
.
width
*
180
width
,
height
=
image
.
width
,
image
.
width
*
180
image
=
image
.
resize
((
width
,
height
)
,
resample
=
Image
.
NEAREST
)
image
=
image
.
resize
((
width
,
height
))
return
image
return
image
@
override
@
override
def
_regularize_videos
(
self
,
videos
:
Sequence
[
"VideoInput"
],
**
kwargs
)
->
List
[
List
[
"ImageObject"
]]:
def
_regularize_videos
(
results
=
[]
self
,
videos
:
Sequence
[
"VideoInput"
],
**
kwargs
)
->
Tuple
[
List
[
List
[
"ImageObject"
]],
List
[
float
]]:
results
,
fps_per_video
=
[],
[]
for
video
in
videos
:
for
video
in
videos
:
container
=
av
.
open
(
video
,
"r"
)
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
total_frames
=
video_stream
.
frames
sample_indices
=
self
.
_get_video_sample_indices
(
video_stream
,
**
kwargs
)
sample_frames
=
self
.
_get_video_sample_frames
(
video_stream
,
**
kwargs
)
sample_indices
=
np
.
linspace
(
0
,
total_frames
-
1
,
sample_frames
).
astype
(
np
.
int32
)
frames
:
List
[
"ImageObject"
]
=
[]
frames
:
List
[
"ImageObject"
]
=
[]
container
.
seek
(
0
)
container
.
seek
(
0
)
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
...
@@ -812,8 +1074,43 @@ class Qwen2vlPlugin(BasePlugin):
...
@@ -812,8 +1074,43 @@ class Qwen2vlPlugin(BasePlugin):
frames
=
self
.
_regularize_images
(
frames
,
**
kwargs
)
frames
=
self
.
_regularize_images
(
frames
,
**
kwargs
)
results
.
append
(
frames
)
results
.
append
(
frames
)
if
video_stream
.
duration
is
None
:
fps_per_video
.
append
(
2.0
)
else
:
fps_per_video
.
append
(
len
(
sample_indices
)
/
float
(
video_stream
.
duration
*
video_stream
.
time_base
))
return
results
return
results
,
fps_per_video
@
override
def
_get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
"ProcessorMixin"
,
)
->
Dict
[
str
,
"torch.Tensor"
]:
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
,
None
)
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
,
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)
mm_inputs
.
update
(
image_processor
(
images
,
return_tensors
=
"pt"
))
if
len
(
videos
)
!=
0
:
videos
,
fps_per_video
=
self
.
_regularize_videos
(
videos
,
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
image_min_pixels
=
getattr
(
processor
,
"video_min_pixels"
,
16
*
16
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)
mm_inputs
.
update
(
image_processor
(
images
=
None
,
videos
=
videos
,
return_tensors
=
"pt"
))
mm_inputs
[
"fps_per_video"
]
=
fps_per_video
return
mm_inputs
@
override
@
override
def
process_messages
(
def
process_messages
(
...
@@ -821,17 +1118,23 @@ class Qwen2vlPlugin(BasePlugin):
...
@@ -821,17 +1118,23 @@ class Qwen2vlPlugin(BasePlugin):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
merge_length
:
int
=
getattr
(
image_processor
,
"merge_size"
)
**
2
merge_length
:
int
=
getattr
(
image_processor
,
"merge_size"
)
**
2
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
if
self
.
expand_mm_tokens
:
image_grid_thw
=
mm_inputs
.
get
(
"image_grid_thw"
,
[])
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
video_grid_thw
=
mm_inputs
.
get
(
"video_grid_thw"
,
[])
image_grid_thw
=
mm_inputs
.
get
(
"image_grid_thw"
,
[])
video_grid_thw
=
mm_inputs
.
get
(
"video_grid_thw"
,
[])
else
:
image_grid_thw
=
[
None
]
*
len
(
images
)
video_grid_thw
=
[
None
]
*
len
(
videos
)
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
for
message
in
messages
:
content
=
message
[
"content"
]
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
while
IMAGE_PLACEHOLDER
in
content
:
...
@@ -869,15 +1172,24 @@ class Qwen2vlPlugin(BasePlugin):
...
@@ -869,15 +1172,24 @@ class Qwen2vlPlugin(BasePlugin):
self
,
self
,
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
fps_per_video
=
mm_inputs
.
pop
(
"fps_per_video"
,
[])
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
if
"second_per_grid_ts"
in
processor
.
model_input_names
and
fps_per_video
:
mm_inputs
[
"second_per_grid_ts"
]
=
[
image_processor
.
temporal_patch_size
/
fps
for
fps
in
fps_per_video
]
return
mm_inputs
@
dataclass
class
VideoLlavaPlugin
(
BasePlugin
):
class
VideoLlavaPlugin
(
BasePlugin
):
@
override
@
override
def
process_messages
(
def
process_messages
(
...
@@ -885,12 +1197,13 @@ class VideoLlavaPlugin(BasePlugin):
...
@@ -885,12 +1197,13 @@ class VideoLlavaPlugin(BasePlugin):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
messages
=
deepcopy
(
messages
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
num_frames
=
0
num_frames
=
0
has_images
=
"pixel_values_images"
in
mm_inputs
has_images
=
"pixel_values_images"
in
mm_inputs
has_videos
=
"pixel_values_videos"
in
mm_inputs
has_videos
=
"pixel_values_videos"
in
mm_inputs
...
@@ -907,7 +1220,7 @@ class VideoLlavaPlugin(BasePlugin):
...
@@ -907,7 +1220,7 @@ class VideoLlavaPlugin(BasePlugin):
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
+
1
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
+
1
video_seqlen
=
image_seqlen
*
num_frames
video_seqlen
=
image_seqlen
*
num_frames
if
getattr
(
processor
,
"vision_feature_select_strategy"
)
==
"default"
:
if
getattr
(
processor
,
"vision_feature_select_strategy"
,
"default"
)
==
"default"
:
image_seqlen
-=
1
image_seqlen
-=
1
else
:
else
:
image_seqlen
,
video_seqlen
=
1
,
1
image_seqlen
,
video_seqlen
=
1
,
1
...
@@ -938,13 +1251,15 @@ class VideoLlavaPlugin(BasePlugin):
...
@@ -938,13 +1251,15 @@ class VideoLlavaPlugin(BasePlugin):
self
,
self
,
images
:
Sequence
[
"ImageInput"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
PLUGINS
=
{
PLUGINS
=
{
...
@@ -956,18 +1271,32 @@ PLUGINS = {
...
@@ -956,18 +1271,32 @@ PLUGINS = {
"mllama"
:
MllamaPlugin
,
"mllama"
:
MllamaPlugin
,
"paligemma"
:
PaliGemmaPlugin
,
"paligemma"
:
PaliGemmaPlugin
,
"pixtral"
:
PixtralPlugin
,
"pixtral"
:
PixtralPlugin
,
"qwen2_vl"
:
Qwen2vlPlugin
,
"qwen2_audio"
:
Qwen2AudioPlugin
,
"qwen2_vl"
:
Qwen2VLPlugin
,
"video_llava"
:
VideoLlavaPlugin
,
"video_llava"
:
VideoLlavaPlugin
,
}
}
def
register_mm_plugin
(
name
:
str
,
plugin_class
:
Type
[
"BasePlugin"
])
->
None
:
r
"""
Registers a multimodal plugin.
"""
if
name
in
PLUGINS
:
raise
ValueError
(
f
"Multimodal plugin
{
name
}
already exists."
)
PLUGINS
[
name
]
=
plugin_class
def
get_mm_plugin
(
def
get_mm_plugin
(
name
:
str
,
name
:
str
,
image_token
:
Optional
[
str
]
=
None
,
image_token
:
Optional
[
str
]
=
None
,
video_token
:
Optional
[
str
]
=
None
,
video_token
:
Optional
[
str
]
=
None
,
audio_token
:
Optional
[
str
]
=
None
,
)
->
"BasePlugin"
:
)
->
"BasePlugin"
:
plugin_class
=
PLUGINS
.
get
(
name
,
None
)
r
"""
if
plugin_class
is
None
:
Gets plugin for multimodal inputs.
"""
if
name
not
in
PLUGINS
:
raise
ValueError
(
f
"Multimodal plugin `
{
name
}
` not found."
)
raise
ValueError
(
f
"Multimodal plugin `
{
name
}
` not found."
)
return
plugin_class
(
image_token
,
video_token
)
return
PLUGINS
[
name
]
(
image_token
,
video_token
,
audio_token
)
src/llamafactory/data/parser.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -44,7 +44,8 @@ class DatasetAttr:
...
@@ -44,7 +44,8 @@ class DatasetAttr:
tools
:
Optional
[
str
]
=
None
tools
:
Optional
[
str
]
=
None
images
:
Optional
[
str
]
=
None
images
:
Optional
[
str
]
=
None
videos
:
Optional
[
str
]
=
None
videos
:
Optional
[
str
]
=
None
# rlhf columns
audios
:
Optional
[
str
]
=
None
# dpo columns
chosen
:
Optional
[
str
]
=
None
chosen
:
Optional
[
str
]
=
None
rejected
:
Optional
[
str
]
=
None
rejected
:
Optional
[
str
]
=
None
kto_tag
:
Optional
[
str
]
=
None
kto_tag
:
Optional
[
str
]
=
None
...
@@ -70,6 +71,26 @@ class DatasetAttr:
...
@@ -70,6 +71,26 @@ class DatasetAttr:
def
set_attr
(
self
,
key
:
str
,
obj
:
Dict
[
str
,
Any
],
default
:
Optional
[
Any
]
=
None
)
->
None
:
def
set_attr
(
self
,
key
:
str
,
obj
:
Dict
[
str
,
Any
],
default
:
Optional
[
Any
]
=
None
)
->
None
:
setattr
(
self
,
key
,
obj
.
get
(
key
,
default
))
setattr
(
self
,
key
,
obj
.
get
(
key
,
default
))
def
join
(
self
,
attr
:
Dict
[
str
,
Any
])
->
None
:
self
.
set_attr
(
"formatting"
,
attr
,
default
=
"alpaca"
)
self
.
set_attr
(
"ranking"
,
attr
,
default
=
False
)
self
.
set_attr
(
"subset"
,
attr
)
self
.
set_attr
(
"split"
,
attr
,
default
=
"train"
)
self
.
set_attr
(
"folder"
,
attr
)
self
.
set_attr
(
"num_samples"
,
attr
)
if
"columns"
in
attr
:
column_names
=
[
"prompt"
,
"query"
,
"response"
,
"history"
,
"messages"
,
"system"
,
"tools"
]
column_names
+=
[
"images"
,
"videos"
,
"audios"
,
"chosen"
,
"rejected"
,
"kto_tag"
]
for
column_name
in
column_names
:
self
.
set_attr
(
column_name
,
attr
[
"columns"
])
if
"tags"
in
attr
:
tag_names
=
[
"role_tag"
,
"content_tag"
]
tag_names
+=
[
"user_tag"
,
"assistant_tag"
,
"observation_tag"
,
"function_tag"
,
"system_tag"
]
for
tag
in
tag_names
:
self
.
set_attr
(
tag
,
attr
[
"tags"
])
def
get_dataset_list
(
dataset_names
:
Optional
[
Sequence
[
str
]],
dataset_dir
:
str
)
->
List
[
"DatasetAttr"
]:
def
get_dataset_list
(
dataset_names
:
Optional
[
Sequence
[
str
]],
dataset_dir
:
str
)
->
List
[
"DatasetAttr"
]:
r
"""
r
"""
...
@@ -127,36 +148,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
...
@@ -127,36 +148,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
else
:
else
:
dataset_attr
=
DatasetAttr
(
"file"
,
dataset_name
=
dataset_info
[
name
][
"file_name"
])
dataset_attr
=
DatasetAttr
(
"file"
,
dataset_name
=
dataset_info
[
name
][
"file_name"
])
dataset_attr
.
set_attr
(
"formatting"
,
dataset_info
[
name
],
default
=
"alpaca"
)
dataset_attr
.
join
(
dataset_info
[
name
])
dataset_attr
.
set_attr
(
"ranking"
,
dataset_info
[
name
],
default
=
False
)
dataset_attr
.
set_attr
(
"subset"
,
dataset_info
[
name
])
dataset_attr
.
set_attr
(
"split"
,
dataset_info
[
name
],
default
=
"train"
)
dataset_attr
.
set_attr
(
"folder"
,
dataset_info
[
name
])
dataset_attr
.
set_attr
(
"num_samples"
,
dataset_info
[
name
])
if
"columns"
in
dataset_info
[
name
]:
column_names
=
[
"system"
,
"tools"
,
"images"
,
"videos"
,
"chosen"
,
"rejected"
,
"kto_tag"
]
if
dataset_attr
.
formatting
==
"alpaca"
:
column_names
.
extend
([
"prompt"
,
"query"
,
"response"
,
"history"
])
else
:
column_names
.
extend
([
"messages"
])
for
column_name
in
column_names
:
dataset_attr
.
set_attr
(
column_name
,
dataset_info
[
name
][
"columns"
])
if
dataset_attr
.
formatting
==
"sharegpt"
and
"tags"
in
dataset_info
[
name
]:
tag_names
=
(
"role_tag"
,
"content_tag"
,
"user_tag"
,
"assistant_tag"
,
"observation_tag"
,
"function_tag"
,
"system_tag"
,
)
for
tag
in
tag_names
:
dataset_attr
.
set_attr
(
tag
,
dataset_info
[
name
][
"tags"
])
dataset_list
.
append
(
dataset_attr
)
dataset_list
.
append
(
dataset_attr
)
return
dataset_list
return
dataset_list
src/llamafactory/data/preprocess.py
deleted
100644 → 0
View file @
37b0ad9f
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Callable
,
Literal
,
Optional
,
Tuple
from
.processors.feedback
import
preprocess_feedback_dataset
from
.processors.pairwise
import
preprocess_pairwise_dataset
,
print_pairwise_dataset_example
from
.processors.pretrain
import
preprocess_pretrain_dataset
,
print_pretrain_dataset_example
from
.processors.supervised
import
(
preprocess_packed_supervised_dataset
,
preprocess_supervised_dataset
,
print_supervised_dataset_example
,
)
from
.processors.unsupervised
import
preprocess_unsupervised_dataset
,
print_unsupervised_dataset_example
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
..hparams
import
DataArguments
from
.template
import
Template
def
get_preprocess_and_print_func
(
data_args
:
"DataArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
do_generate
:
bool
=
False
,
)
->
Tuple
[
Callable
,
Callable
]:
if
stage
==
"pt"
:
preprocess_func
=
partial
(
preprocess_pretrain_dataset
,
tokenizer
=
tokenizer
,
data_args
=
data_args
,
)
print_function
=
partial
(
print_pretrain_dataset_example
,
tokenizer
=
tokenizer
)
elif
stage
==
"sft"
and
not
do_generate
:
if
data_args
.
packing
:
if
data_args
.
neat_packing
:
# hack datasets to have int32 attention mask
from
datasets.arrow_writer
import
OptimizedTypedSequence
,
TypedSequence
def
__init__
(
self
,
data
,
**
kwargs
):
return
TypedSequence
.
__init__
(
self
,
data
,
type
=
kwargs
.
pop
(
"type"
,
None
),
try_type
=
kwargs
.
pop
(
"try_type"
,
None
),
optimized_int_type
=
kwargs
.
pop
(
"optimized_int_type"
,
None
),
)
OptimizedTypedSequence
.
__init__
=
__init__
preprocess_func
=
partial
(
preprocess_packed_supervised_dataset
,
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
data_args
=
data_args
,
)
else
:
preprocess_func
=
partial
(
preprocess_supervised_dataset
,
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
data_args
=
data_args
,
)
print_function
=
partial
(
print_supervised_dataset_example
,
tokenizer
=
tokenizer
)
elif
stage
==
"rm"
:
preprocess_func
=
partial
(
preprocess_pairwise_dataset
,
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
data_args
=
data_args
,
)
print_function
=
partial
(
print_pairwise_dataset_example
,
tokenizer
=
tokenizer
)
elif
stage
==
"kto"
:
preprocess_func
=
partial
(
preprocess_feedback_dataset
,
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
data_args
=
data_args
,
)
print_function
=
partial
(
print_supervised_dataset_example
,
tokenizer
=
tokenizer
)
else
:
preprocess_func
=
partial
(
preprocess_unsupervised_dataset
,
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
data_args
=
data_args
,
)
print_function
=
partial
(
print_unsupervised_dataset_example
,
tokenizer
=
tokenizer
)
return
preprocess_func
,
print_function
src/llamafactory/data/processor/__init__.py
0 → 100644
View file @
317a82e2
from
.feedback
import
FeedbackDatasetProcessor
from
.pairwise
import
PairwiseDatasetProcessor
from
.pretrain
import
PretrainDatasetProcessor
from
.processor_utils
import
DatasetProcessor
from
.supervised
import
PackedSupervisedDatasetProcessor
,
SupervisedDatasetProcessor
from
.unsupervised
import
UnsupervisedDatasetProcessor
__all__
=
[
"DatasetProcessor"
,
"FeedbackDatasetProcessor"
,
"PairwiseDatasetProcessor"
,
"PretrainDatasetProcessor"
,
"PackedSupervisedDatasetProcessor"
,
"SupervisedDatasetProcessor"
,
"UnsupervisedDatasetProcessor"
,
]
src/llamafactory/data/processor/feedback.py
0 → 100644
View file @
317a82e2
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
.processor_utils
import
DatasetProcessor
,
infer_seqlen
if
TYPE_CHECKING
:
from
..mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
logger
=
logging
.
get_logger
(
__name__
)
class
FeedbackDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
self
,
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
kl_response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
],
bool
]:
if
response
[
0
][
"content"
]:
# desired example
kto_tag
=
True
messages
=
prompt
+
[
response
[
0
]]
else
:
# undesired example
kto_tag
=
False
messages
=
prompt
+
[
response
[
1
]]
if
kl_response
[
0
][
"content"
]:
kl_messages
=
prompt
+
[
kl_response
[
0
]]
else
:
kl_messages
=
prompt
+
[
kl_response
[
1
]]
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
messages
,
images
,
videos
,
audios
,
self
.
processor
)
kl_messages
=
self
.
template
.
mm_plugin
.
process_messages
(
kl_messages
,
images
,
videos
,
audios
,
self
.
processor
)
prompt_ids
,
response_ids
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
messages
,
system
,
tools
)
kl_prompt_ids
,
kl_response_ids
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
kl_messages
,
system
,
tools
)
if
self
.
template
.
efficient_eos
:
response_ids
+=
[
self
.
tokenizer
.
eos_token_id
]
kl_response_ids
+=
[
self
.
tokenizer
.
eos_token_id
]
prompt_ids
,
_
=
self
.
template
.
mm_plugin
.
process_token_ids
(
prompt_ids
,
None
,
images
,
videos
,
audios
,
self
.
tokenizer
,
self
.
processor
)
kl_prompt_ids
,
_
=
self
.
template
.
mm_plugin
.
process_token_ids
(
kl_prompt_ids
,
None
,
images
,
videos
,
audios
,
self
.
tokenizer
,
self
.
processor
)
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
len
(
response_ids
),
self
.
data_args
.
cutoff_len
)
prompt_ids
=
prompt_ids
[:
source_len
]
response_ids
=
response_ids
[:
target_len
]
kl_source_len
,
kl_target_len
=
infer_seqlen
(
len
(
kl_prompt_ids
),
len
(
kl_response_ids
),
self
.
data_args
.
cutoff_len
)
kl_prompt_ids
=
kl_prompt_ids
[:
kl_source_len
]
kl_response_ids
=
kl_response_ids
[:
kl_target_len
]
input_ids
=
prompt_ids
+
response_ids
labels
=
[
IGNORE_INDEX
]
*
source_len
+
response_ids
kl_input_ids
=
kl_prompt_ids
+
kl_response_ids
kl_labels
=
[
IGNORE_INDEX
]
*
kl_source_len
+
kl_response_ids
return
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
def
preprocess_dataset
(
self
,
examples
:
Dict
[
str
,
List
[
Any
]])
->
Dict
[
str
,
List
[
Any
]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response
=
examples
[
"_response"
][::
-
1
]
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
=
self
.
_encode_data_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
kl_response
=
kl_response
[
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
audios
=
examples
[
"_audios"
][
i
]
or
[],
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"labels"
].
append
(
labels
)
model_inputs
[
"kl_input_ids"
].
append
(
kl_input_ids
)
model_inputs
[
"kl_attention_mask"
].
append
([
1
]
*
len
(
kl_input_ids
))
model_inputs
[
"kl_labels"
].
append
(
kl_labels
)
model_inputs
[
"kto_tags"
].
append
(
kto_tag
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
model_inputs
[
"audios"
].
append
(
examples
[
"_audios"
][
i
])
desirable_num
=
sum
([
1
for
tag
in
model_inputs
[
"kto_tags"
]
if
tag
])
undesirable_num
=
len
(
model_inputs
[
"kto_tags"
])
-
desirable_num
if
desirable_num
==
0
or
undesirable_num
==
0
:
logger
.
warning_rank0
(
"Your dataset only has one preference type."
)
return
model_inputs
def
print_data_example
(
self
,
example
:
Dict
[
str
,
List
[
int
]])
->
None
:
valid_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"labels"
]))
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
f
"labels:
\n
{
self
.
tokenizer
.
decode
(
valid_labels
,
skip_special_tokens
=
False
)
}
"
)
src/llamafactory/data/processor/pairwise.py
0 → 100644
View file @
317a82e2
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
.processor_utils
import
DatasetProcessor
,
infer_seqlen
if
TYPE_CHECKING
:
from
..mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
logger
=
logging
.
get_logger
(
__name__
)
class
PairwiseDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
self
,
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
]]:
chosen_messages
=
self
.
template
.
mm_plugin
.
process_messages
(
prompt
+
[
response
[
0
]],
images
,
videos
,
audios
,
self
.
processor
)
rejected_messages
=
self
.
template
.
mm_plugin
.
process_messages
(
prompt
+
[
response
[
1
]],
images
,
videos
,
audios
,
self
.
processor
)
prompt_ids
,
chosen_ids
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
chosen_messages
,
system
,
tools
)
_
,
rejected_ids
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
rejected_messages
,
system
,
tools
)
if
self
.
template
.
efficient_eos
:
chosen_ids
+=
[
self
.
tokenizer
.
eos_token_id
]
rejected_ids
+=
[
self
.
tokenizer
.
eos_token_id
]
prompt_ids
,
_
=
self
.
template
.
mm_plugin
.
process_token_ids
(
prompt_ids
,
None
,
images
,
videos
,
audios
,
self
.
tokenizer
,
self
.
processor
)
# consider the response is more important
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
max
(
len
(
chosen_ids
),
len
(
rejected_ids
)),
self
.
data_args
.
cutoff_len
)
prompt_ids
=
prompt_ids
[:
source_len
]
chosen_ids
=
chosen_ids
[:
target_len
]
rejected_ids
=
rejected_ids
[:
target_len
]
chosen_input_ids
=
prompt_ids
+
chosen_ids
chosen_labels
=
[
IGNORE_INDEX
]
*
source_len
+
chosen_ids
rejected_input_ids
=
prompt_ids
+
rejected_ids
rejected_labels
=
[
IGNORE_INDEX
]
*
source_len
+
rejected_ids
return
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
def
preprocess_dataset
(
self
,
examples
:
Dict
[
str
,
List
[
Any
]])
->
Dict
[
str
,
List
[
Any
]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
=
self
.
_encode_data_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
audios
=
examples
[
"_audios"
][
i
]
or
[],
)
model_inputs
[
"chosen_input_ids"
].
append
(
chosen_input_ids
)
model_inputs
[
"chosen_attention_mask"
].
append
([
1
]
*
len
(
chosen_input_ids
))
model_inputs
[
"chosen_labels"
].
append
(
chosen_labels
)
model_inputs
[
"rejected_input_ids"
].
append
(
rejected_input_ids
)
model_inputs
[
"rejected_attention_mask"
].
append
([
1
]
*
len
(
rejected_input_ids
))
model_inputs
[
"rejected_labels"
].
append
(
rejected_labels
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
model_inputs
[
"audios"
].
append
(
examples
[
"_audios"
][
i
])
return
model_inputs
def
print_data_example
(
self
,
example
:
Dict
[
str
,
List
[
int
]])
->
None
:
valid_chosen_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"chosen_labels"
]))
valid_rejected_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"rejected_labels"
]))
print
(
"chosen_input_ids:
\n
{}"
.
format
(
example
[
"chosen_input_ids"
]))
print
(
"chosen_inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"chosen_input_ids"
],
skip_special_tokens
=
False
))
)
print
(
"chosen_label_ids:
\n
{}"
.
format
(
example
[
"chosen_labels"
]))
print
(
f
"chosen_labels:
\n
{
self
.
tokenizer
.
decode
(
valid_chosen_labels
,
skip_special_tokens
=
False
)
}
"
)
print
(
"rejected_input_ids:
\n
{}"
.
format
(
example
[
"rejected_input_ids"
]))
print
(
"rejected_inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"rejected_input_ids"
],
skip_special_tokens
=
False
)
)
)
print
(
"rejected_label_ids:
\n
{}"
.
format
(
example
[
"rejected_labels"
]))
print
(
f
"rejected_labels:
\n
{
self
.
tokenizer
.
decode
(
valid_rejected_labels
,
skip_special_tokens
=
False
)
}
"
)
src/llamafactory/data/processor/pretrain.py
0 → 100644
View file @
317a82e2
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
from
itertools
import
chain
from
typing
import
Any
,
Dict
,
List
from
.processor_utils
import
DatasetProcessor
@
dataclass
class
PretrainDatasetProcessor
(
DatasetProcessor
):
def
preprocess_dataset
(
self
,
examples
:
Dict
[
str
,
List
[
Any
]])
->
Dict
[
str
,
List
[
Any
]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token
=
"<|end_of_text|>"
if
self
.
data_args
.
template
==
"llama3"
else
self
.
tokenizer
.
eos_token
text_examples
=
[
messages
[
0
][
"content"
]
+
eos_token
for
messages
in
examples
[
"_prompt"
]]
if
not
self
.
data_args
.
packing
:
if
getattr
(
self
.
tokenizer
,
"add_bos_token"
,
False
):
text_examples
=
[
self
.
tokenizer
.
bos_token
+
example
for
example
in
text_examples
]
result
=
self
.
tokenizer
(
text_examples
,
add_special_tokens
=
False
,
truncation
=
True
,
max_length
=
self
.
data_args
.
cutoff_len
)
else
:
tokenized_examples
=
self
.
tokenizer
(
text_examples
,
add_special_tokens
=
False
)
concatenated_examples
=
{
k
:
list
(
chain
(
*
tokenized_examples
[
k
]))
for
k
in
tokenized_examples
.
keys
()}
total_length
=
len
(
concatenated_examples
[
list
(
concatenated_examples
.
keys
())[
0
]])
block_size
=
self
.
data_args
.
cutoff_len
total_length
=
(
total_length
//
block_size
)
*
block_size
result
=
{
k
:
[
t
[
i
:
i
+
block_size
]
for
i
in
range
(
0
,
total_length
,
block_size
)]
for
k
,
t
in
concatenated_examples
.
items
()
}
if
getattr
(
self
.
tokenizer
,
"add_bos_token"
,
False
):
for
i
in
range
(
len
(
result
[
"input_ids"
])):
result
[
"input_ids"
][
i
][
0
]
=
self
.
tokenizer
.
bos_token_id
return
result
def
print_data_example
(
self
,
example
:
Dict
[
str
,
List
[
int
]])
->
None
:
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
src/llamafactory/data/processor
s
/processor_utils.py
→
src/llamafactory/data/processor/processor_utils.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -13,7 +13,42 @@
...
@@ -13,7 +13,42 @@
# limitations under the License.
# limitations under the License.
import
bisect
import
bisect
from
typing
import
List
,
Sequence
,
Tuple
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
...hparams
import
DataArguments
from
..template
import
Template
@
dataclass
class
DatasetProcessor
(
ABC
):
r
"""
A class for data processors.
"""
template
:
"Template"
tokenizer
:
"PreTrainedTokenizer"
processor
:
Optional
[
"ProcessorMixin"
]
data_args
:
"DataArguments"
@
abstractmethod
def
preprocess_dataset
(
self
,
examples
:
Dict
[
str
,
List
[
Any
]])
->
Dict
[
str
,
List
[
Any
]]:
r
"""
Builds model inputs from the examples.
"""
...
@
abstractmethod
def
print_data_example
(
self
,
example
:
Dict
[
str
,
List
[
int
]])
->
None
:
r
"""
Print a data example to stdout.
"""
...
def
search_for_fit
(
numbers
:
Sequence
[
int
],
capacity
:
int
)
->
int
:
def
search_for_fit
(
numbers
:
Sequence
[
int
],
capacity
:
int
)
->
int
:
...
...
src/llamafactory/data/processor/supervised.py
0 → 100644
View file @
317a82e2
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
.processor_utils
import
DatasetProcessor
,
greedy_knapsack
,
infer_seqlen
if
TYPE_CHECKING
:
from
..mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
logger
=
logging
.
get_logger
(
__name__
)
@
dataclass
class
SupervisedDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
self
,
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
prompt
+
response
,
images
,
videos
,
audios
,
self
.
processor
)
input_ids
,
labels
=
self
.
template
.
mm_plugin
.
process_token_ids
(
[],
[],
images
,
videos
,
audios
,
self
.
tokenizer
,
self
.
processor
)
encoded_pairs
=
self
.
template
.
encode_multiturn
(
self
.
tokenizer
,
messages
,
system
,
tools
)
total_length
=
len
(
input_ids
)
+
(
1
if
self
.
template
.
efficient_eos
else
0
)
if
self
.
data_args
.
mask_history
:
encoded_pairs
=
encoded_pairs
[::
-
1
]
# high priority for last turns
for
turn_idx
,
(
source_ids
,
target_ids
)
in
enumerate
(
encoded_pairs
):
if
total_length
>=
self
.
data_args
.
cutoff_len
:
break
source_len
,
target_len
=
infer_seqlen
(
len
(
source_ids
),
len
(
target_ids
),
self
.
data_args
.
cutoff_len
-
total_length
)
source_ids
=
source_ids
[:
source_len
]
target_ids
=
target_ids
[:
target_len
]
total_length
+=
source_len
+
target_len
if
self
.
data_args
.
train_on_prompt
:
source_label
=
source_ids
elif
self
.
template
.
efficient_eos
:
source_label
=
[
self
.
tokenizer
.
eos_token_id
]
+
[
IGNORE_INDEX
]
*
(
source_len
-
1
)
else
:
source_label
=
[
IGNORE_INDEX
]
*
source_len
if
self
.
data_args
.
mask_history
and
turn_idx
!=
0
:
# train on the last turn only
target_label
=
[
IGNORE_INDEX
]
*
target_len
else
:
target_label
=
target_ids
if
self
.
data_args
.
mask_history
:
# reversed sequences
input_ids
=
source_ids
+
target_ids
+
input_ids
labels
=
source_label
+
target_label
+
labels
else
:
input_ids
+=
source_ids
+
target_ids
labels
+=
source_label
+
target_label
if
self
.
template
.
efficient_eos
:
input_ids
+=
[
self
.
tokenizer
.
eos_token_id
]
labels
+=
[
self
.
tokenizer
.
eos_token_id
]
return
input_ids
,
labels
def
preprocess_dataset
(
self
,
examples
:
Dict
[
str
,
List
[
Any
]])
->
Dict
[
str
,
List
[
Any
]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
!=
1
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
=
self
.
_encode_data_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
audios
=
examples
[
"_audios"
][
i
]
or
[],
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"labels"
].
append
(
labels
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
model_inputs
[
"audios"
].
append
(
examples
[
"_audios"
][
i
])
return
model_inputs
def
print_data_example
(
self
,
example
:
Dict
[
str
,
List
[
int
]])
->
None
:
valid_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"labels"
]))
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
f
"labels:
\n
{
self
.
tokenizer
.
decode
(
valid_labels
,
skip_special_tokens
=
False
)
}
"
)
@
dataclass
class
PackedSupervisedDatasetProcessor
(
SupervisedDatasetProcessor
):
def
preprocess_dataset
(
self
,
examples
:
Dict
[
str
,
List
[
Any
]])
->
Dict
[
str
,
List
[
Any
]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num
=
0
batch_input_ids
,
batch_labels
,
batch_images
,
batch_videos
,
batch_audios
=
[],
[],
[],
[],
[]
lengths
=
[]
length2indexes
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
!=
1
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
=
self
.
_encode_data_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
audios
=
examples
[
"_audios"
][
i
]
or
[],
)
length
=
len
(
input_ids
)
if
length
>
self
.
data_args
.
cutoff_len
:
logger
.
warning_rank0
(
f
"Dropped lengthy example with length
{
length
}
>
{
self
.
data_args
.
cutoff_len
}
."
)
else
:
lengths
.
append
(
length
)
length2indexes
[
length
].
append
(
valid_num
)
batch_input_ids
.
append
(
input_ids
)
batch_labels
.
append
(
labels
)
batch_images
.
append
(
examples
[
"_images"
][
i
]
or
[])
batch_videos
.
append
(
examples
[
"_videos"
][
i
]
or
[])
batch_audios
.
append
(
examples
[
"_audios"
][
i
]
or
[])
valid_num
+=
1
model_inputs
=
defaultdict
(
list
)
knapsacks
=
greedy_knapsack
(
lengths
,
self
.
data_args
.
cutoff_len
)
for
knapsack
in
knapsacks
:
packed_input_ids
,
packed_attention_masks
,
packed_labels
=
[],
[],
[]
packed_images
,
packed_videos
,
packed_audios
=
[],
[],
[]
for
i
,
length
in
enumerate
(
knapsack
):
index
=
length2indexes
[
length
].
pop
()
packed_input_ids
+=
batch_input_ids
[
index
]
packed_labels
+=
batch_labels
[
index
]
packed_images
+=
batch_images
[
index
]
packed_videos
+=
batch_videos
[
index
]
packed_audios
+=
batch_audios
[
index
]
if
self
.
data_args
.
neat_packing
:
packed_attention_masks
+=
[
i
+
1
]
*
len
(
batch_input_ids
[
index
])
# start from 1
else
:
packed_attention_masks
+=
[
1
]
*
len
(
batch_input_ids
[
index
])
if
len
(
packed_input_ids
)
<
self
.
data_args
.
cutoff_len
+
1
:
# avoid flash_attn drops attn mask
pad_length
=
self
.
data_args
.
cutoff_len
-
len
(
packed_input_ids
)
+
1
packed_input_ids
+=
[
self
.
tokenizer
.
pad_token_id
]
*
pad_length
packed_labels
+=
[
IGNORE_INDEX
]
*
pad_length
if
self
.
data_args
.
neat_packing
:
packed_attention_masks
+=
[
0
]
*
pad_length
else
:
packed_attention_masks
+=
[
1
]
*
pad_length
# more efficient flash_attn
if
len
(
packed_input_ids
)
!=
self
.
data_args
.
cutoff_len
+
1
:
raise
ValueError
(
"The length of packed example should be identical to the cutoff length."
)
model_inputs
[
"input_ids"
].
append
(
packed_input_ids
)
model_inputs
[
"attention_mask"
].
append
(
packed_attention_masks
)
model_inputs
[
"labels"
].
append
(
packed_labels
)
model_inputs
[
"images"
].
append
(
packed_images
or
None
)
model_inputs
[
"videos"
].
append
(
packed_videos
or
None
)
model_inputs
[
"audios"
].
append
(
packed_audios
or
None
)
return
model_inputs
src/llamafactory/data/processor/unsupervised.py
0 → 100644
View file @
317a82e2
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
..data_utils
import
Role
from
.processor_utils
import
DatasetProcessor
,
infer_seqlen
if
TYPE_CHECKING
:
from
..mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
logger
=
logging
.
get_logger
(
__name__
)
class
UnsupervisedDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
self
,
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
if
len
(
response
)
==
1
:
messages
=
prompt
+
response
else
:
messages
=
prompt
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
messages
,
images
,
videos
,
audios
,
self
.
processor
)
input_ids
,
labels
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
messages
,
system
,
tools
)
if
self
.
template
.
efficient_eos
:
labels
+=
[
self
.
tokenizer
.
eos_token_id
]
input_ids
,
_
=
self
.
template
.
mm_plugin
.
process_token_ids
(
input_ids
,
None
,
images
,
videos
,
audios
,
self
.
tokenizer
,
self
.
processor
)
source_len
,
target_len
=
infer_seqlen
(
len
(
input_ids
),
len
(
labels
),
self
.
data_args
.
cutoff_len
)
input_ids
=
input_ids
[:
source_len
]
labels
=
labels
[:
target_len
]
return
input_ids
,
labels
def
preprocess_dataset
(
self
,
examples
:
Dict
[
str
,
List
[
Any
]])
->
Dict
[
str
,
List
[
Any
]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
=
self
.
_encode_data_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
audios
=
examples
[
"_audios"
][
i
]
or
[],
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"labels"
].
append
(
labels
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
model_inputs
[
"audios"
].
append
(
examples
[
"_audios"
][
i
])
return
model_inputs
def
print_data_example
(
self
,
example
:
Dict
[
str
,
List
[
int
]])
->
None
:
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
"labels:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"labels"
],
skip_special_tokens
=
False
)))
src/llamafactory/data/processors/__init__.py
deleted
100644 → 0
View file @
37b0ad9f
src/llamafactory/data/processors/feedback.py
deleted
100644 → 0
View file @
37b0ad9f
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
.processor_utils
import
infer_seqlen
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
...hparams
import
DataArguments
from
..mm_plugin
import
ImageInput
,
VideoInput
from
..template
import
Template
logger
=
logging
.
get_logger
(
__name__
)
def
_encode_feedback_example
(
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
kl_response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
cutoff_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
],
bool
]:
if
response
[
0
][
"content"
]:
# desired example
kto_tag
=
True
messages
=
prompt
+
[
response
[
0
]]
else
:
# undesired example
kto_tag
=
False
messages
=
prompt
+
[
response
[
1
]]
if
kl_response
[
0
][
"content"
]:
kl_messages
=
prompt
+
[
kl_response
[
0
]]
else
:
kl_messages
=
prompt
+
[
kl_response
[
1
]]
messages
=
template
.
mm_plugin
.
process_messages
(
messages
,
images
,
videos
,
processor
)
kl_messages
=
template
.
mm_plugin
.
process_messages
(
kl_messages
,
images
,
videos
,
processor
)
prompt_ids
,
response_ids
=
template
.
encode_oneturn
(
tokenizer
,
messages
,
system
,
tools
)
kl_prompt_ids
,
kl_response_ids
=
template
.
encode_oneturn
(
tokenizer
,
kl_messages
,
system
,
tools
)
if
template
.
efficient_eos
:
response_ids
+=
[
tokenizer
.
eos_token_id
]
kl_response_ids
+=
[
tokenizer
.
eos_token_id
]
prompt_ids
,
_
=
template
.
mm_plugin
.
process_token_ids
(
prompt_ids
,
None
,
images
,
videos
,
tokenizer
,
processor
)
kl_prompt_ids
,
_
=
template
.
mm_plugin
.
process_token_ids
(
kl_prompt_ids
,
None
,
images
,
videos
,
tokenizer
,
processor
)
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
len
(
response_ids
),
cutoff_len
)
prompt_ids
=
prompt_ids
[:
source_len
]
response_ids
=
response_ids
[:
target_len
]
kl_source_len
,
kl_target_len
=
infer_seqlen
(
len
(
kl_prompt_ids
),
len
(
kl_response_ids
),
cutoff_len
)
kl_prompt_ids
=
kl_prompt_ids
[:
kl_source_len
]
kl_response_ids
=
kl_response_ids
[:
kl_target_len
]
input_ids
=
prompt_ids
+
response_ids
labels
=
[
IGNORE_INDEX
]
*
source_len
+
response_ids
kl_input_ids
=
kl_prompt_ids
+
kl_response_ids
kl_labels
=
[
IGNORE_INDEX
]
*
kl_source_len
+
kl_response_ids
return
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
def
preprocess_feedback_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
List
[
Any
]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response
=
examples
[
"_response"
][::
-
1
]
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
=
_encode_feedback_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
kl_response
=
kl_response
[
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
cutoff_len
=
data_args
.
cutoff_len
,
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"labels"
].
append
(
labels
)
model_inputs
[
"kl_input_ids"
].
append
(
kl_input_ids
)
model_inputs
[
"kl_attention_mask"
].
append
([
1
]
*
len
(
kl_input_ids
))
model_inputs
[
"kl_labels"
].
append
(
kl_labels
)
model_inputs
[
"kto_tags"
].
append
(
kto_tag
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
desirable_num
=
sum
([
1
for
tag
in
model_inputs
[
"kto_tags"
]
if
tag
])
undesirable_num
=
len
(
model_inputs
[
"kto_tags"
])
-
desirable_num
if
desirable_num
==
0
or
undesirable_num
==
0
:
logger
.
warning_rank0
(
"Your dataset only has one preference type."
)
return
model_inputs
src/llamafactory/data/processors/mm_utils.py
deleted
100644 → 0
View file @
37b0ad9f
from
typing
import
TYPE_CHECKING
,
List
,
Sequence
from
...extras.packages
import
is_pillow_available
if
is_pillow_available
():
from
PIL
import
Image
if
TYPE_CHECKING
:
from
numpy.typing
import
NDArray
from
PIL.Image
import
Image
as
ImageObject
from
transformers
import
ProcessorMixin
from
transformers.image_processing_utils
import
BaseImageProcessor
def
get_pixel_values
(
images
:
Sequence
[
"ImageObject"
],
processor
:
"ProcessorMixin"
)
->
"NDArray"
:
# process visual inputs (currently only supports a single image)
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
image
=
images
[
0
]
if
len
(
images
)
!=
0
else
Image
.
new
(
"RGB"
,
(
100
,
100
),
(
255
,
255
,
255
))
return
image_processor
(
image
,
return_tensors
=
"pt"
)[
"pixel_values"
][
0
]
# shape (C, H, W)
def
get_paligemma_token_type_ids
(
input_len
:
int
,
processor
:
"ProcessorMixin"
)
->
List
[
int
]:
# get paligemma token type ids for computing loss
image_seq_length
=
getattr
(
processor
,
"image_seq_length"
)
return
[
0
]
*
image_seq_length
+
[
1
]
*
(
input_len
-
image_seq_length
)
src/llamafactory/data/processors/pairwise.py
deleted
100644 → 0
View file @
37b0ad9f
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
.processor_utils
import
infer_seqlen
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
...hparams
import
DataArguments
from
..mm_plugin
import
ImageInput
,
VideoInput
from
..template
import
Template
logger
=
logging
.
get_logger
(
__name__
)
def
_encode_pairwise_example
(
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
cutoff_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
]]:
chosen_messages
=
template
.
mm_plugin
.
process_messages
(
prompt
+
[
response
[
0
]],
images
,
videos
,
processor
)
rejected_messages
=
template
.
mm_plugin
.
process_messages
(
prompt
+
[
response
[
1
]],
images
,
videos
,
processor
)
prompt_ids
,
chosen_ids
=
template
.
encode_oneturn
(
tokenizer
,
chosen_messages
,
system
,
tools
)
_
,
rejected_ids
=
template
.
encode_oneturn
(
tokenizer
,
rejected_messages
,
system
,
tools
)
if
template
.
efficient_eos
:
chosen_ids
+=
[
tokenizer
.
eos_token_id
]
rejected_ids
+=
[
tokenizer
.
eos_token_id
]
prompt_ids
,
_
=
template
.
mm_plugin
.
process_token_ids
(
prompt_ids
,
None
,
images
,
videos
,
tokenizer
,
processor
)
# consider the response is more important
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
max
(
len
(
chosen_ids
),
len
(
rejected_ids
)),
cutoff_len
)
prompt_ids
=
prompt_ids
[:
source_len
]
chosen_ids
=
chosen_ids
[:
target_len
]
rejected_ids
=
rejected_ids
[:
target_len
]
chosen_input_ids
=
prompt_ids
+
chosen_ids
chosen_labels
=
[
IGNORE_INDEX
]
*
source_len
+
chosen_ids
rejected_input_ids
=
prompt_ids
+
rejected_ids
rejected_labels
=
[
IGNORE_INDEX
]
*
source_len
+
rejected_ids
return
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
def
preprocess_pairwise_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
List
[
Any
]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
=
_encode_pairwise_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
cutoff_len
=
data_args
.
cutoff_len
,
)
model_inputs
[
"chosen_input_ids"
].
append
(
chosen_input_ids
)
model_inputs
[
"chosen_attention_mask"
].
append
([
1
]
*
len
(
chosen_input_ids
))
model_inputs
[
"chosen_labels"
].
append
(
chosen_labels
)
model_inputs
[
"rejected_input_ids"
].
append
(
rejected_input_ids
)
model_inputs
[
"rejected_attention_mask"
].
append
([
1
]
*
len
(
rejected_input_ids
))
model_inputs
[
"rejected_labels"
].
append
(
rejected_labels
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
return
model_inputs
def
print_pairwise_dataset_example
(
example
:
Dict
[
str
,
List
[
int
]],
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
valid_chosen_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"chosen_labels"
]))
valid_rejected_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"rejected_labels"
]))
print
(
"chosen_input_ids:
\n
{}"
.
format
(
example
[
"chosen_input_ids"
]))
print
(
"chosen_inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"chosen_input_ids"
],
skip_special_tokens
=
False
)))
print
(
"chosen_label_ids:
\n
{}"
.
format
(
example
[
"chosen_labels"
]))
print
(
f
"chosen_labels:
\n
{
tokenizer
.
decode
(
valid_chosen_labels
,
skip_special_tokens
=
False
)
}
"
)
print
(
"rejected_input_ids:
\n
{}"
.
format
(
example
[
"rejected_input_ids"
]))
print
(
"rejected_inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"rejected_input_ids"
],
skip_special_tokens
=
False
)))
print
(
"rejected_label_ids:
\n
{}"
.
format
(
example
[
"rejected_labels"
]))
print
(
f
"rejected_labels:
\n
{
tokenizer
.
decode
(
valid_rejected_labels
,
skip_special_tokens
=
False
)
}
"
)
src/llamafactory/data/processors/pretrain.py
deleted
100644 → 0
View file @
37b0ad9f
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
itertools
import
chain
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
from
...hparams
import
DataArguments
def
preprocess_pretrain_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
tokenizer
:
"PreTrainedTokenizer"
,
data_args
:
"DataArguments"
)
->
Dict
[
str
,
List
[
Any
]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token
=
"<|end_of_text|>"
if
data_args
.
template
==
"llama3"
else
tokenizer
.
eos_token
text_examples
=
[
messages
[
0
][
"content"
]
+
eos_token
for
messages
in
examples
[
"_prompt"
]]
if
not
data_args
.
packing
:
if
data_args
.
template
==
"gemma"
:
text_examples
=
[
tokenizer
.
bos_token
+
example
for
example
in
text_examples
]
result
=
tokenizer
(
text_examples
,
add_special_tokens
=
False
,
truncation
=
True
,
max_length
=
data_args
.
cutoff_len
)
else
:
tokenized_examples
=
tokenizer
(
text_examples
,
add_special_tokens
=
False
)
concatenated_examples
=
{
k
:
list
(
chain
(
*
tokenized_examples
[
k
]))
for
k
in
tokenized_examples
.
keys
()}
total_length
=
len
(
concatenated_examples
[
list
(
concatenated_examples
.
keys
())[
0
]])
block_size
=
data_args
.
cutoff_len
total_length
=
(
total_length
//
block_size
)
*
block_size
result
=
{
k
:
[
t
[
i
:
i
+
block_size
]
for
i
in
range
(
0
,
total_length
,
block_size
)]
for
k
,
t
in
concatenated_examples
.
items
()
}
if
data_args
.
template
==
"gemma"
:
for
i
in
range
(
len
(
result
[
"input_ids"
])):
result
[
"input_ids"
][
i
][
0
]
=
tokenizer
.
bos_token_id
return
result
def
print_pretrain_dataset_example
(
example
:
Dict
[
str
,
List
[
int
]],
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
src/llamafactory/data/processors/supervised.py
deleted
100644 → 0
View file @
37b0ad9f
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
.processor_utils
import
greedy_knapsack
,
infer_seqlen
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
...hparams
import
DataArguments
from
..mm_plugin
import
ImageInput
,
VideoInput
from
..template
import
Template
logger
=
logging
.
get_logger
(
__name__
)
def
_encode_supervised_example
(
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
cutoff_len
:
int
,
train_on_prompt
:
bool
,
mask_history
:
bool
,
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
messages
=
template
.
mm_plugin
.
process_messages
(
prompt
+
response
,
images
,
videos
,
processor
)
input_ids
,
labels
=
template
.
mm_plugin
.
process_token_ids
([],
[],
images
,
videos
,
tokenizer
,
processor
)
encoded_pairs
=
template
.
encode_multiturn
(
tokenizer
,
messages
,
system
,
tools
)
total_length
=
len
(
input_ids
)
+
(
1
if
template
.
efficient_eos
else
0
)
if
mask_history
:
encoded_pairs
=
encoded_pairs
[::
-
1
]
# high priority for last turns
for
turn_idx
,
(
source_ids
,
target_ids
)
in
enumerate
(
encoded_pairs
):
if
total_length
>=
cutoff_len
:
break
source_len
,
target_len
=
infer_seqlen
(
len
(
source_ids
),
len
(
target_ids
),
cutoff_len
-
total_length
)
source_ids
=
source_ids
[:
source_len
]
target_ids
=
target_ids
[:
target_len
]
total_length
+=
source_len
+
target_len
if
train_on_prompt
:
source_label
=
source_ids
elif
template
.
efficient_eos
:
source_label
=
[
tokenizer
.
eos_token_id
]
+
[
IGNORE_INDEX
]
*
(
source_len
-
1
)
else
:
source_label
=
[
IGNORE_INDEX
]
*
source_len
if
mask_history
and
turn_idx
!=
0
:
# train on the last turn only
target_label
=
[
IGNORE_INDEX
]
*
target_len
else
:
target_label
=
target_ids
if
mask_history
:
# reversed sequences
input_ids
=
source_ids
+
target_ids
+
input_ids
labels
=
source_label
+
target_label
+
labels
else
:
input_ids
+=
source_ids
+
target_ids
labels
+=
source_label
+
target_label
if
template
.
efficient_eos
:
input_ids
+=
[
tokenizer
.
eos_token_id
]
labels
+=
[
tokenizer
.
eos_token_id
]
return
input_ids
,
labels
def
preprocess_supervised_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
List
[
Any
]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
!=
1
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
=
_encode_supervised_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
cutoff_len
=
data_args
.
cutoff_len
,
train_on_prompt
=
data_args
.
train_on_prompt
,
mask_history
=
data_args
.
mask_history
,
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"labels"
].
append
(
labels
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
return
model_inputs
def
preprocess_packed_supervised_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
List
[
Any
]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num
=
0
batch_input_ids
,
batch_labels
,
batch_images
,
batch_videos
=
[],
[],
[],
[]
lengths
=
[]
length2indexes
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
!=
1
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
=
_encode_supervised_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
cutoff_len
=
data_args
.
cutoff_len
-
1
,
# reserved for the padding token
train_on_prompt
=
data_args
.
train_on_prompt
,
mask_history
=
data_args
.
mask_history
,
)
length
=
len
(
input_ids
)
if
length
>
data_args
.
cutoff_len
:
logger
.
warning_rank0
(
f
"Dropped lengthy example with length
{
length
}
>
{
data_args
.
cutoff_len
}
."
)
else
:
lengths
.
append
(
length
)
length2indexes
[
length
].
append
(
valid_num
)
batch_input_ids
.
append
(
input_ids
)
batch_labels
.
append
(
labels
)
batch_images
.
append
(
examples
[
"_images"
][
i
]
or
[])
batch_videos
.
append
(
examples
[
"_videos"
][
i
]
or
[])
valid_num
+=
1
model_inputs
=
defaultdict
(
list
)
knapsacks
=
greedy_knapsack
(
lengths
,
data_args
.
cutoff_len
-
1
)
# reserved for the padding token
for
knapsack
in
knapsacks
:
packed_input_ids
,
packed_attention_masks
,
packed_labels
=
[],
[],
[]
packed_images
,
packed_videos
=
[],
[]
for
i
,
length
in
enumerate
(
knapsack
):
index
=
length2indexes
[
length
].
pop
()
packed_input_ids
+=
batch_input_ids
[
index
]
packed_labels
+=
batch_labels
[
index
]
packed_images
+=
batch_images
[
index
]
packed_videos
+=
batch_videos
[
index
]
if
data_args
.
neat_packing
:
packed_attention_masks
+=
[
i
+
1
]
*
len
(
batch_input_ids
[
index
])
# start from 1
else
:
packed_attention_masks
+=
[
1
]
*
len
(
batch_input_ids
[
index
])
if
len
(
packed_input_ids
)
<
data_args
.
cutoff_len
:
pad_length
=
data_args
.
cutoff_len
-
len
(
packed_input_ids
)
packed_input_ids
+=
[
tokenizer
.
pad_token_id
]
*
pad_length
packed_labels
+=
[
IGNORE_INDEX
]
*
pad_length
if
data_args
.
neat_packing
:
packed_attention_masks
+=
[
0
]
*
pad_length
else
:
packed_attention_masks
+=
[
1
]
*
pad_length
# more efficient flash_attn
if
len
(
packed_input_ids
)
!=
data_args
.
cutoff_len
:
raise
ValueError
(
"The length of packed example should be identical to the cutoff length."
)
model_inputs
[
"input_ids"
].
append
(
packed_input_ids
)
model_inputs
[
"attention_mask"
].
append
(
packed_attention_masks
)
model_inputs
[
"labels"
].
append
(
packed_labels
)
model_inputs
[
"images"
].
append
(
packed_images
or
None
)
model_inputs
[
"videos"
].
append
(
packed_videos
or
None
)
return
model_inputs
def
print_supervised_dataset_example
(
example
:
Dict
[
str
,
List
[
int
]],
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
valid_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"labels"
]))
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
f
"labels:
\n
{
tokenizer
.
decode
(
valid_labels
,
skip_special_tokens
=
False
)
}
"
)
src/llamafactory/data/processors/unsupervised.py
deleted
100644 → 0
View file @
37b0ad9f
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
..data_utils
import
Role
from
.processor_utils
import
infer_seqlen
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
...hparams
import
DataArguments
from
..mm_plugin
import
ImageInput
,
VideoInput
from
..template
import
Template
logger
=
logging
.
get_logger
(
__name__
)
def
_encode_unsupervised_example
(
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
cutoff_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
if
len
(
response
)
==
1
:
messages
=
prompt
+
response
else
:
messages
=
prompt
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
messages
=
template
.
mm_plugin
.
process_messages
(
messages
,
images
,
videos
,
processor
)
input_ids
,
labels
=
template
.
encode_oneturn
(
tokenizer
,
messages
,
system
,
tools
)
if
template
.
efficient_eos
:
labels
+=
[
tokenizer
.
eos_token_id
]
input_ids
,
_
=
template
.
mm_plugin
.
process_token_ids
(
input_ids
,
None
,
images
,
videos
,
tokenizer
,
processor
)
source_len
,
target_len
=
infer_seqlen
(
len
(
input_ids
),
len
(
labels
),
cutoff_len
)
input_ids
=
input_ids
[:
source_len
]
labels
=
labels
[:
target_len
]
return
input_ids
,
labels
def
preprocess_unsupervised_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
List
[
Any
]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
=
_encode_unsupervised_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
cutoff_len
=
data_args
.
cutoff_len
,
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"labels"
].
append
(
labels
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
return
model_inputs
def
print_unsupervised_dataset_example
(
example
:
Dict
[
str
,
List
[
int
]],
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
"labels:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"labels"
],
skip_special_tokens
=
False
)))
src/llamafactory/data/template.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
from
typing_extensions
import
override
from
typing_extensions
import
override
...
@@ -47,6 +47,7 @@ class Template:
...
@@ -47,6 +47,7 @@ class Template:
format_prefix
:
"Formatter"
format_prefix
:
"Formatter"
default_system
:
str
default_system
:
str
stop_words
:
List
[
str
]
stop_words
:
List
[
str
]
thought_words
:
Tuple
[
str
,
str
]
efficient_eos
:
bool
efficient_eos
:
bool
replace_eos
:
bool
replace_eos
:
bool
replace_jinja_template
:
bool
replace_jinja_template
:
bool
...
@@ -67,8 +68,8 @@ class Template:
...
@@ -67,8 +68,8 @@ class Template:
for
encoded_ids
in
encoded_messages
[:
-
1
]:
for
encoded_ids
in
encoded_messages
[:
-
1
]:
prompt_ids
+=
encoded_ids
prompt_ids
+=
encoded_ids
answer
_ids
=
encoded_messages
[
-
1
]
response
_ids
=
encoded_messages
[
-
1
]
return
prompt_ids
,
answer
_ids
return
prompt_ids
,
response
_ids
def
encode_multiturn
(
def
encode_multiturn
(
self
,
self
,
...
@@ -99,6 +100,27 @@ class Template:
...
@@ -99,6 +100,27 @@ class Template:
return
list
(
stop_token_ids
)
return
list
(
stop_token_ids
)
def
_convert_elements_to_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
elements
:
"SLOTS"
)
->
List
[
int
]:
r
"""
Converts elements to token ids.
"""
token_ids
=
[]
for
elem
in
elements
:
if
isinstance
(
elem
,
str
):
if
len
(
elem
)
!=
0
:
token_ids
+=
tokenizer
.
encode
(
elem
,
add_special_tokens
=
False
)
elif
isinstance
(
elem
,
dict
):
token_ids
+=
[
tokenizer
.
convert_tokens_to_ids
(
elem
.
get
(
"token"
))]
elif
isinstance
(
elem
,
set
):
if
"bos_token"
in
elem
and
tokenizer
.
bos_token_id
is
not
None
:
token_ids
+=
[
tokenizer
.
bos_token_id
]
elif
"eos_token"
in
elem
and
tokenizer
.
eos_token_id
is
not
None
:
token_ids
+=
[
tokenizer
.
eos_token_id
]
else
:
raise
ValueError
(
f
"Input must be string, set[str] or dict[str, str], got
{
type
(
elem
)
}
"
)
return
token_ids
def
_encode
(
def
_encode
(
self
,
self
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
...
@@ -109,7 +131,7 @@ class Template:
...
@@ -109,7 +131,7 @@ class Template:
r
"""
r
"""
Encodes formatted inputs to pairs of token ids.
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn 0: prefix + system + query resp
Turn t:
sep + query
resp
Turn t:
query
resp
"""
"""
system
=
system
or
self
.
default_system
system
=
system
or
self
.
default_system
encoded_messages
=
[]
encoded_messages
=
[]
...
@@ -137,26 +159,179 @@ class Template:
...
@@ -137,26 +159,179 @@ class Template:
return
encoded_messages
return
encoded_messages
def
_convert_elements_to_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
elements
:
"SLOTS"
)
->
List
[
int
]:
@
staticmethod
def
_add_or_replace_eos_token
(
tokenizer
:
"PreTrainedTokenizer"
,
eos_token
:
str
)
->
None
:
r
"""
r
"""
Converts elements to token ids
.
Adds or replaces eos token to the tokenizer
.
"""
"""
token_ids
=
[]
is_added
=
tokenizer
.
eos_token_id
is
None
for
elem
in
elements
:
num_added_tokens
=
tokenizer
.
add_special_tokens
({
"eos_token"
:
eos_token
})
if
isinstance
(
elem
,
str
):
if
len
(
elem
)
!=
0
:
token_ids
+=
tokenizer
.
encode
(
elem
,
add_special_tokens
=
False
)
elif
isinstance
(
elem
,
dict
):
token_ids
+=
[
tokenizer
.
convert_tokens_to_ids
(
elem
.
get
(
"token"
))]
elif
isinstance
(
elem
,
set
):
if
"bos_token"
in
elem
and
tokenizer
.
bos_token_id
is
not
None
:
token_ids
+=
[
tokenizer
.
bos_token_id
]
elif
"eos_token"
in
elem
and
tokenizer
.
eos_token_id
is
not
None
:
token_ids
+=
[
tokenizer
.
eos_token_id
]
else
:
raise
ValueError
(
f
"Input must be string, set[str] or dict[str, str], got
{
type
(
elem
)
}
"
)
return
token_ids
if
is_added
:
logger
.
info_rank0
(
f
"Add eos token:
{
tokenizer
.
eos_token
}
."
)
else
:
logger
.
info_rank0
(
f
"Replace eos token:
{
tokenizer
.
eos_token
}
."
)
if
num_added_tokens
>
0
:
logger
.
warning_rank0
(
"New tokens have been added, make sure `resize_vocab` is True."
)
def
fix_special_tokens
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
r
"""
Adds eos token and pad token to the tokenizer.
"""
stop_words
=
self
.
stop_words
if
self
.
replace_eos
:
if
not
stop_words
:
raise
ValueError
(
"Stop words are required to replace the EOS token."
)
self
.
_add_or_replace_eos_token
(
tokenizer
,
eos_token
=
stop_words
[
0
])
stop_words
=
stop_words
[
1
:]
if
tokenizer
.
eos_token_id
is
None
:
self
.
_add_or_replace_eos_token
(
tokenizer
,
eos_token
=
"<|endoftext|>"
)
if
tokenizer
.
pad_token_id
is
None
:
tokenizer
.
pad_token
=
tokenizer
.
eos_token
logger
.
info_rank0
(
f
"Add pad token:
{
tokenizer
.
pad_token
}
"
)
if
stop_words
:
num_added_tokens
=
tokenizer
.
add_special_tokens
(
dict
(
additional_special_tokens
=
stop_words
),
replace_additional_special_tokens
=
False
)
logger
.
info_rank0
(
"Add {} to stop words."
.
format
(
","
.
join
(
stop_words
)))
if
num_added_tokens
>
0
:
logger
.
warning_rank0
(
"New tokens have been added, make sure `resize_vocab` is True."
)
@
staticmethod
def
_jinja_escape
(
content
:
str
)
->
str
:
r
"""
Escape single quotes in content.
"""
return
content
.
replace
(
"'"
,
r
"\'"
)
@
staticmethod
def
_convert_slots_to_jinja
(
slots
:
"SLOTS"
,
tokenizer
:
"PreTrainedTokenizer"
,
placeholder
:
str
=
"content"
)
->
str
:
r
"""
Converts slots to jinja template.
"""
slot_items
=
[]
for
slot
in
slots
:
if
isinstance
(
slot
,
str
):
slot_pieces
=
slot
.
split
(
"{{content}}"
)
if
slot_pieces
[
0
]:
slot_items
.
append
(
"'"
+
Template
.
_jinja_escape
(
slot_pieces
[
0
])
+
"'"
)
if
len
(
slot_pieces
)
>
1
:
slot_items
.
append
(
placeholder
)
if
slot_pieces
[
1
]:
slot_items
.
append
(
"'"
+
Template
.
_jinja_escape
(
slot_pieces
[
1
])
+
"'"
)
elif
isinstance
(
slot
,
set
):
# do not use {{ eos_token }} since it may be replaced
if
"bos_token"
in
slot
and
tokenizer
.
bos_token_id
is
not
None
:
slot_items
.
append
(
"'"
+
tokenizer
.
bos_token
+
"'"
)
elif
"eos_token"
in
slot
and
tokenizer
.
eos_token_id
is
not
None
:
slot_items
.
append
(
"'"
+
tokenizer
.
eos_token
+
"'"
)
elif
isinstance
(
slot
,
dict
):
raise
ValueError
(
"Dict is not supported."
)
return
" + "
.
join
(
slot_items
)
def
_get_jinja_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
r
"""
Returns the jinja template.
"""
prefix
=
self
.
_convert_slots_to_jinja
(
self
.
format_prefix
.
apply
(),
tokenizer
)
system
=
self
.
_convert_slots_to_jinja
(
self
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
"system_message"
)
user
=
self
.
_convert_slots_to_jinja
(
self
.
format_user
.
apply
(),
tokenizer
)
assistant
=
self
.
_convert_slots_to_jinja
(
self
.
format_assistant
.
apply
(),
tokenizer
)
jinja_template
=
""
if
prefix
:
jinja_template
+=
"{{ "
+
prefix
+
" }}"
if
self
.
default_system
:
jinja_template
+=
"{% set system_message = '"
+
self
.
_jinja_escape
(
self
.
default_system
)
+
"' %}"
jinja_template
+=
(
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
"{% if system_message is defined %}{{ "
+
system
+
" }}{% endif %}"
"{% for message in loop_messages %}"
"{% set content = message['content'] %}"
"{% if message['role'] == 'user' %}"
"{{ "
+
user
+
" }}"
"{% elif message['role'] == 'assistant' %}"
"{{ "
+
assistant
+
" }}"
"{% endif %}"
"{% endfor %}"
)
return
jinja_template
def
fix_jinja_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
r
"""
Replaces the jinja template in the tokenizer.
"""
if
tokenizer
.
chat_template
is
None
or
self
.
replace_jinja_template
:
try
:
tokenizer
.
chat_template
=
self
.
_get_jinja_template
(
tokenizer
)
except
ValueError
as
e
:
logger
.
info_rank0
(
f
"Cannot add this chat template to tokenizer:
{
e
}
."
)
@
staticmethod
def
_convert_slots_to_ollama
(
slots
:
"SLOTS"
,
tokenizer
:
"PreTrainedTokenizer"
,
placeholder
:
str
=
"content"
)
->
str
:
r
"""
Converts slots to ollama template.
"""
slot_items
=
[]
for
slot
in
slots
:
if
isinstance
(
slot
,
str
):
slot_pieces
=
slot
.
split
(
"{{content}}"
)
if
slot_pieces
[
0
]:
slot_items
.
append
(
slot_pieces
[
0
])
if
len
(
slot_pieces
)
>
1
:
slot_items
.
append
(
"{{ "
+
placeholder
+
" }}"
)
if
slot_pieces
[
1
]:
slot_items
.
append
(
slot_pieces
[
1
])
elif
isinstance
(
slot
,
set
):
# do not use {{ eos_token }} since it may be replaced
if
"bos_token"
in
slot
and
tokenizer
.
bos_token_id
is
not
None
:
slot_items
.
append
(
tokenizer
.
bos_token
)
elif
"eos_token"
in
slot
and
tokenizer
.
eos_token_id
is
not
None
:
slot_items
.
append
(
tokenizer
.
eos_token
)
elif
isinstance
(
slot
,
dict
):
raise
ValueError
(
"Dict is not supported."
)
return
""
.
join
(
slot_items
)
def
_get_ollama_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
r
"""
Returns the ollama template.
"""
prefix
=
self
.
_convert_slots_to_ollama
(
self
.
format_prefix
.
apply
(),
tokenizer
)
system
=
self
.
_convert_slots_to_ollama
(
self
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
".System"
)
user
=
self
.
_convert_slots_to_ollama
(
self
.
format_user
.
apply
(),
tokenizer
,
placeholder
=
".Content"
)
assistant
=
self
.
_convert_slots_to_ollama
(
self
.
format_assistant
.
apply
(),
tokenizer
,
placeholder
=
".Content"
)
return
(
f
"
{
prefix
}
{{{{ if .System }}}}
{
system
}
{{{{ end }}}}"
f
"""{{{{ range .Messages }}}}{{{{ if eq .Role "user" }}}}
{
user
}
"""
f
"""{{{{ else if eq .Role "assistant" }}}}
{
assistant
}
{{{{ end }}}}{{{{ end }}}}"""
)
def
get_ollama_modelfile
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
r
"""
Returns the ollama modelfile.
TODO: support function calling.
"""
modelfile
=
"# ollama modelfile auto-generated by llamafactory
\n\n
"
modelfile
+=
f
'FROM .
\n\n
TEMPLATE """
{
self
.
_get_ollama_template
(
tokenizer
)
}
"""
\n\n
'
if
self
.
default_system
:
modelfile
+=
f
'SYSTEM """
{
self
.
default_system
}
"""
\n\n
'
for
stop_token_id
in
self
.
get_stop_token_ids
(
tokenizer
):
modelfile
+=
f
'PARAMETER stop "
{
tokenizer
.
convert_ids_to_tokens
(
stop_token_id
)
}
"
\n
'
modelfile
+=
"PARAMETER num_ctx 4096
\n
"
return
modelfile
@
dataclass
@
dataclass
...
@@ -169,11 +344,6 @@ class Llama2Template(Template):
...
@@ -169,11 +344,6 @@ class Llama2Template(Template):
system
:
str
,
system
:
str
,
tools
:
str
,
tools
:
str
,
)
->
List
[
List
[
int
]]:
)
->
List
[
List
[
int
]]:
r
"""
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: sep + query resp
"""
system
=
system
or
self
.
default_system
system
=
system
or
self
.
default_system
encoded_messages
=
[]
encoded_messages
=
[]
for
i
,
message
in
enumerate
(
messages
):
for
i
,
message
in
enumerate
(
messages
):
...
@@ -201,11 +371,41 @@ class Llama2Template(Template):
...
@@ -201,11 +371,41 @@ class Llama2Template(Template):
return
encoded_messages
return
encoded_messages
def
_get_jinja_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
prefix
=
self
.
_convert_slots_to_jinja
(
self
.
format_prefix
.
apply
(),
tokenizer
)
system_message
=
self
.
_convert_slots_to_jinja
(
self
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
"system_message"
)
user_message
=
self
.
_convert_slots_to_jinja
(
self
.
format_user
.
apply
(),
tokenizer
)
assistant_message
=
self
.
_convert_slots_to_jinja
(
self
.
format_assistant
.
apply
(),
tokenizer
)
jinja_template
=
""
if
prefix
:
jinja_template
+=
"{{ "
+
prefix
+
" }}"
if
self
.
default_system
:
jinja_template
+=
"{% set system_message = '"
+
self
.
_jinja_escape
(
self
.
default_system
)
+
"' %}"
jinja_template
+=
(
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
"{% for message in loop_messages %}"
"{% if loop.index0 == 0 and system_message is defined %}"
"{% set content = "
+
system_message
+
" + message['content'] %}"
"{% else %}{% set content = message['content'] %}{% endif %}"
"{% if message['role'] == 'user' %}"
"{{ "
+
user_message
+
" }}"
"{% elif message['role'] == 'assistant' %}"
"{{ "
+
assistant_message
+
" }}"
"{% endif %}"
"{% endfor %}"
)
return
jinja_template
TEMPLATES
:
Dict
[
str
,
"Template"
]
=
{}
TEMPLATES
:
Dict
[
str
,
"Template"
]
=
{}
def
_
register_template
(
def
register_template
(
name
:
str
,
name
:
str
,
format_user
:
Optional
[
"Formatter"
]
=
None
,
format_user
:
Optional
[
"Formatter"
]
=
None
,
format_assistant
:
Optional
[
"Formatter"
]
=
None
,
format_assistant
:
Optional
[
"Formatter"
]
=
None
,
...
@@ -216,10 +416,12 @@ def _register_template(
...
@@ -216,10 +416,12 @@ def _register_template(
format_prefix
:
Optional
[
"Formatter"
]
=
None
,
format_prefix
:
Optional
[
"Formatter"
]
=
None
,
default_system
:
str
=
""
,
default_system
:
str
=
""
,
stop_words
:
Optional
[
Sequence
[
str
]]
=
None
,
stop_words
:
Optional
[
Sequence
[
str
]]
=
None
,
thought_words
:
Optional
[
Tuple
[
str
,
str
]]
=
None
,
efficient_eos
:
bool
=
False
,
efficient_eos
:
bool
=
False
,
replace_eos
:
bool
=
False
,
replace_eos
:
bool
=
False
,
replace_jinja_template
:
bool
=
False
,
replace_jinja_template
:
bool
=
False
,
mm_plugin
:
"BasePlugin"
=
get_mm_plugin
(
name
=
"base"
),
mm_plugin
:
"BasePlugin"
=
get_mm_plugin
(
name
=
"base"
),
template_class
:
Type
[
"Template"
]
=
Template
,
)
->
None
:
)
->
None
:
r
"""
r
"""
Registers a chat template.
Registers a chat template.
...
@@ -234,7 +436,7 @@ def _register_template(
...
@@ -234,7 +436,7 @@ def _register_template(
The corresponding code should be:
The corresponding code should be:
```
```
_
register_template(
register_template(
name="custom",
name="custom",
format_user=StringFormatter(slots=["<user>{{content}}\n<model>"]),
format_user=StringFormatter(slots=["<user>{{content}}\n<model>"]),
format_assistant=StringFormatter(slots=["{{content}}</s>\n"]),
format_assistant=StringFormatter(slots=["{{content}}</s>\n"]),
...
@@ -242,7 +444,9 @@ def _register_template(
...
@@ -242,7 +444,9 @@ def _register_template(
)
)
```
```
"""
"""
template_class
=
Llama2Template
if
any
(
k
in
name
for
k
in
(
"llama2"
,
"mistral"
,
"pixtral"
))
else
Template
if
name
in
TEMPLATES
:
raise
ValueError
(
f
"Template
{
name
}
already exists."
)
default_slots
=
[
"{{content}}"
]
if
efficient_eos
else
[
"{{content}}"
,
{
"eos_token"
}]
default_slots
=
[
"{{content}}"
]
if
efficient_eos
else
[
"{{content}}"
,
{
"eos_token"
}]
default_user_formatter
=
StringFormatter
(
slots
=
[
"{{content}}"
])
default_user_formatter
=
StringFormatter
(
slots
=
[
"{{content}}"
])
default_assistant_formatter
=
StringFormatter
(
slots
=
default_slots
)
default_assistant_formatter
=
StringFormatter
(
slots
=
default_slots
)
...
@@ -259,6 +463,7 @@ def _register_template(
...
@@ -259,6 +463,7 @@ def _register_template(
format_prefix
=
format_prefix
or
default_prefix_formatter
,
format_prefix
=
format_prefix
or
default_prefix_formatter
,
default_system
=
default_system
,
default_system
=
default_system
,
stop_words
=
stop_words
or
[],
stop_words
=
stop_words
or
[],
thought_words
=
thought_words
or
(
"<think>"
,
"</think>"
),
efficient_eos
=
efficient_eos
,
efficient_eos
=
efficient_eos
,
replace_eos
=
replace_eos
,
replace_eos
=
replace_eos
,
replace_jinja_template
=
replace_jinja_template
,
replace_jinja_template
=
replace_jinja_template
,
...
@@ -266,97 +471,83 @@ def _register_template(
...
@@ -266,97 +471,83 @@ def _register_template(
)
)
def
_add_or_replace_eos_token
(
tokenizer
:
"PreTrainedTokenizer"
,
eos_token
:
str
)
->
None
:
def
parse_template
(
tokenizer
:
"PreTrainedTokenizer"
)
->
"Template"
:
is_added
=
tokenizer
.
eos_token_id
is
None
num_added_tokens
=
tokenizer
.
add_special_tokens
({
"eos_token"
:
eos_token
})
if
is_added
:
logger
.
info_rank0
(
f
"Add eos token:
{
tokenizer
.
eos_token
}
"
)
else
:
logger
.
info_rank0
(
f
"Replace eos token:
{
tokenizer
.
eos_token
}
"
)
if
num_added_tokens
>
0
:
logger
.
warning_rank0
(
"New tokens have been added, make sure `resize_vocab` is True."
)
def
_jinja_escape
(
content
:
str
)
->
str
:
return
content
.
replace
(
"'"
,
r
"\'"
)
def
_convert_slots_to_jinja
(
slots
:
"SLOTS"
,
tokenizer
:
"PreTrainedTokenizer"
,
placeholder
:
str
=
"content"
)
->
str
:
slot_items
=
[]
for
slot
in
slots
:
if
isinstance
(
slot
,
str
):
slot_pieces
=
slot
.
split
(
"{{content}}"
)
if
slot_pieces
[
0
]:
slot_items
.
append
(
"'"
+
_jinja_escape
(
slot_pieces
[
0
])
+
"'"
)
if
len
(
slot_pieces
)
>
1
:
slot_items
.
append
(
placeholder
)
if
slot_pieces
[
1
]:
slot_items
.
append
(
"'"
+
_jinja_escape
(
slot_pieces
[
1
])
+
"'"
)
elif
isinstance
(
slot
,
set
):
# do not use {{ eos_token }} since it may be replaced
if
"bos_token"
in
slot
and
tokenizer
.
bos_token_id
is
not
None
:
slot_items
.
append
(
"'"
+
tokenizer
.
bos_token
+
"'"
)
elif
"eos_token"
in
slot
and
tokenizer
.
eos_token_id
is
not
None
:
slot_items
.
append
(
"'"
+
tokenizer
.
eos_token
+
"'"
)
elif
isinstance
(
slot
,
dict
):
raise
ValueError
(
"Dict is not supported."
)
return
" + "
.
join
(
slot_items
)
def
_get_jinja_template
(
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
r
"""
r
"""
Returns the jinja template
.
Extracts a chat template from the tokenizer
.
"""
"""
jinja_template
=
""
prefix
=
_convert_slots_to_jinja
(
template
.
format_prefix
.
apply
(),
tokenizer
)
if
prefix
:
jinja_template
+=
"{{ "
+
prefix
+
" }}"
if
template
.
default_system
:
def
find_diff
(
short_str
:
str
,
long_str
:
str
)
->
str
:
jinja_template
+=
"{% set system_message = '"
+
_jinja_escape
(
template
.
default_system
)
+
"' %}"
i
,
j
=
0
,
0
diff
=
""
jinja_template
+=
(
while
i
<
len
(
short_str
)
and
j
<
len
(
long_str
):
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
if
short_str
[
i
]
==
long_str
[
j
]:
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
i
+=
1
j
+=
1
else
:
diff
+=
long_str
[
j
]
j
+=
1
return
diff
prefix
=
tokenizer
.
decode
(
tokenizer
.
encode
(
""
))
messages
=
[{
"role"
:
"system"
,
"content"
:
"{{content}}"
}]
system_slot
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
False
,
tokenize
=
False
)[
len
(
prefix
)
:]
messages
=
[{
"role"
:
"system"
,
"content"
:
""
},
{
"role"
:
"user"
,
"content"
:
"{{content}}"
}]
user_slot_empty_system
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
)
user_slot_empty_system
=
user_slot_empty_system
[
len
(
prefix
)
:]
messages
=
[{
"role"
:
"user"
,
"content"
:
"{{content}}"
}]
user_slot
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
)
user_slot
=
user_slot
[
len
(
prefix
)
:]
messages
=
[{
"role"
:
"user"
,
"content"
:
"{{content}}"
},
{
"role"
:
"assistant"
,
"content"
:
"{{content}}"
}]
assistant_slot
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
False
,
tokenize
=
False
)
assistant_slot
=
assistant_slot
[
len
(
prefix
)
+
len
(
user_slot
)
:]
if
len
(
user_slot
)
>
len
(
user_slot_empty_system
):
default_system
=
find_diff
(
user_slot_empty_system
,
user_slot
)
sole_system
=
system_slot
.
replace
(
"{{content}}"
,
default_system
,
1
)
user_slot
=
user_slot
[
len
(
sole_system
)
:]
else
:
# if defaut_system is empty, user_slot_empty_system will be longer than user_slot
default_system
=
""
return
Template
(
format_user
=
StringFormatter
(
slots
=
[
user_slot
]),
format_assistant
=
StringFormatter
(
slots
=
[
assistant_slot
]),
format_system
=
StringFormatter
(
slots
=
[
system_slot
]),
format_function
=
FunctionFormatter
(
slots
=
[
assistant_slot
],
tool_format
=
"default"
),
format_observation
=
StringFormatter
(
slots
=
[
user_slot
]),
format_tools
=
ToolFormatter
(
tool_format
=
"default"
),
format_prefix
=
EmptyFormatter
(
slots
=
[
prefix
])
if
prefix
else
EmptyFormatter
(),
default_system
=
default_system
,
stop_words
=
[],
thought_words
=
(
"<think>"
,
"</think>"
),
efficient_eos
=
False
,
replace_eos
=
False
,
replace_jinja_template
=
False
,
mm_plugin
=
get_mm_plugin
(
name
=
"base"
),
)
)
system_message
=
_convert_slots_to_jinja
(
template
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
"system_message"
)
if
not
isinstance
(
template
,
Llama2Template
):
jinja_template
+=
"{% if system_message is defined %}{{ "
+
system_message
+
" }}{% endif %}"
jinja_template
+=
"{% for message in loop_messages %}"
jinja_template
+=
"{% set content = message['content'] %}"
if
isinstance
(
template
,
Llama2Template
):
jinja_template
+=
"{% if loop.index0 == 0 and system_message is defined %}"
jinja_template
+=
"{% set content = "
+
system_message
+
" + message['content'] %}"
jinja_template
+=
"{% endif %}"
jinja_template
+=
"{% if message['role'] == 'user' %}"
user_message
=
_convert_slots_to_jinja
(
template
.
format_user
.
apply
(),
tokenizer
)
jinja_template
+=
"{{ "
+
user_message
+
" }}"
jinja_template
+=
"{% elif message['role'] == 'assistant' %}"
assistant_message
=
_convert_slots_to_jinja
(
template
.
format_assistant
.
apply
(),
tokenizer
)
jinja_template
+=
"{{ "
+
assistant_message
+
" }}"
jinja_template
+=
"{% endif %}"
jinja_template
+=
"{% endfor %}"
return
jinja_template
def
get_template_and_fix_tokenizer
(
tokenizer
:
"PreTrainedTokenizer"
,
data_args
:
"DataArguments"
)
->
"Template"
:
def
get_template_and_fix_tokenizer
(
tokenizer
:
"PreTrainedTokenizer"
,
data_args
:
"DataArguments"
)
->
"Template"
:
r
"""
r
"""
Gets chat template and fixes the tokenizer.
Gets chat template and fixes the tokenizer.
"""
"""
if
data_args
.
template
is
None
:
if
data_args
.
template
is
None
:
template
=
TEMPLATES
[
"empty"
]
# placeholder
if
isinstance
(
tokenizer
.
chat_template
,
str
):
logger
.
warning_rank0
(
"`template` was not specified, try parsing the chat template from the tokenizer."
)
template
=
parse_template
(
tokenizer
)
else
:
logger
.
warning_rank0
(
"`template` was not specified, use `empty` template."
)
template
=
TEMPLATES
[
"empty"
]
# placeholder
else
:
else
:
template
=
TEMPLATES
.
get
(
data_args
.
template
,
None
)
if
data_args
.
template
not
in
TEMPLATES
:
if
template
is
None
:
raise
ValueError
(
f
"Template
{
data_args
.
template
}
does not exist."
)
raise
ValueError
(
f
"Template
{
data_args
.
template
}
does not exist."
)
template
=
TEMPLATES
[
data_args
.
template
]
if
template
.
mm_plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
if
template
.
mm_plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
check_version
(
"transformers>=4.45.0"
)
check_version
(
"transformers>=4.45.0"
)
...
@@ -369,39 +560,12 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
...
@@ -369,39 +560,12 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template
.
format_function
=
FunctionFormatter
(
slots
=
default_slots
,
tool_format
=
data_args
.
tool_format
)
template
.
format_function
=
FunctionFormatter
(
slots
=
default_slots
,
tool_format
=
data_args
.
tool_format
)
template
.
format_tools
=
ToolFormatter
(
tool_format
=
data_args
.
tool_format
)
template
.
format_tools
=
ToolFormatter
(
tool_format
=
data_args
.
tool_format
)
stop_words
=
template
.
stop_words
template
.
fix_special_tokens
(
tokenizer
)
if
template
.
replace_eos
:
template
.
fix_jinja_template
(
tokenizer
)
if
not
stop_words
:
raise
ValueError
(
"Stop words are required to replace the EOS token."
)
_add_or_replace_eos_token
(
tokenizer
,
eos_token
=
stop_words
[
0
])
stop_words
=
stop_words
[
1
:]
if
tokenizer
.
eos_token_id
is
None
:
_add_or_replace_eos_token
(
tokenizer
,
eos_token
=
"<|endoftext|>"
)
if
tokenizer
.
pad_token_id
is
None
:
tokenizer
.
pad_token
=
tokenizer
.
eos_token
logger
.
info_rank0
(
f
"Add pad token:
{
tokenizer
.
pad_token
}
"
)
if
stop_words
:
num_added_tokens
=
tokenizer
.
add_special_tokens
(
dict
(
additional_special_tokens
=
stop_words
),
replace_additional_special_tokens
=
False
)
logger
.
info_rank0
(
"Add {} to stop words."
.
format
(
","
.
join
(
stop_words
)))
if
num_added_tokens
>
0
:
logger
.
warning_rank0
(
"New tokens have been added, make sure `resize_vocab` is True."
)
if
tokenizer
.
chat_template
is
None
or
template
.
replace_jinja_template
:
try
:
tokenizer
.
chat_template
=
_get_jinja_template
(
template
,
tokenizer
)
except
ValueError
as
e
:
logger
.
info_rank0
(
f
"Cannot add this chat template to tokenizer:
{
e
}
."
)
return
template
return
template
_
register_template
(
register_template
(
name
=
"alpaca"
,
name
=
"alpaca"
,
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n\n
### Response:
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n\n
### Response:
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n\n
"
]),
...
@@ -412,7 +576,7 @@ _register_template(
...
@@ -412,7 +576,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"aquila"
,
name
=
"aquila"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}###Assistant:"
]),
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}###Assistant:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}###"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}###"
]),
...
@@ -425,7 +589,7 @@ _register_template(
...
@@ -425,7 +589,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"atom"
,
name
=
"atom"
,
format_user
=
StringFormatter
(
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"Human: {{content}}
\n
"
,
{
"eos_token"
},
{
"bos_token"
},
"Assistant:"
]
slots
=
[{
"bos_token"
},
"Human: {{content}}
\n
"
,
{
"eos_token"
},
{
"bos_token"
},
"Assistant:"
]
...
@@ -434,21 +598,31 @@ _register_template(
...
@@ -434,21 +598,31 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"baichuan"
,
name
=
"baichuan"
,
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"<reserved_102>"
},
"{{content}}"
,
{
"token"
:
"<reserved_103>"
}]),
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"<reserved_102>"
},
"{{content}}"
,
{
"token"
:
"<reserved_103>"
}]),
efficient_eos
=
True
,
efficient_eos
=
True
,
)
)
_
register_template
(
register_template
(
name
=
"baichuan2"
,
name
=
"baichuan2"
,
format_user
=
StringFormatter
(
slots
=
[
"<reserved_106>{{content}}<reserved_107>"
]),
format_user
=
StringFormatter
(
slots
=
[
"<reserved_106>{{content}}<reserved_107>"
]),
efficient_eos
=
True
,
efficient_eos
=
True
,
)
)
_register_template
(
register_template
(
name
=
"bailing"
,
format_user
=
StringFormatter
(
slots
=
[
"<role>HUMAN</role>{{content}}<role>ASSISTANT</role>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<role>SYSTEM</role>{{content}}"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<role>OBSERVATION</role>{{content}}<role>ASSISTANT</role>"
]),
stop_words
=
[
"<|endoftext|>"
],
efficient_eos
=
True
,
)
register_template
(
name
=
"belle"
,
name
=
"belle"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n\n
Belle: "
]),
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n\n
Belle: "
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n\n
"
]),
...
@@ -456,13 +630,13 @@ _register_template(
...
@@ -456,13 +630,13 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"bluelm"
,
name
=
"bluelm"
,
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"[|Human|]:"
},
"{{content}}"
,
{
"token"
:
"[|AI|]:"
}]),
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"[|Human|]:"
},
"{{content}}"
,
{
"token"
:
"[|AI|]:"
}]),
)
)
_
register_template
(
register_template
(
name
=
"breeze"
,
name
=
"breeze"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}} [/INST] "
]),
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}} [/INST] "
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
...
@@ -470,7 +644,7 @@ _register_template(
...
@@ -470,7 +644,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"chatglm2"
,
name
=
"chatglm2"
,
format_user
=
StringFormatter
(
slots
=
[
"[Round {{idx}}]
\n\n
问:{{content}}
\n\n
答:"
]),
format_user
=
StringFormatter
(
slots
=
[
"[Round {{idx}}]
\n\n
问:{{content}}
\n\n
答:"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"token"
:
"[gMASK]"
},
{
"token"
:
"sop"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"token"
:
"[gMASK]"
},
{
"token"
:
"sop"
}]),
...
@@ -478,7 +652,7 @@ _register_template(
...
@@ -478,7 +652,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"chatglm3"
,
name
=
"chatglm3"
,
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"<|user|>"
},
"
\n
"
,
"{{content}}"
,
{
"token"
:
"<|assistant|>"
}]),
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"<|user|>"
},
"
\n
"
,
"{{content}}"
,
{
"token"
:
"<|assistant|>"
}]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
"
,
"{{content}}"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
"
,
"{{content}}"
]),
...
@@ -494,7 +668,7 @@ _register_template(
...
@@ -494,7 +668,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"chatml"
,
name
=
"chatml"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
@@ -507,7 +681,7 @@ _register_template(
...
@@ -507,7 +681,7 @@ _register_template(
# copied from chatml template
# copied from chatml template
_
register_template
(
register_template
(
name
=
"chatml_de"
,
name
=
"chatml_de"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
@@ -520,13 +694,13 @@ _register_template(
...
@@ -520,13 +694,13 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"codegeex2"
,
name
=
"codegeex2"
,
format_prefix
=
EmptyFormatter
(
slots
=
[{
"token"
:
"[gMASK]"
},
{
"token"
:
"sop"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"token"
:
"[gMASK]"
},
{
"token"
:
"sop"
}]),
)
)
_
register_template
(
register_template
(
name
=
"codegeex4"
,
name
=
"codegeex4"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
]),
...
@@ -543,7 +717,7 @@ _register_template(
...
@@ -543,7 +717,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"cohere"
,
name
=
"cohere"
,
format_user
=
StringFormatter
(
format_user
=
StringFormatter
(
slots
=
[
slots
=
[
...
@@ -558,7 +732,7 @@ _register_template(
...
@@ -558,7 +732,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"cpm"
,
name
=
"cpm"
,
format_user
=
StringFormatter
(
slots
=
[
"<用户>{{content}}<AI>"
]),
format_user
=
StringFormatter
(
slots
=
[
"<用户>{{content}}<AI>"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
...
@@ -566,7 +740,7 @@ _register_template(
...
@@ -566,7 +740,7 @@ _register_template(
# copied from chatml template
# copied from chatml template
_
register_template
(
register_template
(
name
=
"cpm3"
,
name
=
"cpm3"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
@@ -577,7 +751,7 @@ _register_template(
...
@@ -577,7 +751,7 @@ _register_template(
# copied from chatml template
# copied from chatml template
_
register_template
(
register_template
(
name
=
"dbrx"
,
name
=
"dbrx"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
@@ -602,7 +776,7 @@ _register_template(
...
@@ -602,7 +776,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"deepseek"
,
name
=
"deepseek"
,
format_user
=
StringFormatter
(
slots
=
[
"User: {{content}}
\n\n
Assistant:"
]),
format_user
=
StringFormatter
(
slots
=
[
"User: {{content}}
\n\n
Assistant:"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
...
@@ -610,14 +784,14 @@ _register_template(
...
@@ -610,14 +784,14 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"deepseek3"
,
name
=
"deepseek3"
,
format_user
=
StringFormatter
(
slots
=
[
"<|User|>{{content}}<|Assistant|>"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|User|>{{content}}<|Assistant|>"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
)
_
register_template
(
register_template
(
name
=
"deepseekcoder"
,
name
=
"deepseekcoder"
,
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n
### Response:"
]),
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n
### Response:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}
\n
<|EOT|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}
\n
<|EOT|>
\n
"
]),
...
@@ -631,7 +805,7 @@ _register_template(
...
@@ -631,7 +805,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"default"
,
name
=
"default"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n
Assistant:"
]),
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n
Assistant:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n
"
]),
...
@@ -639,13 +813,13 @@ _register_template(
...
@@ -639,13 +813,13 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"empty"
,
name
=
"empty"
,
efficient_eos
=
True
,
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
])
,
)
)
_
register_template
(
register_template
(
name
=
"exaone"
,
name
=
"exaone"
,
format_user
=
StringFormatter
(
slots
=
[
"[|user|]{{content}}
\n
[|assistant|]"
]),
format_user
=
StringFormatter
(
slots
=
[
"[|user|]{{content}}
\n
[|assistant|]"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n
"
]),
...
@@ -653,7 +827,7 @@ _register_template(
...
@@ -653,7 +827,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"falcon"
,
name
=
"falcon"
,
format_user
=
StringFormatter
(
slots
=
[
"User: {{content}}
\n
Falcon:"
]),
format_user
=
StringFormatter
(
slots
=
[
"User: {{content}}
\n
Falcon:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
...
@@ -661,14 +835,14 @@ _register_template(
...
@@ -661,14 +835,14 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"fewshot"
,
name
=
"fewshot"
,
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
efficient_eos
=
True
,
efficient_eos
=
True
,
)
)
_
register_template
(
register_template
(
name
=
"gemma"
,
name
=
"gemma"
,
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_turn>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_turn>
\n
"
]),
...
@@ -679,7 +853,7 @@ _register_template(
...
@@ -679,7 +853,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"glm4"
,
name
=
"glm4"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}"
]),
...
@@ -693,7 +867,7 @@ _register_template(
...
@@ -693,7 +867,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"granite3"
,
name
=
"granite3"
,
format_user
=
StringFormatter
(
format_user
=
StringFormatter
(
slots
=
[
slots
=
[
...
@@ -705,7 +879,7 @@ _register_template(
...
@@ -705,7 +879,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"index"
,
name
=
"index"
,
format_user
=
StringFormatter
(
slots
=
[
"reserved_0{{content}}reserved_1"
]),
format_user
=
StringFormatter
(
slots
=
[
"reserved_0{{content}}reserved_1"
]),
format_system
=
StringFormatter
(
slots
=
[
"<unk>{{content}}"
]),
format_system
=
StringFormatter
(
slots
=
[
"<unk>{{content}}"
]),
...
@@ -713,54 +887,59 @@ _register_template(
...
@@ -713,54 +887,59 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"intern"
,
name
=
"intern"
,
format_user
=
StringFormatter
(
slots
=
[
"<|User|>:{{content}}
\n
<|Bot|>:"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|User|>:{{content}}
\n
<|Bot|>:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<eoa>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<eoa>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|System|>:{{content}}
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|System|>:{{content}}
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
default_system
=
(
"You are an AI assistant whose name is InternLM (书生·浦语).
\n
"
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
"(上海人工智能实验室). It is designed to be helpful, honest, and harmless.
\n
"
"- InternLM (书生·浦语) can understand and communicate fluently in the language "
"chosen by the user such as English and 中文."
),
stop_words
=
[
"<eoa>"
],
stop_words
=
[
"<eoa>"
],
)
)
_
register_template
(
register_template
(
name
=
"intern2"
,
name
=
"intern2"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
default_system
=
(
"You are an AI assistant whose name is InternLM (书生·浦语).
\n
"
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
"(上海人工智能实验室). It is designed to be helpful, honest, and harmless.
\n
"
"- InternLM (书生·浦语) can understand and communicate fluently in the language "
"chosen by the user such as English and 中文."
),
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
)
)
# copied from intern2 template
register_template
(
_register_template
(
name
=
"intern3"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|im_end|>"
],
)
_register_template
(
name
=
"llama2"
,
name
=
"llama2"
,
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"[INST] {{content}} [/INST]"
]),
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"[INST] {{content}} [/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"<<SYS>>
\n
{{content}}
\n
<</SYS>>
\n\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<<SYS>>
\n
{{content}}
\n
<</SYS>>
\n\n
"
]),
template_class
=
Llama2Template
,
)
)
# copied from llama2 template
# copied from llama2 template
_
register_template
(
register_template
(
name
=
"llama2_zh"
,
name
=
"llama2_zh"
,
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"[INST] {{content}} [/INST]"
]),
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"[INST] {{content}} [/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"<<SYS>>
\n
{{content}}
\n
<</SYS>>
\n\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<<SYS>>
\n
{{content}}
\n
<</SYS>>
\n\n
"
]),
default_system
=
"You are a helpful assistant. 你是一个乐于助人的助手。"
,
default_system
=
"You are a helpful assistant. 你是一个乐于助人的助手。"
,
template_class
=
Llama2Template
,
)
)
_
register_template
(
register_template
(
name
=
"llama3"
,
name
=
"llama3"
,
format_user
=
StringFormatter
(
format_user
=
StringFormatter
(
slots
=
[
slots
=
[
...
@@ -788,7 +967,7 @@ _register_template(
...
@@ -788,7 +967,7 @@ _register_template(
# copied from llama3 template
# copied from llama3 template
_
register_template
(
register_template
(
name
=
"mllama"
,
name
=
"mllama"
,
format_user
=
StringFormatter
(
format_user
=
StringFormatter
(
slots
=
[
slots
=
[
...
@@ -816,8 +995,20 @@ _register_template(
...
@@ -816,8 +995,20 @@ _register_template(
)
)
register_template
(
name
=
"moonlight"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_user|>user<|im_middle|>{{content}}<|im_end|><|im_assistant|>assistant<|im_middle|>"
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_system|>system<|im_middle|>{{content}}<|im_end|>"
]),
default_system
=
"You are a helpful assistant provided by Moonshot-AI."
,
stop_words
=
[
"<|im_end|>"
],
)
# copied from vicuna template
# copied from vicuna template
_
register_template
(
register_template
(
name
=
"llava"
,
name
=
"llava"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
default_system
=
(
default_system
=
(
...
@@ -829,7 +1020,7 @@ _register_template(
...
@@ -829,7 +1020,7 @@ _register_template(
# copied from vicuna template
# copied from vicuna template
_
register_template
(
register_template
(
name
=
"llava_next"
,
name
=
"llava_next"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
default_system
=
(
default_system
=
(
...
@@ -841,7 +1032,7 @@ _register_template(
...
@@ -841,7 +1032,7 @@ _register_template(
# copied from llama3 template
# copied from llama3 template
_
register_template
(
register_template
(
name
=
"llava_next_llama3"
,
name
=
"llava_next_llama3"
,
format_user
=
StringFormatter
(
format_user
=
StringFormatter
(
slots
=
[
slots
=
[
...
@@ -870,21 +1061,22 @@ _register_template(
...
@@ -870,21 +1061,22 @@ _register_template(
# copied from mistral template
# copied from mistral template
_
register_template
(
register_template
(
name
=
"llava_next_mistral"
,
name
=
"llava_next_mistral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}}[/INST]"
]),
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}}[/INST]"
]),
format_assistant
=
StringFormatter
(
slots
=
[
" {{content}}"
,
{
"eos_token"
}]),
format_assistant
=
StringFormatter
(
slots
=
[
" {{content}}"
,
{
"eos_token"
}]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS]
"
,
"
{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] {{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
template_class
=
Llama2Template
,
)
)
# copied from
chatml
template
# copied from
qwen
template
_
register_template
(
register_template
(
name
=
"llava_next_qwen"
,
name
=
"llava_next_qwen"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
@@ -901,7 +1093,7 @@ _register_template(
...
@@ -901,7 +1093,7 @@ _register_template(
# copied from chatml template
# copied from chatml template
_
register_template
(
register_template
(
name
=
"llava_next_yi"
,
name
=
"llava_next_yi"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
@@ -912,7 +1104,7 @@ _register_template(
...
@@ -912,7 +1104,7 @@ _register_template(
# copied from vicuna template
# copied from vicuna template
_
register_template
(
register_template
(
name
=
"llava_next_video"
,
name
=
"llava_next_video"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
default_system
=
(
default_system
=
(
...
@@ -924,21 +1116,22 @@ _register_template(
...
@@ -924,21 +1116,22 @@ _register_template(
# copied from mistral template
# copied from mistral template
_
register_template
(
register_template
(
name
=
"llava_next_video_mistral"
,
name
=
"llava_next_video_mistral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}}[/INST]"
]),
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}}[/INST]"
]),
format_assistant
=
StringFormatter
(
slots
=
[
" {{content}}"
,
{
"eos_token"
}]),
format_assistant
=
StringFormatter
(
slots
=
[
" {{content}}"
,
{
"eos_token"
}]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS]
"
,
"
{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] {{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next_video"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next_video"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
),
template_class
=
Llama2Template
,
)
)
# copied from chatml template
# copied from chatml template
_
register_template
(
register_template
(
name
=
"llava_next_video_yi"
,
name
=
"llava_next_video_yi"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
@@ -949,7 +1142,7 @@ _register_template(
...
@@ -949,7 +1142,7 @@ _register_template(
# copied from chatml template
# copied from chatml template
_
register_template
(
register_template
(
name
=
"marco"
,
name
=
"marco"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
@@ -965,43 +1158,83 @@ _register_template(
...
@@ -965,43 +1158,83 @@ _register_template(
# copied from chatml template
# copied from chatml template
_
register_template
(
register_template
(
name
=
"minicpm_v"
,
name
=
"minicpm_v"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
stop_words
=
[
"<|im_end|>"
],
default_system
=
"You are a helpful assistant."
,
mm_plugin
=
get_mm_plugin
(
name
=
"minicpm_v"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
),
mm_plugin
=
get_mm_plugin
(
name
=
"minicpm_v"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
),
)
)
_register_template
(
# copied from minicpm_v template
register_template
(
name
=
"minicpm_o"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
default_system
=
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
,
mm_plugin
=
get_mm_plugin
(
name
=
"minicpm_v"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
,
audio_token
=
"<audio>"
),
)
# mistral tokenizer v3 tekken
register_template
(
name
=
"ministral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST]{{content}}[/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS]{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
template_class
=
Llama2Template
,
)
# mistral tokenizer v3
register_template
(
name
=
"mistral"
,
name
=
"mistral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}}[/INST]"
]),
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}}[/INST]"
]),
format_assistant
=
StringFormatter
(
slots
=
[
" {{content}}"
,
{
"eos_token"
}]),
format_assistant
=
StringFormatter
(
slots
=
[
" {{content}}"
,
{
"eos_token"
}]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS]
"
,
"
{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] {{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
template_class
=
Llama2Template
,
)
# mistral tokenizer v7 tekken (copied from ministral)
register_template
(
name
=
"mistral_small"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST]{{content}}[/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"[SYSTEM_PROMPT]{{content}}[/SYSTEM_PROMPT]"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS]{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
)
_
register_template
(
register_template
(
name
=
"olmo"
,
name
=
"olmo"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"eos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"eos_token"
}]),
)
)
_
register_template
(
register_template
(
name
=
"openchat"
,
name
=
"openchat"
,
format_user
=
StringFormatter
(
slots
=
[
"GPT4 Correct User: {{content}}"
,
{
"eos_token"
},
"GPT4 Correct Assistant:"
]),
format_user
=
StringFormatter
(
slots
=
[
"GPT4 Correct User: {{content}}"
,
{
"eos_token"
},
"GPT4 Correct Assistant:"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
)
_
register_template
(
register_template
(
name
=
"openchat-3.6"
,
name
=
"openchat-3.6"
,
format_user
=
StringFormatter
(
format_user
=
StringFormatter
(
slots
=
[
slots
=
[
...
@@ -1017,7 +1250,7 @@ _register_template(
...
@@ -1017,7 +1250,7 @@ _register_template(
# copied from chatml template
# copied from chatml template
_
register_template
(
register_template
(
name
=
"opencoder"
,
name
=
"opencoder"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
@@ -1028,16 +1261,24 @@ _register_template(
...
@@ -1028,16 +1261,24 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"orion"
,
name
=
"orion"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n\n
Assistant: "
,
{
"eos_token"
}]),
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n\n
Assistant: "
,
{
"eos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
)
# copied from gemma template
register_template
(
_register_template
(
name
=
"paligemma"
,
name
=
"paligemma"
,
format_user
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
),
)
# copied from gemma template
register_template
(
name
=
"paligemma_chat"
,
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_turn>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_turn>
\n
"
]),
format_observation
=
StringFormatter
(
format_observation
=
StringFormatter
(
...
@@ -1048,7 +1289,7 @@ _register_template(
...
@@ -1048,7 +1289,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"phi"
,
name
=
"phi"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|end|>
\n
<|assistant|>
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|end|>
\n
<|assistant|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
...
@@ -1057,7 +1298,7 @@ _register_template(
...
@@ -1057,7 +1298,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"phi_small"
,
name
=
"phi_small"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|end|>
\n
<|assistant|>
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|end|>
\n
<|assistant|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
...
@@ -1067,7 +1308,7 @@ _register_template(
...
@@ -1067,7 +1308,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"phi4"
,
name
=
"phi4"
,
format_user
=
StringFormatter
(
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user<|im_sep|>{{content}}<|im_end|><|im_start|>assistant<|im_sep|>"
]
slots
=
[
"<|im_start|>user<|im_sep|>{{content}}<|im_end|><|im_start|>assistant<|im_sep|>"
]
...
@@ -1078,17 +1319,22 @@ _register_template(
...
@@ -1078,17 +1319,22 @@ _register_template(
)
)
_register_template
(
# copied from ministral template
register_template
(
name
=
"pixtral"
,
name
=
"pixtral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST]{{content}}[/INST]"
]),
format_user
=
StringFormatter
(
slots
=
[
"[INST]{{content}}[/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS]{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"pixtral"
,
image_token
=
"[IMG]"
),
mm_plugin
=
get_mm_plugin
(
name
=
"pixtral"
,
image_token
=
"[IMG]"
),
template_class
=
Llama2Template
,
)
)
# copied from chatml template
# copied from chatml template
_
register_template
(
register_template
(
name
=
"qwen"
,
name
=
"qwen"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
@@ -1104,7 +1350,19 @@ _register_template(
...
@@ -1104,7 +1350,19 @@ _register_template(
# copied from chatml template
# copied from chatml template
_register_template
(
register_template
(
name
=
"qwen2_audio"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_audio"
,
audio_token
=
"<|AUDIO|>"
),
)
# copied from qwen template
register_template
(
name
=
"qwen2_vl"
,
name
=
"qwen2_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
@@ -1120,7 +1378,7 @@ _register_template(
...
@@ -1120,7 +1378,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"sailor"
,
name
=
"sailor"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>question
\n
{{content}}<|im_end|>
\n
<|im_start|>answer
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>question
\n
{{content}}<|im_end|>
\n
<|im_start|>answer
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
@@ -1134,7 +1392,7 @@ _register_template(
...
@@ -1134,7 +1392,7 @@ _register_template(
# copied from llama3 template
# copied from llama3 template
_
register_template
(
register_template
(
name
=
"skywork_o1"
,
name
=
"skywork_o1"
,
format_user
=
StringFormatter
(
format_user
=
StringFormatter
(
slots
=
[
slots
=
[
...
@@ -1168,7 +1426,7 @@ _register_template(
...
@@ -1168,7 +1426,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"solar"
,
name
=
"solar"
,
format_user
=
StringFormatter
(
slots
=
[
"### User:
\n
{{content}}
\n\n
### Assistant:
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"### User:
\n
{{content}}
\n\n
### Assistant:
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"### System:
\n
{{content}}
\n\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"### System:
\n
{{content}}
\n\n
"
]),
...
@@ -1176,7 +1434,7 @@ _register_template(
...
@@ -1176,7 +1434,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"starchat"
,
name
=
"starchat"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|end|>
\n
<|assistant|>"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|end|>
\n
<|assistant|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
...
@@ -1185,14 +1443,14 @@ _register_template(
...
@@ -1185,14 +1443,14 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"telechat"
,
name
=
"telechat"
,
format_user
=
StringFormatter
(
slots
=
[
"<_user>{{content}}<_bot>"
]),
format_user
=
StringFormatter
(
slots
=
[
"<_user>{{content}}<_bot>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<_system>{{content}}<_end>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<_system>{{content}}<_end>"
]),
)
)
_
register_template
(
register_template
(
name
=
"telechat2"
,
name
=
"telechat2"
,
format_user
=
StringFormatter
(
slots
=
[
"<_user>{{content}}<_bot>"
]),
format_user
=
StringFormatter
(
slots
=
[
"<_user>{{content}}<_bot>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<_system>{{content}}"
]),
format_system
=
StringFormatter
(
slots
=
[
"<_system>{{content}}"
]),
...
@@ -1202,7 +1460,7 @@ _register_template(
...
@@ -1202,7 +1460,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"vicuna"
,
name
=
"vicuna"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
default_system
=
(
default_system
=
(
...
@@ -1213,7 +1471,7 @@ _register_template(
...
@@ -1213,7 +1471,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"video_llava"
,
name
=
"video_llava"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
default_system
=
(
default_system
=
(
...
@@ -1224,7 +1482,7 @@ _register_template(
...
@@ -1224,7 +1482,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"xuanyuan"
,
name
=
"xuanyuan"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}} Assistant:"
]),
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}} Assistant:"
]),
default_system
=
(
default_system
=
(
...
@@ -1235,13 +1493,13 @@ _register_template(
...
@@ -1235,13 +1493,13 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"xverse"
,
name
=
"xverse"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n\n
Assistant: "
]),
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n\n
Assistant: "
]),
)
)
_
register_template
(
register_template
(
name
=
"yayi"
,
name
=
"yayi"
,
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"<|Human|>"
},
":
\n
{{content}}
\n\n
"
,
{
"token"
:
"<|YaYi|>"
},
":"
]),
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"<|Human|>"
},
":
\n
{{content}}
\n\n
"
,
{
"token"
:
"<|YaYi|>"
},
":"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
...
@@ -1262,7 +1520,7 @@ _register_template(
...
@@ -1262,7 +1520,7 @@ _register_template(
# copied from chatml template
# copied from chatml template
_
register_template
(
register_template
(
name
=
"yi"
,
name
=
"yi"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
@@ -1271,7 +1529,7 @@ _register_template(
...
@@ -1271,7 +1529,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"yi_vl"
,
name
=
"yi_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"### Human: {{content}}
\n
### Assistant:"
]),
format_user
=
StringFormatter
(
slots
=
[
"### Human: {{content}}
\n
### Assistant:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
...
@@ -1288,7 +1546,7 @@ _register_template(
...
@@ -1288,7 +1546,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"yuan"
,
name
=
"yuan"
,
format_user
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"token"
:
"<sep>"
}]),
format_user
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"token"
:
"<sep>"
}]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<eod>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<eod>
\n
"
]),
...
@@ -1296,7 +1554,7 @@ _register_template(
...
@@ -1296,7 +1554,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"zephyr"
,
name
=
"zephyr"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}"
,
{
"eos_token"
},
"<|assistant|>
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}"
,
{
"eos_token"
},
"<|assistant|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
,
{
"eos_token"
}]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
,
{
"eos_token"
}]),
...
@@ -1304,7 +1562,7 @@ _register_template(
...
@@ -1304,7 +1562,7 @@ _register_template(
)
)
_
register_template
(
register_template
(
name
=
"ziya"
,
name
=
"ziya"
,
format_user
=
StringFormatter
(
slots
=
[
"<human>:{{content}}
\n
<bot>:"
]),
format_user
=
StringFormatter
(
slots
=
[
"<human>:{{content}}
\n
<bot>:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment