Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
0722acf1
Commit
0722acf1
authored
Jun 04, 2025
by
chenych
Browse files
Update 0604
parent
c4ba4563
Changes
68
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
362 additions
and
58 deletions
+362
-58
src/llamafactory/webui/locales.py
src/llamafactory/webui/locales.py
+200
-0
src/llamafactory/webui/runner.py
src/llamafactory/webui/runner.py
+13
-7
tests/data/test_formatter.py
tests/data/test_formatter.py
+27
-5
tests/data/test_mm_plugin.py
tests/data/test_mm_plugin.py
+6
-9
tests/data/test_template.py
tests/data/test_template.py
+98
-27
tests/e2e/test_sglang.py
tests/e2e/test_sglang.py
+1
-1
tests/model/model_utils/test_visual.py
tests/model/model_utils/test_visual.py
+16
-8
tests/version.txt
tests/version.txt
+1
-1
No files found.
src/llamafactory/webui/locales.py
View file @
0722acf1
...
...
@@ -871,6 +871,28 @@ LOCALES = {
"info"
:
"拡張ブロックのパラメータのみをトレーニングします。"
,
},
},
"enable_thinking"
:
{
"en"
:
{
"label"
:
"Enable thinking"
,
"info"
:
"Whether or not to enable thinking mode for reasoning models."
,
},
"ru"
:
{
"label"
:
"Включить мысли"
,
"info"
:
"Включить режим мысли для моделей решающего характера."
,
},
"zh"
:
{
"label"
:
"启用思考模式"
,
"info"
:
"是否启用推理模型的思考模式。"
,
},
"ko"
:
{
"label"
:
"생각 모드 활성화"
,
"info"
:
"추론 모델의 생각 모드를 활성화할지 여부."
,
},
"ja"
:
{
"label"
:
"思考モードを有効化"
,
"info"
:
"推論モデルの思考モードを有効にするかどうか。"
,
},
},
"report_to"
:
{
"en"
:
{
"label"
:
"Enable external logger"
,
...
...
@@ -1374,6 +1396,177 @@ LOCALES = {
"info"
:
"PPO トレーニングにおいて報酬スコアをホワイトニング処理します。"
,
},
},
"mm_tab"
:
{
"en"
:
{
"label"
:
"Multimodal configurations"
,
},
"ru"
:
{
"label"
:
"Конфигурации мультимедиа"
,
},
"zh"
:
{
"label"
:
"多模态参数设置"
,
},
"ko"
:
{
"label"
:
"멀티모달 구성"
,
},
"ja"
:
{
"label"
:
"多モーダル設定"
,
},
},
"freeze_vision_tower"
:
{
"en"
:
{
"label"
:
"Freeze vision tower"
,
"info"
:
"Freeze the vision tower in the model."
,
},
"ru"
:
{
"label"
:
"Заморозить башню визиона"
,
"info"
:
"Заморозить башню визиона в модели."
,
},
"zh"
:
{
"label"
:
"冻结视觉编码器"
,
"info"
:
"冻结模型中的视觉编码器。"
,
},
"ko"
:
{
"label"
:
"비전 타워 고정"
,
"info"
:
"모델의 비전 타워를 고정합니다."
,
},
"ja"
:
{
"label"
:
"ビジョンタワーの固定"
,
"info"
:
"モデルのビジョンタワーを固定します。"
,
},
},
"freeze_multi_modal_projector"
:
{
"en"
:
{
"label"
:
"Freeze multi-modal projector"
,
"info"
:
"Freeze the multi-modal projector in the model."
,
},
"ru"
:
{
"label"
:
"Заморозить мультимодальный проектор"
,
"info"
:
"Заморозить мультимодальный проектор в модели."
,
},
"zh"
:
{
"label"
:
"冻结多模态投影器"
,
"info"
:
"冻结模型中的多模态投影器。"
,
},
"ko"
:
{
"label"
:
"멀티모달 프로젝터 고정"
,
"info"
:
"모델의 멀티모달 프로젝터를 고정합니다."
,
},
"ja"
:
{
"label"
:
"多モーダルプロジェクターの固定"
,
"info"
:
"モデルの多モーダルプロジェクターを固定します。"
,
},
},
"freeze_language_model"
:
{
"en"
:
{
"label"
:
"Freeze language model"
,
"info"
:
"Freeze the language model in the model."
,
},
"ru"
:
{
"label"
:
"Заморозить язык модели"
,
"info"
:
"Заморозить язык модели в модели."
,
},
"zh"
:
{
"label"
:
"冻结语言模型"
,
"info"
:
"冻结模型中的语言模型。"
,
},
"ko"
:
{
"label"
:
"언어 모델 고정"
,
"info"
:
"모델의 언어 모델을 고정합니다."
,
},
"ja"
:
{
"label"
:
"言語モデルの固定"
,
"info"
:
"モデルの言語モデルを固定します。"
,
},
},
"image_max_pixels"
:
{
"en"
:
{
"label"
:
"Image max pixels"
,
"info"
:
"The maximum number of pixels of image inputs."
,
},
"ru"
:
{
"label"
:
"Максимальное количество пикселей изображения"
,
"info"
:
"Максимальное количество пикселей изображения."
,
},
"zh"
:
{
"label"
:
"图像最大像素"
,
"info"
:
"输入图像的最大像素数。"
,
},
"ko"
:
{
"label"
:
"이미지 최대 픽셀"
,
"info"
:
"이미지 입력의 최대 픽셀 수입니다."
,
},
"ja"
:
{
"label"
:
"画像最大ピクセル"
,
"info"
:
"画像入力の最大ピクセル数です。"
,
},
},
"image_min_pixels"
:
{
"en"
:
{
"label"
:
"Image min pixels"
,
"info"
:
"The minimum number of pixels of image inputs."
,
},
"ru"
:
{
"label"
:
"Минимальное количество пикселей изображения"
,
"info"
:
"Минимальное количество пикселей изображения."
,
},
"zh"
:
{
"label"
:
"图像最小像素"
,
"info"
:
"输入图像的最小像素数。"
,
},
"ko"
:
{
"label"
:
"이미지 최소 픽셀"
,
"info"
:
"이미지 입력의 최소 픽셀 수입니다."
,
},
"ja"
:
{
"label"
:
"画像最小ピクセル"
,
"info"
:
"画像入力の最小ピクセル数です。"
,
},
},
"video_max_pixels"
:
{
"en"
:
{
"label"
:
"Video max pixels"
,
"info"
:
"The maximum number of pixels of video inputs."
,
},
"ru"
:
{
"label"
:
"Максимальное количество пикселей видео"
,
"info"
:
"Максимальное количество пикселей видео."
,
},
"zh"
:
{
"label"
:
"视频最大像素"
,
"info"
:
"输入视频的最大像素数。"
,
},
"ko"
:
{
"label"
:
"비디오 최대 픽셀"
,
"info"
:
"비디오 입력의 최대 픽셀 수입니다."
,
},
"ja"
:
{
"label"
:
"ビデオ最大ピクセル"
,
"info"
:
"ビデオ入力の最大ピクセル数です。"
,
},
},
"video_min_pixels"
:
{
"en"
:
{
"label"
:
"Video min pixels"
,
"info"
:
"The minimum number of pixels of video inputs."
,
},
"ru"
:
{
"label"
:
"Минимальное количество пикселей видео"
,
"info"
:
"Минимальное количество пикселей видео."
,
},
"zh"
:
{
"label"
:
"视频最小像素"
,
"info"
:
"输入视频的最小像素数。"
,
},
"ko"
:
{
"label"
:
"비디오 최소 픽셀"
,
"info"
:
"비디오 입력의 최소 픽셀 수입니다."
,
},
"ja"
:
{
"label"
:
"ビデオ最小ピクセル"
,
"info"
:
"ビデオ入力の最小ピクセル数です。"
,
},
},
"galore_tab"
:
{
"en"
:
{
"label"
:
"GaLore configurations"
,
...
...
@@ -2779,6 +2972,13 @@ ALERTS = {
"ko"
:
"출력 디렉토리가 이미 존재합니다. 위 출력 디렉토리에 저장된 학습을 재개합니다."
,
"ja"
:
"出力ディレクトリが既に存在します。このチェックポイントからトレーニングを再開します。"
,
},
"warn_no_instruct"
:
{
"en"
:
"You are using a non-instruct model, please fine-tune it first."
,
"ru"
:
"Вы используете модель без инструкции, пожалуйста, primeros выполните донастройку этой модели."
,
"zh"
:
"您正在使用非指令模型,请先对其进行微调。"
,
"ko"
:
"당신은 지시하지 않은 모델을 사용하고 있습니다. 먼저 이를 미세 조정해 주세요."
,
"ja"
:
"インストラクションモデルを使用していません。まずモデルをアダプターに適合させてください。"
,
},
"info_aborting"
:
{
"en"
:
"Aborted, wait for terminating..."
,
"ru"
:
"Прервано, ожидание завершения..."
,
...
...
src/llamafactory/webui/runner.py
View file @
0722acf1
...
...
@@ -22,13 +22,14 @@ from typing import TYPE_CHECKING, Any, Optional
from
transformers.trainer
import
TRAINING_ARGS_NAME
from
transformers.utils
import
is_torch_npu_available
from
..extras.constants
import
LLAMABOARD_CONFIG
,
PEFT_METHODS
,
TRAINING_STAGES
from
..extras.constants
import
LLAMABOARD_CONFIG
,
MULTIMODAL_SUPPORTED_MODELS
,
PEFT_METHODS
,
TRAINING_STAGES
from
..extras.misc
import
is_accelerator_available
,
torch_gc
,
use_ray
from
..extras.packages
import
is_gradio_available
from
.common
import
(
DEFAULT_CACHE_DIR
,
DEFAULT_CONFIG_DIR
,
abort_process
,
calculate_pixels
,
gen_cmd
,
get_save_dir
,
load_args
,
...
...
@@ -162,6 +163,7 @@ class Runner:
mask_history
=
get
(
"train.mask_history"
),
resize_vocab
=
get
(
"train.resize_vocab"
),
use_llama_pro
=
get
(
"train.use_llama_pro"
),
enable_thinking
=
get
(
"train.enable_thinking"
),
report_to
=
get
(
"train.report_to"
),
use_galore
=
get
(
"train.use_galore"
),
use_apollo
=
get
(
"train.use_apollo"
),
...
...
@@ -235,6 +237,16 @@ class Runner:
args
[
"pref_ftx"
]
=
get
(
"train.pref_ftx"
)
args
[
"pref_loss"
]
=
get
(
"train.pref_loss"
)
# multimodal config
if
model_name
in
MULTIMODAL_SUPPORTED_MODELS
:
args
[
"freeze_vision_tower"
]
=
get
(
"train.freeze_vision_tower"
)
args
[
"freeze_multi_modal_projector"
]
=
get
(
"train.freeze_multi_modal_projector"
)
args
[
"freeze_language_model"
]
=
get
(
"train.freeze_language_model"
)
args
[
"image_max_pixels"
]
=
calculate_pixels
(
get
(
"train.image_max_pixels"
))
args
[
"image_min_pixels"
]
=
calculate_pixels
(
get
(
"train.image_min_pixels"
))
args
[
"video_max_pixels"
]
=
calculate_pixels
(
get
(
"train.video_max_pixels"
))
args
[
"video_min_pixels"
]
=
calculate_pixels
(
get
(
"train.video_min_pixels"
))
# galore config
if
args
[
"use_galore"
]:
args
[
"galore_rank"
]
=
get
(
"train.galore_rank"
)
...
...
@@ -256,12 +268,6 @@ class Runner:
args
[
"badam_switch_interval"
]
=
get
(
"train.badam_switch_interval"
)
args
[
"badam_update_ratio"
]
=
get
(
"train.badam_update_ratio"
)
# report_to
if
"none"
in
args
[
"report_to"
]:
args
[
"report_to"
]
=
"none"
elif
"all"
in
args
[
"report_to"
]:
args
[
"report_to"
]
=
"all"
# swanlab config
if
get
(
"train.use_swanlab"
):
args
[
"swanlab_project"
]
=
get
(
"train.swanlab_project"
)
...
...
tests/data/test_formatter.py
View file @
0722acf1
...
...
@@ -50,7 +50,7 @@ def test_function_formatter():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
,
"</s>"
],
tool_format
=
"default"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}
\n
"""
,
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}"""
,
"</s>"
,
]
...
...
@@ -60,7 +60,7 @@ def test_multi_function_formatter():
tool_calls
=
json
.
dumps
([
FUNCTION
]
*
2
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}
\n
"""
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}
\n
"""
,
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}"""
,
"</s>"
,
]
...
...
@@ -85,7 +85,7 @@ def test_default_tool_formatter():
def
test_default_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"default"
)
result
=
"""Action: test_tool
\n
Action Input: {"foo": "bar", "size": 10}
\n
"""
result
=
"""Action: test_tool
\n
Action Input: {"foo": "bar", "size": 10}"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
...
...
@@ -93,7 +93,7 @@ def test_default_multi_tool_extractor():
formatter
=
ToolFormatter
(
tool_format
=
"default"
)
result
=
(
"""Action: test_tool
\n
Action Input: {"foo": "bar", "size": 10}
\n
"""
"""Action: another_tool
\n
Action Input: {"foo": "job", "size": 2}
\n
"""
"""Action: another_tool
\n
Action Input: {"foo": "job", "size": 2}"""
)
assert
formatter
.
extract
(
result
)
==
[
(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
),
...
...
@@ -125,12 +125,22 @@ def test_glm4_tool_extractor():
def
test_llama3_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
],
tool_format
=
"llama3"
)
tool_calls
=
json
.
dumps
(
{
"name"
:
"tool_name"
,
"arguments"
:
{
"foo"
:
"bar"
,
"size"
:
10
}}
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>"""
]
def
test_llama3_multi_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
],
tool_format
=
"llama3"
)
tool_calls
=
json
.
dumps
([
FUNCTION
]
*
2
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""[{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}]"""
"""<|eot_id|>"""
]
def
test_llama3_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"llama3"
)
date
=
datetime
.
now
().
strftime
(
"%d %b %Y"
)
...
...
@@ -150,6 +160,18 @@ def test_llama3_tool_extractor():
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
def
test_llama3_multi_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"llama3"
)
result
=
(
"""[{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}, """
"""{"name": "another_tool", "parameters": {"foo": "job", "size": 2}}]"""
)
assert
formatter
.
extract
(
result
)
==
[
(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
),
(
"another_tool"
,
"""{"foo": "job", "size": 2}"""
),
]
def
test_mistral_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] {{content}}"
,
"</s>"
],
tool_format
=
"mistral"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
...
...
tests/data/test_mm_plugin.py
View file @
0722acf1
...
...
@@ -135,8 +135,7 @@ def _check_plugin(
expected_mm_inputs
:
dict
[
str
,
Any
]
=
{},
expected_no_mm_inputs
:
dict
[
str
,
Any
]
=
{},
)
->
None
:
# test omni_messages
if
plugin
.
__class__
.
__name__
==
"Qwen2OmniPlugin"
:
if
plugin
.
__class__
.
__name__
==
"Qwen2OmniPlugin"
:
# test omni_messages
assert
plugin
.
process_messages
(
OMNI_MESSAGES
,
IMAGES
,
NO_VIDEOS
,
AUDIOS
,
processor
)
==
expected_mm_messages
assert
plugin
.
process_token_ids
(
INPUT_IDS
,
LABELS
,
IMAGES
,
NO_VIDEOS
,
AUDIOS
,
tokenizer
,
processor
)
==
(
expected_input_ids
,
...
...
@@ -146,8 +145,7 @@ def _check_plugin(
plugin
.
get_mm_inputs
(
IMAGES
,
NO_VIDEOS
,
AUDIOS
,
IMGLENS
,
NO_VIDLENS
,
AUDLENS
,
BATCH_IDS
,
processor
),
expected_mm_inputs
,
)
# test mm_messages
if
plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
elif
plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
# test mm_messages
assert
plugin
.
process_messages
(
MM_MESSAGES
,
IMAGES
,
NO_VIDEOS
,
NO_AUDIOS
,
processor
)
==
expected_mm_messages
assert
plugin
.
process_token_ids
(
INPUT_IDS
,
LABELS
,
IMAGES
,
NO_VIDEOS
,
NO_AUDIOS
,
tokenizer
,
processor
)
==
(
expected_input_ids
,
...
...
@@ -201,7 +199,7 @@ def test_gemma3_plugin():
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
xfail
(
reason
=
"Unknown error.
"
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.52.0"
),
reason
=
"Requires transformers>=4.52.0
"
)
def
test_internvl_plugin
():
image_seqlen
=
256
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"OpenGVLab/InternVL3-1B-hf"
)
...
...
@@ -219,7 +217,7 @@ def test_internvl_plugin():
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
xfail
(
reason
=
"Unknown error.
"
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.51.0"
),
reason
=
"Requires transformers>=4.51.0
"
)
def
test_llama4_plugin
():
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
TINY_LLAMA4
)
processor
=
tokenizer_module
[
"processor"
]
...
...
@@ -321,10 +319,9 @@ def test_pixtral_plugin():
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
xfail
(
reason
=
"Unknown error.
"
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.52.0"
),
reason
=
"Requires transformers>=4.52.0
"
)
def
test_qwen2_omni_plugin
():
image_seqlen
=
4
audio_seqlen
=
2
image_seqlen
,
audio_seqlen
=
4
,
2
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"Qwen/Qwen2.5-Omni-7B"
)
qwen2_omni_plugin
=
get_mm_plugin
(
name
=
"qwen2_omni"
,
audio_token
=
"<|AUDIO|>"
,
image_token
=
"<|IMAGE|>"
,
video_token
=
"<|VIDEO|>"
...
...
tests/data/test_template.py
View file @
0722acf1
...
...
@@ -125,6 +125,61 @@ def test_encode_multiturn(use_fast: bool):
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"cot_messages"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enable_thinking"
,
[
True
,
False
,
None
])
def
test_reasoning_encode_oneturn
(
use_fast
:
bool
,
cot_messages
:
bool
,
enable_thinking
:
bool
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen3-8B"
,
use_fast
=
use_fast
)
data_args
=
DataArguments
(
template
=
"qwen3"
,
enable_thinking
=
enable_thinking
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
prompt_ids
,
answer_ids
=
template
.
encode_oneturn
(
tokenizer
,
MESSAGES_WITH_THOUGHT
if
cot_messages
else
MESSAGES
)
prompt_str
=
(
f
"<|im_start|>user
\n
{
MESSAGES
[
0
][
'content'
]
}
<|im_end|>
\n
<|im_start|>assistant
\n
"
f
"
{
MESSAGES
[
1
][
'content'
]
}
<|im_end|>
\n
"
f
"<|im_start|>user
\n
{
MESSAGES
[
2
][
'content'
]
}
<|im_end|>
\n
<|im_start|>assistant
\n
"
)
if
not
cot_messages
or
enable_thinking
is
False
:
answer_str
=
f
"
{
MESSAGES
[
3
][
'content'
]
}
<|im_end|>
\n
"
if
enable_thinking
:
answer_str
=
"<think>
\n\n
</think>
\n\n
"
+
answer_str
else
:
prompt_str
=
prompt_str
+
"<think>
\n\n
</think>
\n\n
"
else
:
answer_str
=
f
"
{
MESSAGES_WITH_THOUGHT
[
3
][
'content'
]
}
<|im_end|>
\n
"
_check_tokenization
(
tokenizer
,
(
prompt_ids
,
answer_ids
),
(
prompt_str
,
answer_str
))
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"cot_messages"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enable_thinking"
,
[
True
,
False
,
None
])
def
test_reasoning_encode_multiturn
(
use_fast
:
bool
,
cot_messages
:
bool
,
enable_thinking
:
bool
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen3-8B"
,
use_fast
=
use_fast
)
data_args
=
DataArguments
(
template
=
"qwen3"
,
enable_thinking
=
enable_thinking
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
encoded_pairs
=
template
.
encode_multiturn
(
tokenizer
,
MESSAGES_WITH_THOUGHT
if
cot_messages
else
MESSAGES
)
messages
=
MESSAGES
if
not
cot_messages
or
enable_thinking
is
False
else
MESSAGES_WITH_THOUGHT
prompt_str_1
=
f
"<|im_start|>user
\n
{
MESSAGES
[
0
][
'content'
]
}
<|im_end|>
\n
<|im_start|>assistant
\n
"
answer_str_1
=
f
"
{
messages
[
1
][
'content'
]
}
<|im_end|>
\n
"
prompt_str_2
=
f
"<|im_start|>user
\n
{
MESSAGES
[
2
][
'content'
]
}
<|im_end|>
\n
<|im_start|>assistant
\n
"
answer_str_2
=
f
"
{
messages
[
3
][
'content'
]
}
<|im_end|>
\n
"
if
not
cot_messages
or
enable_thinking
is
False
:
if
enable_thinking
:
answer_str_1
=
"<think>
\n\n
</think>
\n\n
"
+
answer_str_1
answer_str_2
=
"<think>
\n\n
</think>
\n\n
"
+
answer_str_2
else
:
prompt_str_1
=
prompt_str_1
+
"<think>
\n\n
</think>
\n\n
"
prompt_str_2
=
prompt_str_2
+
"<think>
\n\n
</think>
\n\n
"
_check_tokenization
(
tokenizer
,
(
encoded_pairs
[
0
][
0
],
encoded_pairs
[
0
][
1
],
encoded_pairs
[
1
][
0
],
encoded_pairs
[
1
][
1
]),
(
prompt_str_1
,
answer_str_1
,
prompt_str_2
,
answer_str_2
),
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_jinja_template
(
use_fast
:
bool
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
,
use_fast
=
use_fast
)
...
...
@@ -162,12 +217,12 @@ def test_get_stop_token_ids():
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_gemma_template
(
use_fast
:
bool
):
prompt_str
=
(
"<bos><start_of_turn>user
\n
How are you
<end_of_turn>
\n
"
"<start_of_turn>model
\n
I am fine!
<end_of_turn>
\n
"
"<start_of_turn>user
\n
你好
<end_of_turn>
\n
"
f
"<bos><start_of_turn>user
\n
{
MESSAGES
[
0
][
'content'
]
}
<end_of_turn>
\n
"
f
"<start_of_turn>model
\n
{
MESSAGES
[
1
][
'content'
]
}
<end_of_turn>
\n
"
f
"<start_of_turn>user
\n
{
MESSAGES
[
2
][
'content'
]
}
<end_of_turn>
\n
"
"<start_of_turn>model
\n
"
)
answer_str
=
"很高兴认识你!
<end_of_turn>
\n
"
answer_str
=
f
"
{
MESSAGES
[
3
][
'content'
]
}
<end_of_turn>
\n
"
_check_template
(
"google/gemma-3-4b-it"
,
"gemma"
,
prompt_str
,
answer_str
,
use_fast
)
...
...
@@ -175,12 +230,12 @@ def test_gemma_template(use_fast: bool):
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_llama3_template
(
use_fast
:
bool
):
prompt_str
=
(
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>
\n\n
How are you
<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
I am fine!
<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>
\n\n
你好
<|eot_id|>"
f
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>
\n\n
{
MESSAGES
[
0
][
'content'
]
}
<|eot_id|>"
f
"<|start_header_id|>assistant<|end_header_id|>
\n\n
{
MESSAGES
[
1
][
'content'
]
}
<|eot_id|>"
f
"<|start_header_id|>user<|end_header_id|>
\n\n
{
MESSAGES
[
2
][
'content'
]
}
<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
answer_str
=
"很高兴认识你!
<|eot_id|>"
answer_str
=
f
"
{
MESSAGES
[
3
][
'content'
]
}
<|eot_id|>"
_check_template
(
"meta-llama/Meta-Llama-3-8B-Instruct"
,
"llama3"
,
prompt_str
,
answer_str
,
use_fast
)
...
...
@@ -189,52 +244,64 @@ def test_llama3_template(use_fast: bool):
)
def
test_llama4_template
(
use_fast
:
bool
):
prompt_str
=
(
"<|begin_of_text|><|header_start|>user<|header_end|>
\n\n
How are you
<|eot|>"
"<|header_start|>assistant<|header_end|>
\n\n
I am fine!
<|eot|>"
"<|header_start|>user<|header_end|>
\n\n
你好
<|eot|>"
f
"<|begin_of_text|><|header_start|>user<|header_end|>
\n\n
{
MESSAGES
[
0
][
'content'
]
}
<|eot|>"
f
"<|header_start|>assistant<|header_end|>
\n\n
{
MESSAGES
[
1
][
'content'
]
}
<|eot|>"
f
"<|header_start|>user<|header_end|>
\n\n
{
MESSAGES
[
2
][
'content'
]
}
<|eot|>"
"<|header_start|>assistant<|header_end|>
\n\n
"
)
answer_str
=
"很高兴认识你!
<|eot|>"
answer_str
=
f
"
{
MESSAGES
[
3
][
'content'
]
}
<|eot|>"
_check_template
(
TINY_LLAMA4
,
"llama4"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
pytest
.
param
(
False
,
marks
=
pytest
.
mark
.
xfail
(
reason
=
"Phi-4 slow tokenizer is broken."
))]
"use_fast"
,
[
pytest
.
param
(
True
,
marks
=
pytest
.
mark
.
xfail
(
not
HF_TOKEN
,
reason
=
"Authorization."
)),
pytest
.
param
(
False
,
marks
=
pytest
.
mark
.
xfail
(
reason
=
"Phi-4 slow tokenizer is broken."
)),
],
)
def
test_phi4_template
(
use_fast
:
bool
):
prompt_str
=
(
"<|im_start|>user<|im_sep|>
How are you
<|im_end|>"
"<|im_start|>assistant<|im_sep|>
I am fine!
<|im_end|>"
"<|im_start|>user<|im_sep|>
你好
<|im_end|>"
f
"<|im_start|>user<|im_sep|>
{
MESSAGES
[
0
][
'content'
]
}
<|im_end|>"
f
"<|im_start|>assistant<|im_sep|>
{
MESSAGES
[
1
][
'content'
]
}
<|im_end|>"
f
"<|im_start|>user<|im_sep|>
{
MESSAGES
[
2
][
'content'
]
}
<|im_end|>"
"<|im_start|>assistant<|im_sep|>"
)
answer_str
=
"很高兴认识你!
<|im_end|>"
answer_str
=
f
"
{
MESSAGES
[
3
][
'content'
]
}
<|im_end|>"
_check_template
(
"microsoft/phi-4"
,
"phi4"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
xfail
(
not
HF_TOKEN
,
reason
=
"Authorization."
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_qwen2_5_template
(
use_fast
:
bool
):
prompt_str
=
(
"<|im_start|>system
\n
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
\n
"
"<|im_start|>user
\n
How are you
<|im_end|>
\n
"
"<|im_start|>assistant
\n
I am fine!
<|im_end|>
\n
"
"<|im_start|>user
\n
你好
<|im_end|>
\n
"
f
"<|im_start|>user
\n
{
MESSAGES
[
0
][
'content'
]
}
<|im_end|>
\n
"
f
"<|im_start|>assistant
\n
{
MESSAGES
[
1
][
'content'
]
}
<|im_end|>
\n
"
f
"<|im_start|>user
\n
{
MESSAGES
[
2
][
'content'
]
}
<|im_end|>
\n
"
"<|im_start|>assistant
\n
"
)
answer_str
=
"很高兴认识你!
<|im_end|>
\n
"
answer_str
=
f
"
{
MESSAGES
[
3
][
'content'
]
}
<|im_end|>
\n
"
_check_template
(
"Qwen/Qwen2.5-7B-Instruct"
,
"qwen"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_qwen3_template
(
use_fast
:
bool
):
@
pytest
.
mark
.
parametrize
(
"cot_messages"
,
[
True
,
False
])
def
test_qwen3_template
(
use_fast
:
bool
,
cot_messages
:
bool
):
prompt_str
=
(
"<|im_start|>user
\n
How are you
<|im_end|>
\n
"
"<|im_start|>assistant
\n
I am fine!
<|im_end|>
\n
"
"<|im_start|>user
\n
你好
<|im_end|>
\n
"
f
"<|im_start|>user
\n
{
MESSAGES
[
0
][
'content'
]
}
<|im_end|>
\n
"
f
"<|im_start|>assistant
\n
{
MESSAGES
[
1
][
'content'
]
}
<|im_end|>
\n
"
f
"<|im_start|>user
\n
{
MESSAGES
[
2
][
'content'
]
}
<|im_end|>
\n
"
"<|im_start|>assistant
\n
"
)
answer_str
=
"<think>
\n
模型思考内容
\n
</think>
\n\n
很高兴认识你!<|im_end|>
\n
"
_check_template
(
"Qwen/Qwen3-8B"
,
"qwen3"
,
prompt_str
,
answer_str
,
use_fast
,
messages
=
MESSAGES_WITH_THOUGHT
)
if
not
cot_messages
:
answer_str
=
f
"<think>
\n\n
</think>
\n\n
{
MESSAGES
[
3
][
'content'
]
}
<|im_end|>
\n
"
messages
=
MESSAGES
else
:
answer_str
=
f
"
{
MESSAGES_WITH_THOUGHT
[
3
][
'content'
]
}
<|im_end|>
\n
"
messages
=
MESSAGES_WITH_THOUGHT
_check_template
(
"Qwen/Qwen3-8B"
,
"qwen3"
,
prompt_str
,
answer_str
,
use_fast
,
messages
=
messages
)
def
test_parse_llama3_template
():
...
...
@@ -250,9 +317,11 @@ def test_parse_llama3_template():
assert
template
.
default_system
==
""
@
pytest
.
mark
.
xfail
(
not
HF_TOKEN
,
reason
=
"Authorization."
)
def
test_parse_qwen_template
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen2.5-7B-Instruct"
,
token
=
HF_TOKEN
)
template
=
parse_template
(
tokenizer
)
assert
template
.
__class__
.
__name__
==
"Template"
assert
template
.
format_user
.
slots
==
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]
assert
template
.
format_assistant
.
slots
==
[
"{{content}}<|im_end|>
\n
"
]
assert
template
.
format_system
.
slots
==
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]
...
...
@@ -260,9 +329,11 @@ def test_parse_qwen_template():
assert
template
.
default_system
==
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
@
pytest
.
mark
.
xfail
(
not
HF_TOKEN
,
reason
=
"Authorization."
)
def
test_parse_qwen3_template
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen3-8B"
,
token
=
HF_TOKEN
)
template
=
parse_template
(
tokenizer
)
assert
template
.
__class__
.
__name__
==
"ReasoningTemplate"
assert
template
.
format_user
.
slots
==
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]
assert
template
.
format_assistant
.
slots
==
[
"{{content}}<|im_end|>
\n
"
]
assert
template
.
format_system
.
slots
==
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]
...
...
tests/e2e/test_sglang.py
View file @
0722acf1
...
...
@@ -20,7 +20,7 @@ from llamafactory.chat import ChatModel
from
llamafactory.extras.packages
import
is_sglang_available
MODEL_NAME
=
"
meta-llama/Llama-3.2-1B-Instruct
"
MODEL_NAME
=
"
Qwen/Qwen2.5-0.5B
"
INFER_ARGS
=
{
...
...
tests/model/model_utils/test_visual.py
View file @
0722acf1
...
...
@@ -16,6 +16,7 @@ import pytest
import
torch
from
transformers
import
AutoConfig
,
AutoModelForVision2Seq
from
llamafactory.extras.packages
import
is_transformers_version_greater_than
from
llamafactory.hparams
import
FinetuningArguments
,
ModelArguments
from
llamafactory.model.adapter
import
init_adapter
...
...
@@ -45,10 +46,12 @@ def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bo
assert
param
.
requires_grad
!=
freeze_language_model
@
pytest
.
mark
.
parametrize
(
"freeze_vision_tower
"
,
(
False
,
True
))
def
test_visual_lora
(
freeze_vision_tower
:
bool
):
@
pytest
.
mark
.
parametrize
(
"freeze_vision_tower
,freeze_language_model"
,
((
False
,
False
),
(
False
,
True
),
(
True
,
False
)
))
def
test_visual_lora
(
freeze_vision_tower
:
bool
,
freeze_language_model
:
bool
):
model_args
=
ModelArguments
(
model_name_or_path
=
"Qwen/Qwen2-VL-2B-Instruct"
)
finetuning_args
=
FinetuningArguments
(
finetuning_type
=
"lora"
,
freeze_vision_tower
=
freeze_vision_tower
)
finetuning_args
=
FinetuningArguments
(
finetuning_type
=
"lora"
,
freeze_vision_tower
=
freeze_vision_tower
,
freeze_language_model
=
freeze_language_model
)
config
=
AutoConfig
.
from_pretrained
(
model_args
.
model_name_or_path
)
with
torch
.
device
(
"meta"
):
model
=
AutoModelForVision2Seq
.
from_config
(
config
)
...
...
@@ -61,10 +64,15 @@ def test_visual_lora(freeze_vision_tower: bool):
else
:
frozen_params
.
add
(
name
)
if
freeze_vision_tower
:
assert
"base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight"
not
in
trainable_params
if
is_transformers_version_greater_than
(
"4.52.0"
):
visual_param_name
=
"base_model.model.model.visual.blocks.0.attn.qkv.lora_A.default.weight"
language_param_name
=
"base_model.model.model.language_model.layers.0.self_attn.q_proj.lora_A.default.weight"
merger_param_name
=
"base_model.model.model.visual.merger.lora_A.default.weight"
else
:
assert
"base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight"
in
trainable_params
visual_param_name
=
"base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight"
language_param_name
=
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight"
merger_param_name
=
"base_model.model.visual.merger.lora_A.default.weight"
assert
"merger"
not
in
trainable_params
assert
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight"
in
trainable_params
assert
(
visual_param_name
in
trainable_params
)
!=
freeze_vision_tower
assert
(
language_param_name
in
trainable_params
)
!=
freeze_language_model
assert
(
merger_param_name
in
trainable_params
)
is
False
tests/version.txt
View file @
0722acf1
# change if test fails or cache is outdated
0.9.3.10
6
0.9.3.10
7
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment