Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
24534501
Commit
24534501
authored
May 21, 2025
by
mashun1
Browse files
parallel_tool
parent
c4ba4563
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
497 additions
and
122 deletions
+497
-122
src/llamafactory/model/model_utils/rope.py
src/llamafactory/model/model_utils/rope.py
+29
-19
src/llamafactory/model/model_utils/visual.py
src/llamafactory/model/model_utils/visual.py
+3
-2
src/llamafactory/model/patcher.py
src/llamafactory/model/patcher.py
+3
-3
src/llamafactory/train/dpo/trainer.py
src/llamafactory/train/dpo/trainer.py
+2
-2
src/llamafactory/train/kto/trainer.py
src/llamafactory/train/kto/trainer.py
+2
-3
src/llamafactory/train/pt/trainer.py
src/llamafactory/train/pt/trainer.py
+2
-2
src/llamafactory/train/pt/workflow.py
src/llamafactory/train/pt/workflow.py
+16
-5
src/llamafactory/train/rm/trainer.py
src/llamafactory/train/rm/trainer.py
+2
-2
src/llamafactory/train/sft/trainer.py
src/llamafactory/train/sft/trainer.py
+2
-2
src/llamafactory/webui/chatter.py
src/llamafactory/webui/chatter.py
+42
-31
src/llamafactory/webui/common.py
src/llamafactory/webui/common.py
+8
-0
src/llamafactory/webui/components/chatbot.py
src/llamafactory/webui/components/chatbot.py
+3
-0
src/llamafactory/webui/components/top.py
src/llamafactory/webui/components/top.py
+2
-2
src/llamafactory/webui/components/train.py
src/llamafactory/webui/components/train.py
+41
-3
src/llamafactory/webui/control.py
src/llamafactory/webui/control.py
+11
-0
src/llamafactory/webui/locales.py
src/llamafactory/webui/locales.py
+200
-0
src/llamafactory/webui/runner.py
src/llamafactory/webui/runner.py
+9
-6
tests/data/test_formatter.py
tests/data/test_formatter.py
+27
-5
tests/data/test_mm_plugin.py
tests/data/test_mm_plugin.py
+6
-9
tests/data/test_template.py
tests/data/test_template.py
+87
-26
No files found.
src/llamafactory/model/model_utils/rope.py
View file @
24534501
...
...
@@ -32,7 +32,7 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
def
configure_rope
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
def
configure_rope
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
)
->
None
:
if
model_args
.
rope_scaling
is
None
:
return
...
...
@@ -40,30 +40,40 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
logger
.
warning_rank0
(
"Current model does not support RoPE scaling."
)
return
rope_kwargs
=
{
"rope_type"
:
getattr
(
model_args
.
rope_scaling
,
"value"
,
model_args
.
rope_scaling
)}
# handle enum
if
model_args
.
model_max_length
is
not
None
:
if
is_trainable
and
model_args
.
rope_scaling
==
RopeScaling
.
DYNAMIC
:
if
hasattr
(
config
,
"max_position_embeddings"
):
old_max_length
=
getattr
(
config
,
"max_position_embeddings"
,
None
)
else
:
logger
.
warning_rank0
(
"Cannot find the max position embeddings in the config."
)
return
if
model_args
.
model_max_length
is
not
None
:
# training
if
model_args
.
model_max_length
<=
old_max_length
:
logger
.
warning_rank0
(
"Input length is smaller than max length. Disabling rope scaling."
)
return
if
model_args
.
rope_scaling
==
RopeScaling
.
DYNAMIC
:
logger
.
warning_rank0
(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length
=
getattr
(
config
,
"max_position_embeddings"
,
None
)
if
(
not
current_max_length
)
or
model_args
.
model_max_length
<=
current_max_length
:
logger
.
warning_rank0
(
"Input length is smaller than max length. Disabling rope scaling."
)
return
rope_factor
=
float
(
math
.
ceil
(
model_args
.
model_max_length
/
old_max_length
))
else
:
# inference
rope_factor
=
2.0
logger
.
info_rank0
(
f
"Enlarge max model length from
{
current_max_length
}
to
{
model_args
.
model_max_length
}
."
)
setattr
(
config
,
"max_position_embeddings"
,
model_args
.
model_max_length
)
rope_kwargs
[
"factor"
]
=
float
(
math
.
ceil
(
model_args
.
model_max_length
/
current_max_length
))
if
model_args
.
rope_scaling
==
RopeScaling
.
DYNAMIC
:
rope_kwargs
[
"original_max_position_embeddings"
]
=
current_max_length
elif
model_args
.
rope_scaling
==
RopeScaling
.
LLAMA3
:
rope_kwargs
[
"original_max_position_embeddings"
]
=
current_max_length
rope_kwargs
[
"low_freq_factor"
]
=
1.0
rope_kwargs
[
"high_freq_factor"
]
=
4.0
else
:
rope_kwargs
[
"factor"
]
=
2.0
rope_kwargs
=
{
"rope_type"
:
getattr
(
model_args
.
rope_scaling
,
"value"
,
model_args
.
rope_scaling
),
# handle enum
"factor"
:
rope_factor
,
}
setattr
(
config
,
"max_position_embeddings"
,
old_max_length
*
rope_factor
)
logger
.
info_rank0
(
f
"Enlarge max model length from
{
old_max_length
}
to
{
old_max_length
*
rope_factor
}
."
)
if
model_args
.
rope_scaling
in
[
RopeScaling
.
DYNAMIC
,
RopeScaling
.
YARN
]:
rope_kwargs
[
"original_max_position_embeddings"
]
=
old_max_length
elif
model_args
.
rope_scaling
==
RopeScaling
.
LLAMA3
:
rope_kwargs
[
"original_max_position_embeddings"
]
=
old_max_length
rope_kwargs
[
"low_freq_factor"
]
=
1.0
rope_kwargs
[
"high_freq_factor"
]
=
4.0
setattr
(
config
,
"rope_scaling"
,
rope_kwargs
)
logger
.
info_rank0
(
...
...
src/llamafactory/model/model_utils/visual.py
View file @
24534501
...
...
@@ -24,6 +24,7 @@ import transformers.models
from
transformers.activations
import
ACT2FN
from
...extras
import
logging
from
...extras.packages
import
is_transformers_version_greater_than
if
TYPE_CHECKING
:
...
...
@@ -281,7 +282,7 @@ _register_composite_model(
model_type
=
"qwen2_vl"
,
projector_key
=
"visual.merger"
,
vision_model_keys
=
[
"visual.patch_embed"
,
"visual.blocks"
],
language_model_keys
=
[
"model"
,
"lm_head"
],
language_model_keys
=
[
"language_model"
]
if
is_transformers_version_greater_than
(
"4.52.0"
)
else
[
"model"
,
"lm_head"
],
lora_conflict_keys
=
[
"patch_embed"
],
)
...
...
@@ -290,6 +291,6 @@ _register_composite_model(
model_type
=
"qwen2_5_vl"
,
projector_key
=
"visual.merger"
,
vision_model_keys
=
[
"visual.patch_embed"
,
"visual.blocks"
],
language_model_keys
=
[
"model"
,
"lm_head"
],
language_model_keys
=
[
"language_model"
]
if
is_transformers_version_greater_than
(
"4.52.0"
)
else
[
"model"
,
"lm_head"
],
lora_conflict_keys
=
[
"patch_embed"
],
)
src/llamafactory/model/patcher.py
View file @
24534501
...
...
@@ -85,8 +85,8 @@ def patch_processor(
setattr
(
processor
,
"video_min_pixels"
,
model_args
.
video_min_pixels
)
setattr
(
processor
,
"video_fps"
,
model_args
.
video_fps
)
setattr
(
processor
,
"video_maxlen"
,
model_args
.
video_maxlen
)
setattr
(
processor
,
"audio_sampling_rate"
,
model_args
.
audio_sampling_rate
)
setattr
(
processor
,
"use_audio_in_video"
,
model_args
.
use_audio_in_video
)
setattr
(
processor
,
"audio_sampling_rate"
,
model_args
.
audio_sampling_rate
)
def
patch_config
(
...
...
@@ -102,8 +102,8 @@ def patch_config(
else
:
model_args
.
compute_dtype
=
infer_optim_dtype
(
model_dtype
=
getattr
(
config
,
"torch_dtype"
,
None
))
configure_attn_implementation
(
config
,
model_args
,
is_trainable
)
configure_rope
(
config
,
model_args
,
is_trainable
)
configure_attn_implementation
(
config
,
model_args
)
configure_rope
(
config
,
model_args
)
configure_longlora
(
config
,
model_args
,
is_trainable
)
configure_quantization
(
config
,
tokenizer
,
model_args
,
init_kwargs
)
configure_moe
(
config
,
model_args
,
is_trainable
)
...
...
src/llamafactory/train/dpo/trainer.py
View file @
24534501
...
...
@@ -121,11 +121,11 @@ class CustomDPOTrainer(DPOTrainer):
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
def
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
()
return
super
().
_get_train_sampler
(
*
args
,
**
kwargs
)
@
override
def
get_batch_samples
(
self
,
*
args
,
**
kwargs
):
...
...
src/llamafactory/train/kto/trainer.py
View file @
24534501
...
...
@@ -34,7 +34,6 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, ge
if
TYPE_CHECKING
:
import
torch.utils.data
from
transformers
import
PreTrainedModel
,
ProcessorMixin
from
...hparams
import
FinetuningArguments
...
...
@@ -119,12 +118,12 @@ class CustomKTOTrainer(KTOTrainer):
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
def
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
->
Optional
[
"torch.utils.data.Sampler"
]:
r
"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
Trainer
.
_get_train_sampler
(
self
)
return
Trainer
.
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
@
override
def
get_batch_samples
(
self
,
*
args
,
**
kwargs
):
...
...
src/llamafactory/train/pt/trainer.py
View file @
24534501
...
...
@@ -70,11 +70,11 @@ class CustomTrainer(Trainer):
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
def
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
()
return
super
().
_get_train_sampler
(
*
args
,
**
kwargs
)
@
override
def
compute_loss
(
self
,
model
,
inputs
,
*
args
,
**
kwargs
):
...
...
src/llamafactory/train/pt/workflow.py
View file @
24534501
...
...
@@ -77,12 +77,23 @@ def run_pt(
# Evaluation
if
training_args
.
do_eval
:
metrics
=
trainer
.
evaluate
(
metric_key_prefix
=
"eval"
)
try
:
perplexity
=
math
.
exp
(
metrics
[
"eval_loss"
])
except
OverflowError
:
perplexity
=
float
(
"inf"
)
metrics
[
"perplexity"
]
=
perplexity
if
isinstance
(
dataset_module
.
get
(
"eval_dataset"
),
dict
):
for
key
in
dataset_module
[
"eval_dataset"
].
keys
():
try
:
perplexity
=
math
.
exp
(
metrics
[
f
"eval_
{
key
}
_loss"
])
except
OverflowError
:
perplexity
=
float
(
"inf"
)
metrics
[
f
"eval_
{
key
}
_perplexity"
]
=
perplexity
else
:
try
:
perplexity
=
math
.
exp
(
metrics
[
"eval_loss"
])
except
OverflowError
:
perplexity
=
float
(
"inf"
)
metrics
[
"eval_perplexity"
]
=
perplexity
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
...
...
src/llamafactory/train/rm/trainer.py
View file @
24534501
...
...
@@ -78,11 +78,11 @@ class PairwiseTrainer(Trainer):
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
def
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
()
return
super
().
_get_train_sampler
(
*
args
,
**
kwargs
)
@
override
def
compute_loss
(
...
...
src/llamafactory/train/sft/trainer.py
View file @
24534501
...
...
@@ -92,11 +92,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
def
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
()
return
super
().
_get_train_sampler
(
*
args
,
**
kwargs
)
@
override
def
compute_loss
(
self
,
model
,
inputs
,
*
args
,
**
kwargs
):
...
...
src/llamafactory/webui/chatter.py
View file @
24534501
...
...
@@ -15,6 +15,7 @@
import
json
import
os
from
collections.abc
import
Generator
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
transformers.utils
import
is_torch_npu_available
...
...
@@ -68,6 +69,14 @@ def _format_response(text: str, lang: str, escape_html: bool, thought_words: tup
)
@
contextmanager
def
update_attr
(
obj
:
Any
,
name
:
str
,
value
:
Any
):
old_value
=
getattr
(
obj
,
name
,
None
)
setattr
(
obj
,
name
,
value
)
yield
setattr
(
obj
,
name
,
old_value
)
class
WebChatModel
(
ChatModel
):
def
__init__
(
self
,
manager
:
"Manager"
,
demo_mode
:
bool
=
False
,
lazy_init
:
bool
=
True
)
->
None
:
self
.
manager
=
manager
...
...
@@ -191,40 +200,42 @@ class WebChatModel(ChatModel):
temperature
:
float
,
skip_special_tokens
:
bool
,
escape_html
:
bool
,
enable_thinking
:
bool
,
)
->
Generator
[
tuple
[
list
[
dict
[
str
,
str
]],
list
[
dict
[
str
,
str
]]],
None
,
None
]:
r
"""Generate output text in stream.
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages
"""
chatbot
.
append
({
"role"
:
"assistant"
,
"content"
:
""
})
response
=
""
for
new_text
in
self
.
stream_chat
(
messages
,
system
,
tools
,
images
=
[
image
]
if
image
else
None
,
videos
=
[
video
]
if
video
else
None
,
audios
=
[
audio
]
if
audio
else
None
,
max_new_tokens
=
max_new_tokens
,
top_p
=
top_p
,
temperature
=
temperature
,
skip_special_tokens
=
skip_special_tokens
,
):
response
+=
new_text
if
tools
:
result
=
self
.
engine
.
template
.
extract_tool
(
response
)
else
:
result
=
response
if
isinstance
(
result
,
list
):
tool_calls
=
[{
"name"
:
tool
.
name
,
"arguments"
:
json
.
loads
(
tool
.
arguments
)}
for
tool
in
result
]
tool_calls
=
json
.
dumps
(
tool_calls
,
ensure_ascii
=
False
)
output_messages
=
messages
+
[{
"role"
:
Role
.
FUNCTION
.
value
,
"content"
:
tool_calls
}]
bot_text
=
"```json
\n
"
+
tool_calls
+
"
\n
```"
else
:
output_messages
=
messages
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
result
}]
bot_text
=
_format_response
(
result
,
lang
,
escape_html
,
self
.
engine
.
template
.
thought_words
)
chatbot
[
-
1
]
=
{
"role"
:
"assistant"
,
"content"
:
bot_text
}
yield
chatbot
,
output_messages
with
update_attr
(
self
.
engine
.
template
,
"enable_thinking"
,
enable_thinking
):
chatbot
.
append
({
"role"
:
"assistant"
,
"content"
:
""
})
response
=
""
for
new_text
in
self
.
stream_chat
(
messages
,
system
,
tools
,
images
=
[
image
]
if
image
else
None
,
videos
=
[
video
]
if
video
else
None
,
audios
=
[
audio
]
if
audio
else
None
,
max_new_tokens
=
max_new_tokens
,
top_p
=
top_p
,
temperature
=
temperature
,
skip_special_tokens
=
skip_special_tokens
,
):
response
+=
new_text
if
tools
:
result
=
self
.
engine
.
template
.
extract_tool
(
response
)
else
:
result
=
response
if
isinstance
(
result
,
list
):
tool_calls
=
[{
"name"
:
tool
.
name
,
"arguments"
:
json
.
loads
(
tool
.
arguments
)}
for
tool
in
result
]
tool_calls
=
json
.
dumps
(
tool_calls
,
ensure_ascii
=
False
)
output_messages
=
messages
+
[{
"role"
:
Role
.
FUNCTION
.
value
,
"content"
:
tool_calls
}]
bot_text
=
"```json
\n
"
+
tool_calls
+
"
\n
```"
else
:
output_messages
=
messages
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
result
}]
bot_text
=
_format_response
(
result
,
lang
,
escape_html
,
self
.
engine
.
template
.
thought_words
)
chatbot
[
-
1
]
=
{
"role"
:
"assistant"
,
"content"
:
bot_text
}
yield
chatbot
,
output_messages
src/llamafactory/webui/common.py
View file @
24534501
...
...
@@ -205,6 +205,14 @@ def load_eval_results(path: os.PathLike) -> str:
return
f
"```json
\n
{
result
}
\n
```
\n
"
def
calculate_pixels
(
pixels
:
str
)
->
int
:
r
"""Calculate the number of pixels from the expression."""
if
"*"
in
pixels
:
return
int
(
pixels
.
split
(
"*"
)[
0
])
*
int
(
pixels
.
split
(
"*"
)[
1
])
else
:
return
int
(
pixels
)
def
create_ds_config
()
->
None
:
r
"""Create deepspeed config in the current directory."""
os
.
makedirs
(
DEFAULT_CACHE_DIR
,
exist_ok
=
True
)
...
...
src/llamafactory/webui/components/chatbot.py
View file @
24534501
...
...
@@ -79,6 +79,7 @@ def create_chat_box(
temperature
=
gr
.
Slider
(
minimum
=
0.01
,
maximum
=
1.5
,
value
=
0.95
,
step
=
0.01
)
skip_special_tokens
=
gr
.
Checkbox
(
value
=
True
)
escape_html
=
gr
.
Checkbox
(
value
=
True
)
enable_thinking
=
gr
.
Checkbox
(
value
=
True
)
clear_btn
=
gr
.
Button
()
tools
.
input
(
check_json_schema
,
inputs
=
[
tools
,
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)])
...
...
@@ -103,6 +104,7 @@ def create_chat_box(
temperature
,
skip_special_tokens
,
escape_html
,
enable_thinking
,
],
[
chatbot
,
messages
],
)
...
...
@@ -127,6 +129,7 @@ def create_chat_box(
temperature
=
temperature
,
skip_special_tokens
=
skip_special_tokens
,
escape_html
=
escape_html
,
enable_thinking
=
enable_thinking
,
clear_btn
=
clear_btn
,
),
)
src/llamafactory/webui/components/top.py
View file @
24534501
...
...
@@ -18,7 +18,7 @@ from ...data import TEMPLATES
from
...extras.constants
import
METHODS
,
SUPPORTED_MODELS
from
...extras.packages
import
is_gradio_available
from
..common
import
save_config
from
..control
import
can_quantize
,
can_quantize_to
,
get_model_info
,
list_checkpoints
from
..control
import
can_quantize
,
can_quantize_to
,
check_template
,
get_model_info
,
list_checkpoints
if
is_gradio_available
():
...
...
@@ -49,7 +49,7 @@ def create_top() -> dict[str, "Component"]:
model_name
.
change
(
get_model_info
,
[
model_name
],
[
model_path
,
template
],
queue
=
False
).
then
(
list_checkpoints
,
[
model_name
,
finetuning_type
],
[
checkpoint_path
],
queue
=
False
)
)
.
then
(
check_template
,
[
lang
,
template
])
model_name
.
input
(
save_config
,
inputs
=
[
lang
,
model_name
],
queue
=
False
)
model_path
.
input
(
save_config
,
inputs
=
[
lang
,
model_name
,
model_path
],
queue
=
False
)
finetuning_type
.
change
(
can_quantize
,
[
finetuning_type
],
[
quantization_bit
],
queue
=
False
).
then
(
...
...
src/llamafactory/webui/components/train.py
View file @
24534501
...
...
@@ -106,11 +106,11 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
use_llama_pro
=
gr
.
Checkbox
()
with
gr
.
Column
():
enable_thinking
=
gr
.
Checkbox
(
value
=
True
)
report_to
=
gr
.
Dropdown
(
choices
=
[
"none"
,
"all"
,
"wandb"
,
"mlflow"
,
"neptune"
,
"tensorboard"
],
value
=
[
"none"
]
,
choices
=
[
"none"
,
"wandb"
,
"mlflow"
,
"neptune"
,
"tensorboard"
,
"all"
],
value
=
"none"
,
allow_custom_value
=
True
,
multiselect
=
True
,
)
input_elems
.
update
(
...
...
@@ -126,6 +126,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
mask_history
,
resize_vocab
,
use_llama_pro
,
enable_thinking
,
report_to
,
}
)
...
...
@@ -143,6 +144,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
mask_history
=
mask_history
,
resize_vocab
=
resize_vocab
,
use_llama_pro
=
use_llama_pro
,
enable_thinking
=
enable_thinking
,
report_to
=
report_to
,
)
)
...
...
@@ -231,6 +233,42 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
)
)
with
gr
.
Accordion
(
open
=
False
)
as
mm_tab
:
with
gr
.
Row
():
freeze_vision_tower
=
gr
.
Checkbox
(
value
=
True
)
freeze_multi_modal_projector
=
gr
.
Checkbox
(
value
=
True
)
freeze_language_model
=
gr
.
Checkbox
(
value
=
False
)
with
gr
.
Row
():
image_max_pixels
=
gr
.
Textbox
(
value
=
"768*768"
)
image_min_pixels
=
gr
.
Textbox
(
value
=
"32*32"
)
video_max_pixels
=
gr
.
Textbox
(
value
=
"256*256"
)
video_min_pixels
=
gr
.
Textbox
(
value
=
"16*16"
)
input_elems
.
update
(
{
freeze_vision_tower
,
freeze_multi_modal_projector
,
freeze_language_model
,
image_max_pixels
,
image_min_pixels
,
video_max_pixels
,
video_min_pixels
,
}
)
elem_dict
.
update
(
dict
(
mm_tab
=
mm_tab
,
freeze_vision_tower
=
freeze_vision_tower
,
freeze_multi_modal_projector
=
freeze_multi_modal_projector
,
freeze_language_model
=
freeze_language_model
,
image_max_pixels
=
image_max_pixels
,
image_min_pixels
=
image_min_pixels
,
video_max_pixels
=
video_max_pixels
,
video_min_pixels
=
video_min_pixels
,
)
)
with
gr
.
Accordion
(
open
=
False
)
as
galore_tab
:
with
gr
.
Row
():
use_galore
=
gr
.
Checkbox
()
...
...
src/llamafactory/webui/control.py
View file @
24534501
...
...
@@ -84,6 +84,17 @@ def get_model_info(model_name: str) -> tuple[str, str]:
return
get_model_path
(
model_name
),
get_template
(
model_name
)
def
check_template
(
lang
:
str
,
template
:
str
)
->
None
:
r
"""Check if an instruct model is used.
Please use queue=True to show the warning message.
Inputs: top.lang, top.template
"""
if
template
==
"default"
:
gr
.
Warning
(
ALERTS
[
"warn_no_instruct"
][
lang
])
def
get_trainer_info
(
lang
:
str
,
output_path
:
os
.
PathLike
,
do_train
:
bool
)
->
tuple
[
str
,
"gr.Slider"
,
dict
[
str
,
Any
]]:
r
"""Get training infomation for monitor.
...
...
src/llamafactory/webui/locales.py
View file @
24534501
...
...
@@ -871,6 +871,28 @@ LOCALES = {
"info"
:
"拡張ブロックのパラメータのみをトレーニングします。"
,
},
},
"enable_thinking"
:
{
"en"
:
{
"label"
:
"Enable thinking"
,
"info"
:
"Whether or not to enable thinking mode for reasoning models."
,
},
"ru"
:
{
"label"
:
"Включить мысли"
,
"info"
:
"Включить режим мысли для моделей решающего характера."
,
},
"zh"
:
{
"label"
:
"启用思考模式"
,
"info"
:
"是否启用推理模型的思考模式。"
,
},
"ko"
:
{
"label"
:
"생각 모드 활성화"
,
"info"
:
"추론 모델의 생각 모드를 활성화할지 여부."
,
},
"ja"
:
{
"label"
:
"思考モードを有効化"
,
"info"
:
"推論モデルの思考モードを有効にするかどうか。"
,
},
},
"report_to"
:
{
"en"
:
{
"label"
:
"Enable external logger"
,
...
...
@@ -1374,6 +1396,177 @@ LOCALES = {
"info"
:
"PPO トレーニングにおいて報酬スコアをホワイトニング処理します。"
,
},
},
"mm_tab"
:
{
"en"
:
{
"label"
:
"Multimodal configurations"
,
},
"ru"
:
{
"label"
:
"Конфигурации мультимедиа"
,
},
"zh"
:
{
"label"
:
"多模态参数设置"
,
},
"ko"
:
{
"label"
:
"멀티모달 구성"
,
},
"ja"
:
{
"label"
:
"多モーダル設定"
,
},
},
"freeze_vision_tower"
:
{
"en"
:
{
"label"
:
"Freeze vision tower"
,
"info"
:
"Freeze the vision tower in the model."
,
},
"ru"
:
{
"label"
:
"Заморозить башню визиона"
,
"info"
:
"Заморозить башню визиона в модели."
,
},
"zh"
:
{
"label"
:
"冻结视觉编码器"
,
"info"
:
"冻结模型中的视觉编码器。"
,
},
"ko"
:
{
"label"
:
"비전 타워 고정"
,
"info"
:
"모델의 비전 타워를 고정합니다."
,
},
"ja"
:
{
"label"
:
"ビジョンタワーの固定"
,
"info"
:
"モデルのビジョンタワーを固定します。"
,
},
},
"freeze_multi_modal_projector"
:
{
"en"
:
{
"label"
:
"Freeze multi-modal projector"
,
"info"
:
"Freeze the multi-modal projector in the model."
,
},
"ru"
:
{
"label"
:
"Заморозить мультимодальный проектор"
,
"info"
:
"Заморозить мультимодальный проектор в модели."
,
},
"zh"
:
{
"label"
:
"冻结多模态投影器"
,
"info"
:
"冻结模型中的多模态投影器。"
,
},
"ko"
:
{
"label"
:
"멀티모달 프로젝터 고정"
,
"info"
:
"모델의 멀티모달 프로젝터를 고정합니다."
,
},
"ja"
:
{
"label"
:
"多モーダルプロジェクターの固定"
,
"info"
:
"モデルの多モーダルプロジェクターを固定します。"
,
},
},
"freeze_language_model"
:
{
"en"
:
{
"label"
:
"Freeze language model"
,
"info"
:
"Freeze the language model in the model."
,
},
"ru"
:
{
"label"
:
"Заморозить язык модели"
,
"info"
:
"Заморозить язык модели в модели."
,
},
"zh"
:
{
"label"
:
"冻结语言模型"
,
"info"
:
"冻结模型中的语言模型。"
,
},
"ko"
:
{
"label"
:
"언어 모델 고정"
,
"info"
:
"모델의 언어 모델을 고정합니다."
,
},
"ja"
:
{
"label"
:
"言語モデルの固定"
,
"info"
:
"モデルの言語モデルを固定します。"
,
},
},
"image_max_pixels"
:
{
"en"
:
{
"label"
:
"Image max pixels"
,
"info"
:
"The maximum number of pixels of image inputs."
,
},
"ru"
:
{
"label"
:
"Максимальное количество пикселей изображения"
,
"info"
:
"Максимальное количество пикселей изображения."
,
},
"zh"
:
{
"label"
:
"图像最大像素"
,
"info"
:
"输入图像的最大像素数。"
,
},
"ko"
:
{
"label"
:
"이미지 최대 픽셀"
,
"info"
:
"이미지 입력의 최대 픽셀 수입니다."
,
},
"ja"
:
{
"label"
:
"画像最大ピクセル"
,
"info"
:
"画像入力の最大ピクセル数です。"
,
},
},
"image_min_pixels"
:
{
"en"
:
{
"label"
:
"Image min pixels"
,
"info"
:
"The minimum number of pixels of image inputs."
,
},
"ru"
:
{
"label"
:
"Минимальное количество пикселей изображения"
,
"info"
:
"Минимальное количество пикселей изображения."
,
},
"zh"
:
{
"label"
:
"图像最小像素"
,
"info"
:
"输入图像的最小像素数。"
,
},
"ko"
:
{
"label"
:
"이미지 최소 픽셀"
,
"info"
:
"이미지 입력의 최소 픽셀 수입니다."
,
},
"ja"
:
{
"label"
:
"画像最小ピクセル"
,
"info"
:
"画像入力の最小ピクセル数です。"
,
},
},
"video_max_pixels"
:
{
"en"
:
{
"label"
:
"Video max pixels"
,
"info"
:
"The maximum number of pixels of video inputs."
,
},
"ru"
:
{
"label"
:
"Максимальное количество пикселей видео"
,
"info"
:
"Максимальное количество пикселей видео."
,
},
"zh"
:
{
"label"
:
"视频最大像素"
,
"info"
:
"输入视频的最大像素数。"
,
},
"ko"
:
{
"label"
:
"비디오 최대 픽셀"
,
"info"
:
"비디오 입력의 최대 픽셀 수입니다."
,
},
"ja"
:
{
"label"
:
"ビデオ最大ピクセル"
,
"info"
:
"ビデオ入力の最大ピクセル数です。"
,
},
},
"video_min_pixels"
:
{
"en"
:
{
"label"
:
"Video min pixels"
,
"info"
:
"The minimum number of pixels of video inputs."
,
},
"ru"
:
{
"label"
:
"Минимальное количество пикселей видео"
,
"info"
:
"Минимальное количество пикселей видео."
,
},
"zh"
:
{
"label"
:
"视频最小像素"
,
"info"
:
"输入视频的最小像素数。"
,
},
"ko"
:
{
"label"
:
"비디오 최소 픽셀"
,
"info"
:
"비디오 입력의 최소 픽셀 수입니다."
,
},
"ja"
:
{
"label"
:
"ビデオ最小ピクセル"
,
"info"
:
"ビデオ入力の最小ピクセル数です。"
,
},
},
"galore_tab"
:
{
"en"
:
{
"label"
:
"GaLore configurations"
,
...
...
@@ -2779,6 +2972,13 @@ ALERTS = {
"ko"
:
"출력 디렉토리가 이미 존재합니다. 위 출력 디렉토리에 저장된 학습을 재개합니다."
,
"ja"
:
"出力ディレクトリが既に存在します。このチェックポイントからトレーニングを再開します。"
,
},
"warn_no_instruct"
:
{
"en"
:
"You are using a non-instruct model, please fine-tune it first."
,
"ru"
:
"Вы используете модель без инструкции, пожалуйста, primeros выполните донастройку этой модели."
,
"zh"
:
"您正在使用非指令模型,请先对其进行微调。"
,
"ko"
:
"당신은 지시하지 않은 모델을 사용하고 있습니다. 먼저 이를 미세 조정해 주세요."
,
"ja"
:
"インストラクションモデルを使用していません。まずモデルをアダプターに適合させてください。"
,
},
"info_aborting"
:
{
"en"
:
"Aborted, wait for terminating..."
,
"ru"
:
"Прервано, ожидание завершения..."
,
...
...
src/llamafactory/webui/runner.py
View file @
24534501
...
...
@@ -29,6 +29,7 @@ from .common import (
DEFAULT_CACHE_DIR
,
DEFAULT_CONFIG_DIR
,
abort_process
,
calculate_pixels
,
gen_cmd
,
get_save_dir
,
load_args
,
...
...
@@ -162,7 +163,15 @@ class Runner:
mask_history
=
get
(
"train.mask_history"
),
resize_vocab
=
get
(
"train.resize_vocab"
),
use_llama_pro
=
get
(
"train.use_llama_pro"
),
enable_thinking
=
get
(
"train.enable_thinking"
),
report_to
=
get
(
"train.report_to"
),
freeze_vision_tower
=
get
(
"train.freeze_vision_tower"
),
freeze_multi_modal_projector
=
get
(
"train.freeze_multi_modal_projector"
),
freeze_language_model
=
get
(
"train.freeze_language_model"
),
image_max_pixels
=
calculate_pixels
(
get
(
"train.image_max_pixels"
)),
image_min_pixels
=
calculate_pixels
(
get
(
"train.image_min_pixels"
)),
video_max_pixels
=
calculate_pixels
(
get
(
"train.video_max_pixels"
)),
video_min_pixels
=
calculate_pixels
(
get
(
"train.video_min_pixels"
)),
use_galore
=
get
(
"train.use_galore"
),
use_apollo
=
get
(
"train.use_apollo"
),
use_badam
=
get
(
"train.use_badam"
),
...
...
@@ -256,12 +265,6 @@ class Runner:
args
[
"badam_switch_interval"
]
=
get
(
"train.badam_switch_interval"
)
args
[
"badam_update_ratio"
]
=
get
(
"train.badam_update_ratio"
)
# report_to
if
"none"
in
args
[
"report_to"
]:
args
[
"report_to"
]
=
"none"
elif
"all"
in
args
[
"report_to"
]:
args
[
"report_to"
]
=
"all"
# swanlab config
if
get
(
"train.use_swanlab"
):
args
[
"swanlab_project"
]
=
get
(
"train.swanlab_project"
)
...
...
tests/data/test_formatter.py
View file @
24534501
...
...
@@ -50,7 +50,7 @@ def test_function_formatter():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
,
"</s>"
],
tool_format
=
"default"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}
\n
"""
,
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}"""
,
"</s>"
,
]
...
...
@@ -60,7 +60,7 @@ def test_multi_function_formatter():
tool_calls
=
json
.
dumps
([
FUNCTION
]
*
2
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}
\n
"""
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}
\n
"""
,
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}"""
,
"</s>"
,
]
...
...
@@ -85,7 +85,7 @@ def test_default_tool_formatter():
def
test_default_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"default"
)
result
=
"""Action: test_tool
\n
Action Input: {"foo": "bar", "size": 10}
\n
"""
result
=
"""Action: test_tool
\n
Action Input: {"foo": "bar", "size": 10}"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
...
...
@@ -93,7 +93,7 @@ def test_default_multi_tool_extractor():
formatter
=
ToolFormatter
(
tool_format
=
"default"
)
result
=
(
"""Action: test_tool
\n
Action Input: {"foo": "bar", "size": 10}
\n
"""
"""Action: another_tool
\n
Action Input: {"foo": "job", "size": 2}
\n
"""
"""Action: another_tool
\n
Action Input: {"foo": "job", "size": 2}"""
)
assert
formatter
.
extract
(
result
)
==
[
(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
),
...
...
@@ -125,12 +125,22 @@ def test_glm4_tool_extractor():
def
test_llama3_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
],
tool_format
=
"llama3"
)
tool_calls
=
json
.
dumps
(
{
"name"
:
"tool_name"
,
"arguments"
:
{
"foo"
:
"bar"
,
"size"
:
10
}}
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>"""
]
def
test_llama3_multi_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
],
tool_format
=
"llama3"
)
tool_calls
=
json
.
dumps
([
FUNCTION
]
*
2
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""[{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}]"""
"""<|eot_id|>"""
]
def
test_llama3_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"llama3"
)
date
=
datetime
.
now
().
strftime
(
"%d %b %Y"
)
...
...
@@ -150,6 +160,18 @@ def test_llama3_tool_extractor():
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
def
test_llama3_multi_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"llama3"
)
result
=
(
"""[{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}, """
"""{"name": "another_tool", "parameters": {"foo": "job", "size": 2}}]"""
)
assert
formatter
.
extract
(
result
)
==
[
(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
),
(
"another_tool"
,
"""{"foo": "job", "size": 2}"""
),
]
def
test_mistral_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] {{content}}"
,
"</s>"
],
tool_format
=
"mistral"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
...
...
tests/data/test_mm_plugin.py
View file @
24534501
...
...
@@ -135,8 +135,7 @@ def _check_plugin(
expected_mm_inputs
:
dict
[
str
,
Any
]
=
{},
expected_no_mm_inputs
:
dict
[
str
,
Any
]
=
{},
)
->
None
:
# test omni_messages
if
plugin
.
__class__
.
__name__
==
"Qwen2OmniPlugin"
:
if
plugin
.
__class__
.
__name__
==
"Qwen2OmniPlugin"
:
# test omni_messages
assert
plugin
.
process_messages
(
OMNI_MESSAGES
,
IMAGES
,
NO_VIDEOS
,
AUDIOS
,
processor
)
==
expected_mm_messages
assert
plugin
.
process_token_ids
(
INPUT_IDS
,
LABELS
,
IMAGES
,
NO_VIDEOS
,
AUDIOS
,
tokenizer
,
processor
)
==
(
expected_input_ids
,
...
...
@@ -146,8 +145,7 @@ def _check_plugin(
plugin
.
get_mm_inputs
(
IMAGES
,
NO_VIDEOS
,
AUDIOS
,
IMGLENS
,
NO_VIDLENS
,
AUDLENS
,
BATCH_IDS
,
processor
),
expected_mm_inputs
,
)
# test mm_messages
if
plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
elif
plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
# test mm_messages
assert
plugin
.
process_messages
(
MM_MESSAGES
,
IMAGES
,
NO_VIDEOS
,
NO_AUDIOS
,
processor
)
==
expected_mm_messages
assert
plugin
.
process_token_ids
(
INPUT_IDS
,
LABELS
,
IMAGES
,
NO_VIDEOS
,
NO_AUDIOS
,
tokenizer
,
processor
)
==
(
expected_input_ids
,
...
...
@@ -201,7 +199,7 @@ def test_gemma3_plugin():
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
xfail
(
reason
=
"Unknown error.
"
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.52.0"
),
reason
=
"Requires transformers>=4.52.0
"
)
def
test_internvl_plugin
():
image_seqlen
=
256
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"OpenGVLab/InternVL3-1B-hf"
)
...
...
@@ -219,7 +217,7 @@ def test_internvl_plugin():
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
xfail
(
reason
=
"Unknown error.
"
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.51.0"
),
reason
=
"Requires transformers>=4.51.0
"
)
def
test_llama4_plugin
():
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
TINY_LLAMA4
)
processor
=
tokenizer_module
[
"processor"
]
...
...
@@ -321,10 +319,9 @@ def test_pixtral_plugin():
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
xfail
(
reason
=
"Unknown error.
"
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.52.0"
),
reason
=
"Requires transformers>=4.52.0
"
)
def
test_qwen2_omni_plugin
():
image_seqlen
=
4
audio_seqlen
=
2
image_seqlen
,
audio_seqlen
=
4
,
2
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"Qwen/Qwen2.5-Omni-7B"
)
qwen2_omni_plugin
=
get_mm_plugin
(
name
=
"qwen2_omni"
,
audio_token
=
"<|AUDIO|>"
,
image_token
=
"<|IMAGE|>"
,
video_token
=
"<|VIDEO|>"
...
...
tests/data/test_template.py
View file @
24534501
...
...
@@ -125,6 +125,60 @@ def test_encode_multiturn(use_fast: bool):
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"cot_messages"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enable_thinking"
,
[
True
,
False
,
None
])
def
test_reasoning_encode_oneturn
(
use_fast
:
bool
,
cot_messages
:
bool
,
enable_thinking
:
bool
):
input_messages
=
MESSAGES_WITH_THOUGHT
if
cot_messages
else
MESSAGES
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen3-8B"
,
use_fast
=
use_fast
)
data_args
=
DataArguments
(
template
=
"qwen3"
,
enable_thinking
=
enable_thinking
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
prompt_ids
,
answer_ids
=
template
.
encode_oneturn
(
tokenizer
,
input_messages
)
output_messages
=
MESSAGES
if
enable_thinking
is
False
else
input_messages
prompt_str
=
(
f
"<|im_start|>user
\n
{
output_messages
[
0
][
'content'
]
}
<|im_end|>
\n
<|im_start|>assistant
\n
"
f
"
{
MESSAGES
[
1
][
'content'
]
}
<|im_end|>
\n
"
f
"<|im_start|>user
\n
{
output_messages
[
2
][
'content'
]
}
<|im_end|>
\n
<|im_start|>assistant
\n
"
)
answer_str
=
f
"
{
output_messages
[
3
][
'content'
]
}
<|im_end|>
\n
"
if
not
cot_messages
or
enable_thinking
is
False
:
if
enable_thinking
:
answer_str
=
"<think>
\n\n
</think>
\n\n
"
+
answer_str
else
:
prompt_str
=
prompt_str
+
"<think>
\n\n
</think>
\n\n
"
_check_tokenization
(
tokenizer
,
(
prompt_ids
,
answer_ids
),
(
prompt_str
,
answer_str
))
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"cot_messages"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enable_thinking"
,
[
True
,
False
,
None
])
def
test_reasoning_encode_multiturn
(
use_fast
:
bool
,
cot_messages
:
bool
,
enable_thinking
:
bool
):
input_messages
=
MESSAGES_WITH_THOUGHT
if
cot_messages
else
MESSAGES
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen3-8B"
,
use_fast
=
use_fast
)
data_args
=
DataArguments
(
template
=
"qwen3"
,
enable_thinking
=
enable_thinking
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
encoded_pairs
=
template
.
encode_multiturn
(
tokenizer
,
input_messages
)
output_messages
=
MESSAGES
if
enable_thinking
is
False
else
input_messages
prompt_str_1
=
f
"<|im_start|>user
\n
{
output_messages
[
0
][
'content'
]
}
<|im_end|>
\n
<|im_start|>assistant
\n
"
answer_str_1
=
f
"
{
output_messages
[
1
][
'content'
]
}
<|im_end|>
\n
"
prompt_str_2
=
f
"<|im_start|>user
\n
{
output_messages
[
2
][
'content'
]
}
<|im_end|>
\n
<|im_start|>assistant
\n
"
answer_str_2
=
f
"
{
output_messages
[
3
][
'content'
]
}
<|im_end|>
\n
"
if
not
cot_messages
or
enable_thinking
is
False
:
if
enable_thinking
:
answer_str_1
=
"<think>
\n\n
</think>
\n\n
"
+
answer_str_1
answer_str_2
=
"<think>
\n\n
</think>
\n\n
"
+
answer_str_2
else
:
prompt_str_1
=
prompt_str_1
+
"<think>
\n\n
</think>
\n\n
"
prompt_str_2
=
prompt_str_2
+
"<think>
\n\n
</think>
\n\n
"
_check_tokenization
(
tokenizer
,
(
encoded_pairs
[
0
][
0
],
encoded_pairs
[
0
][
1
],
encoded_pairs
[
1
][
0
],
encoded_pairs
[
1
][
1
]),
(
prompt_str_1
,
answer_str_1
,
prompt_str_2
,
answer_str_2
),
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_jinja_template
(
use_fast
:
bool
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
,
use_fast
=
use_fast
)
...
...
@@ -162,12 +216,12 @@ def test_get_stop_token_ids():
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_gemma_template
(
use_fast
:
bool
):
prompt_str
=
(
"<bos><start_of_turn>user
\n
How are you
<end_of_turn>
\n
"
"<start_of_turn>model
\n
I am fine!
<end_of_turn>
\n
"
"<start_of_turn>user
\n
你好
<end_of_turn>
\n
"
f
"<bos><start_of_turn>user
\n
{
MESSAGES
[
0
][
'content'
]
}
<end_of_turn>
\n
"
f
"<start_of_turn>model
\n
{
MESSAGES
[
1
][
'content'
]
}
<end_of_turn>
\n
"
f
"<start_of_turn>user
\n
{
MESSAGES
[
2
][
'content'
]
}
<end_of_turn>
\n
"
"<start_of_turn>model
\n
"
)
answer_str
=
"很高兴认识你!
<end_of_turn>
\n
"
answer_str
=
f
"
{
MESSAGES
[
3
][
'content'
]
}
<end_of_turn>
\n
"
_check_template
(
"google/gemma-3-4b-it"
,
"gemma"
,
prompt_str
,
answer_str
,
use_fast
)
...
...
@@ -175,12 +229,12 @@ def test_gemma_template(use_fast: bool):
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_llama3_template
(
use_fast
:
bool
):
prompt_str
=
(
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>
\n\n
How are you
<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
I am fine!
<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>
\n\n
你好
<|eot_id|>"
f
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>
\n\n
{
MESSAGES
[
0
][
'content'
]
}
<|eot_id|>"
f
"<|start_header_id|>assistant<|end_header_id|>
\n\n
{
MESSAGES
[
1
][
'content'
]
}
<|eot_id|>"
f
"<|start_header_id|>user<|end_header_id|>
\n\n
{
MESSAGES
[
2
][
'content'
]
}
<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
answer_str
=
"很高兴认识你!
<|eot_id|>"
answer_str
=
f
"
{
MESSAGES
[
3
][
'content'
]
}
<|eot_id|>"
_check_template
(
"meta-llama/Meta-Llama-3-8B-Instruct"
,
"llama3"
,
prompt_str
,
answer_str
,
use_fast
)
...
...
@@ -189,12 +243,12 @@ def test_llama3_template(use_fast: bool):
)
def
test_llama4_template
(
use_fast
:
bool
):
prompt_str
=
(
"<|begin_of_text|><|header_start|>user<|header_end|>
\n\n
How are you
<|eot|>"
"<|header_start|>assistant<|header_end|>
\n\n
I am fine!
<|eot|>"
"<|header_start|>user<|header_end|>
\n\n
你好
<|eot|>"
f
"<|begin_of_text|><|header_start|>user<|header_end|>
\n\n
{
MESSAGES
[
0
][
'content'
]
}
<|eot|>"
f
"<|header_start|>assistant<|header_end|>
\n\n
{
MESSAGES
[
1
][
'content'
]
}
<|eot|>"
f
"<|header_start|>user<|header_end|>
\n\n
{
MESSAGES
[
2
][
'content'
]
}
<|eot|>"
"<|header_start|>assistant<|header_end|>
\n\n
"
)
answer_str
=
"很高兴认识你!
<|eot|>"
answer_str
=
f
"
{
MESSAGES
[
3
][
'content'
]
}
<|eot|>"
_check_template
(
TINY_LLAMA4
,
"llama4"
,
prompt_str
,
answer_str
,
use_fast
)
...
...
@@ -203,12 +257,12 @@ def test_llama4_template(use_fast: bool):
)
def
test_phi4_template
(
use_fast
:
bool
):
prompt_str
=
(
"<|im_start|>user<|im_sep|>
How are you
<|im_end|>"
"<|im_start|>assistant<|im_sep|>
I am fine!
<|im_end|>"
"<|im_start|>user<|im_sep|>
你好
<|im_end|>"
f
"<|im_start|>user<|im_sep|>
{
MESSAGES
[
0
][
'content'
]
}
<|im_end|>"
f
"<|im_start|>assistant<|im_sep|>
{
MESSAGES
[
1
][
'content'
]
}
<|im_end|>"
f
"<|im_start|>user<|im_sep|>
{
MESSAGES
[
2
][
'content'
]
}
<|im_end|>"
"<|im_start|>assistant<|im_sep|>"
)
answer_str
=
"很高兴认识你!
<|im_end|>"
answer_str
=
f
"
{
MESSAGES
[
3
][
'content'
]
}
<|im_end|>"
_check_template
(
"microsoft/phi-4"
,
"phi4"
,
prompt_str
,
answer_str
,
use_fast
)
...
...
@@ -216,25 +270,30 @@ def test_phi4_template(use_fast: bool):
def
test_qwen2_5_template
(
use_fast
:
bool
):
prompt_str
=
(
"<|im_start|>system
\n
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
\n
"
"<|im_start|>user
\n
How are you
<|im_end|>
\n
"
"<|im_start|>assistant
\n
I am fine!
<|im_end|>
\n
"
"<|im_start|>user
\n
你好
<|im_end|>
\n
"
f
"<|im_start|>user
\n
{
MESSAGES
[
0
][
'content'
]
}
<|im_end|>
\n
"
f
"<|im_start|>assistant
\n
{
MESSAGES
[
1
][
'content'
]
}
<|im_end|>
\n
"
f
"<|im_start|>user
\n
{
MESSAGES
[
2
][
'content'
]
}
<|im_end|>
\n
"
"<|im_start|>assistant
\n
"
)
answer_str
=
"很高兴认识你!
<|im_end|>
\n
"
answer_str
=
f
"
{
MESSAGES
[
3
][
'content'
]
}
<|im_end|>
\n
"
_check_template
(
"Qwen/Qwen2.5-7B-Instruct"
,
"qwen"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_qwen3_template
(
use_fast
:
bool
):
@
pytest
.
mark
.
parametrize
(
"cot_messages"
,
[
True
,
False
])
def
test_qwen3_template
(
use_fast
:
bool
,
cot_messages
:
bool
):
messages
=
MESSAGES_WITH_THOUGHT
if
cot_messages
else
MESSAGES
prompt_str
=
(
"<|im_start|>user
\n
How are you
<|im_end|>
\n
"
"<|im_start|>assistant
\n
I am fine!
<|im_end|>
\n
"
"<|im_start|>user
\n
你好
<|im_end|>
\n
"
f
"<|im_start|>user
\n
{
messages
[
0
][
'content'
]
}
<|im_end|>
\n
"
f
"<|im_start|>assistant
\n
{
MESSAGES
[
1
][
'content'
]
}
<|im_end|>
\n
"
f
"<|im_start|>user
\n
{
messages
[
2
][
'content'
]
}
<|im_end|>
\n
"
"<|im_start|>assistant
\n
"
)
answer_str
=
"<think>
\n
模型思考内容
\n
</think>
\n\n
很高兴认识你!<|im_end|>
\n
"
_check_template
(
"Qwen/Qwen3-8B"
,
"qwen3"
,
prompt_str
,
answer_str
,
use_fast
,
messages
=
MESSAGES_WITH_THOUGHT
)
answer_str
=
f
"
{
messages
[
3
][
'content'
]
}
<|im_end|>
\n
"
if
not
cot_messages
:
answer_str
=
"<think>
\n\n
</think>
\n\n
"
+
answer_str
_check_template
(
"Qwen/Qwen3-8B"
,
"qwen3"
,
prompt_str
,
answer_str
,
use_fast
,
messages
=
messages
)
def
test_parse_llama3_template
():
...
...
@@ -253,6 +312,7 @@ def test_parse_llama3_template():
def
test_parse_qwen_template
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen2.5-7B-Instruct"
,
token
=
HF_TOKEN
)
template
=
parse_template
(
tokenizer
)
assert
template
.
__class__
.
__name__
==
"Template"
assert
template
.
format_user
.
slots
==
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]
assert
template
.
format_assistant
.
slots
==
[
"{{content}}<|im_end|>
\n
"
]
assert
template
.
format_system
.
slots
==
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]
...
...
@@ -263,6 +323,7 @@ def test_parse_qwen_template():
def
test_parse_qwen3_template
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen3-8B"
,
token
=
HF_TOKEN
)
template
=
parse_template
(
tokenizer
)
assert
template
.
__class__
.
__name__
==
"ReasoningTemplate"
assert
template
.
format_user
.
slots
==
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]
assert
template
.
format_assistant
.
slots
==
[
"{{content}}<|im_end|>
\n
"
]
assert
template
.
format_system
.
slots
==
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment