Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
317a82e2
Commit
317a82e2
authored
Mar 07, 2025
by
chenych
Browse files
Add QWQ-32B
parent
37b0ad9f
Changes
255
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1754 additions
and
1221 deletions
+1754
-1221
src/llamafactory/data/formatter.py
src/llamafactory/data/formatter.py
+16
-13
src/llamafactory/data/loader.py
src/llamafactory/data/loader.py
+96
-23
src/llamafactory/data/mm_plugin.py
src/llamafactory/data/mm_plugin.py
+482
-153
src/llamafactory/data/parser.py
src/llamafactory/data/parser.py
+24
-32
src/llamafactory/data/preprocess.py
src/llamafactory/data/preprocess.py
+0
-111
src/llamafactory/data/processor/__init__.py
src/llamafactory/data/processor/__init__.py
+17
-0
src/llamafactory/data/processor/feedback.py
src/llamafactory/data/processor/feedback.py
+129
-0
src/llamafactory/data/processor/pairwise.py
src/llamafactory/data/processor/pairwise.py
+118
-0
src/llamafactory/data/processor/pretrain.py
src/llamafactory/data/processor/pretrain.py
+57
-0
src/llamafactory/data/processor/processor_utils.py
src/llamafactory/data/processor/processor_utils.py
+37
-2
src/llamafactory/data/processor/supervised.py
src/llamafactory/data/processor/supervised.py
+200
-0
src/llamafactory/data/processor/unsupervised.py
src/llamafactory/data/processor/unsupervised.py
+91
-0
src/llamafactory/data/processors/__init__.py
src/llamafactory/data/processors/__init__.py
+0
-0
src/llamafactory/data/processors/feedback.py
src/llamafactory/data/processors/feedback.py
+0
-130
src/llamafactory/data/processors/mm_utils.py
src/llamafactory/data/processors/mm_utils.py
+0
-27
src/llamafactory/data/processors/pairwise.py
src/llamafactory/data/processors/pairwise.py
+0
-119
src/llamafactory/data/processors/pretrain.py
src/llamafactory/data/processors/pretrain.py
+0
-59
src/llamafactory/data/processors/supervised.py
src/llamafactory/data/processors/supervised.py
+0
-219
src/llamafactory/data/processors/unsupervised.py
src/llamafactory/data/processors/unsupervised.py
+0
-104
src/llamafactory/data/template.py
src/llamafactory/data/template.py
+487
-229
No files found.
src/llamafactory/data/formatter.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -86,19 +86,25 @@ class StringFormatter(Formatter):
elif
isinstance
(
slot
,
(
dict
,
set
)):
elements
.
append
(
slot
)
else
:
raise
RuntimeError
(
f
"Input must be string, set[str] or dict[str, str], got
{
type
(
slot
)
}
"
)
raise
RuntimeError
(
f
"Input must be string, set[str] or dict[str, str], got
{
type
(
slot
)
}
.
"
)
return
elements
@
dataclass
class
FunctionFormatter
(
Formatter
):
class
FunctionFormatter
(
String
Formatter
):
def
__post_init__
(
self
):
super
().
__post_init__
()
self
.
tool_utils
=
get_tool_utils
(
self
.
tool_format
)
@
override
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
content
=
kwargs
.
pop
(
"content"
)
content
:
str
=
kwargs
.
pop
(
"content"
)
regex
=
re
.
compile
(
r
"<think>(.*)</think>"
,
re
.
DOTALL
)
thought
=
re
.
search
(
regex
,
content
)
if
thought
:
content
=
content
.
replace
(
thought
.
group
(
0
),
""
)
functions
:
List
[
"FunctionCall"
]
=
[]
try
:
tool_calls
=
json
.
loads
(
content
)
...
...
@@ -111,16 +117,13 @@ class FunctionFormatter(Formatter):
)
except
json
.
JSONDecodeError
:
raise
RuntimeError
(
f
"Invalid JSON format in function message:
{
str
([
content
])
}
"
)
# flat string
raise
RuntimeError
(
f
"Invalid JSON format in function message:
{
str
([
content
])
}
.
"
)
# flat string
elements
=
[]
for
slot
in
self
.
slots
:
if
slot
==
"{{content}}"
:
elements
+=
self
.
tool_utils
.
function_formatter
(
functions
)
else
:
elements
.
append
(
slot
)
function_str
=
self
.
tool_utils
.
function_formatter
(
functions
)
if
thought
:
function_str
=
thought
.
group
(
1
)
+
function_str
return
elements
return
super
().
apply
(
content
=
function_str
)
@
dataclass
...
...
@@ -135,7 +138,7 @@ class ToolFormatter(Formatter):
tools
=
json
.
loads
(
content
)
return
[
self
.
tool_utils
.
tool_formatter
(
tools
)
if
len
(
tools
)
!=
0
else
""
]
except
json
.
JSONDecodeError
:
raise
RuntimeError
(
f
"Invalid JSON format in tool description:
{
str
([
content
])
}
"
)
# flat string
raise
RuntimeError
(
f
"Invalid JSON format in tool description:
{
str
([
content
])
}
.
"
)
# flat string
@
override
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
...
...
src/llamafactory/data/loader.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -22,10 +22,17 @@ from datasets import DatasetDict, load_dataset, load_from_disk
from
..extras
import
logging
from
..extras.constants
import
FILEEXT2TYPE
from
..extras.misc
import
check_version
,
has_tokenized_data
from
.
align
er
import
align_dataset
from
.
convert
er
import
align_dataset
from
.data_utils
import
merge_dataset
,
split_dataset
from
.parser
import
get_dataset_list
from
.preprocess
import
get_preprocess_and_print_func
from
.processor
import
(
FeedbackDatasetProcessor
,
PackedSupervisedDatasetProcessor
,
PairwiseDatasetProcessor
,
PretrainDatasetProcessor
,
SupervisedDatasetProcessor
,
UnsupervisedDatasetProcessor
,
)
if
TYPE_CHECKING
:
...
...
@@ -35,6 +42,7 @@ if TYPE_CHECKING:
from
..hparams
import
DataArguments
,
ModelArguments
from
.data_utils
import
DatasetModule
from
.parser
import
DatasetAttr
from
.processor
import
DatasetProcessor
from
.template
import
Template
...
...
@@ -156,21 +164,67 @@ def _get_merged_dataset(
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]:
merge
:
bool
=
True
,
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
Dict
[
str
,
"Dataset"
]]]:
r
"""
G
ets the merged datasets in the standard format.
R
et
urn
s the merged datasets in the standard format.
"""
if
dataset_names
is
None
:
return
None
datasets
=
[]
for
dataset_
attr
in
get_dataset_list
(
dataset_names
,
data_args
.
dataset_dir
):
datasets
=
{}
for
dataset_
name
,
dataset_attr
in
zip
(
dataset_names
,
get_dataset_list
(
dataset_names
,
data_args
.
dataset_dir
)
)
:
if
(
stage
==
"rm"
and
dataset_attr
.
ranking
is
False
)
or
(
stage
!=
"rm"
and
dataset_attr
.
ranking
is
True
):
raise
ValueError
(
"The dataset is not applicable in the current training stage."
)
datasets
.
append
(
_load_single_dataset
(
dataset_attr
,
model_args
,
data_args
,
training_args
))
datasets
[
dataset_name
]
=
_load_single_dataset
(
dataset_attr
,
model_args
,
data_args
,
training_args
)
if
merge
:
return
merge_dataset
(
list
(
datasets
.
values
()),
data_args
,
seed
=
training_args
.
seed
)
else
:
return
datasets
def
_get_dataset_processor
(
data_args
:
"DataArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
do_generate
:
bool
=
False
,
)
->
"DatasetProcessor"
:
r
"""
Returns the corresponding dataset processor.
"""
if
stage
==
"pt"
:
dataset_processor_class
=
PretrainDatasetProcessor
elif
stage
==
"sft"
and
not
do_generate
:
if
data_args
.
packing
:
if
data_args
.
neat_packing
:
# hack datasets to have int32 attention mask
from
datasets.arrow_writer
import
OptimizedTypedSequence
,
TypedSequence
def
__init__
(
self
,
data
,
**
kwargs
):
return
TypedSequence
.
__init__
(
self
,
data
,
type
=
kwargs
.
pop
(
"type"
,
None
),
try_type
=
kwargs
.
pop
(
"try_type"
,
None
),
optimized_int_type
=
kwargs
.
pop
(
"optimized_int_type"
,
None
),
)
OptimizedTypedSequence
.
__init__
=
__init__
dataset_processor_class
=
PackedSupervisedDatasetProcessor
else
:
dataset_processor_class
=
SupervisedDatasetProcessor
elif
stage
==
"rm"
:
dataset_processor_class
=
PairwiseDatasetProcessor
elif
stage
==
"kto"
:
dataset_processor_class
=
FeedbackDatasetProcessor
else
:
dataset_processor_class
=
UnsupervisedDatasetProcessor
return
merge_
dataset
(
datasets
,
data_args
,
seed
=
training_args
.
seed
)
return
dataset
_processor_class
(
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
data_args
=
data_args
)
def
_get_preprocessed_dataset
(
...
...
@@ -189,7 +243,7 @@ def _get_preprocessed_dataset(
if
dataset
is
None
:
return
None
preprocess_func
,
print_function
=
get_
preprocess_and_print_func
(
dataset_processor
=
_
get_
dataset_processor
(
data_args
,
stage
,
template
,
tokenizer
,
processor
,
do_generate
=
(
training_args
.
predict_with_generate
and
is_eval
)
)
column_names
=
list
(
next
(
iter
(
dataset
)).
keys
())
...
...
@@ -202,7 +256,7 @@ def _get_preprocessed_dataset(
)
dataset
=
dataset
.
map
(
preprocess_
func
,
dataset_processor
.
preprocess_
dataset
,
batched
=
True
,
batch_size
=
data_args
.
preprocessing_batch_size
,
remove_columns
=
column_names
,
...
...
@@ -212,7 +266,7 @@ def _get_preprocessed_dataset(
if
training_args
.
should_log
:
try
:
print
(
"eval example:"
if
is_eval
else
"training example:"
)
print_function
(
next
(
iter
(
dataset
)))
dataset_processor
.
print_data_example
(
next
(
iter
(
dataset
)))
except
StopIteration
:
if
stage
==
"pt"
:
raise
RuntimeError
(
"Cannot find sufficient samples, consider increasing dataset size."
)
...
...
@@ -234,7 +288,7 @@ def get_dataset(
r
"""
Gets the train dataset and optionally gets the evaluation dataset.
"""
# Load tokenized dataset
# Load tokenized dataset
if path exists
if
data_args
.
tokenized_path
is
not
None
:
if
has_tokenized_data
(
data_args
.
tokenized_path
):
logger
.
warning_rank0
(
"Loading dataset from disk will ignore other data arguments."
)
...
...
@@ -249,7 +303,7 @@ def get_dataset(
if
"validation"
in
tokenized_data
:
dataset_module
[
"eval_dataset"
]
=
tokenized_data
[
"validation"
]
else
:
#
D
ataset
else
:
#
single d
ataset
dataset_module
[
"train_dataset"
]
=
tokenized_data
if
data_args
.
streaming
:
...
...
@@ -263,15 +317,23 @@ def get_dataset(
# Load and preprocess dataset
with
training_args
.
main_process_first
(
desc
=
"load dataset"
):
dataset
=
_get_merged_dataset
(
data_args
.
dataset
,
model_args
,
data_args
,
training_args
,
stage
)
eval_dataset
=
_get_merged_dataset
(
data_args
.
eval_dataset
,
model_args
,
data_args
,
training_args
,
stage
)
eval_dataset
=
_get_merged_dataset
(
data_args
.
eval_dataset
,
model_args
,
data_args
,
training_args
,
stage
,
merge
=
training_args
.
do_predict
)
with
training_args
.
main_process_first
(
desc
=
"pre-process dataset"
):
dataset
=
_get_preprocessed_dataset
(
dataset
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
False
)
eval_dataset
=
_get_preprocessed_dataset
(
eval_dataset
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
True
)
if
isinstance
(
eval_dataset
,
dict
):
for
eval_name
,
eval_data
in
eval_dataset
.
items
():
eval_dataset
[
eval_name
]
=
_get_preprocessed_dataset
(
eval_data
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
True
)
else
:
eval_dataset
=
_get_preprocessed_dataset
(
eval_dataset
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
True
)
if
data_args
.
val_size
>
1e-6
:
dataset_dict
=
split_dataset
(
dataset
,
data_args
,
seed
=
training_args
.
seed
)
...
...
@@ -284,17 +346,20 @@ def get_dataset(
dataset_dict
[
"train"
]
=
dataset
if
eval_dataset
is
not
None
:
if
data_args
.
streaming
:
eval_dataset
=
eval_dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
training_args
.
seed
)
if
isinstance
(
eval_dataset
,
dict
):
dataset_dict
.
update
({
f
"validation_
{
name
}
"
:
data
for
name
,
data
in
eval_dataset
.
items
()})
else
:
if
data_args
.
streaming
:
eval_dataset
=
eval_dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
training_args
.
seed
)
dataset_dict
[
"validation"
]
=
eval_dataset
dataset_dict
[
"validation"
]
=
eval_dataset
dataset_dict
=
DatasetDict
(
dataset_dict
)
if
data_args
.
tokenized_path
is
not
None
:
if
data_args
.
tokenized_path
is
not
None
:
# save tokenized dataset to disk and exit
if
training_args
.
should_save
:
dataset_dict
.
save_to_disk
(
data_args
.
tokenized_path
)
logger
.
info_rank0
(
f
"Tokenized dataset saved at
{
data_args
.
tokenized_path
}
."
)
logger
.
info_rank0
(
f
"Tokenized dataset
is
saved at
{
data_args
.
tokenized_path
}
."
)
logger
.
info_rank0
(
f
"Please restart the training with `tokenized_path:
{
data_args
.
tokenized_path
}
`."
)
sys
.
exit
(
0
)
...
...
@@ -305,5 +370,13 @@ def get_dataset(
if
"validation"
in
dataset_dict
:
dataset_module
[
"eval_dataset"
]
=
dataset_dict
[
"validation"
]
else
:
eval_dataset
=
{}
for
key
in
dataset_dict
.
keys
():
if
key
.
startswith
(
"validation_"
):
eval_dataset
[
key
[
len
(
"validation_"
)
:]]
=
dataset_dict
[
key
]
if
len
(
eval_dataset
):
dataset_module
[
"eval_dataset"
]
=
eval_dataset
return
dataset_module
src/llamafactory/data/mm_plugin.py
View file @
317a82e2
import
inspect
import
math
import
re
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
io
import
BytesIO
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
TypedDict
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
TypedDict
,
Union
import
numpy
as
np
import
torch
from
transformers.image_utils
import
get_image_size
,
to_numpy_array
from
typing_extensions
import
override
from
..extras.constants
import
IGNORE_INDEX
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.packages
import
is_pillow_available
,
is_pyav_available
,
is_transformers_version_greater_than
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IGNORE_INDEX
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.packages
import
(
is_librosa_available
,
is_pillow_available
,
is_pyav_available
,
is_transformers_version_greater_than
,
)
if
is_librosa_available
():
import
librosa
if
is_pillow_available
():
...
...
@@ -31,7 +42,9 @@ if is_transformers_version_greater_than("4.45.0"):
if
TYPE_CHECKING
:
from
av.stream
import
Stream
from
numpy.typing
import
NDArray
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
transformers.feature_extraction_sequence_utils
import
SequenceFeatureExtractor
from
transformers.image_processing_utils
import
BaseImageProcessor
class
EncodedImage
(
TypedDict
):
...
...
@@ -40,6 +53,7 @@ if TYPE_CHECKING:
ImageInput
=
Union
[
str
,
bytes
,
EncodedImage
,
ImageObject
]
VideoInput
=
str
AudioInput
=
Union
[
str
,
NDArray
]
def
_get_paligemma_token_type_ids
(
...
...
@@ -59,20 +73,25 @@ def _get_paligemma_token_type_ids(
return
batch_token_type_ids
class
BasePlugin
:
def
__init__
(
self
,
image_token
:
Optional
[
str
],
video_token
:
Optional
[
str
])
->
None
:
self
.
image_token
=
image_token
self
.
video_token
=
video_token
self
.
expand_mm_tokens
=
True
@
dataclass
class
MMPluginMixin
:
image_token
:
Optional
[
str
]
video_token
:
Optional
[
str
]
audio_token
:
Optional
[
str
]
expand_mm_tokens
:
bool
=
True
def
_validate_input
(
self
,
processor
:
Optional
[
"ProcessorMixin"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
)
->
None
:
r
"""
Validates if this model accepts the input modalities.
"""
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
,
None
)
feature_extractor
:
"SequenceFeatureExtractor"
=
getattr
(
processor
,
"feature_extractor"
,
None
)
if
len
(
images
)
!=
0
and
self
.
image_token
is
None
:
raise
ValueError
(
"This model does not support image input. Please check whether the correct `template` is used."
...
...
@@ -83,31 +102,54 @@ class BasePlugin:
"This model does not support video input. Please check whether the correct `template` is used."
)
def
_preprocess_image
(
self
,
image
:
"ImageObject"
,
**
kwargs
)
->
"ImageObject"
:
if
len
(
audios
)
!=
0
and
self
.
audio_token
is
None
:
raise
ValueError
(
"This model does not support audio input. Please check whether the correct `template` is used."
)
if
self
.
image_token
is
not
None
and
processor
is
None
:
raise
ValueError
(
"Processor was not found, please check and update your processor config."
)
if
self
.
image_token
is
not
None
and
image_processor
is
None
:
raise
ValueError
(
"Image processor was not found, please check and update your processor config."
)
if
self
.
audio_token
is
not
None
and
feature_extractor
is
None
:
raise
ValueError
(
"Audio feature extractor was not found, please check and update your processor config."
)
def
_preprocess_image
(
self
,
image
:
"ImageObject"
,
image_max_pixels
:
int
,
image_min_pixels
:
int
,
**
kwargs
)
->
"ImageObject"
:
r
"""
Pre-processes a single image.
"""
image_resolution
:
int
=
kwargs
.
get
(
"image_resolution"
)
if
(
image
.
width
*
image
.
height
)
>
image_resolution
:
resize_factor
=
math
.
sqrt
(
image_resolution
/
(
image
.
width
*
image
.
height
))
if
(
image
.
width
*
image
.
height
)
>
image_max_pixels
:
resize_factor
=
math
.
sqrt
(
image_max_pixels
/
(
image
.
width
*
image
.
height
))
width
,
height
=
int
(
image
.
width
*
resize_factor
),
int
(
image
.
height
*
resize_factor
)
image
=
image
.
resize
((
width
,
height
),
resample
=
Image
.
NEAREST
)
image
=
image
.
resize
((
width
,
height
))
if
(
image
.
width
*
image
.
height
)
<
image_min_pixels
:
resize_factor
=
math
.
sqrt
(
image_min_pixels
/
(
image
.
width
*
image
.
height
))
width
,
height
=
int
(
image
.
width
*
resize_factor
),
int
(
image
.
height
*
resize_factor
)
image
=
image
.
resize
((
width
,
height
))
if
image
.
mode
!=
"RGB"
:
image
=
image
.
convert
(
"RGB"
)
return
image
def
_get_video_sample_frames
(
self
,
video_stream
:
"Stream"
,
**
kwargs
)
->
int
:
def
_get_video_sample_indices
(
self
,
video_stream
:
"Stream"
,
video_fps
:
float
,
video_maxlen
:
int
,
**
kwargs
)
->
List
[
int
]:
r
"""
Computes video sample
fram
es according to fps.
Computes video sample
indic
es according to fps.
"""
video_fps
:
float
=
kwargs
.
get
(
"video_fps"
)
video_maxlen
:
int
=
kwargs
.
get
(
"video_maxlen"
)
total_frames
=
video_stream
.
frames
sample_frames
=
float
(
video_stream
.
duration
*
video_stream
.
time_base
)
*
video_fps
if
total_frames
==
0
:
# infinite video
return
np
.
linspace
(
0
,
video_maxlen
-
1
,
video_maxlen
).
astype
(
np
.
int32
)
sample_frames
=
math
.
floor
(
float
(
video_stream
.
duration
*
video_stream
.
time_base
)
*
video_fps
)
sample_frames
=
min
(
total_frames
,
video_maxlen
,
sample_frames
)
return
math
.
floor
(
sample_frames
)
return
np
.
linspace
(
0
,
total_frames
-
1
,
sample_frames
).
astype
(
np
.
int32
)
def
_regularize_images
(
self
,
images
:
Sequence
[
"ImageInput"
],
**
kwargs
)
->
List
[
"ImageObject"
]:
r
"""
...
...
@@ -126,7 +168,7 @@ class BasePlugin:
image
=
Image
.
open
(
image
[
"path"
])
if
not
isinstance
(
image
,
ImageObject
):
raise
ValueError
(
f
"Expect input is a list of
I
mages, but got
{
type
(
image
)
}
."
)
raise
ValueError
(
f
"Expect input is a list of
i
mages, but got
{
type
(
image
)
}
."
)
results
.
append
(
self
.
_preprocess_image
(
image
,
**
kwargs
))
...
...
@@ -140,9 +182,7 @@ class BasePlugin:
for
video
in
videos
:
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
total_frames
=
video_stream
.
frames
sample_frames
=
self
.
_get_video_sample_frames
(
video_stream
,
**
kwargs
)
sample_indices
=
np
.
linspace
(
0
,
total_frames
-
1
,
sample_frames
).
astype
(
np
.
int32
)
sample_indices
=
self
.
_get_video_sample_indices
(
video_stream
,
**
kwargs
)
frames
:
List
[
"ImageObject"
]
=
[]
container
.
seek
(
0
)
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
...
...
@@ -154,10 +194,27 @@ class BasePlugin:
return
results
def
_regularize_audios
(
self
,
audios
:
Sequence
[
"AudioInput"
],
sampling_rate
:
float
,
**
kwargs
)
->
List
[
"NDArray"
]:
r
"""
Regularizes audios to avoid error. Including reading and resampling.
"""
results
=
[]
for
audio
in
audios
:
if
isinstance
(
audio
,
str
):
audio
=
librosa
.
load
(
audio
,
sr
=
sampling_rate
)[
0
]
if
not
isinstance
(
audio
,
np
.
ndarray
):
raise
ValueError
(
f
"Expect input is a list of audios, but got
{
type
(
audio
)
}
."
)
results
.
append
(
audio
)
return
results
def
_get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
"ProcessorMixin"
,
)
->
Dict
[
str
,
"torch.Tensor"
]:
r
"""
...
...
@@ -172,47 +229,65 @@ class BasePlugin:
It holds num_patches == torch.prod(image_grid_thw)
"""
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
,
None
)
video_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"video_processor"
,
image_processor
)
input_dict
=
{
"images"
:
None
}
# default key
feature_extractor
:
"SequenceFeatureExtractor"
=
getattr
(
processor
,
"feature_extractor"
,
None
)
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
,
image_resolution
=
getattr
(
processor
,
"image_resolution"
,
512
*
512
),
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)
input
_dict
[
"images"
]
=
images
mm_
input
s
.
update
(
image_processor
(
images
,
return_tensors
=
"pt"
))
if
len
(
videos
)
!=
0
:
videos
=
self
.
_regularize_videos
(
videos
,
image_resolution
=
getattr
(
processor
,
"video_resolution"
,
128
*
128
),
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
image_min_pixels
=
getattr
(
processor
,
"video_min_pixels"
,
16
*
16
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
64
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)
input_dict
[
"videos"
]
=
videos
mm_inputs
=
{}
if
image_processor
!=
video_processor
:
if
input_dict
.
get
(
"images"
)
is
not
None
:
mm_inputs
.
update
(
image_processor
(
input_dict
[
"images"
],
return_tensors
=
"pt"
))
if
input_dict
.
get
(
"videos"
)
is
not
None
:
mm_inputs
.
update
(
video_processor
(
input_dict
[
"videos"
],
return_tensors
=
"pt"
))
elif
input_dict
.
get
(
"images"
)
is
not
None
or
input_dict
.
get
(
"videos"
)
is
not
None
:
# same processor (qwen2-vl)
mm_inputs
.
update
(
image_processor
(
**
input_dict
,
return_tensors
=
"pt"
))
if
"videos"
in
inspect
.
signature
(
video_processor
.
preprocess
).
parameters
:
# for qwen2_vl and video_llava
mm_inputs
.
update
(
video_processor
(
images
=
None
,
videos
=
videos
,
return_tensors
=
"pt"
))
else
:
# for llava_next_video
mm_inputs
.
update
(
video_processor
(
videos
,
return_tensors
=
"pt"
))
if
len
(
audios
)
!=
0
:
audios
=
self
.
_regularize_audios
(
audios
,
sampling_rate
=
getattr
(
feature_extractor
,
"sampling_rate"
,
16000
),
)
mm_inputs
.
update
(
feature_extractor
(
audios
,
sampling_rate
=
getattr
(
feature_extractor
,
"sampling_rate"
,
16000
),
return_attention_mask
=
True
,
padding
=
"max_length"
,
return_tensors
=
"pt"
,
)
)
mm_inputs
[
"feature_attention_mask"
]
=
mm_inputs
.
pop
(
"attention_mask"
)
# prevent conflicts
return
mm_inputs
@
dataclass
class
BasePlugin
(
MMPluginMixin
):
def
process_messages
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
r
"""
Pre-processes input messages before tokenization for VLMs.
"""
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
messages
def
process_token_ids
(
...
...
@@ -221,21 +296,24 @@ class BasePlugin:
labels
:
Optional
[
List
[
int
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Tuple
[
List
[
int
],
Optional
[
List
[
int
]]]:
r
"""
Pre-processes token ids after tokenization for VLMs.
"""
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
input_ids
,
labels
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
...
...
@@ -247,13 +325,15 @@ class BasePlugin:
videos: a list of video inputs, shape (num_videos,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
audlens: number of audios in each sample, shape (batch_size,)
batch_ids: token ids of input samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos
"""
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
{}
@
dataclass
class
LlavaPlugin
(
BasePlugin
):
@
override
def
process_messages
(
...
...
@@ -261,9 +341,10 @@ class LlavaPlugin(BasePlugin):
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
image_seqlen
=
getattr
(
processor
,
"image_seqlen"
)
if
self
.
expand_mm_tokens
else
1
messages
=
deepcopy
(
messages
)
...
...
@@ -285,15 +366,18 @@ class LlavaPlugin(BasePlugin):
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
@
dataclass
class
LlavaNextPlugin
(
BasePlugin
):
@
override
def
process_messages
(
...
...
@@ -301,16 +385,15 @@ class LlavaNextPlugin(BasePlugin):
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
if
"image_sizes"
in
mm_inputs
:
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
])
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"pixel_values"
in
mm_inputs
:
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
].
tolist
())
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
[
"pixel_values"
][
0
][
0
]))
for
message
in
messages
:
...
...
@@ -319,7 +402,7 @@ class LlavaNextPlugin(BasePlugin):
if
self
.
expand_mm_tokens
:
orig_height
,
orig_width
=
next
(
image_sizes
)
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
if
getattr
(
processor
,
"vision_feature_select_strategy"
)
==
"default"
:
if
getattr
(
processor
,
"vision_feature_select_strategy"
,
"default"
)
==
"default"
:
image_seqlen
-=
1
else
:
image_seqlen
=
1
...
...
@@ -339,15 +422,18 @@ class LlavaNextPlugin(BasePlugin):
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
@
dataclass
class
LlavaNextVideoPlugin
(
BasePlugin
):
@
override
def
process_messages
(
...
...
@@ -355,14 +441,15 @@ class LlavaNextVideoPlugin(BasePlugin):
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"pixel_values"
in
mm_inputs
:
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
])
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
]
.
tolist
()
)
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
[
"pixel_values"
][
0
][
0
]))
for
message
in
messages
:
content
=
message
[
"content"
]
...
...
@@ -370,7 +457,7 @@ class LlavaNextVideoPlugin(BasePlugin):
if
self
.
expand_mm_tokens
:
orig_height
,
orig_width
=
next
(
image_sizes
)
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
if
getattr
(
processor
,
"vision_feature_select_strategy"
)
==
"default"
:
if
getattr
(
processor
,
"vision_feature_select_strategy"
,
"default"
)
==
"default"
:
image_seqlen
-=
1
else
:
image_seqlen
=
1
...
...
@@ -381,12 +468,15 @@ class LlavaNextVideoPlugin(BasePlugin):
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
if
"pixel_values_videos"
in
mm_inputs
:
pixel_values_video
=
to_numpy_array
(
mm_inputs
.
get
(
"pixel_values_videos"
)[
0
])
height
,
width
=
get_image_size
(
pixel_values_video
[
0
])
num_frames
=
pixel_values_video
.
shape
[
0
]
# frame dim is always after batch dim
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
video_seqlen
=
image_seqlen
//
4
*
num_frames
# divide by 4 needed for avg pooling layer
video_seqlen
=
video_seqlen
if
self
.
expand_mm_tokens
else
1
if
self
.
expand_mm_tokens
:
pixel_values_video
=
to_numpy_array
(
mm_inputs
.
get
(
"pixel_values_videos"
)[
0
])
height
,
width
=
get_image_size
(
pixel_values_video
[
0
])
num_frames
=
pixel_values_video
.
shape
[
0
]
# frame dim is always after batch dim
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
video_seqlen
=
image_seqlen
//
4
*
num_frames
# divide by 4 needed for avg pooling layer
else
:
video_seqlen
=
1
for
message
in
messages
:
content
=
message
[
"content"
]
while
VIDEO_PLACEHOLDER
in
content
:
...
...
@@ -408,15 +498,18 @@ class LlavaNextVideoPlugin(BasePlugin):
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
@
dataclass
class
MiniCPMVPlugin
(
BasePlugin
):
@
override
def
process_messages
(
...
...
@@ -424,26 +517,27 @@ class MiniCPMVPlugin(BasePlugin):
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
num_image_tokens
=
0
num_video_tokens
=
0
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
,
num_audio_tokens
=
0
,
0
,
0
messages
=
deepcopy
(
messages
)
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
mm_inputs
=
{}
audio_inputs
=
{}
if
len
(
images
)
!=
0
and
len
(
videos
)
!=
0
:
raise
ValueError
(
"MiniCPM-V model does not support input images and videos at the same time."
)
if
len
(
videos
)
!=
0
:
max_slice_nums
=
2
use_image_id
=
False
mm_inputs
=
self
.
_get_mm_inputs
([],
videos
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
([],
videos
,
[],
processor
)
else
:
max_slice_nums
=
image_processor
.
max_slice_nums
use_image_id
=
image_processor
.
use_image_id
for
message
in
messages
:
for
i
,
message
in
enumerate
(
messages
)
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
,
1
)
...
...
@@ -454,15 +548,24 @@ class MiniCPMVPlugin(BasePlugin):
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{image}}"
*
video_seqlen
,
1
)
num_video_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
"(<image>./</image>)"
)
while
AUDIO_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
"{{audio}}"
,
1
)
num_audio_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
"(<image>./</image>)"
).
replace
(
"{{audio}}"
,
"(<audio>./</audio>)"
)
if
num_image_tokens
>
0
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
[],
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
[],
[],
processor
)
if
num_audio_tokens
>
0
:
audio_inputs
=
self
.
_get_mm_inputs
([],
[],
audios
,
processor
,
ret_phs
=
True
)
if
mm_inputs
:
pattern
=
"(<image>./</image>)"
image_sizes
=
mm_inputs
[
"image_sizes"
]
idx
=
0
for
index
,
message
in
enumerate
(
messages
):
text
=
message
[
"content"
]
image_tags
=
re
.
findall
(
pattern
,
text
)
...
...
@@ -473,9 +576,26 @@ class MiniCPMVPlugin(BasePlugin):
final_text
+
text_chunks
[
i
]
+
image_processor
.
get_slice_image_placeholder
(
image_sizes
[
0
][
i
],
i
,
max_slice_nums
,
use_image_id
image_sizes
[
0
][
i
dx
],
i
dx
,
max_slice_nums
,
use_image_id
)
)
idx
+=
1
final_text
+=
text_chunks
[
-
1
]
messages
[
index
][
"content"
]
=
final_text
if
audio_inputs
:
pattern
=
"(<audio>./</audio>)"
idx
=
0
for
index
,
message
in
enumerate
(
messages
):
text
=
message
[
"content"
]
audio_tags
=
re
.
findall
(
pattern
,
text
)
text_chunks
=
text
.
split
(
pattern
)
final_text
=
""
for
i
in
range
(
len
(
audio_tags
)):
audio_placeholder
=
audio_inputs
[
"audio_phs"
][
0
][
idx
]
final_text
=
final_text
+
text_chunks
[
i
]
+
audio_placeholder
idx
+=
1
final_text
+=
text_chunks
[
-
1
]
messages
[
index
][
"content"
]
=
final_text
...
...
@@ -486,6 +606,9 @@ class MiniCPMVPlugin(BasePlugin):
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
if
len
(
audios
)
!=
num_audio_tokens
:
raise
ValueError
(
f
"The number of audios does not match the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
return
messages
@
override
...
...
@@ -493,15 +616,18 @@ class MiniCPMVPlugin(BasePlugin):
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
"ProcessorMixin"
,
**
kwargs
,
)
->
Dict
[
str
,
"torch.Tensor"
]:
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
feature_extractor
:
"SequenceFeatureExtractor"
=
getattr
(
processor
,
"feature_extractor"
,
None
)
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
,
image_resolution
=
getattr
(
processor
,
"image_resolution"
,
512
*
512
),
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)
if
"valid_image_nums_ls"
in
kwargs
:
valid_image_nums_ls
=
kwargs
[
"valid_image_nums_ls"
]
...
...
@@ -521,13 +647,39 @@ class MiniCPMVPlugin(BasePlugin):
if
len
(
videos
)
!=
0
:
videos
=
self
.
_regularize_videos
(
videos
,
image_resolution
=
getattr
(
processor
,
"video_resolution"
,
128
*
128
),
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
image_min_pixels
=
getattr
(
processor
,
"video_min_pixels"
,
16
*
16
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
64
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)
video_inputs
=
image_processor
(
videos
,
do_pad
=
True
,
max_slice_nums
=
2
,
return_tensors
=
"pt"
)
mm_inputs
.
update
(
video_inputs
)
if
len
(
audios
)
!=
0
:
audios
=
self
.
_regularize_audios
(
audios
,
sampling_rate
=
getattr
(
feature_extractor
,
"sampling_rate"
,
16000
),
)
if
"valid_audio_nums_ls"
in
kwargs
:
valid_audio_nums_ls
=
kwargs
[
"valid_audio_nums_ls"
]
audios_ls
=
[]
idx
=
0
for
valid_audio_nums
in
valid_audio_nums_ls
:
audios_ls
.
append
(
audios
[
idx
:
idx
+
valid_audio_nums
])
idx
+=
valid_audio_nums
else
:
audios_ls
=
[
audios
]
audio_features
,
audio_feature_lens
,
audio_phs
=
processor
.
audio_feature_extract
(
audios_ls
,
chunk_input
=
True
,
sampling_rate
=
16000
,
)
audio_feature_lens
=
[
torch
.
tensor
(
audio_feature_len
)
for
audio_feature_len
in
audio_feature_lens
]
mm_inputs
.
update
({
"audio_features"
:
audio_features
,
"audio_feature_lens"
:
audio_feature_lens
})
if
kwargs
.
get
(
"ret_phs"
,
False
):
mm_inputs
.
update
({
"audio_phs"
:
audio_phs
})
return
mm_inputs
@
override
...
...
@@ -535,15 +687,18 @@ class MiniCPMVPlugin(BasePlugin):
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
# image bound
image_bounds_list
=
[]
valid_image_nums_ls
=
[]
for
input_ids
in
batch_ids
:
for
i
,
input_ids
in
enumerate
(
batch_ids
)
:
input_ids_
=
torch
.
tensor
(
input_ids
)
start_cond
=
(
input_ids_
==
processor
.
tokenizer
.
im_start_id
)
|
(
input_ids_
==
processor
.
tokenizer
.
slice_start_id
...
...
@@ -552,21 +707,51 @@ class MiniCPMVPlugin(BasePlugin):
image_start_tokens
=
torch
.
where
(
start_cond
)[
0
]
image_start_tokens
+=
1
image_end_tokens
=
torch
.
where
(
end_cond
)[
0
]
valid_image_nums
=
max
(
len
(
image_start_tokens
),
len
(
image_end_tokens
))
valid_image_nums_ls
.
append
(
valid_image_nums
)
valid_image_nums_ls
.
append
(
imglens
[
i
])
image_bounds
=
torch
.
hstack
(
[
image_start_tokens
[:
valid_image_nums
]
.
unsqueeze
(
-
1
),
image_end_tokens
[:
valid_image_nums
]
.
unsqueeze
(
-
1
),
image_start_tokens
.
unsqueeze
(
-
1
),
image_end_tokens
.
unsqueeze
(
-
1
),
]
)
image_bounds_list
.
append
(
image_bounds
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
,
valid_image_nums_ls
=
valid_image_nums_ls
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
[],
processor
,
valid_image_nums_ls
=
valid_image_nums_ls
)
if
"tgt_sizes"
not
in
mm_inputs
:
dummy_data
=
[
torch
.
empty
(
0
)
for
_
in
range
(
len
(
batch_ids
))]
mm_inputs
.
update
({
"tgt_sizes"
:
dummy_data
,
"pixel_values"
:
dummy_data
,
"image_sizes"
:
dummy_data
})
mm_inputs
.
update
({
"image_bound"
:
image_bounds_list
})
if
len
(
audios
)
>
0
:
# audio bound
audio_bounds_ls
=
[]
spk_bounds_ls
=
[]
valid_audio_nums_ls
=
[]
for
input_ids
,
audiolen
in
zip
(
batch_ids
,
audlens
):
input_ids_
=
torch
.
tensor
(
input_ids
)
audio_start_idx
=
torch
.
where
(
input_ids_
==
processor
.
tokenizer
.
audio_start_id
)[
0
]
audio_end_idx
=
torch
.
where
(
input_ids_
==
processor
.
tokenizer
.
audio_end_id
)[
0
]
assert
len
(
audio_start_idx
)
==
len
(
audio_end_idx
)
audio_bounds
=
torch
.
hstack
([(
audio_start_idx
+
1
).
unsqueeze
(
-
1
),
audio_end_idx
.
unsqueeze
(
-
1
)])
audio_bounds_ls
.
append
(
audio_bounds
)
valid_audio_nums_ls
.
append
(
audiolen
)
spk_start_idx
=
torch
.
where
(
input_ids_
==
processor
.
tokenizer
.
spk_start_id
)[
0
]
spk_end_idx
=
torch
.
where
(
input_ids_
==
processor
.
tokenizer
.
spk_end_id
)[
0
]
assert
len
(
spk_start_idx
)
==
len
(
spk_end_idx
)
spk_bounds
=
torch
.
hstack
([(
spk_start_idx
+
1
).
unsqueeze
(
-
1
),
spk_end_idx
.
unsqueeze
(
-
1
)])
spk_bounds_ls
.
append
(
spk_bounds
)
audio_inputs
=
self
.
_get_mm_inputs
([],
[],
audios
,
processor
,
valid_audio_nums_ls
=
valid_audio_nums_ls
)
mm_inputs
.
update
(
audio_inputs
)
mm_inputs
.
update
({
"audio_bounds"
:
audio_bounds_ls
,
"spk_bounds"
:
spk_bounds_ls
})
return
mm_inputs
@
dataclass
class
MllamaPlugin
(
BasePlugin
):
@
override
def
process_messages
(
...
...
@@ -574,9 +759,10 @@ class MllamaPlugin(BasePlugin):
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
...
...
@@ -594,8 +780,9 @@ class MllamaPlugin(BasePlugin):
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
"ProcessorMixin"
,
**
kwargs
,
imglens
:
List
[
int
]
,
)
->
Dict
[
str
,
"torch.Tensor"
]:
r
"""
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
...
...
@@ -609,43 +796,56 @@ class MllamaPlugin(BasePlugin):
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
imglens
:
List
[
int
]
=
kwargs
[
"imglens"
]
images
=
self
.
_regularize_images
(
images
,
image_resolution
=
getattr
(
processor
,
"image_resolution"
,
512
*
512
))
batch_images
=
[]
for
image_length
in
imglens
:
batch_images
.
append
(
images
[:
image_length
])
images
=
images
[
image_length
:]
mm_inputs
=
{}
if
len
(
images
)
>
0
:
images
=
self
.
_regularize_images
(
images
,
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)
batch_images
=
[]
for
image_length
in
imglens
:
batch_images
.
append
(
images
[:
image_length
])
images
=
images
[
image_length
:]
mm_inputs
.
update
(
image_processor
(
batch_images
,
return_tensors
=
"pt"
))
return
image_processor
(
batch_images
,
return_tensors
=
"pt"
)
return
mm_inputs
@
override
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
,
imglens
=
imglens
)
num_tiles
=
mm_inputs
.
pop
(
"num_tiles"
)
image_token_id
=
getattr
(
processor
,
"image_token_id"
)
max_image_tiles
=
getattr
(
processor
.
image_processor
,
"max_image_tiles"
)
cross_attention_token_mask
=
[
get_cross_attention_token_mask
(
input_ids
,
image_token_id
)
for
input_ids
in
batch_ids
]
mm_inputs
[
"cross_attention_mask"
]
=
torch
.
from_numpy
(
convert_sparse_cross_attention_mask_to_dense
(
cross_attention_token_mask
,
num_tiles
=
num_tiles
,
max_num_tiles
=
max_image_tiles
,
length
=
max
(
len
(
input_ids
)
for
input_ids
in
batch_ids
),
)
)
# shape: (batch_size, length, max_num_images, max_num_tiles)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
,
imglens
)
if
mm_inputs
:
num_tiles
=
mm_inputs
.
pop
(
"num_tiles"
)
image_token_id
=
getattr
(
processor
,
"image_token_id"
)
max_image_tiles
=
getattr
(
processor
.
image_processor
,
"max_image_tiles"
)
cross_attention_token_mask
=
[
get_cross_attention_token_mask
(
input_ids
,
image_token_id
)
for
input_ids
in
batch_ids
]
mm_inputs
[
"cross_attention_mask"
]
=
torch
.
from_numpy
(
convert_sparse_cross_attention_mask_to_dense
(
cross_attention_token_mask
,
num_tiles
=
num_tiles
,
max_num_tiles
=
max_image_tiles
,
length
=
max
(
len
(
input_ids
)
for
input_ids
in
batch_ids
),
)
)
# shape: (batch_size, length, max_num_images, max_num_tiles)
return
mm_inputs
@
dataclass
class
PaliGemmaPlugin
(
BasePlugin
):
@
override
def
process_messages
(
...
...
@@ -653,9 +853,10 @@ class PaliGemmaPlugin(BasePlugin):
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
...
...
@@ -678,10 +879,11 @@ class PaliGemmaPlugin(BasePlugin):
labels
:
Optional
[
List
[
int
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Tuple
[
List
[
int
],
Optional
[
List
[
int
]]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_images
=
len
(
images
)
image_seqlen
=
num_images
*
getattr
(
processor
,
"image_seqlen"
)
if
self
.
expand_mm_tokens
else
0
# skip mm token
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
self
.
image_token
)
...
...
@@ -696,18 +898,21 @@ class PaliGemmaPlugin(BasePlugin):
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
seqlens
=
[
len
(
input_ids
)
for
input_ids
in
batch_ids
]
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
mm_inputs
[
"token_type_ids"
]
=
_get_paligemma_token_type_ids
(
imglens
,
seqlens
,
processor
)
return
mm_inputs
@
dataclass
class
PixtralPlugin
(
BasePlugin
):
@
override
def
process_messages
(
...
...
@@ -715,9 +920,10 @@ class PixtralPlugin(BasePlugin):
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
patch_size
=
getattr
(
processor
,
"patch_size"
)
image_token
=
getattr
(
processor
,
"image_token"
)
image_break_token
=
getattr
(
processor
,
"image_break_token"
)
...
...
@@ -725,17 +931,15 @@ class PixtralPlugin(BasePlugin):
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
image_input_sizes
=
mm_inputs
.
get
(
"image_sizes"
,
None
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"pixel_values"
in
mm_inputs
:
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
].
tolist
())
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
image_input_sizes
is
None
:
raise
ValueError
(
"Cannot get image input sizes."
)
if
self
.
expand_mm_tokens
:
image_size
=
image_input_sizes
[
0
][
num_image_tokens
]
height
,
width
=
image_size
height
,
width
=
next
(
image_sizes
)
num_height_tokens
=
height
//
patch_size
num_width_tokens
=
width
//
patch_size
replace_tokens
=
[[
image_token
]
*
num_width_tokens
+
[
image_break_token
]]
*
num_height_tokens
...
...
@@ -760,47 +964,105 @@ class PixtralPlugin(BasePlugin):
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
if
mm_inputs
.
get
(
"pixel_values"
):
mm_inputs
[
"pixel_values"
]
=
mm_inputs
[
"pixel_values"
][
0
]
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
mm_inputs
.
pop
(
"image_sizes"
,
None
)
return
mm_inputs
class
Qwen2vlPlugin
(
BasePlugin
):
@
dataclass
class
Qwen2AudioPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
bos_token
:
str
=
getattr
(
processor
,
"audio_bos_token"
)
eos_token
:
str
=
getattr
(
processor
,
"audio_eos_token"
)
messages
=
deepcopy
(
messages
)
mm_inputs
=
self
.
_get_mm_inputs
([],
[],
audios
,
processor
)
if
"feature_attention_mask"
in
mm_inputs
:
audio_lengths
=
mm_inputs
[
"feature_attention_mask"
].
sum
(
-
1
).
tolist
()
num_audio_tokens
=
0
for
message
in
messages
:
content
=
message
[
"content"
]
while
AUDIO_PLACEHOLDER
in
content
:
if
self
.
expand_mm_tokens
:
audio_length
=
audio_lengths
.
pop
(
0
)
input_length
=
(
audio_length
-
1
)
//
2
+
1
audio_seqlen
=
(
input_length
-
2
)
//
2
+
1
else
:
audio_seqlen
=
1
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
f
"
{
bos_token
}{
self
.
audio_token
*
audio_seqlen
}{
eos_token
}
"
,
1
)
num_audio_tokens
+=
1
message
[
"content"
]
=
content
if
len
(
audios
)
!=
num_audio_tokens
:
raise
ValueError
(
f
"The number of audios does not match the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
return
messages
@
override
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
@
dataclass
class
Qwen2VLPlugin
(
BasePlugin
):
@
override
def
_preprocess_image
(
self
,
image
:
"ImageObject"
,
**
kwargs
)
->
"ImageObject"
:
image
=
super
().
_preprocess_image
(
image
,
**
kwargs
)
if
min
(
image
.
width
,
image
.
height
)
<
28
:
width
,
height
=
max
(
image
.
width
,
28
),
max
(
image
.
height
,
28
)
image
=
image
.
resize
((
width
,
height
)
,
resample
=
Image
.
NEAREST
)
image
=
image
.
resize
((
width
,
height
))
if
image
.
width
/
image
.
height
>
200
:
width
,
height
=
image
.
height
*
180
,
image
.
height
image
=
image
.
resize
((
width
,
height
)
,
resample
=
Image
.
NEAREST
)
image
=
image
.
resize
((
width
,
height
))
if
image
.
height
/
image
.
width
>
200
:
width
,
height
=
image
.
width
,
image
.
width
*
180
image
=
image
.
resize
((
width
,
height
)
,
resample
=
Image
.
NEAREST
)
image
=
image
.
resize
((
width
,
height
))
return
image
@
override
def
_regularize_videos
(
self
,
videos
:
Sequence
[
"VideoInput"
],
**
kwargs
)
->
List
[
List
[
"ImageObject"
]]:
results
=
[]
def
_regularize_videos
(
self
,
videos
:
Sequence
[
"VideoInput"
],
**
kwargs
)
->
Tuple
[
List
[
List
[
"ImageObject"
]],
List
[
float
]]:
results
,
fps_per_video
=
[],
[]
for
video
in
videos
:
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
total_frames
=
video_stream
.
frames
sample_frames
=
self
.
_get_video_sample_frames
(
video_stream
,
**
kwargs
)
sample_indices
=
np
.
linspace
(
0
,
total_frames
-
1
,
sample_frames
).
astype
(
np
.
int32
)
sample_indices
=
self
.
_get_video_sample_indices
(
video_stream
,
**
kwargs
)
frames
:
List
[
"ImageObject"
]
=
[]
container
.
seek
(
0
)
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
...
...
@@ -812,8 +1074,43 @@ class Qwen2vlPlugin(BasePlugin):
frames
=
self
.
_regularize_images
(
frames
,
**
kwargs
)
results
.
append
(
frames
)
if
video_stream
.
duration
is
None
:
fps_per_video
.
append
(
2.0
)
else
:
fps_per_video
.
append
(
len
(
sample_indices
)
/
float
(
video_stream
.
duration
*
video_stream
.
time_base
))
return
results
return
results
,
fps_per_video
@
override
def
_get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
"ProcessorMixin"
,
)
->
Dict
[
str
,
"torch.Tensor"
]:
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
,
None
)
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
,
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)
mm_inputs
.
update
(
image_processor
(
images
,
return_tensors
=
"pt"
))
if
len
(
videos
)
!=
0
:
videos
,
fps_per_video
=
self
.
_regularize_videos
(
videos
,
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
image_min_pixels
=
getattr
(
processor
,
"video_min_pixels"
,
16
*
16
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)
mm_inputs
.
update
(
image_processor
(
images
=
None
,
videos
=
videos
,
return_tensors
=
"pt"
))
mm_inputs
[
"fps_per_video"
]
=
fps_per_video
return
mm_inputs
@
override
def
process_messages
(
...
...
@@ -821,17 +1118,23 @@ class Qwen2vlPlugin(BasePlugin):
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
merge_length
:
int
=
getattr
(
image_processor
,
"merge_size"
)
**
2
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
image_grid_thw
=
mm_inputs
.
get
(
"image_grid_thw"
,
[])
video_grid_thw
=
mm_inputs
.
get
(
"video_grid_thw"
,
[])
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
image_grid_thw
=
mm_inputs
.
get
(
"image_grid_thw"
,
[])
video_grid_thw
=
mm_inputs
.
get
(
"video_grid_thw"
,
[])
else
:
image_grid_thw
=
[
None
]
*
len
(
images
)
video_grid_thw
=
[
None
]
*
len
(
videos
)
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
...
...
@@ -869,15 +1172,24 @@ class Qwen2vlPlugin(BasePlugin):
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
fps_per_video
=
mm_inputs
.
pop
(
"fps_per_video"
,
[])
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
if
"second_per_grid_ts"
in
processor
.
model_input_names
and
fps_per_video
:
mm_inputs
[
"second_per_grid_ts"
]
=
[
image_processor
.
temporal_patch_size
/
fps
for
fps
in
fps_per_video
]
return
mm_inputs
@
dataclass
class
VideoLlavaPlugin
(
BasePlugin
):
@
override
def
process_messages
(
...
...
@@ -885,12 +1197,13 @@ class VideoLlavaPlugin(BasePlugin):
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
num_frames
=
0
has_images
=
"pixel_values_images"
in
mm_inputs
has_videos
=
"pixel_values_videos"
in
mm_inputs
...
...
@@ -907,7 +1220,7 @@ class VideoLlavaPlugin(BasePlugin):
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
+
1
video_seqlen
=
image_seqlen
*
num_frames
if
getattr
(
processor
,
"vision_feature_select_strategy"
)
==
"default"
:
if
getattr
(
processor
,
"vision_feature_select_strategy"
,
"default"
)
==
"default"
:
image_seqlen
-=
1
else
:
image_seqlen
,
video_seqlen
=
1
,
1
...
...
@@ -938,13 +1251,15 @@ class VideoLlavaPlugin(BasePlugin):
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
PLUGINS
=
{
...
...
@@ -956,18 +1271,32 @@ PLUGINS = {
"mllama"
:
MllamaPlugin
,
"paligemma"
:
PaliGemmaPlugin
,
"pixtral"
:
PixtralPlugin
,
"qwen2_vl"
:
Qwen2vlPlugin
,
"qwen2_audio"
:
Qwen2AudioPlugin
,
"qwen2_vl"
:
Qwen2VLPlugin
,
"video_llava"
:
VideoLlavaPlugin
,
}
def
register_mm_plugin
(
name
:
str
,
plugin_class
:
Type
[
"BasePlugin"
])
->
None
:
r
"""
Registers a multimodal plugin.
"""
if
name
in
PLUGINS
:
raise
ValueError
(
f
"Multimodal plugin
{
name
}
already exists."
)
PLUGINS
[
name
]
=
plugin_class
def
get_mm_plugin
(
name
:
str
,
image_token
:
Optional
[
str
]
=
None
,
video_token
:
Optional
[
str
]
=
None
,
audio_token
:
Optional
[
str
]
=
None
,
)
->
"BasePlugin"
:
plugin_class
=
PLUGINS
.
get
(
name
,
None
)
if
plugin_class
is
None
:
r
"""
Gets plugin for multimodal inputs.
"""
if
name
not
in
PLUGINS
:
raise
ValueError
(
f
"Multimodal plugin `
{
name
}
` not found."
)
return
plugin_class
(
image_token
,
video_token
)
return
PLUGINS
[
name
]
(
image_token
,
video_token
,
audio_token
)
src/llamafactory/data/parser.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -44,7 +44,8 @@ class DatasetAttr:
tools
:
Optional
[
str
]
=
None
images
:
Optional
[
str
]
=
None
videos
:
Optional
[
str
]
=
None
# rlhf columns
audios
:
Optional
[
str
]
=
None
# dpo columns
chosen
:
Optional
[
str
]
=
None
rejected
:
Optional
[
str
]
=
None
kto_tag
:
Optional
[
str
]
=
None
...
...
@@ -70,6 +71,26 @@ class DatasetAttr:
def
set_attr
(
self
,
key
:
str
,
obj
:
Dict
[
str
,
Any
],
default
:
Optional
[
Any
]
=
None
)
->
None
:
setattr
(
self
,
key
,
obj
.
get
(
key
,
default
))
def
join
(
self
,
attr
:
Dict
[
str
,
Any
])
->
None
:
self
.
set_attr
(
"formatting"
,
attr
,
default
=
"alpaca"
)
self
.
set_attr
(
"ranking"
,
attr
,
default
=
False
)
self
.
set_attr
(
"subset"
,
attr
)
self
.
set_attr
(
"split"
,
attr
,
default
=
"train"
)
self
.
set_attr
(
"folder"
,
attr
)
self
.
set_attr
(
"num_samples"
,
attr
)
if
"columns"
in
attr
:
column_names
=
[
"prompt"
,
"query"
,
"response"
,
"history"
,
"messages"
,
"system"
,
"tools"
]
column_names
+=
[
"images"
,
"videos"
,
"audios"
,
"chosen"
,
"rejected"
,
"kto_tag"
]
for
column_name
in
column_names
:
self
.
set_attr
(
column_name
,
attr
[
"columns"
])
if
"tags"
in
attr
:
tag_names
=
[
"role_tag"
,
"content_tag"
]
tag_names
+=
[
"user_tag"
,
"assistant_tag"
,
"observation_tag"
,
"function_tag"
,
"system_tag"
]
for
tag
in
tag_names
:
self
.
set_attr
(
tag
,
attr
[
"tags"
])
def
get_dataset_list
(
dataset_names
:
Optional
[
Sequence
[
str
]],
dataset_dir
:
str
)
->
List
[
"DatasetAttr"
]:
r
"""
...
...
@@ -127,36 +148,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
else
:
dataset_attr
=
DatasetAttr
(
"file"
,
dataset_name
=
dataset_info
[
name
][
"file_name"
])
dataset_attr
.
set_attr
(
"formatting"
,
dataset_info
[
name
],
default
=
"alpaca"
)
dataset_attr
.
set_attr
(
"ranking"
,
dataset_info
[
name
],
default
=
False
)
dataset_attr
.
set_attr
(
"subset"
,
dataset_info
[
name
])
dataset_attr
.
set_attr
(
"split"
,
dataset_info
[
name
],
default
=
"train"
)
dataset_attr
.
set_attr
(
"folder"
,
dataset_info
[
name
])
dataset_attr
.
set_attr
(
"num_samples"
,
dataset_info
[
name
])
if
"columns"
in
dataset_info
[
name
]:
column_names
=
[
"system"
,
"tools"
,
"images"
,
"videos"
,
"chosen"
,
"rejected"
,
"kto_tag"
]
if
dataset_attr
.
formatting
==
"alpaca"
:
column_names
.
extend
([
"prompt"
,
"query"
,
"response"
,
"history"
])
else
:
column_names
.
extend
([
"messages"
])
for
column_name
in
column_names
:
dataset_attr
.
set_attr
(
column_name
,
dataset_info
[
name
][
"columns"
])
if
dataset_attr
.
formatting
==
"sharegpt"
and
"tags"
in
dataset_info
[
name
]:
tag_names
=
(
"role_tag"
,
"content_tag"
,
"user_tag"
,
"assistant_tag"
,
"observation_tag"
,
"function_tag"
,
"system_tag"
,
)
for
tag
in
tag_names
:
dataset_attr
.
set_attr
(
tag
,
dataset_info
[
name
][
"tags"
])
dataset_attr
.
join
(
dataset_info
[
name
])
dataset_list
.
append
(
dataset_attr
)
return
dataset_list
src/llamafactory/data/preprocess.py
deleted
100644 → 0
View file @
37b0ad9f
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Callable
,
Literal
,
Optional
,
Tuple
from
.processors.feedback
import
preprocess_feedback_dataset
from
.processors.pairwise
import
preprocess_pairwise_dataset
,
print_pairwise_dataset_example
from
.processors.pretrain
import
preprocess_pretrain_dataset
,
print_pretrain_dataset_example
from
.processors.supervised
import
(
preprocess_packed_supervised_dataset
,
preprocess_supervised_dataset
,
print_supervised_dataset_example
,
)
from
.processors.unsupervised
import
preprocess_unsupervised_dataset
,
print_unsupervised_dataset_example
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
..hparams
import
DataArguments
from
.template
import
Template
def
get_preprocess_and_print_func
(
data_args
:
"DataArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
do_generate
:
bool
=
False
,
)
->
Tuple
[
Callable
,
Callable
]:
if
stage
==
"pt"
:
preprocess_func
=
partial
(
preprocess_pretrain_dataset
,
tokenizer
=
tokenizer
,
data_args
=
data_args
,
)
print_function
=
partial
(
print_pretrain_dataset_example
,
tokenizer
=
tokenizer
)
elif
stage
==
"sft"
and
not
do_generate
:
if
data_args
.
packing
:
if
data_args
.
neat_packing
:
# hack datasets to have int32 attention mask
from
datasets.arrow_writer
import
OptimizedTypedSequence
,
TypedSequence
def
__init__
(
self
,
data
,
**
kwargs
):
return
TypedSequence
.
__init__
(
self
,
data
,
type
=
kwargs
.
pop
(
"type"
,
None
),
try_type
=
kwargs
.
pop
(
"try_type"
,
None
),
optimized_int_type
=
kwargs
.
pop
(
"optimized_int_type"
,
None
),
)
OptimizedTypedSequence
.
__init__
=
__init__
preprocess_func
=
partial
(
preprocess_packed_supervised_dataset
,
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
data_args
=
data_args
,
)
else
:
preprocess_func
=
partial
(
preprocess_supervised_dataset
,
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
data_args
=
data_args
,
)
print_function
=
partial
(
print_supervised_dataset_example
,
tokenizer
=
tokenizer
)
elif
stage
==
"rm"
:
preprocess_func
=
partial
(
preprocess_pairwise_dataset
,
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
data_args
=
data_args
,
)
print_function
=
partial
(
print_pairwise_dataset_example
,
tokenizer
=
tokenizer
)
elif
stage
==
"kto"
:
preprocess_func
=
partial
(
preprocess_feedback_dataset
,
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
data_args
=
data_args
,
)
print_function
=
partial
(
print_supervised_dataset_example
,
tokenizer
=
tokenizer
)
else
:
preprocess_func
=
partial
(
preprocess_unsupervised_dataset
,
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
data_args
=
data_args
,
)
print_function
=
partial
(
print_unsupervised_dataset_example
,
tokenizer
=
tokenizer
)
return
preprocess_func
,
print_function
src/llamafactory/data/processor/__init__.py
0 → 100644
View file @
317a82e2
from
.feedback
import
FeedbackDatasetProcessor
from
.pairwise
import
PairwiseDatasetProcessor
from
.pretrain
import
PretrainDatasetProcessor
from
.processor_utils
import
DatasetProcessor
from
.supervised
import
PackedSupervisedDatasetProcessor
,
SupervisedDatasetProcessor
from
.unsupervised
import
UnsupervisedDatasetProcessor
__all__
=
[
"DatasetProcessor"
,
"FeedbackDatasetProcessor"
,
"PairwiseDatasetProcessor"
,
"PretrainDatasetProcessor"
,
"PackedSupervisedDatasetProcessor"
,
"SupervisedDatasetProcessor"
,
"UnsupervisedDatasetProcessor"
,
]
src/llamafactory/data/processor/feedback.py
0 → 100644
View file @
317a82e2
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
.processor_utils
import
DatasetProcessor
,
infer_seqlen
if
TYPE_CHECKING
:
from
..mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
logger
=
logging
.
get_logger
(
__name__
)
class
FeedbackDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
self
,
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
kl_response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
],
bool
]:
if
response
[
0
][
"content"
]:
# desired example
kto_tag
=
True
messages
=
prompt
+
[
response
[
0
]]
else
:
# undesired example
kto_tag
=
False
messages
=
prompt
+
[
response
[
1
]]
if
kl_response
[
0
][
"content"
]:
kl_messages
=
prompt
+
[
kl_response
[
0
]]
else
:
kl_messages
=
prompt
+
[
kl_response
[
1
]]
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
messages
,
images
,
videos
,
audios
,
self
.
processor
)
kl_messages
=
self
.
template
.
mm_plugin
.
process_messages
(
kl_messages
,
images
,
videos
,
audios
,
self
.
processor
)
prompt_ids
,
response_ids
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
messages
,
system
,
tools
)
kl_prompt_ids
,
kl_response_ids
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
kl_messages
,
system
,
tools
)
if
self
.
template
.
efficient_eos
:
response_ids
+=
[
self
.
tokenizer
.
eos_token_id
]
kl_response_ids
+=
[
self
.
tokenizer
.
eos_token_id
]
prompt_ids
,
_
=
self
.
template
.
mm_plugin
.
process_token_ids
(
prompt_ids
,
None
,
images
,
videos
,
audios
,
self
.
tokenizer
,
self
.
processor
)
kl_prompt_ids
,
_
=
self
.
template
.
mm_plugin
.
process_token_ids
(
kl_prompt_ids
,
None
,
images
,
videos
,
audios
,
self
.
tokenizer
,
self
.
processor
)
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
len
(
response_ids
),
self
.
data_args
.
cutoff_len
)
prompt_ids
=
prompt_ids
[:
source_len
]
response_ids
=
response_ids
[:
target_len
]
kl_source_len
,
kl_target_len
=
infer_seqlen
(
len
(
kl_prompt_ids
),
len
(
kl_response_ids
),
self
.
data_args
.
cutoff_len
)
kl_prompt_ids
=
kl_prompt_ids
[:
kl_source_len
]
kl_response_ids
=
kl_response_ids
[:
kl_target_len
]
input_ids
=
prompt_ids
+
response_ids
labels
=
[
IGNORE_INDEX
]
*
source_len
+
response_ids
kl_input_ids
=
kl_prompt_ids
+
kl_response_ids
kl_labels
=
[
IGNORE_INDEX
]
*
kl_source_len
+
kl_response_ids
return
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
def
preprocess_dataset
(
self
,
examples
:
Dict
[
str
,
List
[
Any
]])
->
Dict
[
str
,
List
[
Any
]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response
=
examples
[
"_response"
][::
-
1
]
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
=
self
.
_encode_data_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
kl_response
=
kl_response
[
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
audios
=
examples
[
"_audios"
][
i
]
or
[],
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"labels"
].
append
(
labels
)
model_inputs
[
"kl_input_ids"
].
append
(
kl_input_ids
)
model_inputs
[
"kl_attention_mask"
].
append
([
1
]
*
len
(
kl_input_ids
))
model_inputs
[
"kl_labels"
].
append
(
kl_labels
)
model_inputs
[
"kto_tags"
].
append
(
kto_tag
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
model_inputs
[
"audios"
].
append
(
examples
[
"_audios"
][
i
])
desirable_num
=
sum
([
1
for
tag
in
model_inputs
[
"kto_tags"
]
if
tag
])
undesirable_num
=
len
(
model_inputs
[
"kto_tags"
])
-
desirable_num
if
desirable_num
==
0
or
undesirable_num
==
0
:
logger
.
warning_rank0
(
"Your dataset only has one preference type."
)
return
model_inputs
def
print_data_example
(
self
,
example
:
Dict
[
str
,
List
[
int
]])
->
None
:
valid_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"labels"
]))
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
f
"labels:
\n
{
self
.
tokenizer
.
decode
(
valid_labels
,
skip_special_tokens
=
False
)
}
"
)
src/llamafactory/data/processor/pairwise.py
0 → 100644
View file @
317a82e2
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
.processor_utils
import
DatasetProcessor
,
infer_seqlen
if
TYPE_CHECKING
:
from
..mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
logger
=
logging
.
get_logger
(
__name__
)
class
PairwiseDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
self
,
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
]]:
chosen_messages
=
self
.
template
.
mm_plugin
.
process_messages
(
prompt
+
[
response
[
0
]],
images
,
videos
,
audios
,
self
.
processor
)
rejected_messages
=
self
.
template
.
mm_plugin
.
process_messages
(
prompt
+
[
response
[
1
]],
images
,
videos
,
audios
,
self
.
processor
)
prompt_ids
,
chosen_ids
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
chosen_messages
,
system
,
tools
)
_
,
rejected_ids
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
rejected_messages
,
system
,
tools
)
if
self
.
template
.
efficient_eos
:
chosen_ids
+=
[
self
.
tokenizer
.
eos_token_id
]
rejected_ids
+=
[
self
.
tokenizer
.
eos_token_id
]
prompt_ids
,
_
=
self
.
template
.
mm_plugin
.
process_token_ids
(
prompt_ids
,
None
,
images
,
videos
,
audios
,
self
.
tokenizer
,
self
.
processor
)
# consider the response is more important
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
max
(
len
(
chosen_ids
),
len
(
rejected_ids
)),
self
.
data_args
.
cutoff_len
)
prompt_ids
=
prompt_ids
[:
source_len
]
chosen_ids
=
chosen_ids
[:
target_len
]
rejected_ids
=
rejected_ids
[:
target_len
]
chosen_input_ids
=
prompt_ids
+
chosen_ids
chosen_labels
=
[
IGNORE_INDEX
]
*
source_len
+
chosen_ids
rejected_input_ids
=
prompt_ids
+
rejected_ids
rejected_labels
=
[
IGNORE_INDEX
]
*
source_len
+
rejected_ids
return
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
def
preprocess_dataset
(
self
,
examples
:
Dict
[
str
,
List
[
Any
]])
->
Dict
[
str
,
List
[
Any
]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
=
self
.
_encode_data_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
audios
=
examples
[
"_audios"
][
i
]
or
[],
)
model_inputs
[
"chosen_input_ids"
].
append
(
chosen_input_ids
)
model_inputs
[
"chosen_attention_mask"
].
append
([
1
]
*
len
(
chosen_input_ids
))
model_inputs
[
"chosen_labels"
].
append
(
chosen_labels
)
model_inputs
[
"rejected_input_ids"
].
append
(
rejected_input_ids
)
model_inputs
[
"rejected_attention_mask"
].
append
([
1
]
*
len
(
rejected_input_ids
))
model_inputs
[
"rejected_labels"
].
append
(
rejected_labels
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
model_inputs
[
"audios"
].
append
(
examples
[
"_audios"
][
i
])
return
model_inputs
def
print_data_example
(
self
,
example
:
Dict
[
str
,
List
[
int
]])
->
None
:
valid_chosen_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"chosen_labels"
]))
valid_rejected_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"rejected_labels"
]))
print
(
"chosen_input_ids:
\n
{}"
.
format
(
example
[
"chosen_input_ids"
]))
print
(
"chosen_inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"chosen_input_ids"
],
skip_special_tokens
=
False
))
)
print
(
"chosen_label_ids:
\n
{}"
.
format
(
example
[
"chosen_labels"
]))
print
(
f
"chosen_labels:
\n
{
self
.
tokenizer
.
decode
(
valid_chosen_labels
,
skip_special_tokens
=
False
)
}
"
)
print
(
"rejected_input_ids:
\n
{}"
.
format
(
example
[
"rejected_input_ids"
]))
print
(
"rejected_inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"rejected_input_ids"
],
skip_special_tokens
=
False
)
)
)
print
(
"rejected_label_ids:
\n
{}"
.
format
(
example
[
"rejected_labels"
]))
print
(
f
"rejected_labels:
\n
{
self
.
tokenizer
.
decode
(
valid_rejected_labels
,
skip_special_tokens
=
False
)
}
"
)
src/llamafactory/data/processor/pretrain.py
0 → 100644
View file @
317a82e2
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
from
itertools
import
chain
from
typing
import
Any
,
Dict
,
List
from
.processor_utils
import
DatasetProcessor
@
dataclass
class
PretrainDatasetProcessor
(
DatasetProcessor
):
def
preprocess_dataset
(
self
,
examples
:
Dict
[
str
,
List
[
Any
]])
->
Dict
[
str
,
List
[
Any
]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token
=
"<|end_of_text|>"
if
self
.
data_args
.
template
==
"llama3"
else
self
.
tokenizer
.
eos_token
text_examples
=
[
messages
[
0
][
"content"
]
+
eos_token
for
messages
in
examples
[
"_prompt"
]]
if
not
self
.
data_args
.
packing
:
if
getattr
(
self
.
tokenizer
,
"add_bos_token"
,
False
):
text_examples
=
[
self
.
tokenizer
.
bos_token
+
example
for
example
in
text_examples
]
result
=
self
.
tokenizer
(
text_examples
,
add_special_tokens
=
False
,
truncation
=
True
,
max_length
=
self
.
data_args
.
cutoff_len
)
else
:
tokenized_examples
=
self
.
tokenizer
(
text_examples
,
add_special_tokens
=
False
)
concatenated_examples
=
{
k
:
list
(
chain
(
*
tokenized_examples
[
k
]))
for
k
in
tokenized_examples
.
keys
()}
total_length
=
len
(
concatenated_examples
[
list
(
concatenated_examples
.
keys
())[
0
]])
block_size
=
self
.
data_args
.
cutoff_len
total_length
=
(
total_length
//
block_size
)
*
block_size
result
=
{
k
:
[
t
[
i
:
i
+
block_size
]
for
i
in
range
(
0
,
total_length
,
block_size
)]
for
k
,
t
in
concatenated_examples
.
items
()
}
if
getattr
(
self
.
tokenizer
,
"add_bos_token"
,
False
):
for
i
in
range
(
len
(
result
[
"input_ids"
])):
result
[
"input_ids"
][
i
][
0
]
=
self
.
tokenizer
.
bos_token_id
return
result
def
print_data_example
(
self
,
example
:
Dict
[
str
,
List
[
int
]])
->
None
:
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
src/llamafactory/data/processor
s
/processor_utils.py
→
src/llamafactory/data/processor/processor_utils.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,7 +13,42 @@
# limitations under the License.
import
bisect
from
typing
import
List
,
Sequence
,
Tuple
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
...hparams
import
DataArguments
from
..template
import
Template
@
dataclass
class
DatasetProcessor
(
ABC
):
r
"""
A class for data processors.
"""
template
:
"Template"
tokenizer
:
"PreTrainedTokenizer"
processor
:
Optional
[
"ProcessorMixin"
]
data_args
:
"DataArguments"
@
abstractmethod
def
preprocess_dataset
(
self
,
examples
:
Dict
[
str
,
List
[
Any
]])
->
Dict
[
str
,
List
[
Any
]]:
r
"""
Builds model inputs from the examples.
"""
...
@
abstractmethod
def
print_data_example
(
self
,
example
:
Dict
[
str
,
List
[
int
]])
->
None
:
r
"""
Print a data example to stdout.
"""
...
def
search_for_fit
(
numbers
:
Sequence
[
int
],
capacity
:
int
)
->
int
:
...
...
src/llamafactory/data/processor/supervised.py
0 → 100644
View file @
317a82e2
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
.processor_utils
import
DatasetProcessor
,
greedy_knapsack
,
infer_seqlen
if
TYPE_CHECKING
:
from
..mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
logger
=
logging
.
get_logger
(
__name__
)
@
dataclass
class
SupervisedDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
self
,
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
prompt
+
response
,
images
,
videos
,
audios
,
self
.
processor
)
input_ids
,
labels
=
self
.
template
.
mm_plugin
.
process_token_ids
(
[],
[],
images
,
videos
,
audios
,
self
.
tokenizer
,
self
.
processor
)
encoded_pairs
=
self
.
template
.
encode_multiturn
(
self
.
tokenizer
,
messages
,
system
,
tools
)
total_length
=
len
(
input_ids
)
+
(
1
if
self
.
template
.
efficient_eos
else
0
)
if
self
.
data_args
.
mask_history
:
encoded_pairs
=
encoded_pairs
[::
-
1
]
# high priority for last turns
for
turn_idx
,
(
source_ids
,
target_ids
)
in
enumerate
(
encoded_pairs
):
if
total_length
>=
self
.
data_args
.
cutoff_len
:
break
source_len
,
target_len
=
infer_seqlen
(
len
(
source_ids
),
len
(
target_ids
),
self
.
data_args
.
cutoff_len
-
total_length
)
source_ids
=
source_ids
[:
source_len
]
target_ids
=
target_ids
[:
target_len
]
total_length
+=
source_len
+
target_len
if
self
.
data_args
.
train_on_prompt
:
source_label
=
source_ids
elif
self
.
template
.
efficient_eos
:
source_label
=
[
self
.
tokenizer
.
eos_token_id
]
+
[
IGNORE_INDEX
]
*
(
source_len
-
1
)
else
:
source_label
=
[
IGNORE_INDEX
]
*
source_len
if
self
.
data_args
.
mask_history
and
turn_idx
!=
0
:
# train on the last turn only
target_label
=
[
IGNORE_INDEX
]
*
target_len
else
:
target_label
=
target_ids
if
self
.
data_args
.
mask_history
:
# reversed sequences
input_ids
=
source_ids
+
target_ids
+
input_ids
labels
=
source_label
+
target_label
+
labels
else
:
input_ids
+=
source_ids
+
target_ids
labels
+=
source_label
+
target_label
if
self
.
template
.
efficient_eos
:
input_ids
+=
[
self
.
tokenizer
.
eos_token_id
]
labels
+=
[
self
.
tokenizer
.
eos_token_id
]
return
input_ids
,
labels
def
preprocess_dataset
(
self
,
examples
:
Dict
[
str
,
List
[
Any
]])
->
Dict
[
str
,
List
[
Any
]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
!=
1
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
=
self
.
_encode_data_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
audios
=
examples
[
"_audios"
][
i
]
or
[],
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"labels"
].
append
(
labels
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
model_inputs
[
"audios"
].
append
(
examples
[
"_audios"
][
i
])
return
model_inputs
def
print_data_example
(
self
,
example
:
Dict
[
str
,
List
[
int
]])
->
None
:
valid_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"labels"
]))
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
f
"labels:
\n
{
self
.
tokenizer
.
decode
(
valid_labels
,
skip_special_tokens
=
False
)
}
"
)
@
dataclass
class
PackedSupervisedDatasetProcessor
(
SupervisedDatasetProcessor
):
def
preprocess_dataset
(
self
,
examples
:
Dict
[
str
,
List
[
Any
]])
->
Dict
[
str
,
List
[
Any
]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num
=
0
batch_input_ids
,
batch_labels
,
batch_images
,
batch_videos
,
batch_audios
=
[],
[],
[],
[],
[]
lengths
=
[]
length2indexes
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
!=
1
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
=
self
.
_encode_data_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
audios
=
examples
[
"_audios"
][
i
]
or
[],
)
length
=
len
(
input_ids
)
if
length
>
self
.
data_args
.
cutoff_len
:
logger
.
warning_rank0
(
f
"Dropped lengthy example with length
{
length
}
>
{
self
.
data_args
.
cutoff_len
}
."
)
else
:
lengths
.
append
(
length
)
length2indexes
[
length
].
append
(
valid_num
)
batch_input_ids
.
append
(
input_ids
)
batch_labels
.
append
(
labels
)
batch_images
.
append
(
examples
[
"_images"
][
i
]
or
[])
batch_videos
.
append
(
examples
[
"_videos"
][
i
]
or
[])
batch_audios
.
append
(
examples
[
"_audios"
][
i
]
or
[])
valid_num
+=
1
model_inputs
=
defaultdict
(
list
)
knapsacks
=
greedy_knapsack
(
lengths
,
self
.
data_args
.
cutoff_len
)
for
knapsack
in
knapsacks
:
packed_input_ids
,
packed_attention_masks
,
packed_labels
=
[],
[],
[]
packed_images
,
packed_videos
,
packed_audios
=
[],
[],
[]
for
i
,
length
in
enumerate
(
knapsack
):
index
=
length2indexes
[
length
].
pop
()
packed_input_ids
+=
batch_input_ids
[
index
]
packed_labels
+=
batch_labels
[
index
]
packed_images
+=
batch_images
[
index
]
packed_videos
+=
batch_videos
[
index
]
packed_audios
+=
batch_audios
[
index
]
if
self
.
data_args
.
neat_packing
:
packed_attention_masks
+=
[
i
+
1
]
*
len
(
batch_input_ids
[
index
])
# start from 1
else
:
packed_attention_masks
+=
[
1
]
*
len
(
batch_input_ids
[
index
])
if
len
(
packed_input_ids
)
<
self
.
data_args
.
cutoff_len
+
1
:
# avoid flash_attn drops attn mask
pad_length
=
self
.
data_args
.
cutoff_len
-
len
(
packed_input_ids
)
+
1
packed_input_ids
+=
[
self
.
tokenizer
.
pad_token_id
]
*
pad_length
packed_labels
+=
[
IGNORE_INDEX
]
*
pad_length
if
self
.
data_args
.
neat_packing
:
packed_attention_masks
+=
[
0
]
*
pad_length
else
:
packed_attention_masks
+=
[
1
]
*
pad_length
# more efficient flash_attn
if
len
(
packed_input_ids
)
!=
self
.
data_args
.
cutoff_len
+
1
:
raise
ValueError
(
"The length of packed example should be identical to the cutoff length."
)
model_inputs
[
"input_ids"
].
append
(
packed_input_ids
)
model_inputs
[
"attention_mask"
].
append
(
packed_attention_masks
)
model_inputs
[
"labels"
].
append
(
packed_labels
)
model_inputs
[
"images"
].
append
(
packed_images
or
None
)
model_inputs
[
"videos"
].
append
(
packed_videos
or
None
)
model_inputs
[
"audios"
].
append
(
packed_audios
or
None
)
return
model_inputs
src/llamafactory/data/processor/unsupervised.py
0 → 100644
View file @
317a82e2
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
..data_utils
import
Role
from
.processor_utils
import
DatasetProcessor
,
infer_seqlen
if
TYPE_CHECKING
:
from
..mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
logger
=
logging
.
get_logger
(
__name__
)
class
UnsupervisedDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
self
,
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
if
len
(
response
)
==
1
:
messages
=
prompt
+
response
else
:
messages
=
prompt
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
messages
,
images
,
videos
,
audios
,
self
.
processor
)
input_ids
,
labels
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
messages
,
system
,
tools
)
if
self
.
template
.
efficient_eos
:
labels
+=
[
self
.
tokenizer
.
eos_token_id
]
input_ids
,
_
=
self
.
template
.
mm_plugin
.
process_token_ids
(
input_ids
,
None
,
images
,
videos
,
audios
,
self
.
tokenizer
,
self
.
processor
)
source_len
,
target_len
=
infer_seqlen
(
len
(
input_ids
),
len
(
labels
),
self
.
data_args
.
cutoff_len
)
input_ids
=
input_ids
[:
source_len
]
labels
=
labels
[:
target_len
]
return
input_ids
,
labels
def
preprocess_dataset
(
self
,
examples
:
Dict
[
str
,
List
[
Any
]])
->
Dict
[
str
,
List
[
Any
]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
=
self
.
_encode_data_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
audios
=
examples
[
"_audios"
][
i
]
or
[],
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"labels"
].
append
(
labels
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
model_inputs
[
"audios"
].
append
(
examples
[
"_audios"
][
i
])
return
model_inputs
def
print_data_example
(
self
,
example
:
Dict
[
str
,
List
[
int
]])
->
None
:
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
"labels:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"labels"
],
skip_special_tokens
=
False
)))
src/llamafactory/data/processors/__init__.py
deleted
100644 → 0
View file @
37b0ad9f
src/llamafactory/data/processors/feedback.py
deleted
100644 → 0
View file @
37b0ad9f
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
.processor_utils
import
infer_seqlen
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
...hparams
import
DataArguments
from
..mm_plugin
import
ImageInput
,
VideoInput
from
..template
import
Template
logger
=
logging
.
get_logger
(
__name__
)
def
_encode_feedback_example
(
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
kl_response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
cutoff_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
],
bool
]:
if
response
[
0
][
"content"
]:
# desired example
kto_tag
=
True
messages
=
prompt
+
[
response
[
0
]]
else
:
# undesired example
kto_tag
=
False
messages
=
prompt
+
[
response
[
1
]]
if
kl_response
[
0
][
"content"
]:
kl_messages
=
prompt
+
[
kl_response
[
0
]]
else
:
kl_messages
=
prompt
+
[
kl_response
[
1
]]
messages
=
template
.
mm_plugin
.
process_messages
(
messages
,
images
,
videos
,
processor
)
kl_messages
=
template
.
mm_plugin
.
process_messages
(
kl_messages
,
images
,
videos
,
processor
)
prompt_ids
,
response_ids
=
template
.
encode_oneturn
(
tokenizer
,
messages
,
system
,
tools
)
kl_prompt_ids
,
kl_response_ids
=
template
.
encode_oneturn
(
tokenizer
,
kl_messages
,
system
,
tools
)
if
template
.
efficient_eos
:
response_ids
+=
[
tokenizer
.
eos_token_id
]
kl_response_ids
+=
[
tokenizer
.
eos_token_id
]
prompt_ids
,
_
=
template
.
mm_plugin
.
process_token_ids
(
prompt_ids
,
None
,
images
,
videos
,
tokenizer
,
processor
)
kl_prompt_ids
,
_
=
template
.
mm_plugin
.
process_token_ids
(
kl_prompt_ids
,
None
,
images
,
videos
,
tokenizer
,
processor
)
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
len
(
response_ids
),
cutoff_len
)
prompt_ids
=
prompt_ids
[:
source_len
]
response_ids
=
response_ids
[:
target_len
]
kl_source_len
,
kl_target_len
=
infer_seqlen
(
len
(
kl_prompt_ids
),
len
(
kl_response_ids
),
cutoff_len
)
kl_prompt_ids
=
kl_prompt_ids
[:
kl_source_len
]
kl_response_ids
=
kl_response_ids
[:
kl_target_len
]
input_ids
=
prompt_ids
+
response_ids
labels
=
[
IGNORE_INDEX
]
*
source_len
+
response_ids
kl_input_ids
=
kl_prompt_ids
+
kl_response_ids
kl_labels
=
[
IGNORE_INDEX
]
*
kl_source_len
+
kl_response_ids
return
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
def
preprocess_feedback_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
List
[
Any
]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response
=
examples
[
"_response"
][::
-
1
]
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
=
_encode_feedback_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
kl_response
=
kl_response
[
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
cutoff_len
=
data_args
.
cutoff_len
,
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"labels"
].
append
(
labels
)
model_inputs
[
"kl_input_ids"
].
append
(
kl_input_ids
)
model_inputs
[
"kl_attention_mask"
].
append
([
1
]
*
len
(
kl_input_ids
))
model_inputs
[
"kl_labels"
].
append
(
kl_labels
)
model_inputs
[
"kto_tags"
].
append
(
kto_tag
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
desirable_num
=
sum
([
1
for
tag
in
model_inputs
[
"kto_tags"
]
if
tag
])
undesirable_num
=
len
(
model_inputs
[
"kto_tags"
])
-
desirable_num
if
desirable_num
==
0
or
undesirable_num
==
0
:
logger
.
warning_rank0
(
"Your dataset only has one preference type."
)
return
model_inputs
src/llamafactory/data/processors/mm_utils.py
deleted
100644 → 0
View file @
37b0ad9f
from
typing
import
TYPE_CHECKING
,
List
,
Sequence
from
...extras.packages
import
is_pillow_available
if
is_pillow_available
():
from
PIL
import
Image
if
TYPE_CHECKING
:
from
numpy.typing
import
NDArray
from
PIL.Image
import
Image
as
ImageObject
from
transformers
import
ProcessorMixin
from
transformers.image_processing_utils
import
BaseImageProcessor
def
get_pixel_values
(
images
:
Sequence
[
"ImageObject"
],
processor
:
"ProcessorMixin"
)
->
"NDArray"
:
# process visual inputs (currently only supports a single image)
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
image
=
images
[
0
]
if
len
(
images
)
!=
0
else
Image
.
new
(
"RGB"
,
(
100
,
100
),
(
255
,
255
,
255
))
return
image_processor
(
image
,
return_tensors
=
"pt"
)[
"pixel_values"
][
0
]
# shape (C, H, W)
def
get_paligemma_token_type_ids
(
input_len
:
int
,
processor
:
"ProcessorMixin"
)
->
List
[
int
]:
# get paligemma token type ids for computing loss
image_seq_length
=
getattr
(
processor
,
"image_seq_length"
)
return
[
0
]
*
image_seq_length
+
[
1
]
*
(
input_len
-
image_seq_length
)
src/llamafactory/data/processors/pairwise.py
deleted
100644 → 0
View file @
37b0ad9f
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
.processor_utils
import
infer_seqlen
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
...hparams
import
DataArguments
from
..mm_plugin
import
ImageInput
,
VideoInput
from
..template
import
Template
logger
=
logging
.
get_logger
(
__name__
)
def
_encode_pairwise_example
(
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
cutoff_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
]]:
chosen_messages
=
template
.
mm_plugin
.
process_messages
(
prompt
+
[
response
[
0
]],
images
,
videos
,
processor
)
rejected_messages
=
template
.
mm_plugin
.
process_messages
(
prompt
+
[
response
[
1
]],
images
,
videos
,
processor
)
prompt_ids
,
chosen_ids
=
template
.
encode_oneturn
(
tokenizer
,
chosen_messages
,
system
,
tools
)
_
,
rejected_ids
=
template
.
encode_oneturn
(
tokenizer
,
rejected_messages
,
system
,
tools
)
if
template
.
efficient_eos
:
chosen_ids
+=
[
tokenizer
.
eos_token_id
]
rejected_ids
+=
[
tokenizer
.
eos_token_id
]
prompt_ids
,
_
=
template
.
mm_plugin
.
process_token_ids
(
prompt_ids
,
None
,
images
,
videos
,
tokenizer
,
processor
)
# consider the response is more important
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
max
(
len
(
chosen_ids
),
len
(
rejected_ids
)),
cutoff_len
)
prompt_ids
=
prompt_ids
[:
source_len
]
chosen_ids
=
chosen_ids
[:
target_len
]
rejected_ids
=
rejected_ids
[:
target_len
]
chosen_input_ids
=
prompt_ids
+
chosen_ids
chosen_labels
=
[
IGNORE_INDEX
]
*
source_len
+
chosen_ids
rejected_input_ids
=
prompt_ids
+
rejected_ids
rejected_labels
=
[
IGNORE_INDEX
]
*
source_len
+
rejected_ids
return
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
def
preprocess_pairwise_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
List
[
Any
]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
=
_encode_pairwise_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
cutoff_len
=
data_args
.
cutoff_len
,
)
model_inputs
[
"chosen_input_ids"
].
append
(
chosen_input_ids
)
model_inputs
[
"chosen_attention_mask"
].
append
([
1
]
*
len
(
chosen_input_ids
))
model_inputs
[
"chosen_labels"
].
append
(
chosen_labels
)
model_inputs
[
"rejected_input_ids"
].
append
(
rejected_input_ids
)
model_inputs
[
"rejected_attention_mask"
].
append
([
1
]
*
len
(
rejected_input_ids
))
model_inputs
[
"rejected_labels"
].
append
(
rejected_labels
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
return
model_inputs
def
print_pairwise_dataset_example
(
example
:
Dict
[
str
,
List
[
int
]],
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
valid_chosen_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"chosen_labels"
]))
valid_rejected_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"rejected_labels"
]))
print
(
"chosen_input_ids:
\n
{}"
.
format
(
example
[
"chosen_input_ids"
]))
print
(
"chosen_inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"chosen_input_ids"
],
skip_special_tokens
=
False
)))
print
(
"chosen_label_ids:
\n
{}"
.
format
(
example
[
"chosen_labels"
]))
print
(
f
"chosen_labels:
\n
{
tokenizer
.
decode
(
valid_chosen_labels
,
skip_special_tokens
=
False
)
}
"
)
print
(
"rejected_input_ids:
\n
{}"
.
format
(
example
[
"rejected_input_ids"
]))
print
(
"rejected_inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"rejected_input_ids"
],
skip_special_tokens
=
False
)))
print
(
"rejected_label_ids:
\n
{}"
.
format
(
example
[
"rejected_labels"
]))
print
(
f
"rejected_labels:
\n
{
tokenizer
.
decode
(
valid_rejected_labels
,
skip_special_tokens
=
False
)
}
"
)
src/llamafactory/data/processors/pretrain.py
deleted
100644 → 0
View file @
37b0ad9f
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
itertools
import
chain
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
from
...hparams
import
DataArguments
def
preprocess_pretrain_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
tokenizer
:
"PreTrainedTokenizer"
,
data_args
:
"DataArguments"
)
->
Dict
[
str
,
List
[
Any
]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token
=
"<|end_of_text|>"
if
data_args
.
template
==
"llama3"
else
tokenizer
.
eos_token
text_examples
=
[
messages
[
0
][
"content"
]
+
eos_token
for
messages
in
examples
[
"_prompt"
]]
if
not
data_args
.
packing
:
if
data_args
.
template
==
"gemma"
:
text_examples
=
[
tokenizer
.
bos_token
+
example
for
example
in
text_examples
]
result
=
tokenizer
(
text_examples
,
add_special_tokens
=
False
,
truncation
=
True
,
max_length
=
data_args
.
cutoff_len
)
else
:
tokenized_examples
=
tokenizer
(
text_examples
,
add_special_tokens
=
False
)
concatenated_examples
=
{
k
:
list
(
chain
(
*
tokenized_examples
[
k
]))
for
k
in
tokenized_examples
.
keys
()}
total_length
=
len
(
concatenated_examples
[
list
(
concatenated_examples
.
keys
())[
0
]])
block_size
=
data_args
.
cutoff_len
total_length
=
(
total_length
//
block_size
)
*
block_size
result
=
{
k
:
[
t
[
i
:
i
+
block_size
]
for
i
in
range
(
0
,
total_length
,
block_size
)]
for
k
,
t
in
concatenated_examples
.
items
()
}
if
data_args
.
template
==
"gemma"
:
for
i
in
range
(
len
(
result
[
"input_ids"
])):
result
[
"input_ids"
][
i
][
0
]
=
tokenizer
.
bos_token_id
return
result
def
print_pretrain_dataset_example
(
example
:
Dict
[
str
,
List
[
int
]],
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
src/llamafactory/data/processors/supervised.py
deleted
100644 → 0
View file @
37b0ad9f
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
.processor_utils
import
greedy_knapsack
,
infer_seqlen
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
...hparams
import
DataArguments
from
..mm_plugin
import
ImageInput
,
VideoInput
from
..template
import
Template
logger
=
logging
.
get_logger
(
__name__
)
def
_encode_supervised_example
(
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
cutoff_len
:
int
,
train_on_prompt
:
bool
,
mask_history
:
bool
,
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
messages
=
template
.
mm_plugin
.
process_messages
(
prompt
+
response
,
images
,
videos
,
processor
)
input_ids
,
labels
=
template
.
mm_plugin
.
process_token_ids
([],
[],
images
,
videos
,
tokenizer
,
processor
)
encoded_pairs
=
template
.
encode_multiturn
(
tokenizer
,
messages
,
system
,
tools
)
total_length
=
len
(
input_ids
)
+
(
1
if
template
.
efficient_eos
else
0
)
if
mask_history
:
encoded_pairs
=
encoded_pairs
[::
-
1
]
# high priority for last turns
for
turn_idx
,
(
source_ids
,
target_ids
)
in
enumerate
(
encoded_pairs
):
if
total_length
>=
cutoff_len
:
break
source_len
,
target_len
=
infer_seqlen
(
len
(
source_ids
),
len
(
target_ids
),
cutoff_len
-
total_length
)
source_ids
=
source_ids
[:
source_len
]
target_ids
=
target_ids
[:
target_len
]
total_length
+=
source_len
+
target_len
if
train_on_prompt
:
source_label
=
source_ids
elif
template
.
efficient_eos
:
source_label
=
[
tokenizer
.
eos_token_id
]
+
[
IGNORE_INDEX
]
*
(
source_len
-
1
)
else
:
source_label
=
[
IGNORE_INDEX
]
*
source_len
if
mask_history
and
turn_idx
!=
0
:
# train on the last turn only
target_label
=
[
IGNORE_INDEX
]
*
target_len
else
:
target_label
=
target_ids
if
mask_history
:
# reversed sequences
input_ids
=
source_ids
+
target_ids
+
input_ids
labels
=
source_label
+
target_label
+
labels
else
:
input_ids
+=
source_ids
+
target_ids
labels
+=
source_label
+
target_label
if
template
.
efficient_eos
:
input_ids
+=
[
tokenizer
.
eos_token_id
]
labels
+=
[
tokenizer
.
eos_token_id
]
return
input_ids
,
labels
def
preprocess_supervised_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
List
[
Any
]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
!=
1
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
=
_encode_supervised_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
cutoff_len
=
data_args
.
cutoff_len
,
train_on_prompt
=
data_args
.
train_on_prompt
,
mask_history
=
data_args
.
mask_history
,
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"labels"
].
append
(
labels
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
return
model_inputs
def
preprocess_packed_supervised_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
List
[
Any
]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num
=
0
batch_input_ids
,
batch_labels
,
batch_images
,
batch_videos
=
[],
[],
[],
[]
lengths
=
[]
length2indexes
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
!=
1
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
=
_encode_supervised_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
cutoff_len
=
data_args
.
cutoff_len
-
1
,
# reserved for the padding token
train_on_prompt
=
data_args
.
train_on_prompt
,
mask_history
=
data_args
.
mask_history
,
)
length
=
len
(
input_ids
)
if
length
>
data_args
.
cutoff_len
:
logger
.
warning_rank0
(
f
"Dropped lengthy example with length
{
length
}
>
{
data_args
.
cutoff_len
}
."
)
else
:
lengths
.
append
(
length
)
length2indexes
[
length
].
append
(
valid_num
)
batch_input_ids
.
append
(
input_ids
)
batch_labels
.
append
(
labels
)
batch_images
.
append
(
examples
[
"_images"
][
i
]
or
[])
batch_videos
.
append
(
examples
[
"_videos"
][
i
]
or
[])
valid_num
+=
1
model_inputs
=
defaultdict
(
list
)
knapsacks
=
greedy_knapsack
(
lengths
,
data_args
.
cutoff_len
-
1
)
# reserved for the padding token
for
knapsack
in
knapsacks
:
packed_input_ids
,
packed_attention_masks
,
packed_labels
=
[],
[],
[]
packed_images
,
packed_videos
=
[],
[]
for
i
,
length
in
enumerate
(
knapsack
):
index
=
length2indexes
[
length
].
pop
()
packed_input_ids
+=
batch_input_ids
[
index
]
packed_labels
+=
batch_labels
[
index
]
packed_images
+=
batch_images
[
index
]
packed_videos
+=
batch_videos
[
index
]
if
data_args
.
neat_packing
:
packed_attention_masks
+=
[
i
+
1
]
*
len
(
batch_input_ids
[
index
])
# start from 1
else
:
packed_attention_masks
+=
[
1
]
*
len
(
batch_input_ids
[
index
])
if
len
(
packed_input_ids
)
<
data_args
.
cutoff_len
:
pad_length
=
data_args
.
cutoff_len
-
len
(
packed_input_ids
)
packed_input_ids
+=
[
tokenizer
.
pad_token_id
]
*
pad_length
packed_labels
+=
[
IGNORE_INDEX
]
*
pad_length
if
data_args
.
neat_packing
:
packed_attention_masks
+=
[
0
]
*
pad_length
else
:
packed_attention_masks
+=
[
1
]
*
pad_length
# more efficient flash_attn
if
len
(
packed_input_ids
)
!=
data_args
.
cutoff_len
:
raise
ValueError
(
"The length of packed example should be identical to the cutoff length."
)
model_inputs
[
"input_ids"
].
append
(
packed_input_ids
)
model_inputs
[
"attention_mask"
].
append
(
packed_attention_masks
)
model_inputs
[
"labels"
].
append
(
packed_labels
)
model_inputs
[
"images"
].
append
(
packed_images
or
None
)
model_inputs
[
"videos"
].
append
(
packed_videos
or
None
)
return
model_inputs
def
print_supervised_dataset_example
(
example
:
Dict
[
str
,
List
[
int
]],
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
valid_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"labels"
]))
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
f
"labels:
\n
{
tokenizer
.
decode
(
valid_labels
,
skip_special_tokens
=
False
)
}
"
)
src/llamafactory/data/processors/unsupervised.py
deleted
100644 → 0
View file @
37b0ad9f
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
..data_utils
import
Role
from
.processor_utils
import
infer_seqlen
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
...hparams
import
DataArguments
from
..mm_plugin
import
ImageInput
,
VideoInput
from
..template
import
Template
logger
=
logging
.
get_logger
(
__name__
)
def
_encode_unsupervised_example
(
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
cutoff_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
if
len
(
response
)
==
1
:
messages
=
prompt
+
response
else
:
messages
=
prompt
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
messages
=
template
.
mm_plugin
.
process_messages
(
messages
,
images
,
videos
,
processor
)
input_ids
,
labels
=
template
.
encode_oneturn
(
tokenizer
,
messages
,
system
,
tools
)
if
template
.
efficient_eos
:
labels
+=
[
tokenizer
.
eos_token_id
]
input_ids
,
_
=
template
.
mm_plugin
.
process_token_ids
(
input_ids
,
None
,
images
,
videos
,
tokenizer
,
processor
)
source_len
,
target_len
=
infer_seqlen
(
len
(
input_ids
),
len
(
labels
),
cutoff_len
)
input_ids
=
input_ids
[:
source_len
]
labels
=
labels
[:
target_len
]
return
input_ids
,
labels
def
preprocess_unsupervised_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
List
[
Any
]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
=
_encode_unsupervised_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
cutoff_len
=
data_args
.
cutoff_len
,
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"labels"
].
append
(
labels
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
return
model_inputs
def
print_unsupervised_dataset_example
(
example
:
Dict
[
str
,
List
[
int
]],
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
"labels:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"labels"
],
skip_special_tokens
=
False
)))
src/llamafactory/data/template.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
from
typing_extensions
import
override
...
...
@@ -47,6 +47,7 @@ class Template:
format_prefix
:
"Formatter"
default_system
:
str
stop_words
:
List
[
str
]
thought_words
:
Tuple
[
str
,
str
]
efficient_eos
:
bool
replace_eos
:
bool
replace_jinja_template
:
bool
...
...
@@ -67,8 +68,8 @@ class Template:
for
encoded_ids
in
encoded_messages
[:
-
1
]:
prompt_ids
+=
encoded_ids
answer
_ids
=
encoded_messages
[
-
1
]
return
prompt_ids
,
answer
_ids
response
_ids
=
encoded_messages
[
-
1
]
return
prompt_ids
,
response
_ids
def
encode_multiturn
(
self
,
...
...
@@ -99,6 +100,27 @@ class Template:
return
list
(
stop_token_ids
)
def
_convert_elements_to_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
elements
:
"SLOTS"
)
->
List
[
int
]:
r
"""
Converts elements to token ids.
"""
token_ids
=
[]
for
elem
in
elements
:
if
isinstance
(
elem
,
str
):
if
len
(
elem
)
!=
0
:
token_ids
+=
tokenizer
.
encode
(
elem
,
add_special_tokens
=
False
)
elif
isinstance
(
elem
,
dict
):
token_ids
+=
[
tokenizer
.
convert_tokens_to_ids
(
elem
.
get
(
"token"
))]
elif
isinstance
(
elem
,
set
):
if
"bos_token"
in
elem
and
tokenizer
.
bos_token_id
is
not
None
:
token_ids
+=
[
tokenizer
.
bos_token_id
]
elif
"eos_token"
in
elem
and
tokenizer
.
eos_token_id
is
not
None
:
token_ids
+=
[
tokenizer
.
eos_token_id
]
else
:
raise
ValueError
(
f
"Input must be string, set[str] or dict[str, str], got
{
type
(
elem
)
}
"
)
return
token_ids
def
_encode
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
...
...
@@ -109,7 +131,7 @@ class Template:
r
"""
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t:
sep + query
resp
Turn t:
query
resp
"""
system
=
system
or
self
.
default_system
encoded_messages
=
[]
...
...
@@ -137,26 +159,179 @@ class Template:
return
encoded_messages
def
_convert_elements_to_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
elements
:
"SLOTS"
)
->
List
[
int
]:
@
staticmethod
def
_add_or_replace_eos_token
(
tokenizer
:
"PreTrainedTokenizer"
,
eos_token
:
str
)
->
None
:
r
"""
Converts elements to token ids
.
Adds or replaces eos token to the tokenizer
.
"""
token_ids
=
[]
for
elem
in
elements
:
if
isinstance
(
elem
,
str
):
if
len
(
elem
)
!=
0
:
token_ids
+=
tokenizer
.
encode
(
elem
,
add_special_tokens
=
False
)
elif
isinstance
(
elem
,
dict
):
token_ids
+=
[
tokenizer
.
convert_tokens_to_ids
(
elem
.
get
(
"token"
))]
elif
isinstance
(
elem
,
set
):
if
"bos_token"
in
elem
and
tokenizer
.
bos_token_id
is
not
None
:
token_ids
+=
[
tokenizer
.
bos_token_id
]
elif
"eos_token"
in
elem
and
tokenizer
.
eos_token_id
is
not
None
:
token_ids
+=
[
tokenizer
.
eos_token_id
]
else
:
raise
ValueError
(
f
"Input must be string, set[str] or dict[str, str], got
{
type
(
elem
)
}
"
)
is_added
=
tokenizer
.
eos_token_id
is
None
num_added_tokens
=
tokenizer
.
add_special_tokens
({
"eos_token"
:
eos_token
})
return
token_ids
if
is_added
:
logger
.
info_rank0
(
f
"Add eos token:
{
tokenizer
.
eos_token
}
."
)
else
:
logger
.
info_rank0
(
f
"Replace eos token:
{
tokenizer
.
eos_token
}
."
)
if
num_added_tokens
>
0
:
logger
.
warning_rank0
(
"New tokens have been added, make sure `resize_vocab` is True."
)
def
fix_special_tokens
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
r
"""
Adds eos token and pad token to the tokenizer.
"""
stop_words
=
self
.
stop_words
if
self
.
replace_eos
:
if
not
stop_words
:
raise
ValueError
(
"Stop words are required to replace the EOS token."
)
self
.
_add_or_replace_eos_token
(
tokenizer
,
eos_token
=
stop_words
[
0
])
stop_words
=
stop_words
[
1
:]
if
tokenizer
.
eos_token_id
is
None
:
self
.
_add_or_replace_eos_token
(
tokenizer
,
eos_token
=
"<|endoftext|>"
)
if
tokenizer
.
pad_token_id
is
None
:
tokenizer
.
pad_token
=
tokenizer
.
eos_token
logger
.
info_rank0
(
f
"Add pad token:
{
tokenizer
.
pad_token
}
"
)
if
stop_words
:
num_added_tokens
=
tokenizer
.
add_special_tokens
(
dict
(
additional_special_tokens
=
stop_words
),
replace_additional_special_tokens
=
False
)
logger
.
info_rank0
(
"Add {} to stop words."
.
format
(
","
.
join
(
stop_words
)))
if
num_added_tokens
>
0
:
logger
.
warning_rank0
(
"New tokens have been added, make sure `resize_vocab` is True."
)
@
staticmethod
def
_jinja_escape
(
content
:
str
)
->
str
:
r
"""
Escape single quotes in content.
"""
return
content
.
replace
(
"'"
,
r
"\'"
)
@
staticmethod
def
_convert_slots_to_jinja
(
slots
:
"SLOTS"
,
tokenizer
:
"PreTrainedTokenizer"
,
placeholder
:
str
=
"content"
)
->
str
:
r
"""
Converts slots to jinja template.
"""
slot_items
=
[]
for
slot
in
slots
:
if
isinstance
(
slot
,
str
):
slot_pieces
=
slot
.
split
(
"{{content}}"
)
if
slot_pieces
[
0
]:
slot_items
.
append
(
"'"
+
Template
.
_jinja_escape
(
slot_pieces
[
0
])
+
"'"
)
if
len
(
slot_pieces
)
>
1
:
slot_items
.
append
(
placeholder
)
if
slot_pieces
[
1
]:
slot_items
.
append
(
"'"
+
Template
.
_jinja_escape
(
slot_pieces
[
1
])
+
"'"
)
elif
isinstance
(
slot
,
set
):
# do not use {{ eos_token }} since it may be replaced
if
"bos_token"
in
slot
and
tokenizer
.
bos_token_id
is
not
None
:
slot_items
.
append
(
"'"
+
tokenizer
.
bos_token
+
"'"
)
elif
"eos_token"
in
slot
and
tokenizer
.
eos_token_id
is
not
None
:
slot_items
.
append
(
"'"
+
tokenizer
.
eos_token
+
"'"
)
elif
isinstance
(
slot
,
dict
):
raise
ValueError
(
"Dict is not supported."
)
return
" + "
.
join
(
slot_items
)
def
_get_jinja_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
r
"""
Returns the jinja template.
"""
prefix
=
self
.
_convert_slots_to_jinja
(
self
.
format_prefix
.
apply
(),
tokenizer
)
system
=
self
.
_convert_slots_to_jinja
(
self
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
"system_message"
)
user
=
self
.
_convert_slots_to_jinja
(
self
.
format_user
.
apply
(),
tokenizer
)
assistant
=
self
.
_convert_slots_to_jinja
(
self
.
format_assistant
.
apply
(),
tokenizer
)
jinja_template
=
""
if
prefix
:
jinja_template
+=
"{{ "
+
prefix
+
" }}"
if
self
.
default_system
:
jinja_template
+=
"{% set system_message = '"
+
self
.
_jinja_escape
(
self
.
default_system
)
+
"' %}"
jinja_template
+=
(
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
"{% if system_message is defined %}{{ "
+
system
+
" }}{% endif %}"
"{% for message in loop_messages %}"
"{% set content = message['content'] %}"
"{% if message['role'] == 'user' %}"
"{{ "
+
user
+
" }}"
"{% elif message['role'] == 'assistant' %}"
"{{ "
+
assistant
+
" }}"
"{% endif %}"
"{% endfor %}"
)
return
jinja_template
def
fix_jinja_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
r
"""
Replaces the jinja template in the tokenizer.
"""
if
tokenizer
.
chat_template
is
None
or
self
.
replace_jinja_template
:
try
:
tokenizer
.
chat_template
=
self
.
_get_jinja_template
(
tokenizer
)
except
ValueError
as
e
:
logger
.
info_rank0
(
f
"Cannot add this chat template to tokenizer:
{
e
}
."
)
@
staticmethod
def
_convert_slots_to_ollama
(
slots
:
"SLOTS"
,
tokenizer
:
"PreTrainedTokenizer"
,
placeholder
:
str
=
"content"
)
->
str
:
r
"""
Converts slots to ollama template.
"""
slot_items
=
[]
for
slot
in
slots
:
if
isinstance
(
slot
,
str
):
slot_pieces
=
slot
.
split
(
"{{content}}"
)
if
slot_pieces
[
0
]:
slot_items
.
append
(
slot_pieces
[
0
])
if
len
(
slot_pieces
)
>
1
:
slot_items
.
append
(
"{{ "
+
placeholder
+
" }}"
)
if
slot_pieces
[
1
]:
slot_items
.
append
(
slot_pieces
[
1
])
elif
isinstance
(
slot
,
set
):
# do not use {{ eos_token }} since it may be replaced
if
"bos_token"
in
slot
and
tokenizer
.
bos_token_id
is
not
None
:
slot_items
.
append
(
tokenizer
.
bos_token
)
elif
"eos_token"
in
slot
and
tokenizer
.
eos_token_id
is
not
None
:
slot_items
.
append
(
tokenizer
.
eos_token
)
elif
isinstance
(
slot
,
dict
):
raise
ValueError
(
"Dict is not supported."
)
return
""
.
join
(
slot_items
)
def
_get_ollama_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
r
"""
Returns the ollama template.
"""
prefix
=
self
.
_convert_slots_to_ollama
(
self
.
format_prefix
.
apply
(),
tokenizer
)
system
=
self
.
_convert_slots_to_ollama
(
self
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
".System"
)
user
=
self
.
_convert_slots_to_ollama
(
self
.
format_user
.
apply
(),
tokenizer
,
placeholder
=
".Content"
)
assistant
=
self
.
_convert_slots_to_ollama
(
self
.
format_assistant
.
apply
(),
tokenizer
,
placeholder
=
".Content"
)
return
(
f
"
{
prefix
}
{{{{ if .System }}}}
{
system
}
{{{{ end }}}}"
f
"""{{{{ range .Messages }}}}{{{{ if eq .Role "user" }}}}
{
user
}
"""
f
"""{{{{ else if eq .Role "assistant" }}}}
{
assistant
}
{{{{ end }}}}{{{{ end }}}}"""
)
def
get_ollama_modelfile
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
r
"""
Returns the ollama modelfile.
TODO: support function calling.
"""
modelfile
=
"# ollama modelfile auto-generated by llamafactory
\n\n
"
modelfile
+=
f
'FROM .
\n\n
TEMPLATE """
{
self
.
_get_ollama_template
(
tokenizer
)
}
"""
\n\n
'
if
self
.
default_system
:
modelfile
+=
f
'SYSTEM """
{
self
.
default_system
}
"""
\n\n
'
for
stop_token_id
in
self
.
get_stop_token_ids
(
tokenizer
):
modelfile
+=
f
'PARAMETER stop "
{
tokenizer
.
convert_ids_to_tokens
(
stop_token_id
)
}
"
\n
'
modelfile
+=
"PARAMETER num_ctx 4096
\n
"
return
modelfile
@
dataclass
...
...
@@ -169,11 +344,6 @@ class Llama2Template(Template):
system
:
str
,
tools
:
str
,
)
->
List
[
List
[
int
]]:
r
"""
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: sep + query resp
"""
system
=
system
or
self
.
default_system
encoded_messages
=
[]
for
i
,
message
in
enumerate
(
messages
):
...
...
@@ -201,11 +371,41 @@ class Llama2Template(Template):
return
encoded_messages
def
_get_jinja_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
prefix
=
self
.
_convert_slots_to_jinja
(
self
.
format_prefix
.
apply
(),
tokenizer
)
system_message
=
self
.
_convert_slots_to_jinja
(
self
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
"system_message"
)
user_message
=
self
.
_convert_slots_to_jinja
(
self
.
format_user
.
apply
(),
tokenizer
)
assistant_message
=
self
.
_convert_slots_to_jinja
(
self
.
format_assistant
.
apply
(),
tokenizer
)
jinja_template
=
""
if
prefix
:
jinja_template
+=
"{{ "
+
prefix
+
" }}"
if
self
.
default_system
:
jinja_template
+=
"{% set system_message = '"
+
self
.
_jinja_escape
(
self
.
default_system
)
+
"' %}"
jinja_template
+=
(
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
"{% for message in loop_messages %}"
"{% if loop.index0 == 0 and system_message is defined %}"
"{% set content = "
+
system_message
+
" + message['content'] %}"
"{% else %}{% set content = message['content'] %}{% endif %}"
"{% if message['role'] == 'user' %}"
"{{ "
+
user_message
+
" }}"
"{% elif message['role'] == 'assistant' %}"
"{{ "
+
assistant_message
+
" }}"
"{% endif %}"
"{% endfor %}"
)
return
jinja_template
TEMPLATES
:
Dict
[
str
,
"Template"
]
=
{}
def
_
register_template
(
def
register_template
(
name
:
str
,
format_user
:
Optional
[
"Formatter"
]
=
None
,
format_assistant
:
Optional
[
"Formatter"
]
=
None
,
...
...
@@ -216,10 +416,12 @@ def _register_template(
format_prefix
:
Optional
[
"Formatter"
]
=
None
,
default_system
:
str
=
""
,
stop_words
:
Optional
[
Sequence
[
str
]]
=
None
,
thought_words
:
Optional
[
Tuple
[
str
,
str
]]
=
None
,
efficient_eos
:
bool
=
False
,
replace_eos
:
bool
=
False
,
replace_jinja_template
:
bool
=
False
,
mm_plugin
:
"BasePlugin"
=
get_mm_plugin
(
name
=
"base"
),
template_class
:
Type
[
"Template"
]
=
Template
,
)
->
None
:
r
"""
Registers a chat template.
...
...
@@ -234,7 +436,7 @@ def _register_template(
The corresponding code should be:
```
_
register_template(
register_template(
name="custom",
format_user=StringFormatter(slots=["<user>{{content}}\n<model>"]),
format_assistant=StringFormatter(slots=["{{content}}</s>\n"]),
...
...
@@ -242,7 +444,9 @@ def _register_template(
)
```
"""
template_class
=
Llama2Template
if
any
(
k
in
name
for
k
in
(
"llama2"
,
"mistral"
,
"pixtral"
))
else
Template
if
name
in
TEMPLATES
:
raise
ValueError
(
f
"Template
{
name
}
already exists."
)
default_slots
=
[
"{{content}}"
]
if
efficient_eos
else
[
"{{content}}"
,
{
"eos_token"
}]
default_user_formatter
=
StringFormatter
(
slots
=
[
"{{content}}"
])
default_assistant_formatter
=
StringFormatter
(
slots
=
default_slots
)
...
...
@@ -259,6 +463,7 @@ def _register_template(
format_prefix
=
format_prefix
or
default_prefix_formatter
,
default_system
=
default_system
,
stop_words
=
stop_words
or
[],
thought_words
=
thought_words
or
(
"<think>"
,
"</think>"
),
efficient_eos
=
efficient_eos
,
replace_eos
=
replace_eos
,
replace_jinja_template
=
replace_jinja_template
,
...
...
@@ -266,97 +471,83 @@ def _register_template(
)
def
_add_or_replace_eos_token
(
tokenizer
:
"PreTrainedTokenizer"
,
eos_token
:
str
)
->
None
:
is_added
=
tokenizer
.
eos_token_id
is
None
num_added_tokens
=
tokenizer
.
add_special_tokens
({
"eos_token"
:
eos_token
})
if
is_added
:
logger
.
info_rank0
(
f
"Add eos token:
{
tokenizer
.
eos_token
}
"
)
else
:
logger
.
info_rank0
(
f
"Replace eos token:
{
tokenizer
.
eos_token
}
"
)
if
num_added_tokens
>
0
:
logger
.
warning_rank0
(
"New tokens have been added, make sure `resize_vocab` is True."
)
def
_jinja_escape
(
content
:
str
)
->
str
:
return
content
.
replace
(
"'"
,
r
"\'"
)
def
_convert_slots_to_jinja
(
slots
:
"SLOTS"
,
tokenizer
:
"PreTrainedTokenizer"
,
placeholder
:
str
=
"content"
)
->
str
:
slot_items
=
[]
for
slot
in
slots
:
if
isinstance
(
slot
,
str
):
slot_pieces
=
slot
.
split
(
"{{content}}"
)
if
slot_pieces
[
0
]:
slot_items
.
append
(
"'"
+
_jinja_escape
(
slot_pieces
[
0
])
+
"'"
)
if
len
(
slot_pieces
)
>
1
:
slot_items
.
append
(
placeholder
)
if
slot_pieces
[
1
]:
slot_items
.
append
(
"'"
+
_jinja_escape
(
slot_pieces
[
1
])
+
"'"
)
elif
isinstance
(
slot
,
set
):
# do not use {{ eos_token }} since it may be replaced
if
"bos_token"
in
slot
and
tokenizer
.
bos_token_id
is
not
None
:
slot_items
.
append
(
"'"
+
tokenizer
.
bos_token
+
"'"
)
elif
"eos_token"
in
slot
and
tokenizer
.
eos_token_id
is
not
None
:
slot_items
.
append
(
"'"
+
tokenizer
.
eos_token
+
"'"
)
elif
isinstance
(
slot
,
dict
):
raise
ValueError
(
"Dict is not supported."
)
return
" + "
.
join
(
slot_items
)
def
_get_jinja_template
(
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
def
parse_template
(
tokenizer
:
"PreTrainedTokenizer"
)
->
"Template"
:
r
"""
Returns the jinja template
.
Extracts a chat template from the tokenizer
.
"""
jinja_template
=
""
prefix
=
_convert_slots_to_jinja
(
template
.
format_prefix
.
apply
(),
tokenizer
)
if
prefix
:
jinja_template
+=
"{{ "
+
prefix
+
" }}"
if
template
.
default_system
:
jinja_template
+=
"{% set system_message = '"
+
_jinja_escape
(
template
.
default_system
)
+
"' %}"
jinja_template
+=
(
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
def
find_diff
(
short_str
:
str
,
long_str
:
str
)
->
str
:
i
,
j
=
0
,
0
diff
=
""
while
i
<
len
(
short_str
)
and
j
<
len
(
long_str
):
if
short_str
[
i
]
==
long_str
[
j
]:
i
+=
1
j
+=
1
else
:
diff
+=
long_str
[
j
]
j
+=
1
return
diff
prefix
=
tokenizer
.
decode
(
tokenizer
.
encode
(
""
))
messages
=
[{
"role"
:
"system"
,
"content"
:
"{{content}}"
}]
system_slot
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
False
,
tokenize
=
False
)[
len
(
prefix
)
:]
messages
=
[{
"role"
:
"system"
,
"content"
:
""
},
{
"role"
:
"user"
,
"content"
:
"{{content}}"
}]
user_slot_empty_system
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
)
user_slot_empty_system
=
user_slot_empty_system
[
len
(
prefix
)
:]
messages
=
[{
"role"
:
"user"
,
"content"
:
"{{content}}"
}]
user_slot
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
)
user_slot
=
user_slot
[
len
(
prefix
)
:]
messages
=
[{
"role"
:
"user"
,
"content"
:
"{{content}}"
},
{
"role"
:
"assistant"
,
"content"
:
"{{content}}"
}]
assistant_slot
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
False
,
tokenize
=
False
)
assistant_slot
=
assistant_slot
[
len
(
prefix
)
+
len
(
user_slot
)
:]
if
len
(
user_slot
)
>
len
(
user_slot_empty_system
):
default_system
=
find_diff
(
user_slot_empty_system
,
user_slot
)
sole_system
=
system_slot
.
replace
(
"{{content}}"
,
default_system
,
1
)
user_slot
=
user_slot
[
len
(
sole_system
)
:]
else
:
# if defaut_system is empty, user_slot_empty_system will be longer than user_slot
default_system
=
""
return
Template
(
format_user
=
StringFormatter
(
slots
=
[
user_slot
]),
format_assistant
=
StringFormatter
(
slots
=
[
assistant_slot
]),
format_system
=
StringFormatter
(
slots
=
[
system_slot
]),
format_function
=
FunctionFormatter
(
slots
=
[
assistant_slot
],
tool_format
=
"default"
),
format_observation
=
StringFormatter
(
slots
=
[
user_slot
]),
format_tools
=
ToolFormatter
(
tool_format
=
"default"
),
format_prefix
=
EmptyFormatter
(
slots
=
[
prefix
])
if
prefix
else
EmptyFormatter
(),
default_system
=
default_system
,
stop_words
=
[],
thought_words
=
(
"<think>"
,
"</think>"
),
efficient_eos
=
False
,
replace_eos
=
False
,
replace_jinja_template
=
False
,
mm_plugin
=
get_mm_plugin
(
name
=
"base"
),
)
system_message
=
_convert_slots_to_jinja
(
template
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
"system_message"
)
if
not
isinstance
(
template
,
Llama2Template
):
jinja_template
+=
"{% if system_message is defined %}{{ "
+
system_message
+
" }}{% endif %}"
jinja_template
+=
"{% for message in loop_messages %}"
jinja_template
+=
"{% set content = message['content'] %}"
if
isinstance
(
template
,
Llama2Template
):
jinja_template
+=
"{% if loop.index0 == 0 and system_message is defined %}"
jinja_template
+=
"{% set content = "
+
system_message
+
" + message['content'] %}"
jinja_template
+=
"{% endif %}"
jinja_template
+=
"{% if message['role'] == 'user' %}"
user_message
=
_convert_slots_to_jinja
(
template
.
format_user
.
apply
(),
tokenizer
)
jinja_template
+=
"{{ "
+
user_message
+
" }}"
jinja_template
+=
"{% elif message['role'] == 'assistant' %}"
assistant_message
=
_convert_slots_to_jinja
(
template
.
format_assistant
.
apply
(),
tokenizer
)
jinja_template
+=
"{{ "
+
assistant_message
+
" }}"
jinja_template
+=
"{% endif %}"
jinja_template
+=
"{% endfor %}"
return
jinja_template
def
get_template_and_fix_tokenizer
(
tokenizer
:
"PreTrainedTokenizer"
,
data_args
:
"DataArguments"
)
->
"Template"
:
r
"""
Gets chat template and fixes the tokenizer.
"""
if
data_args
.
template
is
None
:
template
=
TEMPLATES
[
"empty"
]
# placeholder
if
isinstance
(
tokenizer
.
chat_template
,
str
):
logger
.
warning_rank0
(
"`template` was not specified, try parsing the chat template from the tokenizer."
)
template
=
parse_template
(
tokenizer
)
else
:
logger
.
warning_rank0
(
"`template` was not specified, use `empty` template."
)
template
=
TEMPLATES
[
"empty"
]
# placeholder
else
:
template
=
TEMPLATES
.
get
(
data_args
.
template
,
None
)
if
template
is
None
:
if
data_args
.
template
not
in
TEMPLATES
:
raise
ValueError
(
f
"Template
{
data_args
.
template
}
does not exist."
)
template
=
TEMPLATES
[
data_args
.
template
]
if
template
.
mm_plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
check_version
(
"transformers>=4.45.0"
)
...
...
@@ -369,39 +560,12 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template
.
format_function
=
FunctionFormatter
(
slots
=
default_slots
,
tool_format
=
data_args
.
tool_format
)
template
.
format_tools
=
ToolFormatter
(
tool_format
=
data_args
.
tool_format
)
stop_words
=
template
.
stop_words
if
template
.
replace_eos
:
if
not
stop_words
:
raise
ValueError
(
"Stop words are required to replace the EOS token."
)
_add_or_replace_eos_token
(
tokenizer
,
eos_token
=
stop_words
[
0
])
stop_words
=
stop_words
[
1
:]
if
tokenizer
.
eos_token_id
is
None
:
_add_or_replace_eos_token
(
tokenizer
,
eos_token
=
"<|endoftext|>"
)
if
tokenizer
.
pad_token_id
is
None
:
tokenizer
.
pad_token
=
tokenizer
.
eos_token
logger
.
info_rank0
(
f
"Add pad token:
{
tokenizer
.
pad_token
}
"
)
if
stop_words
:
num_added_tokens
=
tokenizer
.
add_special_tokens
(
dict
(
additional_special_tokens
=
stop_words
),
replace_additional_special_tokens
=
False
)
logger
.
info_rank0
(
"Add {} to stop words."
.
format
(
","
.
join
(
stop_words
)))
if
num_added_tokens
>
0
:
logger
.
warning_rank0
(
"New tokens have been added, make sure `resize_vocab` is True."
)
if
tokenizer
.
chat_template
is
None
or
template
.
replace_jinja_template
:
try
:
tokenizer
.
chat_template
=
_get_jinja_template
(
template
,
tokenizer
)
except
ValueError
as
e
:
logger
.
info_rank0
(
f
"Cannot add this chat template to tokenizer:
{
e
}
."
)
template
.
fix_special_tokens
(
tokenizer
)
template
.
fix_jinja_template
(
tokenizer
)
return
template
_
register_template
(
register_template
(
name
=
"alpaca"
,
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n\n
### Response:
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n\n
"
]),
...
...
@@ -412,7 +576,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"aquila"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}###Assistant:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}###"
]),
...
...
@@ -425,7 +589,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"atom"
,
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"Human: {{content}}
\n
"
,
{
"eos_token"
},
{
"bos_token"
},
"Assistant:"
]
...
...
@@ -434,21 +598,31 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"baichuan"
,
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"<reserved_102>"
},
"{{content}}"
,
{
"token"
:
"<reserved_103>"
}]),
efficient_eos
=
True
,
)
_
register_template
(
register_template
(
name
=
"baichuan2"
,
format_user
=
StringFormatter
(
slots
=
[
"<reserved_106>{{content}}<reserved_107>"
]),
efficient_eos
=
True
,
)
_register_template
(
register_template
(
name
=
"bailing"
,
format_user
=
StringFormatter
(
slots
=
[
"<role>HUMAN</role>{{content}}<role>ASSISTANT</role>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<role>SYSTEM</role>{{content}}"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<role>OBSERVATION</role>{{content}}<role>ASSISTANT</role>"
]),
stop_words
=
[
"<|endoftext|>"
],
efficient_eos
=
True
,
)
register_template
(
name
=
"belle"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n\n
Belle: "
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n\n
"
]),
...
...
@@ -456,13 +630,13 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"bluelm"
,
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"[|Human|]:"
},
"{{content}}"
,
{
"token"
:
"[|AI|]:"
}]),
)
_
register_template
(
register_template
(
name
=
"breeze"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}} [/INST] "
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
...
...
@@ -470,7 +644,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"chatglm2"
,
format_user
=
StringFormatter
(
slots
=
[
"[Round {{idx}}]
\n\n
问:{{content}}
\n\n
答:"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"token"
:
"[gMASK]"
},
{
"token"
:
"sop"
}]),
...
...
@@ -478,7 +652,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"chatglm3"
,
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"<|user|>"
},
"
\n
"
,
"{{content}}"
,
{
"token"
:
"<|assistant|>"
}]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
"
,
"{{content}}"
]),
...
...
@@ -494,7 +668,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"chatml"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
...
@@ -507,7 +681,7 @@ _register_template(
# copied from chatml template
_
register_template
(
register_template
(
name
=
"chatml_de"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
...
@@ -520,13 +694,13 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"codegeex2"
,
format_prefix
=
EmptyFormatter
(
slots
=
[{
"token"
:
"[gMASK]"
},
{
"token"
:
"sop"
}]),
)
_
register_template
(
register_template
(
name
=
"codegeex4"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
]),
...
...
@@ -543,7 +717,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"cohere"
,
format_user
=
StringFormatter
(
slots
=
[
...
...
@@ -558,7 +732,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"cpm"
,
format_user
=
StringFormatter
(
slots
=
[
"<用户>{{content}}<AI>"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
...
...
@@ -566,7 +740,7 @@ _register_template(
# copied from chatml template
_
register_template
(
register_template
(
name
=
"cpm3"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
...
@@ -577,7 +751,7 @@ _register_template(
# copied from chatml template
_
register_template
(
register_template
(
name
=
"dbrx"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
...
@@ -602,7 +776,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"deepseek"
,
format_user
=
StringFormatter
(
slots
=
[
"User: {{content}}
\n\n
Assistant:"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
...
...
@@ -610,14 +784,14 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"deepseek3"
,
format_user
=
StringFormatter
(
slots
=
[
"<|User|>{{content}}<|Assistant|>"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
_
register_template
(
register_template
(
name
=
"deepseekcoder"
,
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n
### Response:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}
\n
<|EOT|>
\n
"
]),
...
...
@@ -631,7 +805,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"default"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n
Assistant:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n
"
]),
...
...
@@ -639,13 +813,13 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"empty"
,
efficient_eos
=
True
,
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
])
,
)
_
register_template
(
register_template
(
name
=
"exaone"
,
format_user
=
StringFormatter
(
slots
=
[
"[|user|]{{content}}
\n
[|assistant|]"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n
"
]),
...
...
@@ -653,7 +827,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"falcon"
,
format_user
=
StringFormatter
(
slots
=
[
"User: {{content}}
\n
Falcon:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
...
...
@@ -661,14 +835,14 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"fewshot"
,
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
efficient_eos
=
True
,
)
_
register_template
(
register_template
(
name
=
"gemma"
,
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_turn>
\n
"
]),
...
...
@@ -679,7 +853,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"glm4"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}"
]),
...
...
@@ -693,7 +867,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"granite3"
,
format_user
=
StringFormatter
(
slots
=
[
...
...
@@ -705,7 +879,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"index"
,
format_user
=
StringFormatter
(
slots
=
[
"reserved_0{{content}}reserved_1"
]),
format_system
=
StringFormatter
(
slots
=
[
"<unk>{{content}}"
]),
...
...
@@ -713,54 +887,59 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"intern"
,
format_user
=
StringFormatter
(
slots
=
[
"<|User|>:{{content}}
\n
<|Bot|>:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<eoa>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|System|>:{{content}}
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
default_system
=
(
"You are an AI assistant whose name is InternLM (书生·浦语).
\n
"
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
"(上海人工智能实验室). It is designed to be helpful, honest, and harmless.
\n
"
"- InternLM (书生·浦语) can understand and communicate fluently in the language "
"chosen by the user such as English and 中文."
),
stop_words
=
[
"<eoa>"
],
)
_
register_template
(
register_template
(
name
=
"intern2"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
default_system
=
(
"You are an AI assistant whose name is InternLM (书生·浦语).
\n
"
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
"(上海人工智能实验室). It is designed to be helpful, honest, and harmless.
\n
"
"- InternLM (书生·浦语) can understand and communicate fluently in the language "
"chosen by the user such as English and 中文."
),
stop_words
=
[
"<|im_end|>"
],
)
# copied from intern2 template
_register_template
(
name
=
"intern3"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|im_end|>"
],
)
_register_template
(
register_template
(
name
=
"llama2"
,
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"[INST] {{content}} [/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"<<SYS>>
\n
{{content}}
\n
<</SYS>>
\n\n
"
]),
template_class
=
Llama2Template
,
)
# copied from llama2 template
_
register_template
(
register_template
(
name
=
"llama2_zh"
,
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"[INST] {{content}} [/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"<<SYS>>
\n
{{content}}
\n
<</SYS>>
\n\n
"
]),
default_system
=
"You are a helpful assistant. 你是一个乐于助人的助手。"
,
template_class
=
Llama2Template
,
)
_
register_template
(
register_template
(
name
=
"llama3"
,
format_user
=
StringFormatter
(
slots
=
[
...
...
@@ -788,7 +967,7 @@ _register_template(
# copied from llama3 template
_
register_template
(
register_template
(
name
=
"mllama"
,
format_user
=
StringFormatter
(
slots
=
[
...
...
@@ -816,8 +995,20 @@ _register_template(
)
register_template
(
name
=
"moonlight"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_user|>user<|im_middle|>{{content}}<|im_end|><|im_assistant|>assistant<|im_middle|>"
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_system|>system<|im_middle|>{{content}}<|im_end|>"
]),
default_system
=
"You are a helpful assistant provided by Moonshot-AI."
,
stop_words
=
[
"<|im_end|>"
],
)
# copied from vicuna template
_
register_template
(
register_template
(
name
=
"llava"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
default_system
=
(
...
...
@@ -829,7 +1020,7 @@ _register_template(
# copied from vicuna template
_
register_template
(
register_template
(
name
=
"llava_next"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
default_system
=
(
...
...
@@ -841,7 +1032,7 @@ _register_template(
# copied from llama3 template
_
register_template
(
register_template
(
name
=
"llava_next_llama3"
,
format_user
=
StringFormatter
(
slots
=
[
...
...
@@ -870,21 +1061,22 @@ _register_template(
# copied from mistral template
_
register_template
(
register_template
(
name
=
"llava_next_mistral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}}[/INST]"
]),
format_assistant
=
StringFormatter
(
slots
=
[
" {{content}}"
,
{
"eos_token"
}]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS]
"
,
"
{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] {{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
template_class
=
Llama2Template
,
)
# copied from
chatml
template
_
register_template
(
# copied from
qwen
template
register_template
(
name
=
"llava_next_qwen"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
...
@@ -901,7 +1093,7 @@ _register_template(
# copied from chatml template
_
register_template
(
register_template
(
name
=
"llava_next_yi"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
...
@@ -912,7 +1104,7 @@ _register_template(
# copied from vicuna template
_
register_template
(
register_template
(
name
=
"llava_next_video"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
default_system
=
(
...
...
@@ -924,21 +1116,22 @@ _register_template(
# copied from mistral template
_
register_template
(
register_template
(
name
=
"llava_next_video_mistral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}}[/INST]"
]),
format_assistant
=
StringFormatter
(
slots
=
[
" {{content}}"
,
{
"eos_token"
}]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS]
"
,
"
{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] {{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next_video"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
),
template_class
=
Llama2Template
,
)
# copied from chatml template
_
register_template
(
register_template
(
name
=
"llava_next_video_yi"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
...
@@ -949,7 +1142,7 @@ _register_template(
# copied from chatml template
_
register_template
(
register_template
(
name
=
"marco"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
...
@@ -965,43 +1158,83 @@ _register_template(
# copied from chatml template
_
register_template
(
register_template
(
name
=
"minicpm_v"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
default_system
=
"You are a helpful assistant."
,
mm_plugin
=
get_mm_plugin
(
name
=
"minicpm_v"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
),
)
_register_template
(
# copied from minicpm_v template
register_template
(
name
=
"minicpm_o"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
default_system
=
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
,
mm_plugin
=
get_mm_plugin
(
name
=
"minicpm_v"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
,
audio_token
=
"<audio>"
),
)
# mistral tokenizer v3 tekken
register_template
(
name
=
"ministral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST]{{content}}[/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS]{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
template_class
=
Llama2Template
,
)
# mistral tokenizer v3
register_template
(
name
=
"mistral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}}[/INST]"
]),
format_assistant
=
StringFormatter
(
slots
=
[
" {{content}}"
,
{
"eos_token"
}]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS]
"
,
"
{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] {{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
template_class
=
Llama2Template
,
)
# mistral tokenizer v7 tekken (copied from ministral)
register_template
(
name
=
"mistral_small"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST]{{content}}[/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"[SYSTEM_PROMPT]{{content}}[/SYSTEM_PROMPT]"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS]{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
_
register_template
(
register_template
(
name
=
"olmo"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"eos_token"
}]),
)
_
register_template
(
register_template
(
name
=
"openchat"
,
format_user
=
StringFormatter
(
slots
=
[
"GPT4 Correct User: {{content}}"
,
{
"eos_token"
},
"GPT4 Correct Assistant:"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
_
register_template
(
register_template
(
name
=
"openchat-3.6"
,
format_user
=
StringFormatter
(
slots
=
[
...
...
@@ -1017,7 +1250,7 @@ _register_template(
# copied from chatml template
_
register_template
(
register_template
(
name
=
"opencoder"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
...
@@ -1028,16 +1261,24 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"orion"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n\n
Assistant: "
,
{
"eos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
# copied from gemma template
_register_template
(
register_template
(
name
=
"paligemma"
,
format_user
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
),
)
# copied from gemma template
register_template
(
name
=
"paligemma_chat"
,
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_turn>
\n
"
]),
format_observation
=
StringFormatter
(
...
...
@@ -1048,7 +1289,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"phi"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|end|>
\n
<|assistant|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
...
...
@@ -1057,7 +1298,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"phi_small"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|end|>
\n
<|assistant|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
...
...
@@ -1067,7 +1308,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"phi4"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user<|im_sep|>{{content}}<|im_end|><|im_start|>assistant<|im_sep|>"
]
...
...
@@ -1078,17 +1319,22 @@ _register_template(
)
_register_template
(
# copied from ministral template
register_template
(
name
=
"pixtral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST]{{content}}[/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS]{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"pixtral"
,
image_token
=
"[IMG]"
),
template_class
=
Llama2Template
,
)
# copied from chatml template
_
register_template
(
register_template
(
name
=
"qwen"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
...
@@ -1104,7 +1350,19 @@ _register_template(
# copied from chatml template
_register_template
(
register_template
(
name
=
"qwen2_audio"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_audio"
,
audio_token
=
"<|AUDIO|>"
),
)
# copied from qwen template
register_template
(
name
=
"qwen2_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
...
@@ -1120,7 +1378,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"sailor"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>question
\n
{{content}}<|im_end|>
\n
<|im_start|>answer
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
...
@@ -1134,7 +1392,7 @@ _register_template(
# copied from llama3 template
_
register_template
(
register_template
(
name
=
"skywork_o1"
,
format_user
=
StringFormatter
(
slots
=
[
...
...
@@ -1168,7 +1426,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"solar"
,
format_user
=
StringFormatter
(
slots
=
[
"### User:
\n
{{content}}
\n\n
### Assistant:
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"### System:
\n
{{content}}
\n\n
"
]),
...
...
@@ -1176,7 +1434,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"starchat"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|end|>
\n
<|assistant|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
...
...
@@ -1185,14 +1443,14 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"telechat"
,
format_user
=
StringFormatter
(
slots
=
[
"<_user>{{content}}<_bot>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<_system>{{content}}<_end>"
]),
)
_
register_template
(
register_template
(
name
=
"telechat2"
,
format_user
=
StringFormatter
(
slots
=
[
"<_user>{{content}}<_bot>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<_system>{{content}}"
]),
...
...
@@ -1202,7 +1460,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"vicuna"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
default_system
=
(
...
...
@@ -1213,7 +1471,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"video_llava"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
default_system
=
(
...
...
@@ -1224,7 +1482,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"xuanyuan"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}} Assistant:"
]),
default_system
=
(
...
...
@@ -1235,13 +1493,13 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"xverse"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n\n
Assistant: "
]),
)
_
register_template
(
register_template
(
name
=
"yayi"
,
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"<|Human|>"
},
":
\n
{{content}}
\n\n
"
,
{
"token"
:
"<|YaYi|>"
},
":"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
...
...
@@ -1262,7 +1520,7 @@ _register_template(
# copied from chatml template
_
register_template
(
register_template
(
name
=
"yi"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
...
...
@@ -1271,7 +1529,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"yi_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"### Human: {{content}}
\n
### Assistant:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
...
...
@@ -1288,7 +1546,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"yuan"
,
format_user
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"token"
:
"<sep>"
}]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<eod>
\n
"
]),
...
...
@@ -1296,7 +1554,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"zephyr"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}"
,
{
"eos_token"
},
"<|assistant|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
,
{
"eos_token"
}]),
...
...
@@ -1304,7 +1562,7 @@ _register_template(
)
_
register_template
(
register_template
(
name
=
"ziya"
,
format_user
=
StringFormatter
(
slots
=
[
"<human>:{{content}}
\n
<bot>:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment