Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenych
llama-grpo
Commits
c7c477c7
Commit
c7c477c7
authored
Sep 24, 2025
by
chenych
Browse files
add grpo
parents
Pipeline
#2942
failed with stages
in 0 seconds
Changes
282
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6772 additions
and
0 deletions
+6772
-0
src/llamafactory/chat/vllm_engine.py
src/llamafactory/chat/vllm_engine.py
+263
-0
src/llamafactory/cli.py
src/llamafactory/cli.py
+160
-0
src/llamafactory/data/__init__.py
src/llamafactory/data/__init__.py
+37
-0
src/llamafactory/data/collator.py
src/llamafactory/data/collator.py
+322
-0
src/llamafactory/data/converter.py
src/llamafactory/data/converter.py
+284
-0
src/llamafactory/data/data_utils.py
src/llamafactory/data/data_utils.py
+190
-0
src/llamafactory/data/formatter.py
src/llamafactory/data/formatter.py
+142
-0
src/llamafactory/data/loader.py
src/llamafactory/data/loader.py
+335
-0
src/llamafactory/data/mm_plugin.py
src/llamafactory/data/mm_plugin.py
+1912
-0
src/llamafactory/data/parser.py
src/llamafactory/data/parser.py
+147
-0
src/llamafactory/data/processor/__init__.py
src/llamafactory/data/processor/__init__.py
+31
-0
src/llamafactory/data/processor/feedback.py
src/llamafactory/data/processor/feedback.py
+129
-0
src/llamafactory/data/processor/pairwise.py
src/llamafactory/data/processor/pairwise.py
+118
-0
src/llamafactory/data/processor/pretrain.py
src/llamafactory/data/processor/pretrain.py
+57
-0
src/llamafactory/data/processor/processor_utils.py
src/llamafactory/data/processor/processor_utils.py
+88
-0
src/llamafactory/data/processor/supervised.py
src/llamafactory/data/processor/supervised.py
+203
-0
src/llamafactory/data/processor/unsupervised.py
src/llamafactory/data/processor/unsupervised.py
+91
-0
src/llamafactory/data/template.py
src/llamafactory/data/template.py
+1943
-0
src/llamafactory/data/tool_utils.py
src/llamafactory/data/tool_utils.py
+320
-0
src/llamafactory/eval/__init__.py
src/llamafactory/eval/__init__.py
+0
-0
No files found.
src/llamafactory/chat/vllm_engine.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
uuid
from
collections.abc
import
AsyncGenerator
,
AsyncIterator
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
from
typing_extensions
import
override
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
,
EngineName
from
..extras.misc
import
get_device_count
from
..extras.packages
import
is_vllm_available
from
..model
import
load_config
,
load_tokenizer
from
..model.model_utils.quantization
import
QuantizationMethod
from
..model.model_utils.visual
import
LlavaMultiModalProjectorForYiVLForVLLM
from
.base_engine
import
BaseEngine
,
Response
if
is_vllm_available
():
from
vllm
import
AsyncEngineArgs
,
AsyncLLMEngine
,
RequestOutput
,
SamplingParams
from
vllm.lora.request
import
LoRARequest
if
TYPE_CHECKING
:
from
..data.mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
logging
.
get_logger
(
__name__
)
class
VllmEngine
(
BaseEngine
):
def
__init__
(
self
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
)
->
None
:
self
.
name
=
EngineName
.
VLLM
self
.
model_args
=
model_args
config
=
load_config
(
model_args
)
# may download model from ms hub
if
getattr
(
config
,
"quantization_config"
,
None
):
# gptq models should use float16
quantization_config
:
dict
[
str
,
Any
]
=
getattr
(
config
,
"quantization_config"
,
None
)
quant_method
=
quantization_config
.
get
(
"quant_method"
,
""
)
if
quant_method
==
QuantizationMethod
.
GPTQ
and
model_args
.
infer_dtype
==
"auto"
:
model_args
.
infer_dtype
=
"float16"
self
.
can_generate
=
finetuning_args
.
stage
==
"sft"
tokenizer_module
=
load_tokenizer
(
model_args
)
self
.
tokenizer
=
tokenizer_module
[
"tokenizer"
]
self
.
processor
=
tokenizer_module
[
"processor"
]
self
.
tokenizer
.
padding_side
=
"left"
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
data_args
)
self
.
template
.
mm_plugin
.
expand_mm_tokens
=
False
# for vllm generate
self
.
generating_args
=
generating_args
.
to_dict
()
engine_args
=
{
"model"
:
model_args
.
model_name_or_path
,
"trust_remote_code"
:
model_args
.
trust_remote_code
,
"download_dir"
:
model_args
.
cache_dir
,
"dtype"
:
model_args
.
infer_dtype
,
"max_model_len"
:
model_args
.
vllm_maxlen
,
"tensor_parallel_size"
:
get_device_count
()
or
1
,
"gpu_memory_utilization"
:
model_args
.
vllm_gpu_util
,
"disable_log_stats"
:
True
,
"disable_log_requests"
:
True
,
"enforce_eager"
:
model_args
.
vllm_enforce_eager
,
"enable_lora"
:
model_args
.
adapter_name_or_path
is
not
None
,
"max_lora_rank"
:
model_args
.
vllm_max_lora_rank
,
}
if
self
.
template
.
mm_plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
engine_args
[
"limit_mm_per_prompt"
]
=
{
"image"
:
4
,
"video"
:
2
,
"audio"
:
2
}
if
isinstance
(
model_args
.
vllm_config
,
dict
):
engine_args
.
update
(
model_args
.
vllm_config
)
if
getattr
(
config
,
"is_yi_vl_derived_model"
,
None
):
import
vllm.model_executor.models.llava
logger
.
info_rank0
(
"Detected Yi-VL model, applying projector patch."
)
vllm
.
model_executor
.
models
.
llava
.
LlavaMultiModalProjector
=
LlavaMultiModalProjectorForYiVLForVLLM
self
.
model
=
AsyncLLMEngine
.
from_engine_args
(
AsyncEngineArgs
(
**
engine_args
))
if
model_args
.
adapter_name_or_path
is
not
None
:
self
.
lora_request
=
LoRARequest
(
"default"
,
1
,
model_args
.
adapter_name_or_path
[
0
])
else
:
self
.
lora_request
=
None
async
def
_generate
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncIterator
[
"RequestOutput"
]:
request_id
=
f
"chatcmpl-
{
uuid
.
uuid4
().
hex
}
"
if
images
is
not
None
and
not
any
(
IMAGE_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
IMAGE_PLACEHOLDER
*
len
(
images
)
+
messages
[
0
][
"content"
]
if
videos
is
not
None
and
not
any
(
VIDEO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
VIDEO_PLACEHOLDER
*
len
(
videos
)
+
messages
[
0
][
"content"
]
if
audios
is
not
None
and
not
any
(
AUDIO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
AUDIO_PLACEHOLDER
*
len
(
audios
)
+
messages
[
0
][
"content"
]
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
messages
,
images
or
[],
videos
or
[],
audios
or
[],
self
.
processor
)
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_length
=
len
(
prompt_ids
)
temperature
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"temperature"
,
None
)
top_p
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"top_p"
,
None
)
top_k
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"top_k"
,
None
)
num_return_sequences
:
int
=
input_kwargs
.
pop
(
"num_return_sequences"
,
1
)
repetition_penalty
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"repetition_penalty"
,
None
)
length_penalty
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"length_penalty"
,
None
)
skip_special_tokens
:
Optional
[
bool
]
=
input_kwargs
.
pop
(
"skip_special_tokens"
,
None
)
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
max_new_tokens
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_new_tokens"
,
None
)
stop
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
if
length_penalty
is
not
None
:
logger
.
warning_rank0
(
"Length penalty is not supported by the vllm engine yet."
)
if
"max_new_tokens"
in
self
.
generating_args
:
max_tokens
=
self
.
generating_args
[
"max_new_tokens"
]
elif
"max_length"
in
self
.
generating_args
:
if
self
.
generating_args
[
"max_length"
]
>
prompt_length
:
max_tokens
=
self
.
generating_args
[
"max_length"
]
-
prompt_length
else
:
max_tokens
=
1
if
max_length
:
max_tokens
=
max_length
-
prompt_length
if
max_length
>
prompt_length
else
1
if
max_new_tokens
:
max_tokens
=
max_new_tokens
sampling_params
=
SamplingParams
(
n
=
num_return_sequences
,
repetition_penalty
=
(
repetition_penalty
if
repetition_penalty
is
not
None
else
self
.
generating_args
[
"repetition_penalty"
]
)
or
1.0
,
# repetition_penalty must > 0
temperature
=
temperature
if
temperature
is
not
None
else
self
.
generating_args
[
"temperature"
],
top_p
=
(
top_p
if
top_p
is
not
None
else
self
.
generating_args
[
"top_p"
])
or
1.0
,
# top_p must > 0
top_k
=
(
top_k
if
top_k
is
not
None
else
self
.
generating_args
[
"top_k"
])
or
-
1
,
# top_k must > 0
stop
=
stop
,
stop_token_ids
=
self
.
template
.
get_stop_token_ids
(
self
.
tokenizer
),
max_tokens
=
max_tokens
,
skip_special_tokens
=
skip_special_tokens
if
skip_special_tokens
is
not
None
else
self
.
generating_args
[
"skip_special_tokens"
],
)
if
images
is
not
None
:
# add image features
multi_modal_data
=
{
"image"
:
self
.
template
.
mm_plugin
.
_regularize_images
(
images
,
image_max_pixels
=
self
.
model_args
.
image_max_pixels
,
image_min_pixels
=
self
.
model_args
.
image_min_pixels
,
)[
"images"
]
}
elif
videos
is
not
None
:
multi_modal_data
=
{
"video"
:
self
.
template
.
mm_plugin
.
_regularize_videos
(
videos
,
image_max_pixels
=
self
.
model_args
.
video_max_pixels
,
image_min_pixels
=
self
.
model_args
.
video_min_pixels
,
video_fps
=
self
.
model_args
.
video_fps
,
video_maxlen
=
self
.
model_args
.
video_maxlen
,
)[
"videos"
]
}
elif
audios
is
not
None
:
audio_data
=
self
.
template
.
mm_plugin
.
_regularize_audios
(
audios
,
sampling_rate
=
self
.
model_args
.
audio_sampling_rate
,
)
multi_modal_data
=
{
"audio"
:
zip
(
audio_data
[
"audios"
],
audio_data
[
"sampling_rates"
])}
else
:
multi_modal_data
=
None
result_generator
=
self
.
model
.
generate
(
{
"prompt_token_ids"
:
prompt_ids
,
"multi_modal_data"
:
multi_modal_data
},
sampling_params
=
sampling_params
,
request_id
=
request_id
,
lora_request
=
self
.
lora_request
,
)
return
result_generator
@
override
async
def
chat
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
list
[
"Response"
]:
final_output
=
None
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
async
for
request_output
in
generator
:
final_output
=
request_output
results
=
[]
for
output
in
final_output
.
outputs
:
results
.
append
(
Response
(
response_text
=
output
.
text
,
response_length
=
len
(
output
.
token_ids
),
prompt_length
=
len
(
final_output
.
prompt_token_ids
),
finish_reason
=
output
.
finish_reason
,
)
)
return
results
@
override
async
def
stream_chat
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
generated_text
=
""
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
async
for
result
in
generator
:
delta_text
=
result
.
outputs
[
0
].
text
[
len
(
generated_text
)
:]
generated_text
=
result
.
outputs
[
0
].
text
yield
delta_text
@
override
async
def
get_scores
(
self
,
batch_input
:
list
[
str
],
**
input_kwargs
,
)
->
list
[
float
]:
raise
NotImplementedError
(
"vLLM engine does not support `get_scores`."
)
src/llamafactory/cli.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
subprocess
import
sys
from
copy
import
deepcopy
from
functools
import
partial
USAGE
=
(
"-"
*
70
+
"
\n
"
+
"| Usage: |
\n
"
+
"| llamafactory-cli api -h: launch an OpenAI-style API server |
\n
"
+
"| llamafactory-cli chat -h: launch a chat interface in CLI |
\n
"
+
"| llamafactory-cli eval -h: evaluate models |
\n
"
+
"| llamafactory-cli export -h: merge LoRA adapters and export model |
\n
"
+
"| llamafactory-cli train -h: train models |
\n
"
+
"| llamafactory-cli webchat -h: launch a chat interface in Web UI |
\n
"
+
"| llamafactory-cli webui: launch LlamaBoard |
\n
"
+
"| llamafactory-cli version: show version info |
\n
"
+
"-"
*
70
)
def
main
():
from
.
import
launcher
from
.api.app
import
run_api
from
.chat.chat_model
import
run_chat
from
.eval.evaluator
import
run_eval
from
.extras
import
logging
from
.extras.env
import
VERSION
,
print_env
from
.extras.misc
import
find_available_port
,
get_device_count
,
is_env_enabled
,
use_ray
from
.train.tuner
import
export_model
,
run_exp
from
.webui.interface
import
run_web_demo
,
run_web_ui
logger
=
logging
.
get_logger
(
__name__
)
WELCOME
=
(
"-"
*
58
+
"
\n
"
+
f
"| Welcome to LLaMA Factory, version
{
VERSION
}
"
+
" "
*
(
21
-
len
(
VERSION
))
+
"|
\n
|"
+
" "
*
56
+
"|
\n
"
+
"| Project page: https://github.com/hiyouga/LLaMA-Factory |
\n
"
+
"-"
*
58
)
COMMAND_MAP
=
{
"api"
:
run_api
,
"chat"
:
run_chat
,
"env"
:
print_env
,
"eval"
:
run_eval
,
"export"
:
export_model
,
"train"
:
run_exp
,
"webchat"
:
run_web_demo
,
"webui"
:
run_web_ui
,
"version"
:
partial
(
print
,
WELCOME
),
"help"
:
partial
(
print
,
USAGE
),
}
command
=
sys
.
argv
.
pop
(
1
)
if
len
(
sys
.
argv
)
>
1
else
"help"
if
command
==
"train"
and
(
is_env_enabled
(
"FORCE_TORCHRUN"
)
or
(
get_device_count
()
>
1
and
not
use_ray
())):
# launch distributed training
nnodes
=
os
.
getenv
(
"NNODES"
,
"1"
)
node_rank
=
os
.
getenv
(
"NODE_RANK"
,
"0"
)
nproc_per_node
=
os
.
getenv
(
"NPROC_PER_NODE"
,
str
(
get_device_count
()))
master_addr
=
os
.
getenv
(
"MASTER_ADDR"
,
"127.0.0.1"
)
master_port
=
os
.
getenv
(
"MASTER_PORT"
,
str
(
find_available_port
()))
logger
.
info_rank0
(
f
"Initializing
{
nproc_per_node
}
distributed tasks at:
{
master_addr
}
:
{
master_port
}
"
)
if
int
(
nnodes
)
>
1
:
logger
.
info_rank0
(
f
"Multi-node training enabled: num nodes:
{
nnodes
}
, node rank:
{
node_rank
}
"
)
# elastic launch support
max_restarts
=
os
.
getenv
(
"MAX_RESTARTS"
,
"0"
)
rdzv_id
=
os
.
getenv
(
"RDZV_ID"
)
min_nnodes
=
os
.
getenv
(
"MIN_NNODES"
)
max_nnodes
=
os
.
getenv
(
"MAX_NNODES"
)
env
=
deepcopy
(
os
.
environ
)
if
is_env_enabled
(
"OPTIM_TORCH"
,
"1"
):
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
env
[
"PYTORCH_CUDA_ALLOC_CONF"
]
=
"expandable_segments:True"
env
[
"TORCH_NCCL_AVOID_RECORD_STREAMS"
]
=
"1"
if
rdzv_id
is
not
None
:
# launch elastic job with fault tolerant support when possible
# see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
rdzv_nnodes
=
nnodes
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
if
min_nnodes
is
not
None
and
max_nnodes
is
not
None
:
rdzv_nnodes
=
f
"
{
min_nnodes
}
:
{
max_nnodes
}
"
process
=
subprocess
.
run
(
(
"torchrun --nnodes {rdzv_nnodes} --nproc-per-node {nproc_per_node} "
"--rdzv-id {rdzv_id} --rdzv-backend c10d --rdzv-endpoint {master_addr}:{master_port} "
"--max-restarts {max_restarts} {file_name} {args}"
)
.
format
(
rdzv_nnodes
=
rdzv_nnodes
,
nproc_per_node
=
nproc_per_node
,
rdzv_id
=
rdzv_id
,
master_addr
=
master_addr
,
master_port
=
master_port
,
max_restarts
=
max_restarts
,
file_name
=
launcher
.
__file__
,
args
=
" "
.
join
(
sys
.
argv
[
1
:]),
)
.
split
(),
env
=
env
,
check
=
True
,
)
else
:
# NOTE: DO NOT USE shell=True to avoid security risk
process
=
subprocess
.
run
(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
)
.
format
(
nnodes
=
nnodes
,
node_rank
=
node_rank
,
nproc_per_node
=
nproc_per_node
,
master_addr
=
master_addr
,
master_port
=
master_port
,
file_name
=
launcher
.
__file__
,
args
=
" "
.
join
(
sys
.
argv
[
1
:]),
)
.
split
(),
env
=
env
,
check
=
True
,
)
sys
.
exit
(
process
.
returncode
)
elif
command
in
COMMAND_MAP
:
COMMAND_MAP
[
command
]()
else
:
print
(
f
"Unknown command:
{
command
}
.
\n
{
USAGE
}
"
)
if
__name__
==
"__main__"
:
from
multiprocessing
import
freeze_support
freeze_support
()
main
()
src/llamafactory/data/__init__.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.collator
import
(
KTODataCollatorWithPadding
,
MultiModalDataCollatorForSeq2Seq
,
PairwiseDataCollatorWithPadding
,
SFTDataCollatorWith4DAttentionMask
,
)
from
.data_utils
import
Role
,
split_dataset
from
.loader
import
get_dataset
from
.template
import
TEMPLATES
,
Template
,
get_template_and_fix_tokenizer
__all__
=
[
"TEMPLATES"
,
"KTODataCollatorWithPadding"
,
"MultiModalDataCollatorForSeq2Seq"
,
"PairwiseDataCollatorWithPadding"
,
"Role"
,
"SFTDataCollatorWith4DAttentionMask"
,
"Template"
,
"get_dataset"
,
"get_template_and_fix_tokenizer"
,
"split_dataset"
,
]
src/llamafactory/data/collator.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 OpenAccess AI Collective and the LlamaFactory team.
#
# This code is inspired by the OpenAccess AI Collective's axolotl library.
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
Optional
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
peft
import
PeftModel
from
transformers
import
DataCollatorForSeq2Seq
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IGNORE_INDEX
,
IMAGE_PLACEHOLDER
from
..extras.packages
import
is_pillow_available
if
is_pillow_available
():
from
PIL
import
Image
if
TYPE_CHECKING
:
from
transformers
import
ProcessorMixin
from
.template
import
Template
def
prepare_4d_attention_mask
(
attention_mask_with_indices
:
"torch.Tensor"
,
dtype
:
"torch.dtype"
)
->
"torch.Tensor"
:
r
"""Expand 2d attention mask to 4d attention mask.
Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g.
```python
# input
[[1, 1, 2, 2, 2, 0]]
# output
[
[
[
[o, x, x, x, x, x],
[o, o, x, x, x, x],
[x, x, o, x, x, x],
[x, x, o, o, x, x],
[x, x, o, o, o, x],
[x, x, x, x, x, x],
]
]
]
```
where `o` equals to `0.0`, `x` equals to `min_dtype`.
"""
_
,
seq_len
=
attention_mask_with_indices
.
size
()
min_dtype
=
torch
.
finfo
(
dtype
).
min
zero_tensor
=
torch
.
tensor
(
0
,
dtype
=
dtype
)
# Create a non-padding mask.
non_padding_mask
=
(
attention_mask_with_indices
!=
0
).
unsqueeze
(
1
).
unsqueeze
(
2
)
# Create indices for comparison.
indices
=
attention_mask_with_indices
.
unsqueeze
(
1
).
unsqueeze
(
2
)
# [bsz, 1, 1, seq_len]
indices_t
=
attention_mask_with_indices
.
unsqueeze
(
1
).
unsqueeze
(
3
)
# [bsz, 1, seq_len, 1]
# Create a lower triangular mask.
tril_mask
=
torch
.
tril
(
torch
.
ones
((
seq_len
,
seq_len
),
dtype
=
torch
.
bool
))
attention_mask_4d
=
(
indices
==
indices_t
)
&
non_padding_mask
&
tril_mask
# Invert the attention mask.
attention_mask_4d
=
torch
.
where
(
attention_mask_4d
,
zero_tensor
,
min_dtype
)
return
attention_mask_4d
@
dataclass
class
MultiModalDataCollatorForSeq2Seq
(
DataCollatorForSeq2Seq
):
r
"""Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
"""
template
:
Optional
[
"Template"
]
=
None
processor
:
Optional
[
"ProcessorMixin"
]
=
None
def
__post_init__
(
self
):
if
self
.
template
is
None
:
raise
ValueError
(
"Template is required for MultiModalDataCollator."
)
if
isinstance
(
self
.
model
,
PeftModel
):
self
.
model
=
self
.
model
.
base_model
.
model
if
self
.
model
is
not
None
and
hasattr
(
self
.
model
,
"get_rope_index"
):
# for qwen2vl mrope
self
.
get_rope_func
=
self
.
model
.
get_rope_index
# transformers < 4.52.0 or qwen2.5 omni
elif
self
.
model
is
not
None
and
hasattr
(
self
.
model
,
"model"
)
and
hasattr
(
self
.
model
.
model
,
"get_rope_index"
):
self
.
get_rope_func
=
self
.
model
.
model
.
get_rope_index
# transformers >= 4.52.0
else
:
self
.
get_rope_func
=
None
def
__call__
(
self
,
features
:
list
[
dict
[
str
,
Any
]])
->
dict
[
str
,
"torch.Tensor"
]:
batch_images
,
batch_videos
,
batch_audios
=
[],
[],
[]
batch_imglens
,
batch_vidlens
,
batch_audlens
,
batch_input_ids
=
[],
[],
[],
[]
for
feature
in
features
:
images
=
feature
.
pop
(
"images"
,
None
)
or
[]
videos
=
feature
.
pop
(
"videos"
,
None
)
or
[]
audios
=
feature
.
pop
(
"audios"
,
None
)
or
[]
batch_images
.
extend
(
images
)
batch_videos
.
extend
(
videos
)
batch_audios
.
extend
(
audios
)
batch_imglens
.
append
(
len
(
images
))
batch_vidlens
.
append
(
len
(
videos
))
batch_audlens
.
append
(
len
(
audios
))
batch_input_ids
.
append
(
feature
[
"input_ids"
])
fake_input_ids
=
[]
if
(
self
.
template
.
mm_plugin
.
image_token
is
not
None
and
sum
(
batch_imglens
)
==
0
and
sum
(
batch_vidlens
)
==
0
):
# avoid process hanging in zero3/fsdp case
fake_messages
=
[{
"role"
:
"user"
,
"content"
:
IMAGE_PLACEHOLDER
}]
fake_images
=
[
Image
.
new
(
"RGB"
,
(
64
,
64
),
(
255
,
255
,
255
))]
fake_messages
=
self
.
template
.
mm_plugin
.
process_messages
(
fake_messages
,
fake_images
,
[],
[],
self
.
processor
)
_fake_input_ids
=
self
.
tokenizer
.
encode
(
fake_messages
[
0
][
"content"
],
add_special_tokens
=
False
)
_fake_input_ids
,
_
=
self
.
template
.
mm_plugin
.
process_token_ids
(
_fake_input_ids
,
None
,
fake_images
,
[],
[],
self
.
tokenizer
,
self
.
processor
)
fake_input_ids
.
extend
(
_fake_input_ids
)
batch_images
=
fake_images
batch_imglens
[
0
]
=
1
if
(
self
.
template
.
mm_plugin
.
audio_token
is
not
None
and
sum
(
batch_audlens
)
==
0
):
# avoid process hanging in zero3/fsdp case
fake_messages
=
[{
"role"
:
"user"
,
"content"
:
AUDIO_PLACEHOLDER
}]
fake_audios
=
[
np
.
zeros
(
1600
)]
fake_messages
=
self
.
template
.
mm_plugin
.
process_messages
(
fake_messages
,
[],
[],
fake_audios
,
self
.
processor
)
_fake_input_ids
=
self
.
tokenizer
.
encode
(
fake_messages
[
0
][
"content"
],
add_special_tokens
=
False
)
_fake_input_ids
,
_
=
self
.
template
.
mm_plugin
.
process_token_ids
(
_fake_input_ids
,
None
,
[],
[],
fake_audios
,
self
.
tokenizer
,
self
.
processor
)
fake_input_ids
.
extend
(
_fake_input_ids
)
batch_audios
=
fake_audios
batch_audlens
[
0
]
=
1
if
len
(
fake_input_ids
)
!=
0
:
if
self
.
tokenizer
.
padding_side
==
"right"
:
features
[
0
][
"input_ids"
]
=
features
[
0
][
"input_ids"
]
+
fake_input_ids
features
[
0
][
"attention_mask"
]
=
features
[
0
][
"attention_mask"
]
+
[
0
]
*
len
(
fake_input_ids
)
features
[
0
][
"labels"
]
=
features
[
0
][
"labels"
]
+
[
IGNORE_INDEX
]
*
len
(
fake_input_ids
)
else
:
features
[
0
][
"input_ids"
]
=
fake_input_ids
+
features
[
0
][
"input_ids"
]
features
[
0
][
"attention_mask"
]
=
[
0
]
*
len
(
fake_input_ids
)
+
features
[
0
][
"attention_mask"
]
features
[
0
][
"labels"
]
=
[
IGNORE_INDEX
]
*
len
(
fake_input_ids
)
+
features
[
0
][
"labels"
]
batch_input_ids
[
0
]
=
features
[
0
][
"input_ids"
]
mm_inputs
=
self
.
template
.
mm_plugin
.
get_mm_inputs
(
batch_images
,
batch_videos
,
batch_audios
,
batch_imglens
,
batch_vidlens
,
batch_audlens
,
batch_input_ids
,
self
.
processor
,
)
if
"token_type_ids"
in
mm_inputs
:
token_type_ids
=
mm_inputs
.
pop
(
"token_type_ids"
)
for
i
,
feature
in
enumerate
(
features
):
feature
[
"token_type_ids"
]
=
token_type_ids
[
i
]
features
:
dict
[
str
,
torch
.
Tensor
]
=
super
().
__call__
(
features
)
if
self
.
get_rope_func
is
not
None
:
rope_index_kwargs
=
{
"input_ids"
:
features
[
"input_ids"
],
"image_grid_thw"
:
mm_inputs
.
get
(
"image_grid_thw"
),
"video_grid_thw"
:
mm_inputs
.
get
(
"video_grid_thw"
),
"attention_mask"
:
(
features
[
"attention_mask"
]
>=
1
).
float
(),
}
if
"second_per_grid_ts"
in
mm_inputs
:
# for qwen2vl
rope_index_kwargs
[
"second_per_grid_ts"
]
=
mm_inputs
.
get
(
"second_per_grid_ts"
)
elif
"video_second_per_grid"
in
mm_inputs
:
# for qwen2.5 omni
rope_index_kwargs
[
"second_per_grids"
]
=
mm_inputs
.
get
(
"video_second_per_grid"
)
if
getattr
(
self
.
model
.
config
,
"model_type"
,
None
)
==
"qwen2_5_omni_thinker"
:
# for qwen2.5 omni
rope_index_kwargs
[
"use_audio_in_video"
]
=
getattr
(
self
.
processor
,
"use_audio_in_video"
,
False
)
feature_attention_mask
=
mm_inputs
.
get
(
"feature_attention_mask"
,
None
)
if
feature_attention_mask
is
not
None
:
# FIXME: need to get video image lengths
audio_feature_lengths
=
torch
.
sum
(
feature_attention_mask
,
dim
=
1
)
rope_index_kwargs
[
"audio_seqlens"
]
=
audio_feature_lengths
# prepare for input
features
[
"position_ids"
],
rope_deltas
=
self
.
get_rope_func
(
**
rope_index_kwargs
)
features
[
"rope_deltas"
]
=
rope_deltas
-
(
1
-
rope_index_kwargs
[
"attention_mask"
]).
sum
(
dim
=-
1
).
unsqueeze
(
-
1
)
else
:
# for qwen2vl
features
[
"position_ids"
],
features
[
"rope_deltas"
]
=
self
.
get_rope_func
(
**
rope_index_kwargs
)
if
(
self
.
model
is
not
None
and
getattr
(
self
.
model
.
config
,
"model_type"
,
None
)
in
[
"glm4v"
,
"qwen2_vl"
,
"qwen2_5_vl"
,
"qwen2_5_omni_thinker"
]
and
(
"position_ids"
not
in
features
or
features
[
"position_ids"
].
dim
()
!=
3
)
):
raise
ValueError
(
"Qwen2-VL/Qwen2.5-Omni model requires 3D position ids for mrope."
)
if
"cross_attention_mask"
in
mm_inputs
:
# for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask
=
mm_inputs
.
pop
(
"cross_attention_mask"
)
seq_len
=
features
[
"input_ids"
].
size
(
1
)
orig_len
=
cross_attention_mask
.
size
(
1
)
mm_inputs
[
"cross_attention_mask"
]
=
F
.
pad
(
cross_attention_mask
,
(
0
,
0
,
0
,
0
,
0
,
seq_len
-
orig_len
))
features
.
update
(
mm_inputs
)
if
"image_bound"
in
features
:
# for minicpmv inputs
bsz
,
seq_length
=
features
[
"input_ids"
].
shape
features
[
"position_ids"
]
=
torch
.
arange
(
seq_length
).
long
().
repeat
(
bsz
,
1
)
return
{
"data"
:
features
,
"input_ids"
:
features
[
"input_ids"
],
"labels"
:
features
[
"labels"
]}
return
features
@
dataclass
class
SFTDataCollatorWith4DAttentionMask
(
MultiModalDataCollatorForSeq2Seq
):
r
"""Data collator for 4d attention mask."""
block_diag_attn
:
bool
=
False
attn_implementation
:
Literal
[
"eager"
,
"sdpa"
,
"flash_attention_2"
]
=
"eager"
compute_dtype
:
"torch.dtype"
=
torch
.
float32
def
__call__
(
self
,
features
:
list
[
dict
[
str
,
Any
]])
->
dict
[
str
,
"torch.Tensor"
]:
features
=
super
().
__call__
(
features
)
if
self
.
block_diag_attn
and
self
.
attn_implementation
!=
"flash_attention_2"
:
features
[
"attention_mask"
]
=
prepare_4d_attention_mask
(
features
[
"attention_mask"
],
self
.
compute_dtype
)
for
key
,
value
in
features
.
items
():
# cast data dtype for paligemma
if
torch
.
is_tensor
(
value
)
and
torch
.
is_floating_point
(
value
):
features
[
key
]
=
value
.
to
(
self
.
compute_dtype
)
return
features
@
dataclass
class
PairwiseDataCollatorWithPadding
(
MultiModalDataCollatorForSeq2Seq
):
r
"""Data collator for pairwise data."""
def
__call__
(
self
,
features
:
list
[
dict
[
str
,
Any
]])
->
dict
[
str
,
"torch.Tensor"
]:
r
"""Pad batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
concatenated_features
=
[]
for
key
in
(
"chosen"
,
"rejected"
):
for
feature
in
features
:
target_feature
=
{
"input_ids"
:
feature
[
f
"
{
key
}
_input_ids"
],
"attention_mask"
:
feature
[
f
"
{
key
}
_attention_mask"
],
"labels"
:
feature
[
f
"
{
key
}
_labels"
],
"images"
:
feature
[
"images"
],
"videos"
:
feature
[
"videos"
],
"audios"
:
feature
[
"audios"
],
}
concatenated_features
.
append
(
target_feature
)
return
super
().
__call__
(
concatenated_features
)
@
dataclass
class
KTODataCollatorWithPadding
(
MultiModalDataCollatorForSeq2Seq
):
r
"""Data collator for KTO data."""
def
__call__
(
self
,
features
:
list
[
dict
[
str
,
Any
]])
->
dict
[
str
,
"torch.Tensor"
]:
target_features
=
[]
kl_features
=
[]
kto_tags
=
[]
for
feature
in
features
:
target_feature
=
{
"input_ids"
:
feature
[
"input_ids"
],
"attention_mask"
:
feature
[
"attention_mask"
],
"labels"
:
feature
[
"labels"
],
"images"
:
feature
[
"images"
],
"videos"
:
feature
[
"videos"
],
"audios"
:
feature
[
"audios"
],
}
kl_feature
=
{
"input_ids"
:
feature
[
"kl_input_ids"
],
"attention_mask"
:
feature
[
"kl_attention_mask"
],
"labels"
:
feature
[
"kl_labels"
],
"images"
:
feature
[
"images"
],
"videos"
:
feature
[
"videos"
],
"audios"
:
feature
[
"audios"
],
}
target_features
.
append
(
target_feature
)
kl_features
.
append
(
kl_feature
)
kto_tags
.
append
(
feature
[
"kto_tags"
])
batch
=
super
().
__call__
(
target_features
)
kl_batch
=
super
().
__call__
(
kl_features
)
batch
[
"kl_input_ids"
]
=
kl_batch
[
"input_ids"
]
batch
[
"kl_attention_mask"
]
=
kl_batch
[
"attention_mask"
]
batch
[
"kl_labels"
]
=
kl_batch
[
"labels"
]
if
"cross_attention_mask"
in
kl_batch
:
# for mllama inputs
batch
[
"kl_cross_attention_mask"
]
=
kl_batch
[
"cross_attention_mask"
]
if
"token_type_ids"
in
kl_batch
:
batch
[
"kl_token_type_ids"
]
=
kl_batch
[
"token_type_ids"
]
batch
[
"kto_tags"
]
=
torch
.
tensor
(
kto_tags
)
return
batch
src/llamafactory/data/converter.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
from
..extras
import
logging
from
.data_utils
import
Role
if
TYPE_CHECKING
:
from
datasets
import
Dataset
,
IterableDataset
from
transformers
import
Seq2SeqTrainingArguments
from
..hparams
import
DataArguments
from
.mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
from
.parser
import
DatasetAttr
MediaType
=
Union
[
ImageInput
,
VideoInput
,
AudioInput
]
logger
=
logging
.
get_logger
(
__name__
)
@
dataclass
class
DatasetConverter
:
dataset_attr
:
"DatasetAttr"
data_args
:
"DataArguments"
def
_find_medias
(
self
,
medias
:
Union
[
"MediaType"
,
list
[
"MediaType"
],
None
])
->
Optional
[
list
[
"MediaType"
]]:
r
"""Optionally concatenate media path to media dir when loading from local disk."""
if
medias
is
None
:
return
None
elif
not
isinstance
(
medias
,
list
):
medias
=
[
medias
]
elif
len
(
medias
)
==
0
:
return
None
else
:
medias
=
medias
[:]
if
self
.
dataset_attr
.
load_from
in
[
"script"
,
"file"
]:
if
isinstance
(
medias
[
0
],
str
):
for
i
in
range
(
len
(
medias
)):
media_path
=
os
.
path
.
join
(
self
.
data_args
.
media_dir
,
medias
[
i
])
if
os
.
path
.
isfile
(
media_path
):
medias
[
i
]
=
media_path
else
:
logger
.
warning_rank0_once
(
f
"Media
{
medias
[
i
]
}
does not exist in `media_dir`. Use original path."
)
elif
isinstance
(
medias
[
0
],
list
):
# for processed video frames
# medias is a list of lists, e.g., [[frame1.jpg, frame2.jpg], [frame3.jpg, frame4.jpg]]
for
i
in
range
(
len
(
medias
)):
for
j
in
range
(
len
(
medias
[
i
])):
media_path
=
os
.
path
.
join
(
self
.
data_args
.
media_dir
,
medias
[
i
][
j
])
if
os
.
path
.
isfile
(
media_path
):
medias
[
i
][
j
]
=
media_path
else
:
logger
.
warning_rank0_once
(
f
"Media
{
medias
[
i
][
j
]
}
does not exist in `media_dir`. Use original path."
)
return
medias
@
abstractmethod
def
__call__
(
self
,
example
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
r
"""Convert a single example in the dataset to the standard format."""
...
@
dataclass
class
AlpacaDatasetConverter
(
DatasetConverter
):
def
__call__
(
self
,
example
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
prompt
=
[]
if
self
.
dataset_attr
.
history
and
isinstance
(
example
[
self
.
dataset_attr
.
history
],
list
):
for
old_prompt
,
old_response
in
example
[
self
.
dataset_attr
.
history
]:
prompt
.
append
({
"role"
:
Role
.
USER
.
value
,
"content"
:
old_prompt
})
prompt
.
append
({
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
old_response
})
query
=
[]
if
self
.
dataset_attr
.
prompt
and
example
[
self
.
dataset_attr
.
prompt
]:
query
.
append
(
example
[
self
.
dataset_attr
.
prompt
])
if
self
.
dataset_attr
.
query
and
example
[
self
.
dataset_attr
.
query
]:
query
.
append
(
example
[
self
.
dataset_attr
.
query
])
prompt
.
append
({
"role"
:
Role
.
USER
.
value
,
"content"
:
"
\n
"
.
join
(
query
)})
# "prompt\nquery"
if
self
.
dataset_attr
.
kto_tag
and
isinstance
(
example
[
self
.
dataset_attr
.
kto_tag
],
bool
):
# kto example
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
example
[
self
.
dataset_attr
.
response
]}]
if
example
[
self
.
dataset_attr
.
kto_tag
]:
response
=
response
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
else
:
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
+
response
elif
(
self
.
dataset_attr
.
ranking
and
isinstance
(
example
[
self
.
dataset_attr
.
chosen
],
str
)
and
isinstance
(
example
[
self
.
dataset_attr
.
rejected
],
str
)
):
# pairwise example
response
=
[
{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
example
[
self
.
dataset_attr
.
chosen
]},
{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
example
[
self
.
dataset_attr
.
rejected
]},
]
elif
self
.
dataset_attr
.
response
and
isinstance
(
example
[
self
.
dataset_attr
.
response
],
str
):
# normal example
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
example
[
self
.
dataset_attr
.
response
]}]
else
:
# unsupervised
response
=
[]
output
=
{
"_prompt"
:
prompt
,
"_response"
:
response
,
"_system"
:
example
[
self
.
dataset_attr
.
system
]
if
self
.
dataset_attr
.
system
else
""
,
"_tools"
:
example
[
self
.
dataset_attr
.
tools
]
if
self
.
dataset_attr
.
tools
else
""
,
"_images"
:
self
.
_find_medias
(
example
[
self
.
dataset_attr
.
images
])
if
self
.
dataset_attr
.
images
else
None
,
"_videos"
:
self
.
_find_medias
(
example
[
self
.
dataset_attr
.
videos
])
if
self
.
dataset_attr
.
videos
else
None
,
"_audios"
:
self
.
_find_medias
(
example
[
self
.
dataset_attr
.
audios
])
if
self
.
dataset_attr
.
audios
else
None
,
}
return
output
@
dataclass
class
SharegptDatasetConverter
(
DatasetConverter
):
def
__call__
(
self
,
example
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
tag_mapping
=
{
self
.
dataset_attr
.
user_tag
:
Role
.
USER
.
value
,
self
.
dataset_attr
.
assistant_tag
:
Role
.
ASSISTANT
.
value
,
self
.
dataset_attr
.
observation_tag
:
Role
.
OBSERVATION
.
value
,
self
.
dataset_attr
.
function_tag
:
Role
.
FUNCTION
.
value
,
self
.
dataset_attr
.
system_tag
:
Role
.
SYSTEM
.
value
,
}
odd_tags
=
(
self
.
dataset_attr
.
user_tag
,
self
.
dataset_attr
.
observation_tag
)
even_tags
=
(
self
.
dataset_attr
.
assistant_tag
,
self
.
dataset_attr
.
function_tag
)
accept_tags
=
(
odd_tags
,
even_tags
)
messages
=
example
[
self
.
dataset_attr
.
messages
]
if
(
self
.
dataset_attr
.
system_tag
and
len
(
messages
)
!=
0
and
messages
[
0
][
self
.
dataset_attr
.
role_tag
]
==
self
.
dataset_attr
.
system_tag
):
system
=
messages
[
0
][
self
.
dataset_attr
.
content_tag
]
messages
=
messages
[
1
:]
else
:
system
=
example
[
self
.
dataset_attr
.
system
]
if
self
.
dataset_attr
.
system
else
""
aligned_messages
=
[]
broken_data
=
False
for
turn_idx
,
message
in
enumerate
(
messages
):
if
message
[
self
.
dataset_attr
.
role_tag
]
not
in
accept_tags
[
turn_idx
%
2
]:
logger
.
warning_rank0
(
f
"Invalid role tag in
{
messages
}
."
)
broken_data
=
True
break
aligned_messages
.
append
(
{
"role"
:
tag_mapping
[
message
[
self
.
dataset_attr
.
role_tag
]],
"content"
:
message
[
self
.
dataset_attr
.
content_tag
],
}
)
if
(
not
self
.
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
!=
0
)
or
(
self
.
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
==
0
):
logger
.
warning_rank0
(
f
"Invalid message count in
{
messages
}
."
)
broken_data
=
True
if
broken_data
:
logger
.
warning_rank0
(
"Skipping this abnormal example."
)
prompt
,
response
=
[],
[]
elif
self
.
dataset_attr
.
kto_tag
and
isinstance
(
example
[
self
.
dataset_attr
.
kto_tag
],
bool
):
# kto example
prompt
=
aligned_messages
[:
-
1
]
response
=
aligned_messages
[
-
1
:]
if
example
[
self
.
dataset_attr
.
kto_tag
]:
response
=
response
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
else
:
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
+
response
elif
(
self
.
dataset_attr
.
ranking
and
isinstance
(
example
[
self
.
dataset_attr
.
chosen
],
dict
)
and
isinstance
(
example
[
self
.
dataset_attr
.
rejected
],
dict
)
):
# pairwise example
chosen
=
example
[
self
.
dataset_attr
.
chosen
]
rejected
=
example
[
self
.
dataset_attr
.
rejected
]
if
(
chosen
[
self
.
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
or
rejected
[
self
.
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
):
logger
.
warning_rank0
(
f
"Invalid role tag in
{
[
chosen
,
rejected
]
}
."
)
broken_data
=
True
prompt
=
aligned_messages
response
=
[
{
"role"
:
tag_mapping
[
chosen
[
self
.
dataset_attr
.
role_tag
]],
"content"
:
chosen
[
self
.
dataset_attr
.
content_tag
],
},
{
"role"
:
tag_mapping
[
rejected
[
self
.
dataset_attr
.
role_tag
]],
"content"
:
rejected
[
self
.
dataset_attr
.
content_tag
],
},
]
else
:
# normal example
prompt
=
aligned_messages
[:
-
1
]
response
=
aligned_messages
[
-
1
:]
output
=
{
"_prompt"
:
prompt
,
"_response"
:
response
,
"_system"
:
system
,
"_tools"
:
example
[
self
.
dataset_attr
.
tools
]
if
self
.
dataset_attr
.
tools
else
""
,
"_images"
:
self
.
_find_medias
(
example
[
self
.
dataset_attr
.
images
])
if
self
.
dataset_attr
.
images
else
None
,
"_videos"
:
self
.
_find_medias
(
example
[
self
.
dataset_attr
.
videos
])
if
self
.
dataset_attr
.
videos
else
None
,
"_audios"
:
self
.
_find_medias
(
example
[
self
.
dataset_attr
.
audios
])
if
self
.
dataset_attr
.
audios
else
None
,
}
return
output
DATASET_CONVERTERS
=
{
"alpaca"
:
AlpacaDatasetConverter
,
"sharegpt"
:
SharegptDatasetConverter
,
}
def
register_dataset_converter
(
name
:
str
,
dataset_converter
:
type
[
"DatasetConverter"
])
->
None
:
r
"""Register a new dataset converter."""
if
name
in
DATASET_CONVERTERS
:
raise
ValueError
(
f
"Dataset converter
{
name
}
already exists."
)
DATASET_CONVERTERS
[
name
]
=
dataset_converter
def
get_dataset_converter
(
name
:
str
,
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
)
->
"DatasetConverter"
:
r
"""Get a dataset converter."""
if
name
not
in
DATASET_CONVERTERS
:
raise
ValueError
(
f
"Dataset converter
{
name
}
not found."
)
return
DATASET_CONVERTERS
[
name
](
dataset_attr
,
data_args
)
def
align_dataset
(
dataset
:
Union
[
"Dataset"
,
"IterableDataset"
],
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
r
"""Align the dataset to a specific format.
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_tools: "..."
_images: []
_videos: []
_audios: []
"""
column_names
=
list
(
next
(
iter
(
dataset
)).
keys
())
kwargs
=
{}
if
not
data_args
.
streaming
:
kwargs
=
dict
(
num_proc
=
data_args
.
preprocessing_num_workers
,
load_from_cache_file
=
(
not
data_args
.
overwrite_cache
)
or
(
training_args
.
local_process_index
!=
0
),
desc
=
"Converting format of dataset"
,
)
dataset_converter
=
get_dataset_converter
(
dataset_attr
.
formatting
,
dataset_attr
,
data_args
)
return
dataset
.
map
(
dataset_converter
,
batched
=
False
,
remove_columns
=
column_names
,
**
kwargs
,
)
src/llamafactory/data/data_utils.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
from
enum
import
Enum
,
unique
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
TypedDict
,
Union
import
fsspec
from
datasets
import
DatasetDict
,
concatenate_datasets
,
interleave_datasets
from
..extras
import
logging
if
TYPE_CHECKING
:
from
datasets
import
Dataset
,
IterableDataset
from
..hparams
import
DataArguments
logger
=
logging
.
get_logger
(
__name__
)
SLOTS
=
list
[
Union
[
str
,
set
[
str
],
dict
[
str
,
str
]]]
@
unique
class
Role
(
str
,
Enum
):
USER
=
"user"
ASSISTANT
=
"assistant"
SYSTEM
=
"system"
FUNCTION
=
"function"
OBSERVATION
=
"observation"
class
DatasetModule
(
TypedDict
):
train_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]
eval_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
dict
[
str
,
"Dataset"
]]]
def
merge_dataset
(
all_datasets
:
list
[
Union
[
"Dataset"
,
"IterableDataset"
]],
data_args
:
"DataArguments"
,
seed
:
int
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
r
"""Merge multiple datasets to a unified dataset."""
if
len
(
all_datasets
)
==
1
:
return
all_datasets
[
0
]
elif
data_args
.
mix_strategy
==
"concat"
:
if
data_args
.
streaming
:
logger
.
warning_rank0_once
(
"The samples between different datasets will not be mixed in streaming mode."
)
return
concatenate_datasets
(
all_datasets
)
elif
data_args
.
mix_strategy
.
startswith
(
"interleave"
):
if
not
data_args
.
streaming
:
logger
.
warning_rank0_once
(
"We recommend using `mix_strategy=concat` in non-streaming mode."
)
return
interleave_datasets
(
datasets
=
all_datasets
,
probabilities
=
data_args
.
interleave_probs
,
seed
=
seed
,
stopping_strategy
=
"first_exhausted"
if
data_args
.
mix_strategy
.
endswith
(
"under"
)
else
"all_exhausted"
,
)
else
:
raise
ValueError
(
f
"Unknown mixing strategy:
{
data_args
.
mix_strategy
}
."
)
def
split_dataset
(
dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]],
eval_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
dict
[
str
,
"Dataset"
]]],
data_args
:
"DataArguments"
,
seed
:
int
,
)
->
"DatasetDict"
:
r
"""Split the dataset and returns a dataset dict containing train set and validation set.
Support both map dataset and iterable dataset.
"""
if
eval_dataset
is
not
None
and
data_args
.
val_size
>
1e-6
:
raise
ValueError
(
"Cannot specify `val_size` if `eval_dataset` is not None."
)
dataset_dict
=
{}
if
dataset
is
not
None
:
if
data_args
.
streaming
:
dataset
=
dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
seed
)
if
data_args
.
val_size
>
1e-6
:
if
data_args
.
streaming
:
dataset_dict
[
"validation"
]
=
dataset
.
take
(
int
(
data_args
.
val_size
))
dataset_dict
[
"train"
]
=
dataset
.
skip
(
int
(
data_args
.
val_size
))
else
:
val_size
=
int
(
data_args
.
val_size
)
if
data_args
.
val_size
>
1
else
data_args
.
val_size
dataset_dict
=
dataset
.
train_test_split
(
test_size
=
val_size
,
seed
=
seed
)
dataset
=
dataset
.
train_test_split
(
test_size
=
val_size
,
seed
=
seed
)
dataset_dict
=
{
"train"
:
dataset
[
"train"
],
"validation"
:
dataset
[
"test"
]}
else
:
dataset_dict
[
"train"
]
=
dataset
if
eval_dataset
is
not
None
:
if
isinstance
(
eval_dataset
,
dict
):
dataset_dict
.
update
({
f
"validation_
{
name
}
"
:
data
for
name
,
data
in
eval_dataset
.
items
()})
else
:
if
data_args
.
streaming
:
eval_dataset
=
eval_dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
seed
)
dataset_dict
[
"validation"
]
=
eval_dataset
return
DatasetDict
(
dataset_dict
)
def
get_dataset_module
(
dataset
:
Union
[
"Dataset"
,
"DatasetDict"
])
->
"DatasetModule"
:
r
"""Convert dataset or dataset dict to dataset module."""
dataset_module
:
DatasetModule
=
{}
if
isinstance
(
dataset
,
DatasetDict
):
# dataset dict
if
"train"
in
dataset
:
dataset_module
[
"train_dataset"
]
=
dataset
[
"train"
]
if
"validation"
in
dataset
:
dataset_module
[
"eval_dataset"
]
=
dataset
[
"validation"
]
else
:
eval_dataset
=
{}
for
key
in
dataset
.
keys
():
if
key
.
startswith
(
"validation_"
):
eval_dataset
[
key
[
len
(
"validation_"
)
:]]
=
dataset
[
key
]
if
len
(
eval_dataset
):
dataset_module
[
"eval_dataset"
]
=
eval_dataset
else
:
# single dataset
dataset_module
[
"train_dataset"
]
=
dataset
return
dataset_module
def
setup_fs
(
path
:
str
,
anon
:
bool
=
False
)
->
"fsspec.AbstractFileSystem"
:
r
"""Set up a filesystem object based on the path protocol."""
storage_options
=
{
"anon"
:
anon
}
if
anon
else
{}
if
path
.
startswith
(
"s3://"
):
fs
=
fsspec
.
filesystem
(
"s3"
,
**
storage_options
)
elif
path
.
startswith
((
"gs://"
,
"gcs://"
)):
fs
=
fsspec
.
filesystem
(
"gcs"
,
**
storage_options
)
else
:
raise
ValueError
(
f
"Unsupported protocol in path:
{
path
}
. Use 's3://' or 'gs://'."
)
if
not
fs
.
exists
(
path
):
raise
ValueError
(
f
"Path does not exist:
{
path
}
."
)
return
fs
def
_read_json_with_fs
(
fs
:
"fsspec.AbstractFileSystem"
,
path
:
str
)
->
list
[
Any
]:
r
"""Helper function to read JSON/JSONL files using fsspec."""
with
fs
.
open
(
path
,
"r"
)
as
f
:
if
path
.
endswith
(
".jsonl"
):
return
[
json
.
loads
(
line
)
for
line
in
f
if
line
.
strip
()]
else
:
return
json
.
load
(
f
)
def
read_cloud_json
(
cloud_path
:
str
)
->
list
[
Any
]:
r
"""Read a JSON/JSONL file from cloud storage (S3 or GCS).
Args:
cloud_path: str
Cloud path in the format:
- 's3://bucket-name/file.json' for AWS S3
- 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage
"""
try
:
fs
=
setup_fs
(
cloud_path
,
anon
=
True
)
# try with anonymous access first
except
Exception
:
fs
=
setup_fs
(
cloud_path
)
# try again with credentials
# filter out non-JSON files
files
=
[
x
[
"Key"
]
for
x
in
fs
.
listdir
(
cloud_path
)]
if
fs
.
isdir
(
cloud_path
)
else
[
cloud_path
]
files
=
filter
(
lambda
file
:
file
.
endswith
(
".json"
)
or
file
.
endswith
(
".jsonl"
),
files
)
if
not
files
:
raise
ValueError
(
f
"No JSON/JSONL files found in the specified path:
{
cloud_path
}
."
)
return
sum
([
_read_json_with_fs
(
fs
,
file
)
for
file
in
files
],
[])
src/llamafactory/data/formatter.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
re
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
,
Union
from
typing_extensions
import
override
from
.data_utils
import
SLOTS
from
.tool_utils
import
FunctionCall
,
get_tool_utils
@
dataclass
class
Formatter
(
ABC
):
slots
:
SLOTS
=
field
(
default_factory
=
list
)
tool_format
:
Optional
[
str
]
=
None
@
abstractmethod
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
r
"""Forms a list of slots according to the inputs to encode."""
...
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
r
"""Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments.
"""
raise
NotImplementedError
@
dataclass
class
EmptyFormatter
(
Formatter
):
def
__post_init__
(
self
):
has_placeholder
=
False
for
slot
in
filter
(
lambda
s
:
isinstance
(
s
,
str
),
self
.
slots
):
if
re
.
search
(
r
"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}"
,
slot
):
has_placeholder
=
True
if
has_placeholder
:
raise
ValueError
(
"Empty formatter should not contain any placeholder."
)
@
override
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
return
self
.
slots
@
dataclass
class
StringFormatter
(
Formatter
):
def
__post_init__
(
self
):
has_placeholder
=
False
for
slot
in
filter
(
lambda
s
:
isinstance
(
s
,
str
),
self
.
slots
):
if
re
.
search
(
r
"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}"
,
slot
):
has_placeholder
=
True
if
not
has_placeholder
:
raise
ValueError
(
"A placeholder is required in the string formatter."
)
@
override
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
elements
=
[]
for
slot
in
self
.
slots
:
if
isinstance
(
slot
,
str
):
for
name
,
value
in
kwargs
.
items
():
if
not
isinstance
(
value
,
str
):
raise
RuntimeError
(
f
"Expected a string, got
{
value
}
"
)
slot
=
slot
.
replace
(
"{{"
+
name
+
"}}"
,
value
,
1
)
elements
.
append
(
slot
)
elif
isinstance
(
slot
,
(
dict
,
set
)):
elements
.
append
(
slot
)
else
:
raise
RuntimeError
(
f
"Input must be string, set[str] or dict[str, str], got
{
type
(
slot
)
}
."
)
return
elements
@
dataclass
class
FunctionFormatter
(
StringFormatter
):
def
__post_init__
(
self
):
super
().
__post_init__
()
self
.
tool_utils
=
get_tool_utils
(
self
.
tool_format
)
@
override
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
content
:
str
=
kwargs
.
pop
(
"content"
)
regex
=
re
.
compile
(
r
"<think>(.*)</think>"
,
re
.
DOTALL
)
thought
=
re
.
search
(
regex
,
content
)
if
thought
:
content
=
content
.
replace
(
thought
.
group
(
0
),
""
)
functions
:
list
[
FunctionCall
]
=
[]
try
:
tool_calls
=
json
.
loads
(
content
)
if
not
isinstance
(
tool_calls
,
list
):
# parallel function call
tool_calls
=
[
tool_calls
]
for
tool_call
in
tool_calls
:
functions
.
append
(
FunctionCall
(
tool_call
[
"name"
],
json
.
dumps
(
tool_call
[
"arguments"
],
ensure_ascii
=
False
))
)
except
json
.
JSONDecodeError
:
raise
RuntimeError
(
f
"Invalid JSON format in function message:
{
str
([
content
])
}
."
)
# flat string
function_str
=
self
.
tool_utils
.
function_formatter
(
functions
)
if
thought
:
function_str
=
thought
.
group
(
0
)
+
function_str
return
super
().
apply
(
content
=
function_str
)
@
dataclass
class
ToolFormatter
(
Formatter
):
def
__post_init__
(
self
):
self
.
tool_utils
=
get_tool_utils
(
self
.
tool_format
)
@
override
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
content
=
kwargs
.
pop
(
"content"
)
try
:
tools
=
json
.
loads
(
content
)
return
[
self
.
tool_utils
.
tool_formatter
(
tools
)
if
len
(
tools
)
!=
0
else
""
]
except
json
.
JSONDecodeError
:
raise
RuntimeError
(
f
"Invalid JSON format in tool description:
{
str
([
content
])
}
."
)
# flat string
@
override
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
return
self
.
tool_utils
.
tool_extractor
(
content
)
src/llamafactory/data/loader.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
,
Union
import
numpy
as
np
from
datasets
import
Dataset
,
load_dataset
,
load_from_disk
from
..extras
import
logging
from
..extras.constants
import
FILEEXT2TYPE
from
..extras.misc
import
check_version
,
has_tokenized_data
from
.converter
import
align_dataset
from
.data_utils
import
get_dataset_module
,
merge_dataset
,
read_cloud_json
,
split_dataset
from
.parser
import
get_dataset_list
from
.processor
import
(
FeedbackDatasetProcessor
,
PackedSupervisedDatasetProcessor
,
PairwiseDatasetProcessor
,
PretrainDatasetProcessor
,
SupervisedDatasetProcessor
,
UnsupervisedDatasetProcessor
,
)
if
TYPE_CHECKING
:
from
datasets
import
Dataset
,
IterableDataset
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
,
Seq2SeqTrainingArguments
from
..hparams
import
DataArguments
,
ModelArguments
from
.data_utils
import
DatasetModule
from
.parser
import
DatasetAttr
from
.processor
import
DatasetProcessor
from
.template
import
Template
logger
=
logging
.
get_logger
(
__name__
)
def
_load_single_dataset
(
dataset_attr
:
"DatasetAttr"
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
r
"""Load a single dataset and aligns it to the standard format."""
logger
.
info_rank0
(
f
"Loading dataset
{
dataset_attr
}
..."
)
data_path
,
data_name
,
data_dir
,
data_files
=
None
,
None
,
None
,
None
if
dataset_attr
.
load_from
in
[
"hf_hub"
,
"ms_hub"
,
"om_hub"
]:
data_path
=
dataset_attr
.
dataset_name
data_name
=
dataset_attr
.
subset
data_dir
=
dataset_attr
.
folder
elif
dataset_attr
.
load_from
==
"script"
:
data_path
=
os
.
path
.
join
(
data_args
.
dataset_dir
,
dataset_attr
.
dataset_name
)
data_name
=
dataset_attr
.
subset
data_dir
=
dataset_attr
.
folder
elif
dataset_attr
.
load_from
==
"cloud_file"
:
data_path
=
dataset_attr
.
dataset_name
elif
dataset_attr
.
load_from
==
"file"
:
data_files
=
[]
local_path
=
os
.
path
.
join
(
data_args
.
dataset_dir
,
dataset_attr
.
dataset_name
)
if
os
.
path
.
isdir
(
local_path
):
# is directory
for
file_name
in
os
.
listdir
(
local_path
):
data_files
.
append
(
os
.
path
.
join
(
local_path
,
file_name
))
elif
os
.
path
.
isfile
(
local_path
):
# is file
data_files
.
append
(
local_path
)
else
:
raise
ValueError
(
f
"File
{
local_path
}
not found."
)
data_path
=
FILEEXT2TYPE
.
get
(
os
.
path
.
splitext
(
data_files
[
0
])[
-
1
][
1
:],
None
)
if
data_path
is
None
:
raise
ValueError
(
"Allowed file types: {}."
.
format
(
","
.
join
(
FILEEXT2TYPE
.
keys
())))
if
any
(
data_path
!=
FILEEXT2TYPE
.
get
(
os
.
path
.
splitext
(
data_file
)[
-
1
][
1
:],
None
)
for
data_file
in
data_files
):
raise
ValueError
(
"File types should be identical."
)
else
:
raise
NotImplementedError
(
f
"Unknown load type:
{
dataset_attr
.
load_from
}
."
)
if
dataset_attr
.
load_from
==
"ms_hub"
:
check_version
(
"modelscope>=1.14.0"
,
mandatory
=
True
)
from
modelscope
import
MsDataset
# type: ignore
from
modelscope.utils.config_ds
import
MS_DATASETS_CACHE
# type: ignore
cache_dir
=
model_args
.
cache_dir
or
MS_DATASETS_CACHE
dataset
=
MsDataset
.
load
(
dataset_name
=
data_path
,
subset_name
=
data_name
,
data_dir
=
data_dir
,
data_files
=
data_files
,
split
=
dataset_attr
.
split
,
cache_dir
=
cache_dir
,
token
=
model_args
.
ms_hub_token
,
use_streaming
=
data_args
.
streaming
,
)
if
isinstance
(
dataset
,
MsDataset
):
dataset
=
dataset
.
to_hf_dataset
()
elif
dataset_attr
.
load_from
==
"om_hub"
:
check_version
(
"openmind>=0.8.0"
,
mandatory
=
True
)
from
openmind
import
OmDataset
# type: ignore
from
openmind.utils.hub
import
OM_DATASETS_CACHE
# type: ignore
cache_dir
=
model_args
.
cache_dir
or
OM_DATASETS_CACHE
dataset
=
OmDataset
.
load_dataset
(
path
=
data_path
,
name
=
data_name
,
data_dir
=
data_dir
,
data_files
=
data_files
,
split
=
dataset_attr
.
split
,
cache_dir
=
cache_dir
,
token
=
model_args
.
om_hub_token
,
streaming
=
data_args
.
streaming
,
)
elif
dataset_attr
.
load_from
==
"cloud_file"
:
dataset
=
Dataset
.
from_list
(
read_cloud_json
(
data_path
),
split
=
dataset_attr
.
split
)
else
:
dataset
=
load_dataset
(
path
=
data_path
,
name
=
data_name
,
data_dir
=
data_dir
,
data_files
=
data_files
,
split
=
dataset_attr
.
split
,
cache_dir
=
model_args
.
cache_dir
,
token
=
model_args
.
hf_hub_token
,
num_proc
=
data_args
.
preprocessing_num_workers
,
trust_remote_code
=
model_args
.
trust_remote_code
,
streaming
=
data_args
.
streaming
and
dataset_attr
.
load_from
!=
"file"
,
)
if
data_args
.
streaming
and
dataset_attr
.
load_from
==
"file"
:
dataset
=
dataset
.
to_iterable_dataset
(
num_shards
=
training_args
.
dataloader_num_workers
)
if
dataset_attr
.
num_samples
is
not
None
and
not
data_args
.
streaming
:
target_num
=
dataset_attr
.
num_samples
indexes
=
np
.
random
.
permutation
(
len
(
dataset
))[:
target_num
]
# all samples should be included
target_num
-=
len
(
indexes
)
if
target_num
>
0
:
expand_indexes
=
np
.
random
.
choice
(
len
(
dataset
),
target_num
)
indexes
=
np
.
concatenate
((
indexes
,
expand_indexes
),
axis
=
0
)
assert
len
(
indexes
)
==
dataset_attr
.
num_samples
,
"Sample num mismatched."
dataset
=
dataset
.
select
(
indexes
)
logger
.
info_rank0
(
f
"Sampled
{
dataset_attr
.
num_samples
}
examples from dataset
{
dataset_attr
}
."
)
if
data_args
.
max_samples
is
not
None
:
# truncate dataset
max_samples
=
min
(
data_args
.
max_samples
,
len
(
dataset
))
dataset
=
dataset
.
select
(
range
(
max_samples
))
return
align_dataset
(
dataset
,
dataset_attr
,
data_args
,
training_args
)
def
_get_merged_dataset
(
dataset_names
:
Optional
[
list
[
str
]],
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
return_dict
:
bool
=
False
,
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
dict
[
str
,
"Dataset"
]]]:
r
"""Return the merged datasets in the standard format."""
if
dataset_names
is
None
:
return
None
datasets
=
{}
for
dataset_name
,
dataset_attr
in
zip
(
dataset_names
,
get_dataset_list
(
dataset_names
,
data_args
.
dataset_dir
)):
if
(
stage
==
"rm"
and
dataset_attr
.
ranking
is
False
)
or
(
stage
!=
"rm"
and
dataset_attr
.
ranking
is
True
):
raise
ValueError
(
"The dataset is not applicable in the current training stage."
)
datasets
[
dataset_name
]
=
_load_single_dataset
(
dataset_attr
,
model_args
,
data_args
,
training_args
)
if
return_dict
:
return
datasets
else
:
return
merge_dataset
(
list
(
datasets
.
values
()),
data_args
,
seed
=
training_args
.
seed
)
def
_get_dataset_processor
(
data_args
:
"DataArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
do_generate
:
bool
=
False
,
)
->
"DatasetProcessor"
:
r
"""Return the corresponding dataset processor."""
if
stage
==
"pt"
:
dataset_processor_class
=
PretrainDatasetProcessor
elif
stage
==
"sft"
and
not
do_generate
:
if
data_args
.
packing
:
if
data_args
.
neat_packing
:
# hack datasets to have int32 attention mask
from
datasets.arrow_writer
import
OptimizedTypedSequence
,
TypedSequence
def
__init__
(
self
,
data
,
**
kwargs
):
return
TypedSequence
.
__init__
(
self
,
data
,
type
=
kwargs
.
pop
(
"type"
,
None
),
try_type
=
kwargs
.
pop
(
"try_type"
,
None
),
optimized_int_type
=
kwargs
.
pop
(
"optimized_int_type"
,
None
),
)
OptimizedTypedSequence
.
__init__
=
__init__
dataset_processor_class
=
PackedSupervisedDatasetProcessor
else
:
dataset_processor_class
=
SupervisedDatasetProcessor
elif
stage
==
"rm"
:
dataset_processor_class
=
PairwiseDatasetProcessor
elif
stage
==
"kto"
:
dataset_processor_class
=
FeedbackDatasetProcessor
else
:
dataset_processor_class
=
UnsupervisedDatasetProcessor
return
dataset_processor_class
(
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
data_args
=
data_args
)
def
_get_preprocessed_dataset
(
dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]],
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
]
=
None
,
is_eval
:
bool
=
False
,
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]:
r
"""Preprocesses the dataset, including format checking and tokenization."""
if
dataset
is
None
:
return
None
dataset_processor
=
_get_dataset_processor
(
data_args
,
stage
,
template
,
tokenizer
,
processor
,
do_generate
=
(
training_args
.
predict_with_generate
and
is_eval
)
)
column_names
=
list
(
next
(
iter
(
dataset
)).
keys
())
kwargs
=
{}
if
not
data_args
.
streaming
:
kwargs
=
dict
(
num_proc
=
data_args
.
preprocessing_num_workers
,
load_from_cache_file
=
(
not
data_args
.
overwrite_cache
)
or
(
training_args
.
local_process_index
!=
0
),
desc
=
"Running tokenizer on dataset"
,
)
dataset
=
dataset
.
map
(
dataset_processor
.
preprocess_dataset
,
batched
=
True
,
batch_size
=
data_args
.
preprocessing_batch_size
,
remove_columns
=
column_names
,
**
kwargs
,
)
if
training_args
.
should_log
:
try
:
print
(
"eval example:"
if
is_eval
else
"training example:"
)
dataset_processor
.
print_data_example
(
next
(
iter
(
dataset
)))
except
StopIteration
:
if
stage
==
"pt"
:
raise
RuntimeError
(
"Cannot find sufficient samples, consider increasing dataset size."
)
else
:
raise
RuntimeError
(
"Cannot find valid samples, check `data/README.md` for the data format."
)
return
dataset
def
get_dataset
(
template
:
"Template"
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
]
=
None
,
)
->
"DatasetModule"
:
r
"""Get the train dataset and optionally gets the evaluation dataset."""
# Load tokenized dataset if path exists
if
data_args
.
tokenized_path
is
not
None
:
if
has_tokenized_data
(
data_args
.
tokenized_path
):
logger
.
warning_rank0
(
"Loading dataset from disk will ignore other data arguments."
)
tokenized_data
=
load_from_disk
(
data_args
.
tokenized_path
)
dataset_module
=
get_dataset_module
(
tokenized_data
)
if
data_args
.
streaming
:
dataset_module
[
"train_dataset"
]
=
dataset_module
[
"train_dataset"
].
to_iterable_dataset
()
logger
.
info_rank0
(
f
"Loaded tokenized dataset from
{
data_args
.
tokenized_path
}
."
)
return
dataset_module
if
data_args
.
streaming
:
raise
ValueError
(
"Turn off `streaming` when saving dataset to disk."
)
# Load and preprocess dataset
with
training_args
.
main_process_first
(
desc
=
"load dataset"
,
local
=
(
not
data_args
.
data_shared_file_system
)):
dataset
=
_get_merged_dataset
(
data_args
.
dataset
,
model_args
,
data_args
,
training_args
,
stage
)
eval_dataset
=
_get_merged_dataset
(
data_args
.
eval_dataset
,
model_args
,
data_args
,
training_args
,
stage
,
return_dict
=
data_args
.
eval_on_each_dataset
,
)
with
training_args
.
main_process_first
(
desc
=
"pre-process dataset"
,
local
=
(
not
data_args
.
data_shared_file_system
)):
dataset
=
_get_preprocessed_dataset
(
dataset
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
False
)
if
isinstance
(
eval_dataset
,
dict
):
for
eval_name
,
eval_data
in
eval_dataset
.
items
():
eval_dataset
[
eval_name
]
=
_get_preprocessed_dataset
(
eval_data
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
True
)
else
:
eval_dataset
=
_get_preprocessed_dataset
(
eval_dataset
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
True
)
dataset_dict
=
split_dataset
(
dataset
,
eval_dataset
,
data_args
,
seed
=
training_args
.
seed
)
if
data_args
.
tokenized_path
is
not
None
:
# save tokenized dataset to disk
if
training_args
.
should_save
:
dataset_dict
.
save_to_disk
(
data_args
.
tokenized_path
)
logger
.
info_rank0
(
f
"Tokenized dataset is saved at
{
data_args
.
tokenized_path
}
."
)
logger
.
info_rank0
(
f
"Please launch the training with `tokenized_path:
{
data_args
.
tokenized_path
}
`."
)
return
get_dataset_module
(
dataset_dict
)
src/llamafactory/data/mm_plugin.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/processing_llava.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
inspect
import
math
import
os
import
re
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
io
import
BytesIO
from
typing
import
TYPE_CHECKING
,
BinaryIO
,
Literal
,
Optional
,
TypedDict
,
Union
import
numpy
as
np
import
torch
from
transformers.image_utils
import
get_image_size
,
is_valid_image
,
to_numpy_array
from
transformers.models.mllama.processing_mllama
import
(
convert_sparse_cross_attention_mask_to_dense
,
get_cross_attention_token_mask
,
)
from
typing_extensions
import
override
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IGNORE_INDEX
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.packages
import
(
is_librosa_available
,
is_pillow_available
,
is_pyav_available
,
is_transformers_version_greater_than
,
)
if
is_librosa_available
():
import
librosa
if
is_pillow_available
():
from
PIL
import
Image
from
PIL.Image
import
Image
as
ImageObject
if
is_pyav_available
():
import
av
if
is_transformers_version_greater_than
(
"4.52.0"
):
from
transformers.image_utils
import
make_flat_list_of_images
from
transformers.video_utils
import
make_batched_videos
else
:
from
transformers.image_utils
import
make_batched_videos
,
make_flat_list_of_images
if
TYPE_CHECKING
:
from
av.stream
import
Stream
from
numpy.typing
import
NDArray
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
transformers.feature_extraction_sequence_utils
import
SequenceFeatureExtractor
from
transformers.image_processing_utils
import
BaseImageProcessor
class
EncodedImage
(
TypedDict
):
path
:
Optional
[
str
]
bytes
:
Optional
[
bytes
]
ImageInput
=
Union
[
str
,
bytes
,
EncodedImage
,
BinaryIO
,
ImageObject
]
VideoInput
=
Union
[
str
,
BinaryIO
,
list
[
list
[
ImageInput
]]]
AudioInput
=
Union
[
str
,
BinaryIO
,
NDArray
]
class
MMProcessor
(
ProcessorMixin
):
patch_size
:
int
image_seq_length
:
int
num_additional_image_tokens
:
int
vision_feature_select_strategy
:
Literal
[
"default"
,
"full"
]
def
_get_number_of_features
(
self
,
orig_height
:
int
,
orig_width
:
int
,
height
:
int
,
width
:
int
)
->
int
:
pass
def
_get_paligemma_token_type_ids
(
imglens
:
list
[
int
],
seqlens
:
list
[
int
],
processor
:
"MMProcessor"
)
->
list
[
list
[
int
]]:
r
"""Get paligemma token type ids for computing loss.
It is slightly different with the original token type ids where the prompt part is 0.
Returns:
batch_token_type_ids: shape (batch_size, seq_length)
"""
batch_token_type_ids
=
[]
for
imglen
,
seqlen
in
zip
(
imglens
,
seqlens
):
image_seqlen
=
imglen
*
processor
.
image_seq_length
batch_token_type_ids
.
append
([
0
]
*
image_seqlen
+
[
1
]
*
(
seqlen
-
image_seqlen
))
return
batch_token_type_ids
def
_get_gemma3_token_type_ids
(
batch_ids
:
list
[
list
[
int
]],
processor
:
"MMProcessor"
):
r
"""Get gemma3 token type ids for computing loss.
Returns:
batch_token_type_ids: shape (batch_size, seq_length)
"""
image_token_id
:
int
=
getattr
(
processor
,
"image_token_id"
)
batch_token_type_ids
=
[]
for
token_ids
in
batch_ids
:
token_ids
=
np
.
array
(
token_ids
)
token_type_ids
=
np
.
zeros_like
(
token_ids
)
token_type_ids
[
token_ids
==
image_token_id
]
=
1
batch_token_type_ids
.
append
(
token_type_ids
.
tolist
())
return
batch_token_type_ids
def
_make_batched_images
(
images
:
list
[
"ImageObject"
],
imglens
:
list
[
int
])
->
list
[
list
[
"ImageObject"
]]:
r
"""Make nested list of images."""
batch_images
=
[]
for
imglen
in
imglens
:
batch_images
.
append
(
images
[:
imglen
])
images
=
images
[
imglen
:]
return
batch_images
def
_check_video_is_nested_images
(
video
:
"VideoInput"
)
->
bool
:
r
"""Check if the video is nested images."""
return
isinstance
(
video
,
list
)
and
all
(
isinstance
(
frame
,
(
str
,
BinaryIO
,
dict
))
for
frame
in
video
)
@
dataclass
class
MMPluginMixin
:
image_token
:
Optional
[
str
]
video_token
:
Optional
[
str
]
audio_token
:
Optional
[
str
]
expand_mm_tokens
:
bool
=
True
def
_validate_input
(
self
,
processor
:
Optional
[
"MMProcessor"
],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
)
->
None
:
r
"""Validate if this model accepts the input modalities."""
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
video_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"video_processor"
,
getattr
(
processor
,
"image_processor"
,
None
)
)
feature_extractor
:
SequenceFeatureExtractor
=
getattr
(
processor
,
"feature_extractor"
,
None
)
if
len
(
images
)
!=
0
and
self
.
image_token
is
None
:
raise
ValueError
(
"This model does not support image input. Please check whether the correct `template` is used."
)
if
len
(
videos
)
!=
0
and
self
.
video_token
is
None
:
raise
ValueError
(
"This model does not support video input. Please check whether the correct `template` is used."
)
if
len
(
audios
)
!=
0
and
self
.
audio_token
is
None
:
raise
ValueError
(
"This model does not support audio input. Please check whether the correct `template` is used."
)
if
self
.
image_token
is
not
None
and
processor
is
None
:
raise
ValueError
(
"Processor was not found, please check and update your model file."
)
if
self
.
image_token
is
not
None
and
image_processor
is
None
:
raise
ValueError
(
"Image processor was not found, please check and update your model file."
)
if
self
.
video_token
is
not
None
and
video_processor
is
None
:
raise
ValueError
(
"Video processor was not found, please check and update your model file."
)
if
self
.
audio_token
is
not
None
and
feature_extractor
is
None
:
raise
ValueError
(
"Audio feature extractor was not found, please check and update your model file."
)
def
_validate_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
):
r
"""Validate if the number of images, videos and audios match the number of placeholders in messages."""
num_image_tokens
,
num_video_tokens
,
num_audio_tokens
=
0
,
0
,
0
for
message
in
messages
:
num_image_tokens
+=
message
[
"content"
].
count
(
IMAGE_PLACEHOLDER
)
num_video_tokens
+=
message
[
"content"
].
count
(
VIDEO_PLACEHOLDER
)
num_audio_tokens
+=
message
[
"content"
].
count
(
AUDIO_PLACEHOLDER
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens in
{
messages
}
."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens in
{
messages
}
."
)
if
len
(
audios
)
!=
num_audio_tokens
:
raise
ValueError
(
f
"The number of audios does not match the number of
{
AUDIO_PLACEHOLDER
}
tokens in
{
messages
}
."
)
def
_preprocess_image
(
self
,
image
:
"ImageObject"
,
image_max_pixels
:
int
,
image_min_pixels
:
int
,
**
kwargs
)
->
"ImageObject"
:
r
"""Pre-process a single image."""
if
(
image
.
width
*
image
.
height
)
>
image_max_pixels
:
resize_factor
=
math
.
sqrt
(
image_max_pixels
/
(
image
.
width
*
image
.
height
))
width
,
height
=
int
(
image
.
width
*
resize_factor
),
int
(
image
.
height
*
resize_factor
)
image
=
image
.
resize
((
width
,
height
))
if
(
image
.
width
*
image
.
height
)
<
image_min_pixels
:
resize_factor
=
math
.
sqrt
(
image_min_pixels
/
(
image
.
width
*
image
.
height
))
width
,
height
=
int
(
image
.
width
*
resize_factor
),
int
(
image
.
height
*
resize_factor
)
image
=
image
.
resize
((
width
,
height
))
if
image
.
mode
!=
"RGB"
:
image
=
image
.
convert
(
"RGB"
)
return
image
def
_get_video_sample_indices
(
self
,
video_stream
:
"Stream"
,
video_fps
:
float
,
video_maxlen
:
int
,
**
kwargs
)
->
list
[
int
]:
r
"""Compute video sample indices according to fps."""
total_frames
=
video_stream
.
frames
if
total_frames
==
0
:
# infinite video
return
np
.
linspace
(
0
,
video_maxlen
-
1
,
video_maxlen
).
astype
(
np
.
int32
)
sample_frames
=
max
(
1
,
math
.
floor
(
float
(
video_stream
.
duration
*
video_stream
.
time_base
)
*
video_fps
))
sample_frames
=
min
(
total_frames
,
video_maxlen
,
sample_frames
)
return
np
.
linspace
(
0
,
total_frames
-
1
,
sample_frames
).
astype
(
np
.
int32
)
def
_regularize_images
(
self
,
images
:
list
[
"ImageInput"
],
**
kwargs
)
->
dict
[
str
,
list
[
"ImageObject"
]]:
r
"""Regularize images to avoid error. Including reading and pre-processing."""
results
=
[]
for
image
in
images
:
if
isinstance
(
image
,
(
str
,
BinaryIO
)):
image
=
Image
.
open
(
image
)
elif
isinstance
(
image
,
bytes
):
image
=
Image
.
open
(
BytesIO
(
image
))
elif
isinstance
(
image
,
dict
):
if
image
[
"bytes"
]
is
not
None
:
image
=
Image
.
open
(
BytesIO
(
image
[
"bytes"
]))
else
:
image
=
Image
.
open
(
image
[
"path"
])
if
not
isinstance
(
image
,
ImageObject
):
raise
ValueError
(
f
"Expect input is a list of images, but got
{
type
(
image
)
}
."
)
results
.
append
(
self
.
_preprocess_image
(
image
,
**
kwargs
))
return
{
"images"
:
results
}
def
_regularize_videos
(
self
,
videos
:
list
[
"VideoInput"
],
**
kwargs
)
->
dict
[
str
,
list
[
list
[
"ImageObject"
]]]:
r
"""Regularizes videos to avoid error. Including reading, resizing and converting."""
results
=
[]
for
video
in
videos
:
frames
:
list
[
ImageObject
]
=
[]
if
_check_video_is_nested_images
(
video
):
for
frame
in
video
:
if
not
is_valid_image
(
frame
)
and
not
isinstance
(
frame
,
dict
)
and
not
os
.
path
.
exists
(
frame
):
raise
ValueError
(
"Invalid image found in video frames."
)
frames
=
video
else
:
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
sample_indices
=
self
.
_get_video_sample_indices
(
video_stream
,
**
kwargs
)
container
.
seek
(
0
)
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
if
frame_idx
in
sample_indices
:
frames
.
append
(
frame
.
to_image
())
frames
=
self
.
_regularize_images
(
frames
,
**
kwargs
)[
"images"
]
results
.
append
(
frames
)
return
{
"videos"
:
results
}
def
_regularize_audios
(
self
,
audios
:
list
[
"AudioInput"
],
sampling_rate
:
float
,
**
kwargs
)
->
dict
[
str
,
Union
[
list
[
"NDArray"
],
list
[
float
]]]:
r
"""Regularizes audios to avoid error. Including reading and resampling."""
results
,
sampling_rates
=
[],
[]
for
audio
in
audios
:
if
not
isinstance
(
audio
,
np
.
ndarray
):
audio
,
sampling_rate
=
librosa
.
load
(
audio
,
sr
=
sampling_rate
)
results
.
append
(
audio
)
sampling_rates
.
append
(
sampling_rate
)
return
{
"audios"
:
results
,
"sampling_rates"
:
sampling_rates
}
def
_get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
"MMProcessor"
,
imglens
:
Optional
[
list
[
int
]]
=
None
,
)
->
dict
[
str
,
"torch.Tensor"
]:
r
"""Process visual inputs.
Returns: (llava and paligemma)
pixel_values: tensor with shape (B, C, H, W)
Returns: (qwen2-vl)
pixel_values: tensor with shape (num_patches, patch_dim)
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
where num_patches == torch.prod(image_grid_thw)
Returns: (mllama)
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
images
=
self
.
_regularize_images
(
images
,
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)[
"images"
]
if
imglens
is
not
None
:
# if imglens are provided, make batched images
images
=
_make_batched_images
(
images
,
imglens
)
image_processor_kwargs
=
{}
if
getattr
(
processor
,
"image_do_pan_and_scan"
,
False
):
# gemma3 image processor
image_processor_kwargs
.
update
(
{
"do_pan_and_scan"
:
True
,
"pan_and_scan_min_crop_size"
:
256
,
"pan_and_scan_max_num_crops"
:
4
,
"pan_and_scan_min_ratio_to_activate"
:
1.2
,
}
)
mm_inputs
.
update
(
image_processor
(
images
,
return_tensors
=
"pt"
,
**
image_processor_kwargs
))
if
len
(
videos
)
!=
0
:
video_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"video_processor"
,
getattr
(
processor
,
"image_processor"
,
None
)
)
videos
=
self
.
_regularize_videos
(
videos
,
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
image_min_pixels
=
getattr
(
processor
,
"video_min_pixels"
,
16
*
16
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)[
"videos"
]
if
"videos"
in
inspect
.
signature
(
video_processor
.
preprocess
).
parameters
:
# for qwen2_vl and video_llava
mm_inputs
.
update
(
video_processor
(
images
=
None
,
videos
=
videos
,
return_tensors
=
"pt"
))
else
:
# for llava_next_video
mm_inputs
.
update
(
video_processor
(
videos
,
return_tensors
=
"pt"
))
if
len
(
audios
)
!=
0
:
feature_extractor
:
SequenceFeatureExtractor
=
getattr
(
processor
,
"feature_extractor"
,
None
)
audios
=
self
.
_regularize_audios
(
audios
,
sampling_rate
=
getattr
(
processor
,
"audio_sampling_rate"
,
16000
),
)[
"audios"
]
mm_inputs
.
update
(
feature_extractor
(
audios
,
sampling_rate
=
getattr
(
processor
,
"audio_sampling_rate"
,
16000
),
return_attention_mask
=
True
,
padding
=
"max_length"
,
return_tensors
=
"pt"
,
)
)
mm_inputs
[
"feature_attention_mask"
]
=
mm_inputs
.
pop
(
"attention_mask"
,
None
)
# prevent conflicts
return
mm_inputs
@
dataclass
class
BasePlugin
(
MMPluginMixin
):
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
r
"""Pre-process input messages before tokenization for VLMs."""
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
messages
def
process_token_ids
(
self
,
input_ids
:
list
[
int
],
labels
:
Optional
[
list
[
int
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"MMProcessor"
],
)
->
tuple
[
list
[
int
],
Optional
[
list
[
int
]]]:
r
"""Pre-process token ids after tokenization for VLMs."""
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
input_ids
,
labels
def
get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
imglens
:
list
[
int
],
vidlens
:
list
[
int
],
audlens
:
list
[
int
],
batch_ids
:
list
[
list
[
int
]],
processor
:
Optional
[
"MMProcessor"
],
)
->
dict
[
str
,
Union
[
list
[
int
],
"torch.Tensor"
]]:
r
"""Build batched multimodal inputs for VLMs.
Arguments:
images: a list of image inputs, shape (num_images,)
videos: a list of video inputs, shape (num_videos,)
audios: a list of audio inputs, shape (num_audios,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
audlens: number of audios in each sample, shape (batch_size,)
batch_ids: token ids of input samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos
"""
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
@
dataclass
class
Gemma3Plugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
boi_token
:
str
=
getattr
(
processor
,
"boi_token"
)
full_image_sequence
:
str
=
getattr
(
processor
,
"full_image_sequence"
)
image_str
=
full_image_sequence
if
self
.
expand_mm_tokens
else
boi_token
do_pan_and_scan
:
bool
=
getattr
(
processor
,
"image_do_pan_and_scan"
,
False
)
if
do_pan_and_scan
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
do_pan_and_scan
:
image_placeholder_str
=
(
"Here is the original image {{image}} and here are some crops to help you see better "
+
" "
.
join
([
"{{image}}"
]
*
mm_inputs
[
"num_crops"
][
0
][
num_image_tokens
])
)
else
:
image_placeholder_str
=
"{{image}}"
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
image_placeholder_str
,
1
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
image_str
)
return
messages
@
override
def
get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
imglens
:
list
[
int
],
vidlens
:
list
[
int
],
audlens
:
list
[
int
],
batch_ids
:
list
[
list
[
int
]],
processor
:
Optional
[
"MMProcessor"
],
)
->
dict
[
str
,
Union
[
list
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
mm_inputs
.
pop
(
"num_crops"
,
None
)
mm_inputs
[
"token_type_ids"
]
=
_get_gemma3_token_type_ids
(
batch_ids
,
processor
)
return
mm_inputs
class
Gemma3nPlugin
(
Gemma3Plugin
):
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
messages
=
deepcopy
(
messages
)
boi_token
:
str
=
getattr
(
processor
,
"boi_token"
)
boa_token
:
str
=
getattr
(
processor
,
"boa_token"
)
full_image_sequence
:
str
=
getattr
(
processor
,
"full_image_sequence"
)
full_audio_sequence
:
str
=
getattr
(
processor
,
"full_audio_sequence"
)
image_str
=
full_image_sequence
if
self
.
expand_mm_tokens
else
boi_token
audio_str
=
full_audio_sequence
if
self
.
expand_mm_tokens
else
boa_token
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
image_str
,
1
)
while
AUDIO_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
audio_str
,
1
)
message
[
"content"
]
=
content
return
messages
@
dataclass
class
InternVLPlugin
(
BasePlugin
):
@
override
def
_get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
"ProcessorMixin"
,
**
kwargs
,
)
->
dict
[
str
,
"torch.Tensor"
]:
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
image_processor_kwargs
=
{}
if
getattr
(
processor
,
"crop_to_patches"
,
False
):
image_processor_kwargs
.
update
(
{
"crop_to_patches"
:
True
,
"max_patches"
:
12
,
"min_patches"
:
1
,
}
)
mm_inputs
=
{}
image_video_patches
=
[]
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
,
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
1024
*
1024
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)[
"images"
]
if
len
(
videos
)
!=
0
:
videos
=
self
.
_regularize_videos
(
videos
,
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
image_min_pixels
=
getattr
(
processor
,
"video_min_pixels"
,
16
*
16
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)[
"videos"
]
if
len
(
images
)
!=
0
:
images
=
make_flat_list_of_images
(
images
)
image_inputs
=
image_processor
(
images
=
images
,
return_tensors
=
"pt"
,
**
image_processor_kwargs
)
image_num_patches
=
image_inputs
.
pop
(
"num_patches"
)
image_pixel_values
=
image_inputs
.
pop
(
"pixel_values"
)
image_num_patches_indices
=
np
.
cumsum
(
image_num_patches
)
if
len
(
videos
)
!=
0
:
videos
=
make_batched_videos
(
videos
)
num_frames_per_video
=
[
len
(
video
)
for
video
in
videos
]
patch_indices
=
np
.
cumsum
(
num_frames_per_video
)
image_processor_kwargs
[
"crop_to_patches"
]
=
False
video_inputs
=
image_processor
(
images
=
videos
,
return_tensors
=
"pt"
,
**
image_processor_kwargs
)
video_num_patches
=
video_inputs
.
pop
(
"num_patches"
)
video_pixel_values
=
video_inputs
.
pop
(
"pixel_values"
)
video_num_patches_indices
=
np
.
cumsum
(
video_num_patches
)
# NOT SUPPORT IMAGE VIDEO INTERLEAVED
if
len
(
images
)
!=
0
and
image_pixel_values
is
not
None
:
for
i
in
range
(
len
(
images
)):
start_index
=
image_num_patches_indices
[
i
-
1
]
if
i
>
0
else
0
end_index
=
image_num_patches_indices
[
i
]
image_video_patches
.
append
(
image_pixel_values
[
start_index
:
end_index
])
if
len
(
videos
)
!=
0
and
video_pixel_values
is
not
None
:
patch_indices_with_prefix
=
[
0
]
+
list
(
patch_indices
)
for
i
in
range
(
len
(
videos
)):
current_patch_index
=
patch_indices_with_prefix
[
i
]
end_patch_index
=
patch_indices_with_prefix
[
i
+
1
]
start_index
=
video_num_patches_indices
[
current_patch_index
-
1
]
if
i
>
0
else
0
end_index
=
video_num_patches_indices
[
end_patch_index
-
1
]
image_video_patches
.
append
(
video_pixel_values
[
start_index
:
end_index
])
if
len
(
images
)
!=
0
or
len
(
videos
)
!=
0
:
mm_inputs
[
"pixel_values"
]
=
torch
.
cat
(
image_video_patches
,
dim
=
0
)
if
len
(
images
)
!=
0
:
mm_inputs
.
update
({
"image_num_patches"
:
image_num_patches
})
if
len
(
videos
)
!=
0
:
mm_inputs
.
update
({
"video_patch_indices"
:
patch_indices
})
mm_inputs
.
update
({
"video_num_patches"
:
video_num_patches
})
return
mm_inputs
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
image_seqlen
=
getattr
(
processor
,
"image_seq_length"
)
if
self
.
expand_mm_tokens
else
1
messages
=
deepcopy
(
messages
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
image_pixel_patch_list
=
mm_inputs
.
get
(
"image_num_patches"
)
# pathes of images
video_num_patches
=
mm_inputs
.
get
(
"video_num_patches"
)
# all patches for frames of videos
video_patch_indices
=
mm_inputs
.
get
(
"video_patch_indices"
)
# num frames of per video
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
f
"<img>
{
'<IMG_CONTEXT>'
*
image_seqlen
*
image_pixel_patch_list
[
num_image_tokens
]
}
</img>"
,
1
,
)
num_image_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
current_patch_index
=
video_patch_indices
[
num_video_tokens
-
1
]
if
num_video_tokens
>
0
else
0
end_patch_index
=
video_patch_indices
[
num_video_tokens
]
num_patches
=
list
(
video_num_patches
[
current_patch_index
:
end_patch_index
])
video_replaced_prompt
=
"
\n
"
.
join
(
f
"Frame
{
i
+
1
}
: <img>
{
'<IMG_CONTEXT>'
*
image_seqlen
*
num_patches
[
i
]
}
</img>"
for
i
in
range
(
len
(
num_patches
))
)
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
video_replaced_prompt
,
1
)
num_video_tokens
+=
1
message
[
"content"
]
=
content
return
messages
@
override
def
get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
imglens
:
list
[
int
],
vidlens
:
list
[
int
],
audlens
:
list
[
int
],
batch_ids
:
list
[
list
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
dict
[
str
,
Union
[
list
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
mm_inputs
.
pop
(
"image_num_patches"
,
None
)
mm_inputs
.
pop
(
"video_patch_indices"
,
None
)
mm_inputs
.
pop
(
"video_num_patches"
,
None
)
return
mm_inputs
class
KimiVLPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
,
images
,
videos
,
audios
,
processor
):
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
image_grid_hws
=
mm_inputs
.
get
(
"image_grid_hws"
,
[])
else
:
image_grid_hws
=
[
None
]
*
len
(
images
)
num_image_tokens
=
0
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
merge_length
=
math
.
prod
(
image_processor
.
merge_kernel_size
)
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
image_seqlen
=
image_grid_hws
[
num_image_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
f
"<|media_start|>image<|media_content|>
{
self
.
image_token
*
image_seqlen
}
<|media_end|>"
,
1
,
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
return
messages
@
dataclass
class
Llama4Plugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"pixel_values"
in
mm_inputs
:
image_height
,
image_width
=
mm_inputs
[
"pixel_values"
][
0
].
shape
[
-
2
:]
num_patches_per_chunk
=
int
(
(
image_height
//
processor
.
patch_size
)
*
(
image_width
//
processor
.
patch_size
)
//
processor
.
downsample_ratio
)
aspect_ratios
=
mm_inputs
.
pop
(
"aspect_ratios"
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
content
=
message
[
"content"
]
if
self
.
expand_mm_tokens
:
placeholder_count
=
content
.
count
(
IMAGE_PLACEHOLDER
)
prompt_splits
=
content
.
split
(
IMAGE_PLACEHOLDER
)
new_content
=
[]
for
local_image_index
,
split_part
in
enumerate
(
prompt_splits
):
new_content
.
append
(
split_part
)
if
local_image_index
<
placeholder_count
:
tokens_for_this_image
=
processor
.
_prompt_split_image
(
aspect_ratios
[
num_image_tokens
],
num_patches_per_chunk
)
num_image_tokens
+=
1
new_content
.
append
(
tokens_for_this_image
)
content
=
""
.
join
(
new_content
)
else
:
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
self
.
image_token
)
message
[
"content"
]
=
content
return
messages
@
override
def
get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
imglens
:
list
[
int
],
vidlens
:
list
[
int
],
audlens
:
list
[
int
],
batch_ids
:
list
[
list
[
int
]],
processor
:
Optional
[
"MMProcessor"
],
)
->
dict
[
str
,
Union
[
list
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
mm_inputs
.
pop
(
"aspect_ratios"
,
None
)
return
mm_inputs
@
dataclass
class
LlavaPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"pixel_values"
in
mm_inputs
:
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
[
"pixel_values"
][
0
]))
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
+
processor
.
num_additional_image_tokens
if
processor
.
vision_feature_select_strategy
==
"default"
:
image_seqlen
-=
1
else
:
image_seqlen
=
1
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
return
messages
@
dataclass
class
LlavaNextPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"pixel_values"
in
mm_inputs
:
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
].
tolist
())
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
[
"pixel_values"
][
0
][
0
]))
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
self
.
expand_mm_tokens
:
orig_height
,
orig_width
=
next
(
image_sizes
)
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
if
processor
.
vision_feature_select_strategy
==
"default"
:
image_seqlen
-=
1
else
:
image_seqlen
=
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
return
messages
@
dataclass
class
LlavaNextVideoPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"pixel_values"
in
mm_inputs
:
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
].
tolist
())
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
[
"pixel_values"
][
0
][
0
]))
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
self
.
expand_mm_tokens
:
orig_height
,
orig_width
=
next
(
image_sizes
)
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
if
processor
.
vision_feature_select_strategy
==
"default"
:
image_seqlen
-=
1
else
:
image_seqlen
=
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
if
self
.
expand_mm_tokens
:
if
"pixel_values_videos"
in
mm_inputs
:
one_video
=
to_numpy_array
(
mm_inputs
.
get
(
"pixel_values_videos"
)[
0
])
height
,
width
=
get_image_size
(
one_video
[
0
])
num_frames
=
one_video
.
shape
[
0
]
# frame dim is always after batch dim
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
video_seqlen
=
image_seqlen
//
4
*
num_frames
# divide by 4 needed for avg pooling layer
else
:
video_seqlen
=
1
for
message
in
messages
:
content
=
message
[
"content"
]
while
VIDEO_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{video}}"
*
video_seqlen
,
1
)
message
[
"content"
]
=
content
.
replace
(
"{{video}}"
,
self
.
video_token
)
return
messages
@
dataclass
class
MiniCPMVPlugin
(
BasePlugin
):
@
override
def
_get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
"MMProcessor"
,
**
kwargs
,
)
->
dict
[
str
,
"torch.Tensor"
]:
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
,
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)[
"images"
]
if
"valid_image_nums_ls"
in
kwargs
:
valid_image_nums_ls
=
kwargs
[
"valid_image_nums_ls"
]
new_images
=
[]
idx
=
0
for
valid_image_nums
in
valid_image_nums_ls
:
new_images
.
append
(
images
[
idx
:
idx
+
valid_image_nums
])
idx
+=
valid_image_nums
images
=
new_images
image_inputs
=
image_processor
(
images
,
do_pad
=
True
,
max_slice_nums
=
image_processor
.
max_slice_nums
,
return_tensors
=
"pt"
)
mm_inputs
.
update
(
image_inputs
)
if
len
(
videos
)
!=
0
:
videos
=
self
.
_regularize_videos
(
videos
,
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
image_min_pixels
=
getattr
(
processor
,
"video_min_pixels"
,
16
*
16
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)[
"videos"
]
video_inputs
=
image_processor
(
videos
,
do_pad
=
True
,
max_slice_nums
=
2
,
return_tensors
=
"pt"
)
mm_inputs
.
update
(
video_inputs
)
if
len
(
audios
)
!=
0
:
audios
=
self
.
_regularize_audios
(
audios
,
sampling_rate
=
getattr
(
processor
,
"audio_sampling_rate"
,
16000
),
)[
"audios"
]
if
"valid_audio_nums_ls"
in
kwargs
:
valid_audio_nums_ls
=
kwargs
[
"valid_audio_nums_ls"
]
audios_ls
=
[]
idx
=
0
for
valid_audio_nums
in
valid_audio_nums_ls
:
audios_ls
.
append
(
audios
[
idx
:
idx
+
valid_audio_nums
])
idx
+=
valid_audio_nums
else
:
audios_ls
=
[
audios
]
audio_features
,
audio_feature_lens
,
audio_phs
=
processor
.
audio_feature_extract
(
audios_ls
,
chunk_input
=
True
,
sampling_rate
=
getattr
(
processor
,
"audio_sampling_rate"
,
16000
),
)
audio_feature_lens
=
[
torch
.
tensor
(
audio_feature_len
)
for
audio_feature_len
in
audio_feature_lens
]
mm_inputs
.
update
({
"audio_features"
:
audio_features
,
"audio_feature_lens"
:
audio_feature_lens
})
if
kwargs
.
get
(
"ret_phs"
,
False
):
mm_inputs
.
update
({
"audio_phs"
:
audio_phs
})
return
mm_inputs
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
,
num_audio_tokens
=
0
,
0
,
0
messages
=
deepcopy
(
messages
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
mm_inputs
,
audio_inputs
=
{},
{}
if
len
(
images
)
!=
0
and
len
(
videos
)
!=
0
:
raise
ValueError
(
"MiniCPM-V model does not support input images and videos at the same time."
)
if
len
(
videos
)
!=
0
:
max_slice_nums
=
2
use_image_id
=
False
mm_inputs
=
self
.
_get_mm_inputs
([],
videos
,
[],
processor
)
else
:
max_slice_nums
=
image_processor
.
max_slice_nums
use_image_id
=
image_processor
.
use_image_id
for
i
,
message
in
enumerate
(
messages
):
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
,
1
)
num_image_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
video_seqlen
=
len
(
mm_inputs
[
"pixel_values"
][
num_video_tokens
])
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{image}}"
*
video_seqlen
,
1
)
num_video_tokens
+=
1
while
AUDIO_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
"{{audio}}"
,
1
)
num_audio_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
"(<image>./</image>)"
).
replace
(
"{{audio}}"
,
"(<audio>./</audio>)"
)
if
len
(
images
):
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
[],
[],
processor
)
if
len
(
audios
):
audio_inputs
=
self
.
_get_mm_inputs
([],
[],
audios
,
processor
,
ret_phs
=
True
)
if
self
.
expand_mm_tokens
and
mm_inputs
:
pattern
=
"(<image>./</image>)"
image_sizes
=
mm_inputs
[
"image_sizes"
]
idx
=
0
for
index
,
message
in
enumerate
(
messages
):
text
=
message
[
"content"
]
image_tags
=
re
.
findall
(
pattern
,
text
)
text_chunks
=
text
.
split
(
pattern
)
final_text
=
""
for
i
in
range
(
len
(
image_tags
)):
final_text
=
(
final_text
+
text_chunks
[
i
]
+
image_processor
.
get_slice_image_placeholder
(
image_sizes
[
0
][
idx
],
idx
,
max_slice_nums
,
use_image_id
)
)
idx
+=
1
final_text
+=
text_chunks
[
-
1
]
messages
[
index
][
"content"
]
=
final_text
if
self
.
expand_mm_tokens
and
audio_inputs
:
pattern
=
"(<audio>./</audio>)"
idx
=
0
for
index
,
message
in
enumerate
(
messages
):
text
=
message
[
"content"
]
audio_tags
=
re
.
findall
(
pattern
,
text
)
text_chunks
=
text
.
split
(
pattern
)
final_text
=
""
for
i
in
range
(
len
(
audio_tags
)):
audio_placeholder
=
audio_inputs
[
"audio_phs"
][
0
][
idx
]
final_text
=
final_text
+
text_chunks
[
i
]
+
audio_placeholder
idx
+=
1
final_text
+=
text_chunks
[
-
1
]
messages
[
index
][
"content"
]
=
final_text
return
messages
@
override
def
get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
imglens
:
list
[
int
],
vidlens
:
list
[
int
],
audlens
:
list
[
int
],
batch_ids
:
list
[
list
[
int
]],
processor
:
Optional
[
"MMProcessor"
],
)
->
dict
[
str
,
Union
[
list
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
# image bound
image_bounds_list
=
[]
valid_image_nums_ls
=
[]
for
i
,
input_ids
in
enumerate
(
batch_ids
):
input_ids_
=
torch
.
tensor
(
input_ids
)
start_cond
=
(
input_ids_
==
processor
.
tokenizer
.
im_start_id
)
|
(
input_ids_
==
processor
.
tokenizer
.
slice_start_id
)
end_cond
=
(
input_ids_
==
processor
.
tokenizer
.
im_end_id
)
|
(
input_ids_
==
processor
.
tokenizer
.
slice_end_id
)
image_start_tokens
=
torch
.
where
(
start_cond
)[
0
]
image_start_tokens
+=
1
image_end_tokens
=
torch
.
where
(
end_cond
)[
0
]
valid_image_nums_ls
.
append
(
imglens
[
i
])
image_bounds
=
torch
.
hstack
(
[
image_start_tokens
.
unsqueeze
(
-
1
),
image_end_tokens
.
unsqueeze
(
-
1
),
]
)
image_bounds_list
.
append
(
image_bounds
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
[],
processor
,
valid_image_nums_ls
=
valid_image_nums_ls
)
if
"tgt_sizes"
not
in
mm_inputs
:
dummy_data
=
[
torch
.
empty
(
0
)
for
_
in
range
(
len
(
batch_ids
))]
mm_inputs
.
update
({
"tgt_sizes"
:
dummy_data
,
"pixel_values"
:
dummy_data
,
"image_sizes"
:
dummy_data
})
mm_inputs
.
update
({
"image_bound"
:
image_bounds_list
})
if
len
(
audios
)
>
0
:
# audio bound
audio_bounds_ls
=
[]
spk_bounds_ls
=
[]
valid_audio_nums_ls
=
[]
for
input_ids
,
audiolen
in
zip
(
batch_ids
,
audlens
):
input_ids_
=
torch
.
tensor
(
input_ids
)
audio_start_idx
=
torch
.
where
(
input_ids_
==
processor
.
tokenizer
.
audio_start_id
)[
0
]
audio_end_idx
=
torch
.
where
(
input_ids_
==
processor
.
tokenizer
.
audio_end_id
)[
0
]
assert
len
(
audio_start_idx
)
==
len
(
audio_end_idx
)
audio_bounds
=
torch
.
hstack
([(
audio_start_idx
+
1
).
unsqueeze
(
-
1
),
audio_end_idx
.
unsqueeze
(
-
1
)])
audio_bounds_ls
.
append
(
audio_bounds
)
valid_audio_nums_ls
.
append
(
audiolen
)
spk_start_idx
=
torch
.
where
(
input_ids_
==
processor
.
tokenizer
.
spk_start_id
)[
0
]
spk_end_idx
=
torch
.
where
(
input_ids_
==
processor
.
tokenizer
.
spk_end_id
)[
0
]
assert
len
(
spk_start_idx
)
==
len
(
spk_end_idx
)
spk_bounds
=
torch
.
hstack
([(
spk_start_idx
+
1
).
unsqueeze
(
-
1
),
spk_end_idx
.
unsqueeze
(
-
1
)])
spk_bounds_ls
.
append
(
spk_bounds
)
audio_inputs
=
self
.
_get_mm_inputs
([],
[],
audios
,
processor
,
valid_audio_nums_ls
=
valid_audio_nums_ls
)
mm_inputs
.
update
(
audio_inputs
)
mm_inputs
.
update
({
"audio_bounds"
:
audio_bounds_ls
,
"spk_bounds"
:
spk_bounds_ls
})
return
mm_inputs
@
dataclass
class
MllamaPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
content
=
message
[
"content"
]
num_image_tokens
+=
content
.
count
(
IMAGE_PLACEHOLDER
)
message
[
"content"
]
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
self
.
image_token
)
return
messages
@
override
def
get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
imglens
:
list
[
int
],
vidlens
:
list
[
int
],
audlens
:
list
[
int
],
batch_ids
:
list
[
list
[
int
]],
processor
:
Optional
[
"MMProcessor"
],
)
->
dict
[
str
,
Union
[
list
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
,
imglens
)
if
mm_inputs
:
num_tiles
=
mm_inputs
.
pop
(
"num_tiles"
)
image_token_id
:
int
=
getattr
(
processor
,
"image_token_id"
)
max_image_tiles
:
int
=
getattr
(
processor
.
image_processor
,
"max_image_tiles"
)
cross_attention_token_mask
=
[
get_cross_attention_token_mask
(
input_ids
,
image_token_id
)
for
input_ids
in
batch_ids
]
mm_inputs
[
"cross_attention_mask"
]
=
torch
.
from_numpy
(
convert_sparse_cross_attention_mask_to_dense
(
cross_attention_token_mask
,
num_tiles
=
num_tiles
,
max_num_tiles
=
max_image_tiles
,
length
=
max
(
len
(
input_ids
)
for
input_ids
in
batch_ids
),
)
)
# shape: (batch_size, length, max_num_images, max_num_tiles)
return
mm_inputs
@
dataclass
class
PaliGemmaPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
""
,
1
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
return
messages
@
override
def
process_token_ids
(
self
,
input_ids
:
list
[
int
],
labels
:
Optional
[
list
[
int
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"MMProcessor"
],
)
->
tuple
[
list
[
int
],
Optional
[
list
[
int
]]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_images
=
len
(
images
)
image_seqlen
=
processor
.
image_seq_length
if
self
.
expand_mm_tokens
else
0
# skip mm token
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
self
.
image_token
)
input_ids
=
[
image_token_id
]
*
num_images
*
image_seqlen
+
input_ids
if
labels
is
not
None
:
labels
=
[
IGNORE_INDEX
]
*
num_images
*
image_seqlen
+
labels
return
input_ids
,
labels
@
override
def
get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
imglens
:
list
[
int
],
vidlens
:
list
[
int
],
audlens
:
list
[
int
],
batch_ids
:
list
[
list
[
int
]],
processor
:
Optional
[
"MMProcessor"
],
)
->
dict
[
str
,
Union
[
list
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
seqlens
=
[
len
(
input_ids
)
for
input_ids
in
batch_ids
]
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
mm_inputs
[
"token_type_ids"
]
=
_get_paligemma_token_type_ids
(
imglens
,
seqlens
,
processor
)
return
mm_inputs
@
dataclass
class
PixtralPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"pixel_values"
in
mm_inputs
:
# BC for transformers < 4.49.0
if
isinstance
(
mm_inputs
[
"image_sizes"
],
list
):
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
][
0
])
else
:
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
].
tolist
())
image_break_token
:
str
=
getattr
(
processor
,
"image_break_token"
)
image_end_token
:
str
=
getattr
(
processor
,
"image_end_token"
)
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
self
.
expand_mm_tokens
:
patch_size
=
processor
.
patch_size
*
getattr
(
processor
,
"spatial_merge_size"
,
1
)
height
,
width
=
next
(
image_sizes
)
num_height_tokens
=
height
//
patch_size
num_width_tokens
=
width
//
patch_size
replace_tokens
=
[[
self
.
image_token
]
*
num_width_tokens
+
[
image_break_token
]]
*
num_height_tokens
replace_tokens
=
[
item
for
sublist
in
replace_tokens
for
item
in
sublist
]
# flatten list
replace_tokens
[
-
1
]
=
image_end_token
replace_str
=
""
.
join
(
replace_tokens
)
else
:
replace_str
=
self
.
image_token
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
replace_str
,
1
)
message
[
"content"
]
=
content
return
messages
@
override
def
get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
imglens
:
list
[
int
],
vidlens
:
list
[
int
],
audlens
:
list
[
int
],
batch_ids
:
list
[
list
[
int
]],
processor
:
Optional
[
"MMProcessor"
],
)
->
dict
[
str
,
Union
[
list
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
# ref to this commit https://github.com/huggingface/transformers/pull/35122
# after transformers 4.49.0, the `image_sizes` is mandatory as an input parameter for Pixtral VisionEncoder forwarding.
# it can be passed into `LlavaConditionalGeneration` as a parameter.
if
not
is_transformers_version_greater_than
(
"4.49.0"
):
mm_inputs
.
pop
(
"image_sizes"
,
None
)
return
mm_inputs
@
dataclass
class
Qwen2AudioPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
bos_token
:
str
=
getattr
(
processor
,
"audio_bos_token"
)
eos_token
:
str
=
getattr
(
processor
,
"audio_eos_token"
)
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
([],
[],
audios
,
processor
)
if
"feature_attention_mask"
in
mm_inputs
:
audio_lengths
=
mm_inputs
[
"feature_attention_mask"
].
sum
(
-
1
).
tolist
()
for
message
in
messages
:
content
=
message
[
"content"
]
while
AUDIO_PLACEHOLDER
in
content
:
if
self
.
expand_mm_tokens
:
audio_length
=
audio_lengths
.
pop
(
0
)
input_length
=
(
audio_length
-
1
)
//
2
+
1
audio_seqlen
=
(
input_length
-
2
)
//
2
+
1
else
:
audio_seqlen
=
1
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
f
"
{
bos_token
}{
self
.
audio_token
*
audio_seqlen
}{
eos_token
}
"
,
1
)
message
[
"content"
]
=
content
return
messages
@
override
def
get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
imglens
:
list
[
int
],
vidlens
:
list
[
int
],
audlens
:
list
[
int
],
batch_ids
:
list
[
list
[
int
]],
processor
:
Optional
[
"MMProcessor"
],
)
->
dict
[
str
,
Union
[
list
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
@
dataclass
class
Qwen2VLPlugin
(
BasePlugin
):
@
override
def
_preprocess_image
(
self
,
image
:
"ImageObject"
,
**
kwargs
)
->
"ImageObject"
:
image
=
super
().
_preprocess_image
(
image
,
**
kwargs
)
if
min
(
image
.
width
,
image
.
height
)
<
28
:
width
,
height
=
max
(
image
.
width
,
28
),
max
(
image
.
height
,
28
)
image
=
image
.
resize
((
width
,
height
))
if
image
.
width
/
image
.
height
>
200
:
width
,
height
=
image
.
height
*
180
,
image
.
height
image
=
image
.
resize
((
width
,
height
))
if
image
.
height
/
image
.
width
>
200
:
width
,
height
=
image
.
width
,
image
.
width
*
180
image
=
image
.
resize
((
width
,
height
))
return
image
@
override
def
_regularize_videos
(
self
,
videos
:
list
[
"VideoInput"
],
**
kwargs
)
->
dict
[
str
,
Union
[
list
[
list
[
"ImageObject"
]],
list
[
float
]]]:
results
,
fps_per_video
=
[],
[]
for
video
in
videos
:
frames
:
list
[
ImageObject
]
=
[]
if
_check_video_is_nested_images
(
video
):
for
frame
in
video
:
if
not
is_valid_image
(
frame
)
and
not
isinstance
(
frame
,
dict
)
and
not
os
.
path
.
exists
(
frame
):
raise
ValueError
(
"Invalid image found in video frames."
)
frames
=
video
fps_per_video
.
append
(
kwargs
.
get
(
"video_fps"
,
2.0
))
else
:
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
sample_indices
=
self
.
_get_video_sample_indices
(
video_stream
,
**
kwargs
)
container
.
seek
(
0
)
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
if
frame_idx
in
sample_indices
:
frames
.
append
(
frame
.
to_image
())
if
video_stream
.
duration
is
None
:
fps_per_video
.
append
(
kwargs
.
get
(
"video_fps"
,
2.0
))
else
:
fps_per_video
.
append
(
len
(
sample_indices
)
/
float
(
video_stream
.
duration
*
video_stream
.
time_base
))
if
len
(
frames
)
%
2
!=
0
:
frames
.
append
(
frames
[
-
1
])
frames
=
self
.
_regularize_images
(
frames
,
**
kwargs
)[
"images"
]
results
.
append
(
frames
)
return
{
"videos"
:
results
,
"fps_per_video"
:
fps_per_video
}
@
override
def
_get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
"MMProcessor"
,
)
->
dict
[
str
,
"torch.Tensor"
]:
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
,
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)[
"images"
]
mm_inputs
.
update
(
image_processor
(
images
,
return_tensors
=
"pt"
))
if
len
(
videos
)
!=
0
:
video_data
=
self
.
_regularize_videos
(
videos
,
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
image_min_pixels
=
getattr
(
processor
,
"video_min_pixels"
,
16
*
16
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)
mm_inputs
.
update
(
image_processor
(
images
=
None
,
videos
=
video_data
[
"videos"
],
return_tensors
=
"pt"
))
temporal_patch_size
:
int
=
getattr
(
image_processor
,
"temporal_patch_size"
,
2
)
if
"second_per_grid_ts"
in
processor
.
model_input_names
:
mm_inputs
[
"second_per_grid_ts"
]
=
[
temporal_patch_size
/
fps
for
fps
in
video_data
[
"fps_per_video"
]]
return
mm_inputs
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
merge_length
:
int
=
getattr
(
image_processor
,
"merge_size"
)
**
2
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
image_grid_thw
=
mm_inputs
.
get
(
"image_grid_thw"
,
[])
video_grid_thw
=
mm_inputs
.
get
(
"video_grid_thw"
,
[])
else
:
image_grid_thw
=
[
None
]
*
len
(
images
)
video_grid_thw
=
[
None
]
*
len
(
videos
)
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
image_seqlen
=
image_grid_thw
[
num_image_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
f
"<|vision_start|>
{
self
.
image_token
*
image_seqlen
}
<|vision_end|>"
,
1
)
num_image_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
video_seqlen
=
video_grid_thw
[
num_video_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
f
"<|vision_start|>
{
self
.
video_token
*
video_seqlen
}
<|vision_end|>"
,
1
)
num_video_tokens
+=
1
message
[
"content"
]
=
content
return
messages
@
dataclass
class
GLM4VPlugin
(
Qwen2VLPlugin
):
@
override
def
_get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
"MMProcessor"
,
)
->
dict
[
str
,
"torch.Tensor"
]:
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
video_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"video_processor"
,
None
)
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
,
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)[
"images"
]
mm_inputs
.
update
(
image_processor
(
images
,
return_tensors
=
"pt"
))
if
len
(
videos
)
!=
0
:
video_data
=
self
.
_regularize_videos
(
videos
,
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
image_min_pixels
=
getattr
(
processor
,
"video_min_pixels"
,
16
*
16
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)
# prepare video metadata
video_metadata
=
[
{
"fps"
:
2
,
"duration"
:
len
(
video
),
"total_frames"
:
len
(
video
)}
for
video
in
video_data
[
"videos"
]
]
mm_inputs
.
update
(
video_processor
(
images
=
None
,
videos
=
video_data
[
"videos"
],
video_metadata
=
video_metadata
))
return
mm_inputs
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
merge_length
:
int
=
getattr
(
image_processor
,
"merge_size"
)
**
2
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
image_grid_thw
=
mm_inputs
.
get
(
"image_grid_thw"
,
[])
video_grid_thw
=
mm_inputs
.
get
(
"video_grid_thw"
,
[])
num_frames
=
video_grid_thw
[
0
][
0
]
if
len
(
video_grid_thw
)
>
0
else
0
# hard code for now
timestamps
=
mm_inputs
.
get
(
"timestamps"
,
[])
if
hasattr
(
timestamps
,
"tolist"
):
timestamps
=
timestamps
.
tolist
()
if
not
timestamps
:
timestamps_list
=
[]
elif
isinstance
(
timestamps
[
0
],
list
):
timestamps_list
=
timestamps
[
0
]
else
:
timestamps_list
=
timestamps
unique_timestamps
=
timestamps_list
.
copy
()
selected_timestamps
=
unique_timestamps
[:
num_frames
]
while
len
(
selected_timestamps
)
<
num_frames
:
selected_timestamps
.
append
(
selected_timestamps
[
-
1
]
if
selected_timestamps
else
0
)
else
:
image_grid_thw
=
[
None
]
*
len
(
images
)
video_grid_thw
=
[
None
]
*
len
(
videos
)
num_frames
=
0
selected_timestamps
=
[
0
]
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
image_seqlen
=
image_grid_thw
[
num_image_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
f
"<|begin_of_image|>
{
self
.
image_token
*
image_seqlen
}
<|end_of_image|>"
,
1
)
num_image_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
video_structure
=
""
for
frame_index
in
range
(
num_frames
):
video_seqlen
=
(
video_grid_thw
[
num_video_tokens
][
1
:].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
)
timestamp_sec
=
selected_timestamps
[
frame_index
]
frame_structure
=
(
f
"<|begin_of_image|>
{
self
.
image_token
*
video_seqlen
}
<|end_of_image|>
{
timestamp_sec
}
"
)
video_structure
+=
frame_structure
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
f
"<|begin_of_video|>
{
video_structure
}
<|end_of_video|>"
,
1
)
num_video_tokens
+=
1
message
[
"content"
]
=
content
return
messages
@
override
def
get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
imglens
:
list
[
int
],
vidlens
:
list
[
int
],
audlens
:
list
[
int
],
batch_ids
:
list
[
list
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
dict
[
str
,
Union
[
list
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
mm_inputs
.
pop
(
"timestamps"
,
None
)
return
mm_inputs
class
Qwen2OmniPlugin
(
Qwen2VLPlugin
):
@
override
def
_get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
"MMProcessor"
,
)
->
dict
[
str
,
"torch.Tensor"
]:
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
feature_extractor
:
SequenceFeatureExtractor
=
getattr
(
processor
,
"feature_extractor"
,
None
)
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
,
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)[
"images"
]
mm_inputs
.
update
(
image_processor
(
images
,
return_tensors
=
"pt"
))
if
len
(
videos
)
!=
0
:
video_dict
=
self
.
_regularize_videos
(
videos
,
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
image_min_pixels
=
getattr
(
processor
,
"video_min_pixels"
,
16
*
16
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)
mm_inputs
.
update
(
image_processor
(
images
=
None
,
videos
=
video_dict
[
"videos"
],
return_tensors
=
"pt"
))
temporal_patch_size
:
int
=
getattr
(
image_processor
,
"temporal_patch_size"
,
2
)
mm_inputs
[
"video_second_per_grid"
]
=
torch
.
tensor
(
[
temporal_patch_size
/
fps
for
fps
in
video_dict
[
"fps_per_video"
]]
)
if
len
(
audios
)
!=
0
:
audios
=
self
.
_regularize_audios
(
audios
,
sampling_rate
=
getattr
(
processor
,
"audio_sampling_rate"
,
16000
),
)[
"audios"
]
mm_inputs
.
update
(
feature_extractor
(
audios
,
sampling_rate
=
getattr
(
processor
,
"audio_sampling_rate"
,
16000
),
return_attention_mask
=
True
,
padding
=
"max_length"
,
return_tensors
=
"pt"
,
)
)
mm_inputs
[
"feature_attention_mask"
]
=
mm_inputs
.
pop
(
"attention_mask"
)
# prevent conflicts
return
mm_inputs
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
,
num_audio_tokens
=
0
,
0
,
0
messages
=
deepcopy
(
messages
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
merge_length
=
processor
.
image_processor
.
merge_size
**
2
use_audio_in_video
=
getattr
(
processor
,
"use_audio_in_video"
,
False
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
image_grid_thw
=
mm_inputs
.
get
(
"image_grid_thw"
,
[])
video_grid_thw
=
mm_inputs
.
get
(
"video_grid_thw"
,
[])
if
"feature_attention_mask"
in
mm_inputs
:
input_lengths
=
(
mm_inputs
[
"feature_attention_mask"
].
sum
(
-
1
).
numpy
()
-
1
)
//
2
+
1
audio_lengths
=
(
input_lengths
-
2
)
//
2
+
1
else
:
mm_inputs
=
{}
image_grid_thw
=
[
None
]
*
len
(
images
)
video_grid_thw
=
[
None
]
*
len
(
videos
)
audio_lengths
=
[
None
]
*
len
(
audios
)
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
image_seqlen
=
image_grid_thw
[
num_image_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
f
"<|vision_bos|>
{
self
.
image_token
*
image_seqlen
}
<|vision_eos|>"
,
1
)
num_image_tokens
+=
1
if
(
use_audio_in_video
and
len
(
audios
)
and
len
(
videos
)
):
# if use the audio of video # deal video token and audio token togather
if
len
(
videos
)
!=
len
(
audios
):
raise
ValueError
(
f
"Number of videos (
{
len
(
videos
)
}
) must match number of audios (
{
len
(
audios
)
}
) when using audio in video."
)
while
VIDEO_PLACEHOLDER
in
content
:
video_pos
=
content
.
find
(
VIDEO_PLACEHOLDER
)
audio_pos
=
content
.
find
(
AUDIO_PLACEHOLDER
,
video_pos
)
if
audio_pos
==
-
1
or
audio_pos
<
video_pos
:
raise
ValueError
(
f
"Each
{
VIDEO_PLACEHOLDER
}
must be followed by an
{
AUDIO_PLACEHOLDER
}
when using audio in video."
)
audio_t_index
=
torch
.
arange
(
audio_lengths
[
num_audio_tokens
])
video_t_index
=
(
torch
.
arange
(
video_grid_thw
[
num_video_tokens
][
0
])
.
view
(
-
1
,
1
,
1
)
.
expand
(
-
1
,
video_grid_thw
[
num_video_tokens
][
1
]
//
image_processor
.
merge_size
,
video_grid_thw
[
num_video_tokens
][
2
]
//
image_processor
.
merge_size
,
)
.
flatten
()
*
mm_inputs
[
"video_second_per_grid"
][
num_video_tokens
]
*
25
# FIXME hardcode of position_id_per_seconds=25
).
long
()
t_ntoken_per_chunk
=
50
# FIXME hardcode: [25 * 2]
video_chunk_indices
=
processor
.
get_chunked_index
(
video_t_index
,
t_ntoken_per_chunk
)
audio_chunk_indices
=
processor
.
get_chunked_index
(
audio_t_index
,
t_ntoken_per_chunk
)
placeholder_string
=
""
placeholder_string
+=
"<|vision_bos|>"
+
"<|audio_bos|>"
for
j
in
range
(
max
(
len
(
video_chunk_indices
),
len
(
audio_chunk_indices
))):
video_chunk_index
=
video_chunk_indices
[
j
]
if
j
<
len
(
video_chunk_indices
)
else
None
audio_chunk_index
=
audio_chunk_indices
[
j
]
if
j
<
len
(
audio_chunk_indices
)
else
None
if
video_chunk_index
is
not
None
:
placeholder_string
+=
self
.
video_token
*
(
video_chunk_index
[
1
]
-
video_chunk_index
[
0
])
if
audio_chunk_index
is
not
None
:
placeholder_string
+=
self
.
audio_token
*
(
audio_chunk_index
[
1
]
-
audio_chunk_index
[
0
])
placeholder_string
+=
"<|audio_eos|>"
+
"<|vision_eos|>"
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
placeholder_string
,
1
)
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
""
,
1
)
num_audio_tokens
+=
1
num_video_tokens
+=
1
else
:
while
AUDIO_PLACEHOLDER
in
content
:
audio_seqlen
=
audio_lengths
[
num_audio_tokens
]
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
f
"<|audio_bos|>
{
self
.
audio_token
*
audio_seqlen
}
<|audio_eos|>"
,
1
)
num_audio_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
video_seqlen
=
(
video_grid_thw
[
num_video_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
)
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
f
"<|vision_bos|>
{
self
.
video_token
*
video_seqlen
}
<|vision_eos|>"
,
1
)
num_video_tokens
+=
1
message
[
"content"
]
=
content
return
messages
@
dataclass
class
VideoLlavaPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
num_frames
=
0
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"pixel_values_images"
in
mm_inputs
:
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
[
"pixel_values_images"
][
0
]))
num_frames
=
1
if
"pixel_values_videos"
in
mm_inputs
:
one_video
=
to_numpy_array
(
mm_inputs
[
"pixel_values_videos"
][
0
])
height
,
width
=
get_image_size
(
one_video
[
0
])
num_frames
=
one_video
.
shape
[
0
]
# frame dim is always after batch dim
if
"pixel_values_images"
in
mm_inputs
or
"pixel_values_videos"
in
mm_inputs
:
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
+
processor
.
num_additional_image_tokens
video_seqlen
=
image_seqlen
*
num_frames
if
processor
.
vision_feature_select_strategy
==
"default"
:
image_seqlen
-=
1
else
:
image_seqlen
,
video_seqlen
=
1
,
1
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
num_image_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{video}}"
*
video_seqlen
,
1
)
num_video_tokens
+=
1
content
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
message
[
"content"
]
=
content
.
replace
(
"{{video}}"
,
self
.
video_token
)
return
messages
PLUGINS
=
{
"base"
:
BasePlugin
,
"gemma3"
:
Gemma3Plugin
,
"glm4v"
:
GLM4VPlugin
,
"gemma3n"
:
Gemma3nPlugin
,
"intern_vl"
:
InternVLPlugin
,
"kimi_vl"
:
KimiVLPlugin
,
"llama4"
:
Llama4Plugin
,
"llava"
:
LlavaPlugin
,
"llava_next"
:
LlavaNextPlugin
,
"llava_next_video"
:
LlavaNextVideoPlugin
,
"minicpm_v"
:
MiniCPMVPlugin
,
"mllama"
:
MllamaPlugin
,
"paligemma"
:
PaliGemmaPlugin
,
"pixtral"
:
PixtralPlugin
,
"qwen2_audio"
:
Qwen2AudioPlugin
,
"qwen2_omni"
:
Qwen2OmniPlugin
,
"qwen2_vl"
:
Qwen2VLPlugin
,
"video_llava"
:
VideoLlavaPlugin
,
}
def
register_mm_plugin
(
name
:
str
,
plugin_class
:
type
[
"BasePlugin"
])
->
None
:
r
"""Register a multimodal plugin."""
if
name
in
PLUGINS
:
raise
ValueError
(
f
"Multimodal plugin
{
name
}
already exists."
)
PLUGINS
[
name
]
=
plugin_class
def
get_mm_plugin
(
name
:
str
,
image_token
:
Optional
[
str
]
=
None
,
video_token
:
Optional
[
str
]
=
None
,
audio_token
:
Optional
[
str
]
=
None
,
)
->
"BasePlugin"
:
r
"""Get plugin for multimodal inputs."""
if
name
not
in
PLUGINS
:
raise
ValueError
(
f
"Multimodal plugin `
{
name
}
` not found."
)
return
PLUGINS
[
name
](
image_token
,
video_token
,
audio_token
)
src/llamafactory/data/parser.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
from
dataclasses
import
dataclass
from
typing
import
Any
,
Literal
,
Optional
from
huggingface_hub
import
hf_hub_download
from
..extras.constants
import
DATA_CONFIG
from
..extras.misc
import
use_modelscope
,
use_openmind
@
dataclass
class
DatasetAttr
:
r
"""Dataset attributes."""
# basic configs
load_from
:
Literal
[
"hf_hub"
,
"ms_hub"
,
"om_hub"
,
"script"
,
"file"
]
dataset_name
:
str
formatting
:
Literal
[
"alpaca"
,
"sharegpt"
]
=
"alpaca"
ranking
:
bool
=
False
# extra configs
subset
:
Optional
[
str
]
=
None
split
:
str
=
"train"
folder
:
Optional
[
str
]
=
None
num_samples
:
Optional
[
int
]
=
None
# common columns
system
:
Optional
[
str
]
=
None
tools
:
Optional
[
str
]
=
None
images
:
Optional
[
str
]
=
None
videos
:
Optional
[
str
]
=
None
audios
:
Optional
[
str
]
=
None
# dpo columns
chosen
:
Optional
[
str
]
=
None
rejected
:
Optional
[
str
]
=
None
kto_tag
:
Optional
[
str
]
=
None
# alpaca columns
prompt
:
Optional
[
str
]
=
"instruction"
query
:
Optional
[
str
]
=
"input"
response
:
Optional
[
str
]
=
"output"
history
:
Optional
[
str
]
=
None
# sharegpt columns
messages
:
Optional
[
str
]
=
"conversations"
# sharegpt tags
role_tag
:
Optional
[
str
]
=
"from"
content_tag
:
Optional
[
str
]
=
"value"
user_tag
:
Optional
[
str
]
=
"human"
assistant_tag
:
Optional
[
str
]
=
"gpt"
observation_tag
:
Optional
[
str
]
=
"observation"
function_tag
:
Optional
[
str
]
=
"function_call"
system_tag
:
Optional
[
str
]
=
"system"
def
__repr__
(
self
)
->
str
:
return
self
.
dataset_name
def
set_attr
(
self
,
key
:
str
,
obj
:
dict
[
str
,
Any
],
default
:
Optional
[
Any
]
=
None
)
->
None
:
setattr
(
self
,
key
,
obj
.
get
(
key
,
default
))
def
join
(
self
,
attr
:
dict
[
str
,
Any
])
->
None
:
self
.
set_attr
(
"formatting"
,
attr
,
default
=
"alpaca"
)
self
.
set_attr
(
"ranking"
,
attr
,
default
=
False
)
self
.
set_attr
(
"subset"
,
attr
)
self
.
set_attr
(
"split"
,
attr
,
default
=
"train"
)
self
.
set_attr
(
"folder"
,
attr
)
self
.
set_attr
(
"num_samples"
,
attr
)
if
"columns"
in
attr
:
column_names
=
[
"prompt"
,
"query"
,
"response"
,
"history"
,
"messages"
,
"system"
,
"tools"
]
column_names
+=
[
"images"
,
"videos"
,
"audios"
,
"chosen"
,
"rejected"
,
"kto_tag"
]
for
column_name
in
column_names
:
self
.
set_attr
(
column_name
,
attr
[
"columns"
])
if
"tags"
in
attr
:
tag_names
=
[
"role_tag"
,
"content_tag"
]
tag_names
+=
[
"user_tag"
,
"assistant_tag"
,
"observation_tag"
,
"function_tag"
,
"system_tag"
]
for
tag
in
tag_names
:
self
.
set_attr
(
tag
,
attr
[
"tags"
])
def
get_dataset_list
(
dataset_names
:
Optional
[
list
[
str
]],
dataset_dir
:
str
)
->
list
[
"DatasetAttr"
]:
r
"""Get the attributes of the datasets."""
if
dataset_names
is
None
:
dataset_names
=
[]
if
dataset_dir
==
"ONLINE"
:
dataset_info
=
None
else
:
if
dataset_dir
.
startswith
(
"REMOTE:"
):
config_path
=
hf_hub_download
(
repo_id
=
dataset_dir
[
7
:],
filename
=
DATA_CONFIG
,
repo_type
=
"dataset"
)
else
:
config_path
=
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
)
try
:
with
open
(
config_path
)
as
f
:
dataset_info
=
json
.
load
(
f
)
except
Exception
as
err
:
if
len
(
dataset_names
)
!=
0
:
raise
ValueError
(
f
"Cannot open
{
config_path
}
due to
{
str
(
err
)
}
."
)
dataset_info
=
None
dataset_list
:
list
[
DatasetAttr
]
=
[]
for
name
in
dataset_names
:
if
dataset_info
is
None
:
# dataset_dir is ONLINE
load_from
=
"ms_hub"
if
use_modelscope
()
else
"om_hub"
if
use_openmind
()
else
"hf_hub"
dataset_attr
=
DatasetAttr
(
load_from
,
dataset_name
=
name
)
dataset_list
.
append
(
dataset_attr
)
continue
if
name
not
in
dataset_info
:
raise
ValueError
(
f
"Undefined dataset
{
name
}
in
{
DATA_CONFIG
}
."
)
has_hf_url
=
"hf_hub_url"
in
dataset_info
[
name
]
has_ms_url
=
"ms_hub_url"
in
dataset_info
[
name
]
has_om_url
=
"om_hub_url"
in
dataset_info
[
name
]
if
has_hf_url
or
has_ms_url
or
has_om_url
:
if
has_ms_url
and
(
use_modelscope
()
or
not
has_hf_url
):
dataset_attr
=
DatasetAttr
(
"ms_hub"
,
dataset_name
=
dataset_info
[
name
][
"ms_hub_url"
])
elif
has_om_url
and
(
use_openmind
()
or
not
has_hf_url
):
dataset_attr
=
DatasetAttr
(
"om_hub"
,
dataset_name
=
dataset_info
[
name
][
"om_hub_url"
])
else
:
dataset_attr
=
DatasetAttr
(
"hf_hub"
,
dataset_name
=
dataset_info
[
name
][
"hf_hub_url"
])
elif
"script_url"
in
dataset_info
[
name
]:
dataset_attr
=
DatasetAttr
(
"script"
,
dataset_name
=
dataset_info
[
name
][
"script_url"
])
elif
"cloud_file_name"
in
dataset_info
[
name
]:
dataset_attr
=
DatasetAttr
(
"cloud_file"
,
dataset_name
=
dataset_info
[
name
][
"cloud_file_name"
])
else
:
dataset_attr
=
DatasetAttr
(
"file"
,
dataset_name
=
dataset_info
[
name
][
"file_name"
])
dataset_attr
.
join
(
dataset_info
[
name
])
dataset_list
.
append
(
dataset_attr
)
return
dataset_list
src/llamafactory/data/processor/__init__.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.feedback
import
FeedbackDatasetProcessor
from
.pairwise
import
PairwiseDatasetProcessor
from
.pretrain
import
PretrainDatasetProcessor
from
.processor_utils
import
DatasetProcessor
from
.supervised
import
PackedSupervisedDatasetProcessor
,
SupervisedDatasetProcessor
from
.unsupervised
import
UnsupervisedDatasetProcessor
__all__
=
[
"DatasetProcessor"
,
"FeedbackDatasetProcessor"
,
"PackedSupervisedDatasetProcessor"
,
"PairwiseDatasetProcessor"
,
"PretrainDatasetProcessor"
,
"SupervisedDatasetProcessor"
,
"UnsupervisedDatasetProcessor"
,
]
src/llamafactory/data/processor/feedback.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
.processor_utils
import
DatasetProcessor
,
infer_seqlen
if
TYPE_CHECKING
:
from
..mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
logger
=
logging
.
get_logger
(
__name__
)
class
FeedbackDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
self
,
prompt
:
list
[
dict
[
str
,
str
]],
response
:
list
[
dict
[
str
,
str
]],
kl_response
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
)
->
tuple
[
list
[
int
],
list
[
int
],
list
[
int
],
list
[
int
],
bool
]:
if
response
[
0
][
"content"
]:
# desired example
kto_tag
=
True
messages
=
prompt
+
[
response
[
0
]]
else
:
# undesired example
kto_tag
=
False
messages
=
prompt
+
[
response
[
1
]]
if
kl_response
[
0
][
"content"
]:
kl_messages
=
prompt
+
[
kl_response
[
0
]]
else
:
kl_messages
=
prompt
+
[
kl_response
[
1
]]
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
messages
,
images
,
videos
,
audios
,
self
.
processor
)
kl_messages
=
self
.
template
.
mm_plugin
.
process_messages
(
kl_messages
,
images
,
videos
,
audios
,
self
.
processor
)
prompt_ids
,
response_ids
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
messages
,
system
,
tools
)
kl_prompt_ids
,
kl_response_ids
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
kl_messages
,
system
,
tools
)
if
self
.
template
.
efficient_eos
:
response_ids
+=
[
self
.
tokenizer
.
eos_token_id
]
kl_response_ids
+=
[
self
.
tokenizer
.
eos_token_id
]
prompt_ids
,
_
=
self
.
template
.
mm_plugin
.
process_token_ids
(
prompt_ids
,
None
,
images
,
videos
,
audios
,
self
.
tokenizer
,
self
.
processor
)
kl_prompt_ids
,
_
=
self
.
template
.
mm_plugin
.
process_token_ids
(
kl_prompt_ids
,
None
,
images
,
videos
,
audios
,
self
.
tokenizer
,
self
.
processor
)
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
len
(
response_ids
),
self
.
data_args
.
cutoff_len
)
prompt_ids
=
prompt_ids
[:
source_len
]
response_ids
=
response_ids
[:
target_len
]
kl_source_len
,
kl_target_len
=
infer_seqlen
(
len
(
kl_prompt_ids
),
len
(
kl_response_ids
),
self
.
data_args
.
cutoff_len
)
kl_prompt_ids
=
kl_prompt_ids
[:
kl_source_len
]
kl_response_ids
=
kl_response_ids
[:
kl_target_len
]
input_ids
=
prompt_ids
+
response_ids
labels
=
[
IGNORE_INDEX
]
*
source_len
+
response_ids
kl_input_ids
=
kl_prompt_ids
+
kl_response_ids
kl_labels
=
[
IGNORE_INDEX
]
*
kl_source_len
+
kl_response_ids
return
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
def
preprocess_dataset
(
self
,
examples
:
dict
[
str
,
list
[
Any
]])
->
dict
[
str
,
list
[
Any
]]:
# Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions.
kl_response
=
[
examples
[
"_response"
][
-
1
]]
+
examples
[
"_response"
][:
-
1
]
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
=
self
.
_encode_data_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
kl_response
=
kl_response
[
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
audios
=
examples
[
"_audios"
][
i
]
or
[],
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"labels"
].
append
(
labels
)
model_inputs
[
"kl_input_ids"
].
append
(
kl_input_ids
)
model_inputs
[
"kl_attention_mask"
].
append
([
1
]
*
len
(
kl_input_ids
))
model_inputs
[
"kl_labels"
].
append
(
kl_labels
)
model_inputs
[
"kto_tags"
].
append
(
kto_tag
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
model_inputs
[
"audios"
].
append
(
examples
[
"_audios"
][
i
])
desirable_num
=
sum
([
1
for
tag
in
model_inputs
[
"kto_tags"
]
if
tag
])
undesirable_num
=
len
(
model_inputs
[
"kto_tags"
])
-
desirable_num
if
desirable_num
==
0
or
undesirable_num
==
0
:
logger
.
warning_rank0
(
"Your dataset only has one preference type."
)
return
model_inputs
def
print_data_example
(
self
,
example
:
dict
[
str
,
list
[
int
]])
->
None
:
valid_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"labels"
]))
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
f
"labels:
\n
{
self
.
tokenizer
.
decode
(
valid_labels
,
skip_special_tokens
=
False
)
}
"
)
src/llamafactory/data/processor/pairwise.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
.processor_utils
import
DatasetProcessor
,
infer_seqlen
if
TYPE_CHECKING
:
from
..mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
logger
=
logging
.
get_logger
(
__name__
)
class
PairwiseDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
self
,
prompt
:
list
[
dict
[
str
,
str
]],
response
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
)
->
tuple
[
list
[
int
],
list
[
int
],
list
[
int
],
list
[
int
]]:
chosen_messages
=
self
.
template
.
mm_plugin
.
process_messages
(
prompt
+
[
response
[
0
]],
images
,
videos
,
audios
,
self
.
processor
)
rejected_messages
=
self
.
template
.
mm_plugin
.
process_messages
(
prompt
+
[
response
[
1
]],
images
,
videos
,
audios
,
self
.
processor
)
prompt_ids
,
chosen_ids
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
chosen_messages
,
system
,
tools
)
_
,
rejected_ids
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
rejected_messages
,
system
,
tools
)
if
self
.
template
.
efficient_eos
:
chosen_ids
+=
[
self
.
tokenizer
.
eos_token_id
]
rejected_ids
+=
[
self
.
tokenizer
.
eos_token_id
]
prompt_ids
,
_
=
self
.
template
.
mm_plugin
.
process_token_ids
(
prompt_ids
,
None
,
images
,
videos
,
audios
,
self
.
tokenizer
,
self
.
processor
)
# consider the response is more important
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
max
(
len
(
chosen_ids
),
len
(
rejected_ids
)),
self
.
data_args
.
cutoff_len
)
prompt_ids
=
prompt_ids
[:
source_len
]
chosen_ids
=
chosen_ids
[:
target_len
]
rejected_ids
=
rejected_ids
[:
target_len
]
chosen_input_ids
=
prompt_ids
+
chosen_ids
chosen_labels
=
[
IGNORE_INDEX
]
*
source_len
+
chosen_ids
rejected_input_ids
=
prompt_ids
+
rejected_ids
rejected_labels
=
[
IGNORE_INDEX
]
*
source_len
+
rejected_ids
return
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
def
preprocess_dataset
(
self
,
examples
:
dict
[
str
,
list
[
Any
]])
->
dict
[
str
,
list
[
Any
]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
=
self
.
_encode_data_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
audios
=
examples
[
"_audios"
][
i
]
or
[],
)
model_inputs
[
"chosen_input_ids"
].
append
(
chosen_input_ids
)
model_inputs
[
"chosen_attention_mask"
].
append
([
1
]
*
len
(
chosen_input_ids
))
model_inputs
[
"chosen_labels"
].
append
(
chosen_labels
)
model_inputs
[
"rejected_input_ids"
].
append
(
rejected_input_ids
)
model_inputs
[
"rejected_attention_mask"
].
append
([
1
]
*
len
(
rejected_input_ids
))
model_inputs
[
"rejected_labels"
].
append
(
rejected_labels
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
model_inputs
[
"audios"
].
append
(
examples
[
"_audios"
][
i
])
return
model_inputs
def
print_data_example
(
self
,
example
:
dict
[
str
,
list
[
int
]])
->
None
:
valid_chosen_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"chosen_labels"
]))
valid_rejected_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"rejected_labels"
]))
print
(
"chosen_input_ids:
\n
{}"
.
format
(
example
[
"chosen_input_ids"
]))
print
(
"chosen_inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"chosen_input_ids"
],
skip_special_tokens
=
False
))
)
print
(
"chosen_label_ids:
\n
{}"
.
format
(
example
[
"chosen_labels"
]))
print
(
f
"chosen_labels:
\n
{
self
.
tokenizer
.
decode
(
valid_chosen_labels
,
skip_special_tokens
=
False
)
}
"
)
print
(
"rejected_input_ids:
\n
{}"
.
format
(
example
[
"rejected_input_ids"
]))
print
(
"rejected_inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"rejected_input_ids"
],
skip_special_tokens
=
False
)
)
)
print
(
"rejected_label_ids:
\n
{}"
.
format
(
example
[
"rejected_labels"
]))
print
(
f
"rejected_labels:
\n
{
self
.
tokenizer
.
decode
(
valid_rejected_labels
,
skip_special_tokens
=
False
)
}
"
)
src/llamafactory/data/processor/pretrain.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
from
itertools
import
chain
from
typing
import
Any
from
.processor_utils
import
DatasetProcessor
@
dataclass
class
PretrainDatasetProcessor
(
DatasetProcessor
):
def
preprocess_dataset
(
self
,
examples
:
dict
[
str
,
list
[
Any
]])
->
dict
[
str
,
list
[
Any
]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token
=
"<|end_of_text|>"
if
self
.
data_args
.
template
==
"llama3"
else
self
.
tokenizer
.
eos_token
text_examples
=
[
messages
[
0
][
"content"
]
+
eos_token
for
messages
in
examples
[
"_prompt"
]]
if
not
self
.
data_args
.
packing
:
if
getattr
(
self
.
tokenizer
,
"add_bos_token"
,
False
):
text_examples
=
[
self
.
tokenizer
.
bos_token
+
example
for
example
in
text_examples
]
result
=
self
.
tokenizer
(
text_examples
,
add_special_tokens
=
False
,
truncation
=
True
,
max_length
=
self
.
data_args
.
cutoff_len
)
else
:
tokenized_examples
=
self
.
tokenizer
(
text_examples
,
add_special_tokens
=
False
)
concatenated_examples
=
{
k
:
list
(
chain
(
*
tokenized_examples
[
k
]))
for
k
in
tokenized_examples
.
keys
()}
total_length
=
len
(
concatenated_examples
[
list
(
concatenated_examples
.
keys
())[
0
]])
block_size
=
self
.
data_args
.
cutoff_len
total_length
=
(
total_length
//
block_size
)
*
block_size
result
=
{
k
:
[
t
[
i
:
i
+
block_size
]
for
i
in
range
(
0
,
total_length
,
block_size
)]
for
k
,
t
in
concatenated_examples
.
items
()
}
if
getattr
(
self
.
tokenizer
,
"add_bos_token"
,
False
):
for
i
in
range
(
len
(
result
[
"input_ids"
])):
result
[
"input_ids"
][
i
][
0
]
=
self
.
tokenizer
.
bos_token_id
return
result
def
print_data_example
(
self
,
example
:
dict
[
str
,
list
[
int
]])
->
None
:
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
src/llamafactory/data/processor/processor_utils.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
bisect
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
...hparams
import
DataArguments
from
..template
import
Template
@
dataclass
class
DatasetProcessor
(
ABC
):
r
"""A class for data processors."""
template
:
"Template"
tokenizer
:
"PreTrainedTokenizer"
processor
:
Optional
[
"ProcessorMixin"
]
data_args
:
"DataArguments"
@
abstractmethod
def
preprocess_dataset
(
self
,
examples
:
dict
[
str
,
list
[
Any
]])
->
dict
[
str
,
list
[
Any
]]:
r
"""Build model inputs from the examples."""
...
@
abstractmethod
def
print_data_example
(
self
,
example
:
dict
[
str
,
list
[
int
]])
->
None
:
r
"""Print a data example to stdout."""
...
def
search_for_fit
(
numbers
:
list
[
int
],
capacity
:
int
)
->
int
:
r
"""Find the index of largest number that fits into the knapsack with the given capacity."""
index
=
bisect
.
bisect
(
numbers
,
capacity
)
return
-
1
if
index
==
0
else
(
index
-
1
)
def
greedy_knapsack
(
numbers
:
list
[
int
],
capacity
:
int
)
->
list
[
list
[
int
]]:
r
"""Implement efficient greedy algorithm with binary search for the knapsack problem."""
numbers
.
sort
()
# sort numbers in ascending order for binary search
knapsacks
=
[]
while
numbers
:
current_knapsack
=
[]
remaining_capacity
=
capacity
while
True
:
index
=
search_for_fit
(
numbers
,
remaining_capacity
)
if
index
==
-
1
:
break
# no more numbers fit in this knapsack
remaining_capacity
-=
numbers
[
index
]
# update the remaining capacity
current_knapsack
.
append
(
numbers
.
pop
(
index
))
# add the number to knapsack
knapsacks
.
append
(
current_knapsack
)
return
knapsacks
def
infer_seqlen
(
source_len
:
int
,
target_len
:
int
,
cutoff_len
:
int
)
->
tuple
[
int
,
int
]:
r
"""Compute the real sequence length after truncation by the cutoff_len."""
if
target_len
*
2
<
cutoff_len
:
# truncate source
max_target_len
=
cutoff_len
elif
source_len
*
2
<
cutoff_len
:
# truncate target
max_target_len
=
cutoff_len
-
source_len
else
:
# truncate both
max_target_len
=
int
(
cutoff_len
*
(
target_len
/
(
source_len
+
target_len
)))
new_target_len
=
min
(
max_target_len
,
target_len
)
max_source_len
=
max
(
cutoff_len
-
new_target_len
,
0
)
new_source_len
=
min
(
max_source_len
,
source_len
)
return
new_source_len
,
new_target_len
src/llamafactory/data/processor/supervised.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
.processor_utils
import
DatasetProcessor
,
greedy_knapsack
,
infer_seqlen
if
TYPE_CHECKING
:
from
..mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
logger
=
logging
.
get_logger
(
__name__
)
@
dataclass
class
SupervisedDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
self
,
prompt
:
list
[
dict
[
str
,
str
]],
response
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
)
->
tuple
[
list
[
int
],
list
[
int
]]:
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
prompt
+
response
,
images
,
videos
,
audios
,
self
.
processor
)
input_ids
,
labels
=
self
.
template
.
mm_plugin
.
process_token_ids
(
[],
[],
images
,
videos
,
audios
,
self
.
tokenizer
,
self
.
processor
)
encoded_pairs
=
self
.
template
.
encode_multiturn
(
self
.
tokenizer
,
messages
,
system
,
tools
)
total_length
=
len
(
input_ids
)
+
(
1
if
self
.
template
.
efficient_eos
else
0
)
if
self
.
data_args
.
mask_history
:
encoded_pairs
=
encoded_pairs
[::
-
1
]
# high priority for last turns
for
turn_idx
,
(
source_ids
,
target_ids
)
in
enumerate
(
encoded_pairs
):
if
total_length
>=
self
.
data_args
.
cutoff_len
:
break
source_len
,
target_len
=
infer_seqlen
(
len
(
source_ids
),
len
(
target_ids
),
self
.
data_args
.
cutoff_len
-
total_length
)
source_ids
=
source_ids
[:
source_len
]
target_ids
=
target_ids
[:
target_len
]
total_length
+=
source_len
+
target_len
if
self
.
data_args
.
train_on_prompt
:
source_label
=
source_ids
elif
self
.
template
.
efficient_eos
:
source_label
=
[
self
.
tokenizer
.
eos_token_id
]
+
[
IGNORE_INDEX
]
*
(
source_len
-
1
)
else
:
source_label
=
[
IGNORE_INDEX
]
*
source_len
if
self
.
data_args
.
mask_history
and
turn_idx
!=
0
:
# train on the last turn only
target_label
=
[
IGNORE_INDEX
]
*
target_len
else
:
target_label
=
target_ids
if
self
.
data_args
.
mask_history
:
# reversed sequences
input_ids
=
source_ids
+
target_ids
+
input_ids
labels
=
source_label
+
target_label
+
labels
else
:
input_ids
+=
source_ids
+
target_ids
labels
+=
source_label
+
target_label
if
self
.
template
.
efficient_eos
:
input_ids
+=
[
self
.
tokenizer
.
eos_token_id
]
labels
+=
[
self
.
tokenizer
.
eos_token_id
]
return
input_ids
,
labels
def
preprocess_dataset
(
self
,
examples
:
dict
[
str
,
list
[
Any
]])
->
dict
[
str
,
list
[
Any
]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
!=
1
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
=
self
.
_encode_data_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
audios
=
examples
[
"_audios"
][
i
]
or
[],
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"labels"
].
append
(
labels
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
model_inputs
[
"audios"
].
append
(
examples
[
"_audios"
][
i
])
return
model_inputs
def
print_data_example
(
self
,
example
:
dict
[
str
,
list
[
int
]])
->
None
:
valid_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"labels"
]))
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
f
"labels:
\n
{
self
.
tokenizer
.
decode
(
valid_labels
,
skip_special_tokens
=
False
)
}
"
)
@
dataclass
class
PackedSupervisedDatasetProcessor
(
SupervisedDatasetProcessor
):
def
preprocess_dataset
(
self
,
examples
:
dict
[
str
,
list
[
Any
]])
->
dict
[
str
,
list
[
Any
]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num
=
0
batch_input_ids
,
batch_labels
,
batch_images
,
batch_videos
,
batch_audios
=
[],
[],
[],
[],
[]
lengths
=
[]
length2indexes
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
!=
1
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
=
self
.
_encode_data_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
audios
=
examples
[
"_audios"
][
i
]
or
[],
)
length
=
len
(
input_ids
)
if
length
>
self
.
data_args
.
cutoff_len
:
logger
.
warning_rank0
(
f
"Dropped lengthy example with length
{
length
}
>
{
self
.
data_args
.
cutoff_len
}
."
)
else
:
lengths
.
append
(
length
)
length2indexes
[
length
].
append
(
valid_num
)
batch_input_ids
.
append
(
input_ids
)
batch_labels
.
append
(
labels
)
batch_images
.
append
(
examples
[
"_images"
][
i
]
or
[])
batch_videos
.
append
(
examples
[
"_videos"
][
i
]
or
[])
batch_audios
.
append
(
examples
[
"_audios"
][
i
]
or
[])
valid_num
+=
1
model_inputs
=
defaultdict
(
list
)
knapsacks
=
greedy_knapsack
(
lengths
,
self
.
data_args
.
cutoff_len
)
for
knapsack
in
knapsacks
:
packed_input_ids
,
packed_attention_masks
,
packed_position_ids
,
packed_labels
=
[],
[],
[],
[]
packed_images
,
packed_videos
,
packed_audios
=
[],
[],
[]
for
i
,
length
in
enumerate
(
knapsack
):
index
=
length2indexes
[
length
].
pop
()
packed_input_ids
+=
batch_input_ids
[
index
]
packed_position_ids
+=
list
(
range
(
len
(
batch_input_ids
[
index
])))
# NOTE: pad_to_multiple_of ignore this
packed_labels
+=
batch_labels
[
index
]
packed_images
+=
batch_images
[
index
]
packed_videos
+=
batch_videos
[
index
]
packed_audios
+=
batch_audios
[
index
]
if
self
.
data_args
.
neat_packing
:
packed_attention_masks
+=
[
i
+
1
]
*
len
(
batch_input_ids
[
index
])
# start from 1
else
:
packed_attention_masks
+=
[
1
]
*
len
(
batch_input_ids
[
index
])
if
len
(
packed_input_ids
)
<
self
.
data_args
.
cutoff_len
+
1
:
# avoid flash_attn drops attn mask
pad_length
=
self
.
data_args
.
cutoff_len
-
len
(
packed_input_ids
)
+
1
packed_input_ids
+=
[
self
.
tokenizer
.
pad_token_id
]
*
pad_length
packed_position_ids
+=
[
0
]
*
pad_length
packed_labels
+=
[
IGNORE_INDEX
]
*
pad_length
if
self
.
data_args
.
neat_packing
:
packed_attention_masks
+=
[
0
]
*
pad_length
else
:
packed_attention_masks
+=
[
1
]
*
pad_length
# more efficient flash_attn
if
len
(
packed_input_ids
)
!=
self
.
data_args
.
cutoff_len
+
1
:
raise
ValueError
(
"The length of packed example should be identical to the cutoff length."
)
model_inputs
[
"input_ids"
].
append
(
packed_input_ids
)
model_inputs
[
"attention_mask"
].
append
(
packed_attention_masks
)
model_inputs
[
"position_ids"
].
append
(
packed_position_ids
)
model_inputs
[
"labels"
].
append
(
packed_labels
)
model_inputs
[
"images"
].
append
(
packed_images
or
None
)
model_inputs
[
"videos"
].
append
(
packed_videos
or
None
)
model_inputs
[
"audios"
].
append
(
packed_audios
or
None
)
return
model_inputs
src/llamafactory/data/processor/unsupervised.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
...extras
import
logging
from
..data_utils
import
Role
from
.processor_utils
import
DatasetProcessor
,
infer_seqlen
if
TYPE_CHECKING
:
from
..mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
logger
=
logging
.
get_logger
(
__name__
)
class
UnsupervisedDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
self
,
prompt
:
list
[
dict
[
str
,
str
]],
response
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
)
->
tuple
[
list
[
int
],
list
[
int
]]:
if
len
(
response
)
==
1
:
messages
=
prompt
+
response
else
:
messages
=
prompt
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
messages
,
images
,
videos
,
audios
,
self
.
processor
)
input_ids
,
labels
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
messages
,
system
,
tools
)
if
self
.
template
.
efficient_eos
:
labels
+=
[
self
.
tokenizer
.
eos_token_id
]
input_ids
,
_
=
self
.
template
.
mm_plugin
.
process_token_ids
(
input_ids
,
None
,
images
,
videos
,
audios
,
self
.
tokenizer
,
self
.
processor
)
source_len
,
target_len
=
infer_seqlen
(
len
(
input_ids
),
len
(
labels
),
self
.
data_args
.
cutoff_len
)
input_ids
=
input_ids
[:
source_len
]
labels
=
labels
[:
target_len
]
return
input_ids
,
labels
def
preprocess_dataset
(
self
,
examples
:
dict
[
str
,
list
[
Any
]])
->
dict
[
str
,
list
[
Any
]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
:
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
=
self
.
_encode_data_example
(
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
audios
=
examples
[
"_audios"
][
i
]
or
[],
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"labels"
].
append
(
labels
)
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
model_inputs
[
"audios"
].
append
(
examples
[
"_audios"
][
i
])
return
model_inputs
def
print_data_example
(
self
,
example
:
dict
[
str
,
list
[
int
]])
->
None
:
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
"labels:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"labels"
],
skip_special_tokens
=
False
)))
src/llamafactory/data/template.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
re
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing_extensions
import
override
from
..extras
import
logging
from
.data_utils
import
Role
from
.formatter
import
EmptyFormatter
,
FunctionFormatter
,
StringFormatter
,
ToolFormatter
from
.mm_plugin
import
get_mm_plugin
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
from
..hparams
import
DataArguments
from
.formatter
import
SLOTS
,
Formatter
from
.mm_plugin
import
BasePlugin
from
.tool_utils
import
FunctionCall
logger
=
logging
.
get_logger
(
__name__
)
@
dataclass
class
Template
:
format_user
:
"Formatter"
format_assistant
:
"Formatter"
format_system
:
"Formatter"
format_function
:
"Formatter"
format_observation
:
"Formatter"
format_tools
:
"Formatter"
format_prefix
:
"Formatter"
default_system
:
str
stop_words
:
list
[
str
]
thought_words
:
tuple
[
str
,
str
]
efficient_eos
:
bool
replace_eos
:
bool
replace_jinja_template
:
bool
enable_thinking
:
Optional
[
bool
]
mm_plugin
:
"BasePlugin"
def
encode_oneturn
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
tuple
[
list
[
int
],
list
[
int
]]:
r
"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
prompt_ids
=
[]
for
encoded_ids
in
encoded_messages
[:
-
1
]:
prompt_ids
+=
encoded_ids
response_ids
=
encoded_messages
[
-
1
]
return
prompt_ids
,
response_ids
def
encode_multiturn
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
list
[
tuple
[
list
[
int
],
list
[
int
]]]:
r
"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
return
[(
encoded_messages
[
i
],
encoded_messages
[
i
+
1
])
for
i
in
range
(
0
,
len
(
encoded_messages
),
2
)]
def
extract_tool
(
self
,
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
r
"""Extract tool message."""
return
self
.
format_tools
.
extract
(
content
)
def
get_stop_token_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
list
[
int
]:
r
"""Return stop token ids."""
stop_token_ids
=
{
tokenizer
.
eos_token_id
}
for
token
in
self
.
stop_words
:
stop_token_ids
.
add
(
tokenizer
.
convert_tokens_to_ids
(
token
))
return
list
(
stop_token_ids
)
def
add_thought
(
self
,
content
:
str
=
""
)
->
str
:
r
"""Add empty thought to assistant message."""
return
f
"
{
self
.
thought_words
[
0
]
}
\n\n
{
self
.
thought_words
[
1
]
}
\n\n
"
+
content
def
remove_thought
(
self
,
content
:
str
)
->
str
:
r
"""Remove thought from assistant message."""
pattern
=
re
.
compile
(
f
"
{
re
.
escape
(
self
.
thought_words
[
0
])
}
(.*?)
{
re
.
escape
(
self
.
thought_words
[
1
])
}
"
,
re
.
DOTALL
)
return
re
.
sub
(
pattern
,
""
,
content
).
lstrip
(
"
\n
"
)
def
get_thought_word_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
list
[
int
]:
r
"""Get the token ids of thought words."""
return
tokenizer
.
encode
(
self
.
add_thought
(),
add_special_tokens
=
False
)
def
_convert_elements_to_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
elements
:
"SLOTS"
)
->
list
[
int
]:
r
"""Convert elements to token ids."""
token_ids
=
[]
for
elem
in
elements
:
if
isinstance
(
elem
,
str
):
if
len
(
elem
)
!=
0
:
token_ids
+=
tokenizer
.
encode
(
elem
,
add_special_tokens
=
False
)
elif
isinstance
(
elem
,
dict
):
token_ids
+=
[
tokenizer
.
convert_tokens_to_ids
(
elem
.
get
(
"token"
))]
elif
isinstance
(
elem
,
set
):
if
"bos_token"
in
elem
and
tokenizer
.
bos_token_id
is
not
None
:
token_ids
+=
[
tokenizer
.
bos_token_id
]
elif
"eos_token"
in
elem
and
tokenizer
.
eos_token_id
is
not
None
:
token_ids
+=
[
tokenizer
.
eos_token_id
]
else
:
raise
ValueError
(
f
"Input must be string, set[str] or dict[str, str], got
{
type
(
elem
)
}
"
)
return
token_ids
def
_encode
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
)
->
list
[
list
[
int
]]:
r
"""Encode formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: query resp.
"""
system
=
system
or
self
.
default_system
encoded_messages
=
[]
for
i
,
message
in
enumerate
(
messages
):
elements
=
[]
if
i
==
0
:
elements
+=
self
.
format_prefix
.
apply
()
if
system
or
tools
:
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
elements
+=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))
if
message
[
"role"
]
==
Role
.
USER
:
elements
+=
self
.
format_user
.
apply
(
content
=
message
[
"content"
],
idx
=
str
(
i
//
2
))
elif
message
[
"role"
]
==
Role
.
ASSISTANT
:
elements
+=
self
.
format_assistant
.
apply
(
content
=
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
OBSERVATION
:
elements
+=
self
.
format_observation
.
apply
(
content
=
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
FUNCTION
:
elements
+=
self
.
format_function
.
apply
(
content
=
message
[
"content"
])
else
:
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
encoded_messages
.
append
(
self
.
_convert_elements_to_ids
(
tokenizer
,
elements
))
return
encoded_messages
@
staticmethod
def
_add_or_replace_eos_token
(
tokenizer
:
"PreTrainedTokenizer"
,
eos_token
:
str
)
->
None
:
r
"""Add or replace eos token to the tokenizer."""
if
tokenizer
.
eos_token
==
eos_token
:
return
is_added
=
tokenizer
.
eos_token_id
is
None
num_added_tokens
=
tokenizer
.
add_special_tokens
({
"eos_token"
:
eos_token
})
if
is_added
:
logger
.
info_rank0
(
f
"Add eos token:
{
tokenizer
.
eos_token
}
."
)
else
:
logger
.
info_rank0
(
f
"Replace eos token:
{
tokenizer
.
eos_token
}
."
)
if
num_added_tokens
>
0
:
logger
.
warning_rank0
(
"New tokens have been added, make sure `resize_vocab` is True."
)
def
fix_special_tokens
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
r
"""Add eos token and pad token to the tokenizer."""
stop_words
=
self
.
stop_words
if
self
.
replace_eos
:
if
not
stop_words
:
raise
ValueError
(
"Stop words are required to replace the EOS token."
)
self
.
_add_or_replace_eos_token
(
tokenizer
,
eos_token
=
stop_words
[
0
])
stop_words
=
stop_words
[
1
:]
if
tokenizer
.
eos_token_id
is
None
:
self
.
_add_or_replace_eos_token
(
tokenizer
,
eos_token
=
"<|endoftext|>"
)
if
tokenizer
.
pad_token_id
is
None
:
tokenizer
.
pad_token
=
tokenizer
.
eos_token
logger
.
info_rank0
(
f
"Add pad token:
{
tokenizer
.
pad_token
}
"
)
if
stop_words
:
num_added_tokens
=
tokenizer
.
add_special_tokens
(
dict
(
additional_special_tokens
=
stop_words
),
replace_additional_special_tokens
=
False
)
logger
.
info_rank0
(
"Add {} to stop words."
.
format
(
","
.
join
(
stop_words
)))
if
num_added_tokens
>
0
:
logger
.
warning_rank0
(
"New tokens have been added, make sure `resize_vocab` is True."
)
@
staticmethod
def
_jinja_escape
(
content
:
str
)
->
str
:
r
"""Escape single quotes in content."""
return
content
.
replace
(
"'"
,
r
"\'"
)
@
staticmethod
def
_convert_slots_to_jinja
(
slots
:
"SLOTS"
,
tokenizer
:
"PreTrainedTokenizer"
,
placeholder
:
str
=
"content"
)
->
str
:
r
"""Convert slots to jinja template."""
slot_items
=
[]
for
slot
in
slots
:
if
isinstance
(
slot
,
str
):
slot_pieces
=
slot
.
split
(
"{{content}}"
)
if
slot_pieces
[
0
]:
slot_items
.
append
(
"'"
+
Template
.
_jinja_escape
(
slot_pieces
[
0
])
+
"'"
)
if
len
(
slot_pieces
)
>
1
:
slot_items
.
append
(
placeholder
)
if
slot_pieces
[
1
]:
slot_items
.
append
(
"'"
+
Template
.
_jinja_escape
(
slot_pieces
[
1
])
+
"'"
)
elif
isinstance
(
slot
,
set
):
# do not use {{ eos_token }} since it may be replaced
if
"bos_token"
in
slot
and
tokenizer
.
bos_token_id
is
not
None
:
slot_items
.
append
(
"'"
+
tokenizer
.
bos_token
+
"'"
)
elif
"eos_token"
in
slot
and
tokenizer
.
eos_token_id
is
not
None
:
slot_items
.
append
(
"'"
+
tokenizer
.
eos_token
+
"'"
)
elif
isinstance
(
slot
,
dict
):
raise
ValueError
(
"Dict is not supported."
)
return
" + "
.
join
(
slot_items
)
def
_get_jinja_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
r
"""Return the jinja template."""
prefix
=
self
.
_convert_slots_to_jinja
(
self
.
format_prefix
.
apply
(),
tokenizer
)
system
=
self
.
_convert_slots_to_jinja
(
self
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
"system_message"
)
user
=
self
.
_convert_slots_to_jinja
(
self
.
format_user
.
apply
(),
tokenizer
)
assistant
=
self
.
_convert_slots_to_jinja
(
self
.
format_assistant
.
apply
(),
tokenizer
)
jinja_template
=
""
if
prefix
:
jinja_template
+=
"{{ "
+
prefix
+
" }}"
if
self
.
default_system
:
jinja_template
+=
"{% set system_message = '"
+
self
.
_jinja_escape
(
self
.
default_system
)
+
"' %}"
jinja_template
+=
(
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
"{% if system_message is defined %}{{ "
+
system
+
" }}{% endif %}"
"{% for message in loop_messages %}"
"{% set content = message['content'] %}"
"{% if message['role'] == 'user' %}"
"{{ "
+
user
+
" }}"
"{% elif message['role'] == 'assistant' %}"
"{{ "
+
assistant
+
" }}"
"{% endif %}"
"{% endfor %}"
)
return
jinja_template
def
fix_jinja_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
r
"""Replace the jinja template in the tokenizer."""
if
tokenizer
.
chat_template
is
None
or
self
.
replace_jinja_template
:
try
:
tokenizer
.
chat_template
=
self
.
_get_jinja_template
(
tokenizer
)
except
ValueError
as
e
:
logger
.
info_rank0
(
f
"Cannot add this chat template to tokenizer:
{
e
}
."
)
@
staticmethod
def
_convert_slots_to_ollama
(
slots
:
"SLOTS"
,
tokenizer
:
"PreTrainedTokenizer"
,
placeholder
:
str
=
"content"
)
->
str
:
r
"""Convert slots to ollama template."""
slot_items
=
[]
for
slot
in
slots
:
if
isinstance
(
slot
,
str
):
slot_pieces
=
slot
.
split
(
"{{content}}"
)
if
slot_pieces
[
0
]:
slot_items
.
append
(
slot_pieces
[
0
])
if
len
(
slot_pieces
)
>
1
:
slot_items
.
append
(
"{{ "
+
placeholder
+
" }}"
)
if
slot_pieces
[
1
]:
slot_items
.
append
(
slot_pieces
[
1
])
elif
isinstance
(
slot
,
set
):
# do not use {{ eos_token }} since it may be replaced
if
"bos_token"
in
slot
and
tokenizer
.
bos_token_id
is
not
None
:
slot_items
.
append
(
tokenizer
.
bos_token
)
elif
"eos_token"
in
slot
and
tokenizer
.
eos_token_id
is
not
None
:
slot_items
.
append
(
tokenizer
.
eos_token
)
elif
isinstance
(
slot
,
dict
):
raise
ValueError
(
"Dict is not supported."
)
return
""
.
join
(
slot_items
)
def
_get_ollama_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
r
"""Return the ollama template."""
prefix
=
self
.
_convert_slots_to_ollama
(
self
.
format_prefix
.
apply
(),
tokenizer
)
system
=
self
.
_convert_slots_to_ollama
(
self
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
".System"
)
user
=
self
.
_convert_slots_to_ollama
(
self
.
format_user
.
apply
(),
tokenizer
,
placeholder
=
".Content"
)
assistant
=
self
.
_convert_slots_to_ollama
(
self
.
format_assistant
.
apply
(),
tokenizer
,
placeholder
=
".Content"
)
return
(
f
"
{
prefix
}
{{{{ if .System }}}}
{
system
}
{{{{ end }}}}"
f
"""{{{{ range .Messages }}}}{{{{ if eq .Role "user" }}}}
{
user
}
"""
f
"""{{{{ else if eq .Role "assistant" }}}}
{
assistant
}
{{{{ end }}}}{{{{ end }}}}"""
)
def
get_ollama_modelfile
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
r
"""Return the ollama modelfile.
TODO: support function calling.
"""
modelfile
=
"# ollama modelfile auto-generated by llamafactory
\n\n
"
modelfile
+=
f
'FROM .
\n\n
TEMPLATE """
{
self
.
_get_ollama_template
(
tokenizer
)
}
"""
\n\n
'
if
self
.
default_system
:
modelfile
+=
f
'SYSTEM """
{
self
.
default_system
}
"""
\n\n
'
for
stop_token_id
in
self
.
get_stop_token_ids
(
tokenizer
):
modelfile
+=
f
'PARAMETER stop "
{
tokenizer
.
convert_ids_to_tokens
(
stop_token_id
)
}
"
\n
'
modelfile
+=
"PARAMETER num_ctx 4096
\n
"
return
modelfile
@
dataclass
class
Llama2Template
(
Template
):
r
"""A template that fuse the system message to first user message."""
@
override
def
_encode
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
str
,
tools
:
str
,
)
->
list
[
list
[
int
]]:
system
=
system
or
self
.
default_system
encoded_messages
=
[]
for
i
,
message
in
enumerate
(
messages
):
elements
=
[]
system_text
=
""
if
i
==
0
:
elements
+=
self
.
format_prefix
.
apply
()
if
system
or
tools
:
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
system_text
=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))[
0
]
if
message
[
"role"
]
==
Role
.
USER
:
elements
+=
self
.
format_user
.
apply
(
content
=
system_text
+
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
ASSISTANT
:
elements
+=
self
.
format_assistant
.
apply
(
content
=
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
OBSERVATION
:
elements
+=
self
.
format_observation
.
apply
(
content
=
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
FUNCTION
:
elements
+=
self
.
format_function
.
apply
(
content
=
message
[
"content"
])
else
:
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
encoded_messages
.
append
(
self
.
_convert_elements_to_ids
(
tokenizer
,
elements
))
return
encoded_messages
def
_get_jinja_template
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
prefix
=
self
.
_convert_slots_to_jinja
(
self
.
format_prefix
.
apply
(),
tokenizer
)
system_message
=
self
.
_convert_slots_to_jinja
(
self
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
"system_message"
)
user_message
=
self
.
_convert_slots_to_jinja
(
self
.
format_user
.
apply
(),
tokenizer
)
assistant_message
=
self
.
_convert_slots_to_jinja
(
self
.
format_assistant
.
apply
(),
tokenizer
)
jinja_template
=
""
if
prefix
:
jinja_template
+=
"{{ "
+
prefix
+
" }}"
if
self
.
default_system
:
jinja_template
+=
"{% set system_message = '"
+
self
.
_jinja_escape
(
self
.
default_system
)
+
"' %}"
jinja_template
+=
(
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
"{% for message in loop_messages %}"
"{% if loop.index0 == 0 and system_message is defined %}"
"{% set content = "
+
system_message
+
" + message['content'] %}"
"{% else %}{% set content = message['content'] %}{% endif %}"
"{% if message['role'] == 'user' %}"
"{{ "
+
user_message
+
" }}"
"{% elif message['role'] == 'assistant' %}"
"{{ "
+
assistant_message
+
" }}"
"{% endif %}"
"{% endfor %}"
)
return
jinja_template
@
dataclass
class
ReasoningTemplate
(
Template
):
r
"""A template that add thought to assistant message."""
@
override
def
encode_oneturn
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
tuple
[
list
[
int
],
list
[
int
]]:
messages
=
deepcopy
(
messages
)
for
i
in
range
(
1
,
len
(
messages
)
-
2
,
2
):
messages
[
i
][
"content"
]
=
self
.
remove_thought
(
messages
[
i
][
"content"
])
if
self
.
enable_thinking
is
False
:
# remove all cot
messages
[
-
1
][
"content"
]
=
self
.
remove_thought
(
messages
[
-
1
][
"content"
])
prompt_ids
,
response_ids
=
super
().
encode_oneturn
(
tokenizer
,
messages
,
system
,
tools
)
if
(
self
.
thought_words
[
0
]
not
in
messages
[
-
1
][
"content"
]
and
self
.
thought_words
[
1
]
not
in
messages
[
-
1
][
"content"
]
):
# add empty cot
if
not
self
.
enable_thinking
:
# do not compute loss
prompt_ids
+=
self
.
get_thought_word_ids
(
tokenizer
)
else
:
# do compute loss
response_ids
=
self
.
get_thought_word_ids
(
tokenizer
)
+
response_ids
return
prompt_ids
,
response_ids
@
override
def
encode_multiturn
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
list
[
tuple
[
list
[
int
],
list
[
int
]]]:
messages
=
deepcopy
(
messages
)
if
self
.
enable_thinking
is
False
:
# remove all cot
for
i
in
range
(
1
,
len
(
messages
),
2
):
messages
[
i
][
"content"
]
=
self
.
remove_thought
(
messages
[
i
][
"content"
])
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
for
i
in
range
(
0
,
len
(
messages
),
2
):
if
(
self
.
thought_words
[
0
]
not
in
messages
[
i
+
1
][
"content"
]
and
self
.
thought_words
[
1
]
not
in
messages
[
i
+
1
][
"content"
]
):
# add empty cot
if
not
self
.
enable_thinking
:
# do not compute loss
encoded_messages
[
i
]
+=
self
.
get_thought_word_ids
(
tokenizer
)
else
:
# do compute loss
encoded_messages
[
i
+
1
]
=
self
.
get_thought_word_ids
(
tokenizer
)
+
encoded_messages
[
i
+
1
]
return
[(
encoded_messages
[
i
],
encoded_messages
[
i
+
1
])
for
i
in
range
(
0
,
len
(
encoded_messages
),
2
)]
TEMPLATES
:
dict
[
str
,
"Template"
]
=
{}
def
register_template
(
name
:
str
,
format_user
:
Optional
[
"Formatter"
]
=
None
,
format_assistant
:
Optional
[
"Formatter"
]
=
None
,
format_system
:
Optional
[
"Formatter"
]
=
None
,
format_function
:
Optional
[
"Formatter"
]
=
None
,
format_observation
:
Optional
[
"Formatter"
]
=
None
,
format_tools
:
Optional
[
"Formatter"
]
=
None
,
format_prefix
:
Optional
[
"Formatter"
]
=
None
,
default_system
:
str
=
""
,
stop_words
:
Optional
[
list
[
str
]]
=
None
,
thought_words
:
Optional
[
tuple
[
str
,
str
]]
=
None
,
efficient_eos
:
bool
=
False
,
replace_eos
:
bool
=
False
,
replace_jinja_template
:
bool
=
False
,
enable_thinking
:
Optional
[
bool
]
=
True
,
mm_plugin
:
"BasePlugin"
=
get_mm_plugin
(
name
=
"base"
),
template_class
:
type
[
"Template"
]
=
Template
,
)
->
None
:
r
"""Register a chat template.
To add the following chat template:
```
<s><user>user prompt here
<model>model response here</s>
<user>user prompt here
<model>model response here</s>
```
The corresponding code should be:
```
register_template(
name="custom",
format_user=StringFormatter(slots=["<user>{{content}}\n<model>"]),
format_assistant=StringFormatter(slots=["{{content}}</s>\n"]),
format_prefix=EmptyFormatter("<s>"),
)
```
"""
if
name
in
TEMPLATES
:
raise
ValueError
(
f
"Template
{
name
}
already exists."
)
default_slots
=
[
"{{content}}"
]
if
efficient_eos
else
[
"{{content}}"
,
{
"eos_token"
}]
default_user_formatter
=
StringFormatter
(
slots
=
[
"{{content}}"
])
default_assistant_formatter
=
StringFormatter
(
slots
=
default_slots
)
if
format_assistant
is
not
None
:
default_function_formatter
=
FunctionFormatter
(
slots
=
format_assistant
.
slots
,
tool_format
=
"default"
)
else
:
default_function_formatter
=
FunctionFormatter
(
slots
=
default_slots
,
tool_format
=
"default"
)
default_tool_formatter
=
ToolFormatter
(
tool_format
=
"default"
)
default_prefix_formatter
=
EmptyFormatter
()
TEMPLATES
[
name
]
=
template_class
(
format_user
=
format_user
or
default_user_formatter
,
format_assistant
=
format_assistant
or
default_assistant_formatter
,
format_system
=
format_system
or
default_user_formatter
,
format_function
=
format_function
or
default_function_formatter
,
format_observation
=
format_observation
or
format_user
or
default_user_formatter
,
format_tools
=
format_tools
or
default_tool_formatter
,
format_prefix
=
format_prefix
or
default_prefix_formatter
,
default_system
=
default_system
,
stop_words
=
stop_words
or
[],
thought_words
=
thought_words
or
(
"<think>"
,
"</think>"
),
efficient_eos
=
efficient_eos
,
replace_eos
=
replace_eos
,
replace_jinja_template
=
replace_jinja_template
,
enable_thinking
=
enable_thinking
,
mm_plugin
=
mm_plugin
,
)
def
parse_template
(
tokenizer
:
"PreTrainedTokenizer"
)
->
"Template"
:
r
"""Extract a chat template from the tokenizer."""
def
find_diff
(
short_str
:
str
,
long_str
:
str
)
->
str
:
i
,
j
=
0
,
0
diff
=
""
while
i
<
len
(
short_str
)
and
j
<
len
(
long_str
):
if
short_str
[
i
]
==
long_str
[
j
]:
i
+=
1
j
+=
1
else
:
diff
+=
long_str
[
j
]
j
+=
1
return
diff
prefix
=
tokenizer
.
decode
(
tokenizer
.
encode
(
""
))
messages
=
[{
"role"
:
"system"
,
"content"
:
"{{content}}"
}]
system_slot
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
False
,
tokenize
=
False
)[
len
(
prefix
)
:]
messages
=
[{
"role"
:
"system"
,
"content"
:
""
},
{
"role"
:
"user"
,
"content"
:
"{{content}}"
}]
user_slot_empty_system
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
)
user_slot_empty_system
=
user_slot_empty_system
[
len
(
prefix
)
:]
messages
=
[{
"role"
:
"user"
,
"content"
:
"{{content}}"
}]
user_slot
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
)
user_slot
=
user_slot
[
len
(
prefix
)
:]
messages
=
[{
"role"
:
"user"
,
"content"
:
"{{content}}"
},
{
"role"
:
"assistant"
,
"content"
:
"{{content}}"
}]
assistant_slot
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
False
,
tokenize
=
False
)
assistant_slot
=
assistant_slot
[
len
(
prefix
)
+
len
(
user_slot
)
:]
template_class
=
ReasoningTemplate
if
"<think>"
in
assistant_slot
else
Template
assistant_slot
=
assistant_slot
.
replace
(
"<think>"
,
""
).
replace
(
"</think>"
,
""
).
lstrip
(
"
\n
"
)
# remove thought tags
if
len
(
user_slot
)
>
len
(
user_slot_empty_system
):
default_system
=
find_diff
(
user_slot_empty_system
,
user_slot
)
sole_system
=
system_slot
.
replace
(
"{{content}}"
,
default_system
,
1
)
user_slot
=
user_slot
[
len
(
sole_system
)
:]
else
:
# if defaut_system is empty, user_slot_empty_system will be longer than user_slot
default_system
=
""
return
template_class
(
format_user
=
StringFormatter
(
slots
=
[
user_slot
]),
format_assistant
=
StringFormatter
(
slots
=
[
assistant_slot
]),
format_system
=
StringFormatter
(
slots
=
[
system_slot
]),
format_function
=
FunctionFormatter
(
slots
=
[
assistant_slot
],
tool_format
=
"default"
),
format_observation
=
StringFormatter
(
slots
=
[
user_slot
]),
format_tools
=
ToolFormatter
(
tool_format
=
"default"
),
format_prefix
=
EmptyFormatter
(
slots
=
[
prefix
])
if
prefix
else
EmptyFormatter
(),
default_system
=
default_system
,
stop_words
=
[],
thought_words
=
(
"<think>"
,
"</think>"
),
efficient_eos
=
False
,
replace_eos
=
False
,
replace_jinja_template
=
False
,
enable_thinking
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"base"
),
)
def
get_template_and_fix_tokenizer
(
tokenizer
:
"PreTrainedTokenizer"
,
data_args
:
"DataArguments"
)
->
"Template"
:
r
"""Get chat template and fixes the tokenizer."""
if
data_args
.
template
is
None
:
if
isinstance
(
tokenizer
.
chat_template
,
str
):
logger
.
warning_rank0
(
"`template` was not specified, try parsing the chat template from the tokenizer."
)
template
=
parse_template
(
tokenizer
)
else
:
logger
.
warning_rank0
(
"`template` was not specified, use `empty` template."
)
template
=
TEMPLATES
[
"empty"
]
# placeholder
else
:
if
data_args
.
template
not
in
TEMPLATES
:
raise
ValueError
(
f
"Template
{
data_args
.
template
}
does not exist."
)
template
=
TEMPLATES
[
data_args
.
template
]
if
data_args
.
train_on_prompt
and
template
.
efficient_eos
:
raise
ValueError
(
"Current template does not support `train_on_prompt`."
)
if
data_args
.
tool_format
is
not
None
:
logger
.
info_rank0
(
f
"Using tool format:
{
data_args
.
tool_format
}
."
)
default_slots
=
[
"{{content}}"
]
if
template
.
efficient_eos
else
[
"{{content}}"
,
{
"eos_token"
}]
template
.
format_function
=
FunctionFormatter
(
slots
=
default_slots
,
tool_format
=
data_args
.
tool_format
)
template
.
format_tools
=
ToolFormatter
(
tool_format
=
data_args
.
tool_format
)
if
data_args
.
default_system
is
not
None
:
logger
.
info_rank0
(
f
"Using default system message:
{
data_args
.
default_system
}
."
)
template
.
default_system
=
data_args
.
default_system
template
.
enable_thinking
=
data_args
.
enable_thinking
template
.
fix_special_tokens
(
tokenizer
)
template
.
fix_jinja_template
(
tokenizer
)
return
template
register_template
(
name
=
"alpaca"
,
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n\n
### Response:
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n\n
"
]),
default_system
=
(
"Below is an instruction that describes a task. Write a response that appropriately completes the request.
\n\n
"
),
replace_jinja_template
=
True
,
)
register_template
(
name
=
"aquila"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}###Assistant:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}###"
]),
format_system
=
StringFormatter
(
slots
=
[
"System: {{content}}###"
]),
default_system
=
(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
stop_words
=
[
"</s>"
],
)
register_template
(
name
=
"atom"
,
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"Human: {{content}}
\n
"
,
{
"eos_token"
},
{
"bos_token"
},
"Assistant:"
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
,
{
"eos_token"
}]),
)
register_template
(
name
=
"baichuan"
,
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"<reserved_102>"
},
"{{content}}"
,
{
"token"
:
"<reserved_103>"
}]),
efficient_eos
=
True
,
)
register_template
(
name
=
"baichuan2"
,
format_user
=
StringFormatter
(
slots
=
[
"<reserved_106>{{content}}<reserved_107>"
]),
efficient_eos
=
True
,
)
register_template
(
name
=
"bailing"
,
format_user
=
StringFormatter
(
slots
=
[
"<role>HUMAN</role>{{content}}<role>ASSISTANT</role>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<role>SYSTEM</role>{{content}}"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<role>OBSERVATION</role>{{content}}<role>ASSISTANT</role>"
]),
stop_words
=
[
"<|endoftext|>"
],
efficient_eos
=
True
,
)
register_template
(
name
=
"belle"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n\n
Belle: "
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
register_template
(
name
=
"bluelm"
,
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"[|Human|]:"
},
"{{content}}"
,
{
"token"
:
"[|AI|]:"
}]),
)
register_template
(
name
=
"breeze"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}} [/INST] "
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
efficient_eos
=
True
,
)
register_template
(
name
=
"chatglm2"
,
format_user
=
StringFormatter
(
slots
=
[
"[Round {{idx}}]
\n\n
问:{{content}}
\n\n
答:"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"token"
:
"[gMASK]"
},
{
"token"
:
"sop"
}]),
efficient_eos
=
True
,
)
register_template
(
name
=
"chatglm3"
,
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"<|user|>"
},
"
\n
"
,
"{{content}}"
,
{
"token"
:
"<|assistant|>"
}]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
"
,
"{{content}}"
]),
format_system
=
StringFormatter
(
slots
=
[{
"token"
:
"<|system|>"
},
"
\n
"
,
"{{content}}"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}"
],
tool_format
=
"glm4"
),
format_observation
=
StringFormatter
(
slots
=
[{
"token"
:
"<|observation|>"
},
"
\n
"
,
"{{content}}"
,
{
"token"
:
"<|assistant|>"
}]
),
format_tools
=
ToolFormatter
(
tool_format
=
"glm4"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"token"
:
"[gMASK]"
},
{
"token"
:
"sop"
}]),
stop_words
=
[
"<|user|>"
,
"<|observation|>"
],
efficient_eos
=
True
,
)
register_template
(
name
=
"chatml"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
stop_words
=
[
"<|im_end|>"
,
"<|im_start|>"
],
replace_eos
=
True
,
replace_jinja_template
=
True
,
)
# copied from chatml template
register_template
(
name
=
"chatml_de"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
default_system
=
"Du bist ein freundlicher und hilfsbereiter KI-Assistent."
,
stop_words
=
[
"<|im_end|>"
,
"<|im_start|>"
],
replace_eos
=
True
,
replace_jinja_template
=
True
,
)
register_template
(
name
=
"codegeex2"
,
format_prefix
=
EmptyFormatter
(
slots
=
[{
"token"
:
"[gMASK]"
},
{
"token"
:
"sop"
}]),
)
register_template
(
name
=
"codegeex4"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}"
],
tool_format
=
"glm4"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|observation|>
\n
{{content}}<|assistant|>
\n
"
]),
format_tools
=
ToolFormatter
(
tool_format
=
"glm4"
),
format_prefix
=
EmptyFormatter
(
slots
=
[
"[gMASK]<sop>"
]),
default_system
=
(
"你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,"
"并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"
),
stop_words
=
[
"<|user|>"
,
"<|observation|>"
],
efficient_eos
=
True
,
)
register_template
(
name
=
"cohere"
,
format_user
=
StringFormatter
(
slots
=
[
(
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
)
]
),
format_system
=
StringFormatter
(
slots
=
[
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
register_template
(
name
=
"cpm"
,
format_user
=
StringFormatter
(
slots
=
[
"<用户>{{content}}<AI>"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
# copied from chatml template
register_template
(
name
=
"cpm3"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|im_end|>"
],
)
# copied from chatml template
register_template
(
name
=
"cpm4"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|im_end|>"
],
)
# copied from chatml template
register_template
(
name
=
"dbrx"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
default_system
=
(
"You are DBRX, created by Databricks. You were last updated in December 2023. "
"You answer questions based on information available up to that point.
\n
"
"YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough "
"responses to more complex and open-ended questions.
\n
You assist with various tasks, "
"from writing to coding (using markdown for code blocks — remember to use ``` with "
"code, JSON, and tables).
\n
(You do not have real-time data access or code execution "
"capabilities. You avoid stereotyping and provide balanced perspectives on "
"controversial topics. You do not provide song lyrics, poems, or news articles and "
"do not divulge details of your training data.)
\n
This is your system prompt, "
"guiding your responses. Do not reference it, just respond to the user. If you find "
"yourself talking about this message, stop. You should be responding appropriately "
"and usually that means not mentioning this.
\n
YOU DO NOT MENTION ANY OF THIS INFORMATION "
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
register_template
(
name
=
"deepseek"
,
format_user
=
StringFormatter
(
slots
=
[
"User: {{content}}
\n\n
Assistant:"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
register_template
(
name
=
"deepseek3"
,
format_user
=
StringFormatter
(
slots
=
[
"<|User|>{{content}}<|Assistant|>"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
# copied from deepseek3 template
register_template
(
name
=
"deepseekr1"
,
format_user
=
StringFormatter
(
slots
=
[
"<|User|>{{content}}<|Assistant|>"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
template_class
=
ReasoningTemplate
,
)
register_template
(
name
=
"deepseekcoder"
,
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n
### Response:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}
\n
<|EOT|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
default_system
=
(
"You are an AI programming assistant, utilizing the DeepSeek Coder model, "
"developed by DeepSeek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer.
\n
"
),
)
register_template
(
name
=
"default"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}"
,
{
"eos_token"
},
"
\n
Assistant:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"System: {{content}}"
,
{
"eos_token"
},
"
\n
"
]),
replace_jinja_template
=
True
,
)
register_template
(
name
=
"empty"
,
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
]),
)
register_template
(
name
=
"exaone"
,
format_user
=
StringFormatter
(
slots
=
[
"[|user|]{{content}}
\n
[|assistant|]"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"[|system|]{{content}}[|endofturn|]
\n
"
]),
)
register_template
(
name
=
"falcon"
,
format_user
=
StringFormatter
(
slots
=
[
"User: {{content}}
\n
Falcon:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
efficient_eos
=
True
,
)
# copied from chatml template
register_template
(
name
=
"falcon_h1"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|im_end|>"
,
"<|end_of_text|>"
],
)
register_template
(
name
=
"fewshot"
,
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
efficient_eos
=
True
,
replace_jinja_template
=
True
,
)
register_template
(
name
=
"gemma"
,
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_turn>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<start_of_turn>tool
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
replace_eos
=
True
,
template_class
=
Llama2Template
,
)
# copied from gemma template
register_template
(
name
=
"gemma2"
,
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_turn>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<start_of_turn>tool
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<eos>"
,
"<end_of_turn>"
],
efficient_eos
=
True
,
template_class
=
Llama2Template
,
)
# copied from gemma template
register_template
(
name
=
"gemma3"
,
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_turn>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<start_of_turn>tool
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
"gemma3"
,
image_token
=
"<image_soft_token>"
),
template_class
=
Llama2Template
,
)
register_template
(
name
=
"gemma3n"
,
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_turn>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<start_of_turn>tool
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
"gemma3n"
,
image_token
=
"<image_soft_token>"
,
audio_token
=
"<audio_soft_token>"
),
template_class
=
Llama2Template
,
)
register_template
(
name
=
"glm4"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}"
],
tool_format
=
"glm4"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|observation|>
\n
{{content}}<|assistant|>"
]),
format_tools
=
ToolFormatter
(
tool_format
=
"glm4"
),
format_prefix
=
EmptyFormatter
(
slots
=
[
"[gMASK]<sop>"
]),
stop_words
=
[
"<|user|>"
,
"<|observation|>"
],
efficient_eos
=
True
,
)
# copied from glm4 template
register_template
(
name
=
"glm4v"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}"
],
tool_format
=
"glm4"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|observation|>
\n
{{content}}<|assistant|>"
]),
format_tools
=
ToolFormatter
(
tool_format
=
"glm4"
),
format_prefix
=
EmptyFormatter
(
slots
=
[
"[gMASK]<sop>"
]),
stop_words
=
[
"<|user|>"
,
"<|observation|>"
,
"</answer>"
],
efficient_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"glm4v"
,
image_token
=
"<|image|>"
,
video_token
=
"<|video|>"
),
template_class
=
ReasoningTemplate
,
)
# copied from glm4 template
register_template
(
name
=
"glmz1"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}"
],
tool_format
=
"glm4"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|observation|>
\n
{{content}}<|assistant|>"
]),
format_tools
=
ToolFormatter
(
tool_format
=
"glm4"
),
format_prefix
=
EmptyFormatter
(
slots
=
[
"[gMASK]<sop>"
]),
stop_words
=
[
"<|user|>"
,
"<|observation|>"
],
efficient_eos
=
True
,
template_class
=
ReasoningTemplate
,
)
register_template
(
name
=
"granite3"
,
format_user
=
StringFormatter
(
slots
=
[
"<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>
\n
<|start_of_role|>assistant<|end_of_role|>"
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end_of_text|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>
\n
"
]),
)
register_template
(
name
=
"granite3_vision"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}
\n
<|assistant|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}
\n
"
]),
default_system
=
(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
)
register_template
(
name
=
"index"
,
format_user
=
StringFormatter
(
slots
=
[
"reserved_0{{content}}reserved_1"
]),
format_system
=
StringFormatter
(
slots
=
[
"<unk>{{content}}"
]),
efficient_eos
=
True
,
)
register_template
(
name
=
"hunyuan"
,
format_user
=
StringFormatter
(
slots
=
[
"<|bos|>user
\n
{{content}}<|eos|>
\n
<|bos|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|eos|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|bos|>system
\n
{{content}}<|eos|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[
"<|bos|>"
]),
stop_words
=
[
"<|eos|>"
],
)
register_template
(
name
=
"intern"
,
format_user
=
StringFormatter
(
slots
=
[
"<|User|>:{{content}}
\n
<|Bot|>:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<eoa>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|System|>:{{content}}
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
default_system
=
(
"You are an AI assistant whose name is InternLM (书生·浦语).
\n
"
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
"(上海人工智能实验室). It is designed to be helpful, honest, and harmless.
\n
"
"- InternLM (书生·浦语) can understand and communicate fluently in the language "
"chosen by the user such as English and 中文."
),
stop_words
=
[
"<eoa>"
],
)
register_template
(
name
=
"intern2"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
default_system
=
(
"You are an AI assistant whose name is InternLM (书生·浦语).
\n
"
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
"(上海人工智能实验室). It is designed to be helpful, honest, and harmless.
\n
"
"- InternLM (书生·浦语) can understand and communicate fluently in the language "
"chosen by the user such as English and 中文."
),
stop_words
=
[
"<|im_end|>"
],
)
register_template
(
name
=
"intern_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
default_system
=
(
"你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。"
),
stop_words
=
[
"<|im_end|>"
],
mm_plugin
=
get_mm_plugin
(
name
=
"intern_vl"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
),
)
register_template
(
name
=
"kimi_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_user|>user<|im_middle|>{{content}}<|im_end|><|im_assistant|>assistant<|im_middle|>"
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_system|>system<|im_middle|>{{content}}<|im_end|>"
]),
default_system
=
"You are a helpful assistant"
,
stop_words
=
[
"<|im_end|>"
],
thought_words
=
(
"◁think▷"
,
"◁/think▷"
),
mm_plugin
=
get_mm_plugin
(
"kimi_vl"
,
image_token
=
"<|media_pad|>"
),
template_class
=
ReasoningTemplate
,
)
register_template
(
name
=
"llama2"
,
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"[INST] {{content}} [/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"<<SYS>>
\n
{{content}}
\n
<</SYS>>
\n\n
"
]),
template_class
=
Llama2Template
,
)
# copied from llama2 template
register_template
(
name
=
"llama2_zh"
,
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"[INST] {{content}} [/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"<<SYS>>
\n
{{content}}
\n
<</SYS>>
\n\n
"
]),
default_system
=
"You are a helpful assistant. 你是一个乐于助人的助手。"
,
template_class
=
Llama2Template
,
)
register_template
(
name
=
"llama3"
,
format_user
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>user<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|start_header_id|>system<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
],
tool_format
=
"llama3"
),
format_observation
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>ipython<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
replace_eos
=
True
,
)
register_template
(
name
=
"llama4"
,
format_user
=
StringFormatter
(
slots
=
[
"<|header_start|>user<|header_end|>
\n\n
{{content}}<|eot|><|header_start|>assistant<|header_end|>
\n\n
"
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|eot|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|header_start|>system<|header_end|>
\n\n
{{content}}<|eot|>"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot|>"
],
tool_format
=
"llama3"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|header_start|>ipython<|header_end|>
\n\n
{{content}}<|eot|><|header_start|>assistant<|header_end|>
\n\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot|>"
,
"<|eom|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"llama4"
,
image_token
=
"<|image|>"
),
)
# copied from llama3 template
register_template
(
name
=
"mllama"
,
format_user
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>user<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|start_header_id|>system<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
],
tool_format
=
"llama3"
),
format_observation
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>ipython<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"mllama"
,
image_token
=
"<|image|>"
),
)
register_template
(
name
=
"moonlight"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_user|>user<|im_middle|>{{content}}<|im_end|><|im_assistant|>assistant<|im_middle|>"
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_system|>system<|im_middle|>{{content}}<|im_end|>"
]),
default_system
=
"You are a helpful assistant provided by Moonshot-AI."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
# copied from vicuna template
register_template
(
name
=
"llava"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
default_system
=
(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin
=
get_mm_plugin
(
name
=
"llava"
,
image_token
=
"<image>"
),
)
# copied from vicuna template
register_template
(
name
=
"llava_next"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
default_system
=
(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
)
# copied from llama3 template
register_template
(
name
=
"llava_next_llama3"
,
format_user
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>user<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|start_header_id|>system<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
],
tool_format
=
"llama3"
),
format_observation
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>ipython<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
)
# copied from mistral template
register_template
(
name
=
"llava_next_mistral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}}[/INST]"
]),
format_assistant
=
StringFormatter
(
slots
=
[
" {{content}}"
,
{
"eos_token"
}]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] {{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
template_class
=
Llama2Template
,
)
# copied from qwen template
register_template
(
name
=
"llava_next_qwen"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
)
# copied from chatml template
register_template
(
name
=
"llava_next_yi"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
)
# copied from vicuna template
register_template
(
name
=
"llava_next_video"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
default_system
=
(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next_video"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
),
)
# copied from mistral template
register_template
(
name
=
"llava_next_video_mistral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}}[/INST]"
]),
format_assistant
=
StringFormatter
(
slots
=
[
" {{content}}"
,
{
"eos_token"
}]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] {{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next_video"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
),
template_class
=
Llama2Template
,
)
# copied from chatml template
register_template
(
name
=
"llava_next_video_yi"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next_video"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
),
)
# copied from chatml template
register_template
(
name
=
"marco"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
default_system
=
(
"你是一个经过良好训练的AI助手,你的名字是Marco-o1."
"由阿里国际数字商业集团的AI Business创造.
\n
## 重要!!!!!
\n
"
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。
\n
"
"<Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。
\n
"
),
stop_words
=
[
"<|im_end|>"
],
)
# copied from qwen template
register_template
(
name
=
"mimo"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
template_class
=
ReasoningTemplate
,
)
# copied from qwen2vl
register_template
(
name
=
"mimo_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are MiMo, an AI assistant developed by Xiaomi."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_vl"
,
image_token
=
"<|image_pad|>"
,
video_token
=
"<|video_pad|>"
),
template_class
=
ReasoningTemplate
,
)
# copied from chatml template
register_template
(
name
=
"minicpm_v"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
default_system
=
"You are a helpful assistant."
,
mm_plugin
=
get_mm_plugin
(
name
=
"minicpm_v"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
),
)
# copied from minicpm_v template
register_template
(
name
=
"minicpm_o"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
default_system
=
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
,
mm_plugin
=
get_mm_plugin
(
name
=
"minicpm_v"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
,
audio_token
=
"<audio>"
),
)
# mistral tokenizer v3 tekken
register_template
(
name
=
"ministral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST]{{content}}[/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS]{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
template_class
=
Llama2Template
,
)
# mistral tokenizer v3
register_template
(
name
=
"mistral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}}[/INST]"
]),
format_assistant
=
StringFormatter
(
slots
=
[
" {{content}}"
,
{
"eos_token"
}]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] {{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
template_class
=
Llama2Template
,
)
# mistral tokenizer v7 tekken (copied from ministral)
register_template
(
name
=
"mistral_small"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST]{{content}}[/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"[SYSTEM_PROMPT]{{content}}[/SYSTEM_PROMPT]"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS]{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"pixtral"
,
image_token
=
"[IMG]"
),
)
register_template
(
name
=
"olmo"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"eos_token"
}]),
)
register_template
(
name
=
"openchat"
,
format_user
=
StringFormatter
(
slots
=
[
"GPT4 Correct User: {{content}}"
,
{
"eos_token"
},
"GPT4 Correct Assistant:"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
register_template
(
name
=
"openchat-3.6"
,
format_user
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>GPT4 Correct User<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>
\n\n
"
)
]
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
],
)
# copied from chatml template
register_template
(
name
=
"opencoder"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
default_system
=
"You are OpenCoder, created by OpenCoder Team."
,
stop_words
=
[
"<|im_end|>"
],
)
register_template
(
name
=
"orion"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n\n
Assistant: "
,
{
"eos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
register_template
(
name
=
"paligemma"
,
format_user
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
),
template_class
=
Llama2Template
,
)
# copied from gemma template
register_template
(
name
=
"paligemma_chat"
,
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_turn>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<start_of_turn>tool
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<end_of_turn>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
),
template_class
=
Llama2Template
,
)
register_template
(
name
=
"phi"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|end|>
\n
<|assistant|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
stop_words
=
[
"<|end|>"
],
replace_eos
=
True
,
)
register_template
(
name
=
"phi_small"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|end|>
\n
<|assistant|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"<|endoftext|>"
}]),
stop_words
=
[
"<|end|>"
],
replace_eos
=
True
,
)
register_template
(
name
=
"phi4"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user<|im_sep|>{{content}}<|im_end|><|im_start|>assistant<|im_sep|>"
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system<|im_sep|>{{content}}<|im_end|>"
]),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
# copied from ministral template
register_template
(
name
=
"pixtral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST]{{content}}[/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS]{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"pixtral"
,
image_token
=
"[IMG]"
),
template_class
=
Llama2Template
,
)
# copied from chatml template
register_template
(
name
=
"qwen"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
# copied from qwen template
register_template
(
name
=
"qwen3"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
template_class
=
ReasoningTemplate
,
)
# copied from chatml template
register_template
(
name
=
"qwen2_audio"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_audio"
,
audio_token
=
"<|AUDIO|>"
),
)
# copied from qwen template
register_template
(
name
=
"qwen2_omni"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_omni"
,
audio_token
=
"<|AUDIO|>"
,
image_token
=
"<|IMAGE|>"
,
video_token
=
"<|VIDEO|>"
),
)
# copied from qwen template
register_template
(
name
=
"qwen2_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_vl"
,
image_token
=
"<|image_pad|>"
,
video_token
=
"<|video_pad|>"
),
)
register_template
(
name
=
"sailor"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>question
\n
{{content}}<|im_end|>
\n
<|im_start|>answer
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
default_system
=
(
"You are an AI assistant named Sailor created by Sea AI Lab. "
"Your answer should be friendly, unbiased, faithful, informative and detailed."
),
stop_words
=
[
"<|im_end|>"
],
)
register_template
(
name
=
"seed_coder"
,
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"user
\n
{{content}}"
,
{
"eos_token"
},
{
"bos_token"
},
"assistant
\n
"
]
),
format_system
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"system
\n
{{content}}"
,
{
"eos_token"
}]),
default_system
=
(
"You are an AI programming assistant, utilizing the Seed-Coder model, developed by ByteDance Seed, "
"and you only answer questions related to computer science. For politically sensitive questions, "
"security and privacy issues, and other non-computer science questions, you will refuse to answer.
\n\n
"
),
)
# copied from llama3 template
register_template
(
name
=
"skywork_o1"
,
format_user
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>user<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|start_header_id|>system<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
],
tool_format
=
"llama3"
),
format_observation
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>ipython<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
default_system
=
(
"You are Skywork-o1, a thinking model developed by Skywork AI, specializing in solving complex problems "
"involving mathematics, coding, and logical reasoning through deep thought. When faced with a user's request, "
"you first engage in a lengthy and in-depth thinking process to explore possible solutions to the problem. "
"After completing your thoughts, you then provide a detailed explanation of the solution process "
"in your response."
),
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
)
register_template
(
name
=
"smollm"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
)
register_template
(
name
=
"smollm2"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
default_system
=
"You are a helpful AI assistant named SmolLM, trained by Hugging Face."
,
)
register_template
(
name
=
"solar"
,
format_user
=
StringFormatter
(
slots
=
[
"### User:
\n
{{content}}
\n\n
### Assistant:
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"### System:
\n
{{content}}
\n\n
"
]),
efficient_eos
=
True
,
)
register_template
(
name
=
"starchat"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|end|>
\n
<|assistant|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
stop_words
=
[
"<|end|>"
],
)
register_template
(
name
=
"telechat"
,
format_user
=
StringFormatter
(
slots
=
[
"<_user>{{content}}<_bot>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<_system>{{content}}<_end>"
]),
)
register_template
(
name
=
"telechat2"
,
format_user
=
StringFormatter
(
slots
=
[
"<_user>{{content}}<_bot>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<_system>{{content}}"
]),
default_system
=
(
"你是中国电信星辰语义大模型,英文名是TeleChat,你是由中电信人工智能科技有限公司和中国电信人工智能研究院(TeleAI)研发的人工智能助手。"
),
)
register_template
(
name
=
"vicuna"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
default_system
=
(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
replace_jinja_template
=
True
,
)
register_template
(
name
=
"video_llava"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
default_system
=
(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin
=
get_mm_plugin
(
name
=
"video_llava"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
),
)
register_template
(
name
=
"xuanyuan"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}} Assistant:"
]),
default_system
=
(
"以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
"不安全、有争议、政治敏感等相关的话题、问题和指示。
\n
"
),
)
register_template
(
name
=
"xverse"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n\n
Assistant: "
]),
)
register_template
(
name
=
"yayi"
,
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"<|Human|>"
},
":
\n
{{content}}
\n\n
"
,
{
"token"
:
"<|YaYi|>"
},
":"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_system
=
StringFormatter
(
slots
=
[{
"token"
:
"<|System|>"
},
":
\n
{{content}}
\n\n
"
]),
default_system
=
(
"You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.
\n\n
"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
stop_words
=
[
"<|End|>"
],
)
# copied from chatml template
register_template
(
name
=
"yi"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
)
register_template
(
name
=
"yi_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"### Human: {{content}}
\n
### Assistant:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
default_system
=
(
"This is a chat between an inquisitive human and an AI assistant. "
"Assume the role of the AI assistant. Read all the images carefully, "
"and respond to the human's questions with informative, helpful, detailed and polite answers. "
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。"
"仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。
\n\n
"
),
stop_words
=
[
"###"
],
efficient_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"llava"
,
image_token
=
"<image>"
),
)
register_template
(
name
=
"yuan"
,
format_user
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"token"
:
"<sep>"
}]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<eod>
\n
"
]),
stop_words
=
[
"<eod>"
],
)
register_template
(
name
=
"zephyr"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}"
,
{
"eos_token"
},
"<|assistant|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
,
{
"eos_token"
}]),
default_system
=
"You are Zephyr, a helpful assistant."
,
)
register_template
(
name
=
"ziya"
,
format_user
=
StringFormatter
(
slots
=
[
"<human>:{{content}}
\n
<bot>:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
)
src/llamafactory/data/tool_utils.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
re
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
datetime
import
datetime
from
typing
import
Any
,
NamedTuple
,
Union
from
typing_extensions
import
override
class
FunctionCall
(
NamedTuple
):
name
:
str
arguments
:
str
DEFAULT_TOOL_PROMPT
=
(
"You have access to the following tools:
\n
{tool_text}"
"Use the following format if using a tool:
\n
"
"```
\n
"
"Action: tool name (one of [{tool_names}])
\n
"
"Action Input: the input to the tool, in a JSON format representing the kwargs "
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```)
\n
"""
"```
\n
"
)
GLM4_TOOL_PROMPT
=
(
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱 AI 公司训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。
\n\n
# 可用工具{tool_text}"
)
LLAMA3_TOOL_PROMPT
=
(
"Cutting Knowledge Date: December 2023
\n
Today Date: {date}
\n\n
"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
"""Respond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}. """
"Do not use variables.
\n\n
{tool_text}"
)
QWEN_TOOL_PROMPT
=
(
"
\n\n
# Tools
\n\n
You may call one or more functions to assist with the user query.
\n\n
"
"You are provided with function signatures within <tools></tools> XML tags:
\n
<tools>{tool_text}"
"
\n
</tools>
\n\n
For each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:
\n
<tool_call>
\n
{{"name": <function-name>, """
""""arguments": <args-json-object>}}
\n
</tool_call>"""
)
@
dataclass
class
ToolUtils
(
ABC
):
"""Base class for tool utilities."""
@
staticmethod
@
abstractmethod
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
r
"""Generate the system message describing all the available tools."""
...
@
staticmethod
@
abstractmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
r
"""Generate the assistant message including all the tool calls."""
...
@
staticmethod
@
abstractmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
r
"""Extract all the function calls from the assistant message.
It should be an inverse function of `function_formatter`.
"""
...
class
DefaultToolUtils
(
ToolUtils
):
r
"""Default tool using template."""
@
override
@
staticmethod
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
tool_text
=
""
tool_names
=
[]
for
tool
in
tools
:
tool
=
tool
.
get
(
"function"
,
""
)
if
tool
.
get
(
"type"
)
==
"function"
else
tool
param_text
=
""
for
name
,
param
in
tool
[
"parameters"
][
"properties"
].
items
():
required
,
enum
,
items
=
""
,
""
,
""
if
name
in
tool
[
"parameters"
].
get
(
"required"
,
[]):
required
=
", required"
if
param
.
get
(
"enum"
,
None
):
enum
=
", should be one of [{}]"
.
format
(
", "
.
join
(
param
[
"enum"
]))
if
param
.
get
(
"items"
,
None
):
items
=
", where each item should be {}"
.
format
(
param
[
"items"
].
get
(
"type"
,
""
))
param_text
+=
" - {name} ({type}{required}): {desc}{enum}{items}
\n
"
.
format
(
name
=
name
,
type
=
param
.
get
(
"type"
,
""
),
required
=
required
,
desc
=
param
.
get
(
"description"
,
""
),
enum
=
enum
,
items
=
items
,
)
tool_text
+=
"> Tool Name: {name}
\n
Tool Description: {desc}
\n
Tool Args:
\n
{args}
\n
"
.
format
(
name
=
tool
[
"name"
],
desc
=
tool
.
get
(
"description"
,
""
),
args
=
param_text
)
tool_names
.
append
(
tool
[
"name"
])
return
DEFAULT_TOOL_PROMPT
.
format
(
tool_text
=
tool_text
,
tool_names
=
", "
.
join
(
tool_names
))
@
override
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
return
"
\n
"
.
join
([
f
"Action:
{
name
}
\n
Action Input:
{
arguments
}
"
for
name
,
arguments
in
functions
])
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
regex
=
re
.
compile
(
r
"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)"
,
re
.
DOTALL
)
action_match
:
list
[
tuple
[
str
,
str
]]
=
re
.
findall
(
regex
,
content
)
if
not
action_match
:
return
content
results
=
[]
for
match
in
action_match
:
tool_name
=
match
[
0
].
strip
()
tool_input
=
match
[
1
].
strip
().
strip
(
'"'
).
strip
(
"```"
)
try
:
arguments
=
json
.
loads
(
tool_input
)
results
.
append
(
FunctionCall
(
tool_name
,
json
.
dumps
(
arguments
,
ensure_ascii
=
False
)))
except
json
.
JSONDecodeError
:
return
content
return
results
class
GLM4ToolUtils
(
ToolUtils
):
r
"""GLM-4 tool using template."""
@
override
@
staticmethod
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
tool_text
=
""
for
tool
in
tools
:
tool
=
tool
.
get
(
"function"
,
""
)
if
tool
.
get
(
"type"
)
==
"function"
else
tool
tool_text
+=
"
\n\n
## {name}
\n\n
{body}
\n
在调用上述函数时,请使用 Json 格式表示调用的参数。"
.
format
(
name
=
tool
[
"name"
],
body
=
json
.
dumps
(
tool
,
indent
=
4
,
ensure_ascii
=
False
)
)
return
GLM4_TOOL_PROMPT
.
format
(
tool_text
=
tool_text
)
@
override
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
if
len
(
functions
)
>
1
:
raise
ValueError
(
"GLM-4 does not support parallel functions."
)
return
f
"
{
functions
[
0
].
name
}
\n
{
functions
[
0
].
arguments
}
"
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
if
"
\n
"
not
in
content
:
return
content
tool_name
,
tool_input
=
content
.
split
(
"
\n
"
,
maxsplit
=
1
)
try
:
arguments
=
json
.
loads
(
tool_input
.
strip
())
except
json
.
JSONDecodeError
:
return
content
return
[
FunctionCall
(
tool_name
,
json
.
dumps
(
arguments
,
ensure_ascii
=
False
))]
class
Llama3ToolUtils
(
ToolUtils
):
r
"""Llama 3.x tool using template with `tools_in_user_message=False`.
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
"""
@
override
@
staticmethod
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
date
=
datetime
.
now
().
strftime
(
"%d %b %Y"
)
tool_text
=
""
for
tool
in
tools
:
wrapped_tool
=
tool
if
tool
.
get
(
"type"
)
==
"function"
else
{
"type"
:
"function"
,
"function"
:
tool
}
tool_text
+=
json
.
dumps
(
wrapped_tool
,
indent
=
4
,
ensure_ascii
=
False
)
+
"
\n\n
"
return
LLAMA3_TOOL_PROMPT
.
format
(
date
=
date
,
tool_text
=
tool_text
)
@
override
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
function_objects
=
[{
"name"
:
name
,
"parameters"
:
json
.
loads
(
arguments
)}
for
name
,
arguments
in
functions
]
return
json
.
dumps
(
function_objects
[
0
]
if
len
(
function_objects
)
==
1
else
function_objects
,
ensure_ascii
=
False
)
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
try
:
tools
=
json
.
loads
(
content
.
strip
())
except
json
.
JSONDecodeError
:
return
content
tools
=
[
tools
]
if
not
isinstance
(
tools
,
list
)
else
tools
try
:
return
[
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"parameters"
],
ensure_ascii
=
False
))
for
tool
in
tools
]
except
KeyError
:
return
content
class
MistralToolUtils
(
ToolUtils
):
r
"""Mistral v0.3 tool using template."""
@
override
@
staticmethod
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
wrapped_tools
=
[]
for
tool
in
tools
:
wrapped_tools
.
append
(
tool
if
tool
.
get
(
"type"
)
==
"function"
else
{
"type"
:
"function"
,
"function"
:
tool
})
return
"[AVAILABLE_TOOLS] "
+
json
.
dumps
(
wrapped_tools
,
ensure_ascii
=
False
)
+
"[/AVAILABLE_TOOLS]"
@
override
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
return
json
.
dumps
(
[{
"name"
:
name
,
"arguments"
:
json
.
loads
(
arguments
)}
for
name
,
arguments
in
functions
],
ensure_ascii
=
False
)
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
try
:
tools
=
json
.
loads
(
content
.
strip
())
except
json
.
JSONDecodeError
:
return
content
tools
=
[
tools
]
if
not
isinstance
(
tools
,
list
)
else
tools
try
:
return
[
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"arguments"
],
ensure_ascii
=
False
))
for
tool
in
tools
]
except
KeyError
:
return
content
class
QwenToolUtils
(
ToolUtils
):
r
"""Qwen 2.5 tool using template."""
@
override
@
staticmethod
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
tool_text
=
""
for
tool
in
tools
:
wrapped_tool
=
tool
if
tool
.
get
(
"type"
)
==
"function"
else
{
"type"
:
"function"
,
"function"
:
tool
}
tool_text
+=
"
\n
"
+
json
.
dumps
(
wrapped_tool
,
ensure_ascii
=
False
)
return
QWEN_TOOL_PROMPT
.
format
(
tool_text
=
tool_text
)
@
override
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
function_texts
=
[
json
.
dumps
({
"name"
:
name
,
"arguments"
:
json
.
loads
(
arguments
)},
ensure_ascii
=
False
)
for
name
,
arguments
in
functions
]
return
"
\n
"
.
join
([
f
"<tool_call>
\n
{
text
}
\n
</tool_call>"
for
text
in
function_texts
])
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
regex
=
re
.
compile
(
r
"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)"
,
re
.
DOTALL
)
tool_match
:
list
[
str
]
=
re
.
findall
(
regex
,
content
)
if
not
tool_match
:
return
content
results
=
[]
for
tool
in
tool_match
:
try
:
tool
=
json
.
loads
(
tool
.
strip
())
except
json
.
JSONDecodeError
:
return
content
if
"name"
not
in
tool
or
"arguments"
not
in
tool
:
return
content
results
.
append
(
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"arguments"
],
ensure_ascii
=
False
)))
return
results
TOOLS
=
{
"default"
:
DefaultToolUtils
(),
"glm4"
:
GLM4ToolUtils
(),
"llama3"
:
Llama3ToolUtils
(),
"mistral"
:
MistralToolUtils
(),
"qwen"
:
QwenToolUtils
(),
}
def
get_tool_utils
(
name
:
str
)
->
"ToolUtils"
:
tool_utils
=
TOOLS
.
get
(
name
,
None
)
if
tool_utils
is
None
:
raise
ValueError
(
f
"Tool utils `
{
name
}
` not found."
)
return
tool_utils
src/llamafactory/eval/__init__.py
0 → 100644
View file @
c7c477c7
Prev
1
…
4
5
6
7
8
9
10
11
12
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment