Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
ca625f43
Commit
ca625f43
authored
Mar 30, 2026
by
shihm
Browse files
uodata
parent
7164651d
Changes
327
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1127 additions
and
15 deletions
+1127
-15
tests/data/test_mm_plugin.py
tests/data/test_mm_plugin.py
+51
-1
tests/data/test_template.py
tests/data/test_template.py
+17
-0
tests/e2e/test_chat.py
tests/e2e/test_chat.py
+4
-0
tests/e2e/test_sglang.py
tests/e2e/test_sglang.py
+2
-0
tests/e2e/test_train.py
tests/e2e/test_train.py
+2
-0
tests/eval/test_eval_template.py
tests/eval/test_eval_template.py
+4
-0
tests/model/model_utils/test_attention.py
tests/model/model_utils/test_attention.py
+11
-1
tests/model/test_base.py
tests/model/test_base.py
+1
-6
tests/model/test_lora.py
tests/model/test_lora.py
+0
-6
tests/version.txt
tests/version.txt
+1
-1
tests_v1/accelerator/test_interface.py
tests_v1/accelerator/test_interface.py
+59
-0
tests_v1/config/test_args_parser.py
tests_v1/config/test_args_parser.py
+71
-0
tests_v1/conftest.py
tests_v1/conftest.py
+161
-0
tests_v1/core/test_data_engine.py
tests_v1/core/test_data_engine.py
+36
-0
tests_v1/core/test_data_loader.py
tests_v1/core/test_data_loader.py
+173
-0
tests_v1/core/test_model_loader.py
tests_v1/core/test_model_loader.py
+51
-0
tests_v1/core/utils/test_batching.py
tests_v1/core/utils/test_batching.py
+52
-0
tests_v1/core/utils/test_rendering.py
tests_v1/core/utils/test_rendering.py
+243
-0
tests_v1/plugins/data_plugins/test_converter.py
tests_v1/plugins/data_plugins/test_converter.py
+125
-0
tests_v1/plugins/model_plugins/test_init_plugin.py
tests_v1/plugins/model_plugins/test_init_plugin.py
+63
-0
No files found.
tests/data/test_mm_plugin.py
View file @
ca625f43
...
@@ -56,10 +56,17 @@ TEXT_MESSAGES = [
...
@@ -56,10 +56,17 @@ TEXT_MESSAGES = [
{
"role"
:
"assistant"
,
"content"
:
"I am fine!"
},
{
"role"
:
"assistant"
,
"content"
:
"I am fine!"
},
]
]
VIDEO_MESSAGES
=
[
{
"role"
:
"user"
,
"content"
:
"<video>What is in this viode?"
},
{
"role"
:
"assistant"
,
"content"
:
"A cat."
},
]
AUDIOS
=
[
np
.
zeros
(
1600
)]
AUDIOS
=
[
np
.
zeros
(
1600
)]
IMAGES
=
[
Image
.
new
(
"RGB"
,
(
32
,
32
),
(
255
,
255
,
255
))]
IMAGES
=
[
Image
.
new
(
"RGB"
,
(
32
,
32
),
(
255
,
255
,
255
))]
VIDEOS
=
[[
Image
.
new
(
"RGB"
,
(
32
,
32
),
(
255
,
255
,
255
))]
*
4
]
NO_IMAGES
=
[]
NO_IMAGES
=
[]
NO_VIDEOS
=
[]
NO_VIDEOS
=
[]
...
@@ -145,6 +152,8 @@ def _check_plugin(
...
@@ -145,6 +152,8 @@ def _check_plugin(
plugin
.
get_mm_inputs
(
IMAGES
,
NO_VIDEOS
,
AUDIOS
,
IMGLENS
,
NO_VIDLENS
,
AUDLENS
,
BATCH_IDS
,
processor
),
plugin
.
get_mm_inputs
(
IMAGES
,
NO_VIDEOS
,
AUDIOS
,
IMGLENS
,
NO_VIDLENS
,
AUDLENS
,
BATCH_IDS
,
processor
),
expected_mm_inputs
,
expected_mm_inputs
,
)
)
elif
plugin
.
__class__
.
__name__
==
"Qwen3VLPlugin"
:
# only check replacement
assert
plugin
.
process_messages
(
VIDEO_MESSAGES
,
NO_IMAGES
,
VIDEOS
,
NO_AUDIOS
,
processor
)
==
expected_mm_messages
elif
plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
# test mm_messages
elif
plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
# test mm_messages
assert
plugin
.
process_messages
(
MM_MESSAGES
,
IMAGES
,
NO_VIDEOS
,
NO_AUDIOS
,
processor
)
==
expected_mm_messages
assert
plugin
.
process_messages
(
MM_MESSAGES
,
IMAGES
,
NO_VIDEOS
,
NO_AUDIOS
,
processor
)
==
expected_mm_messages
assert
plugin
.
process_token_ids
(
INPUT_IDS
,
LABELS
,
IMAGES
,
NO_VIDEOS
,
NO_AUDIOS
,
tokenizer
,
processor
)
==
(
assert
plugin
.
process_token_ids
(
INPUT_IDS
,
LABELS
,
IMAGES
,
NO_VIDEOS
,
NO_AUDIOS
,
tokenizer
,
processor
)
==
(
...
@@ -170,6 +179,7 @@ def _check_plugin(
...
@@ -170,6 +179,7 @@ def _check_plugin(
)
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_base_plugin
():
def
test_base_plugin
():
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
TINY_LLAMA3
)
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
TINY_LLAMA3
)
base_plugin
=
get_mm_plugin
(
name
=
"base"
)
base_plugin
=
get_mm_plugin
(
name
=
"base"
)
...
@@ -177,6 +187,7 @@ def test_base_plugin():
...
@@ -177,6 +187,7 @@ def test_base_plugin():
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.50.0"
),
reason
=
"Requires transformers>=4.50.0"
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.50.0"
),
reason
=
"Requires transformers>=4.50.0"
)
def
test_gemma3_plugin
():
def
test_gemma3_plugin
():
...
@@ -199,6 +210,7 @@ def test_gemma3_plugin():
...
@@ -199,6 +210,7 @@ def test_gemma3_plugin():
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.52.0"
),
reason
=
"Requires transformers>=4.52.0"
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.52.0"
),
reason
=
"Requires transformers>=4.52.0"
)
def
test_internvl_plugin
():
def
test_internvl_plugin
():
image_seqlen
=
256
image_seqlen
=
256
...
@@ -217,6 +229,7 @@ def test_internvl_plugin():
...
@@ -217,6 +229,7 @@ def test_internvl_plugin():
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.51.0"
),
reason
=
"Requires transformers>=4.51.0"
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.51.0"
),
reason
=
"Requires transformers>=4.51.0"
)
def
test_llama4_plugin
():
def
test_llama4_plugin
():
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
TINY_LLAMA4
)
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
TINY_LLAMA4
)
...
@@ -238,6 +251,7 @@ def test_llama4_plugin():
...
@@ -238,6 +251,7 @@ def test_llama4_plugin():
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_llava_plugin
():
def
test_llava_plugin
():
image_seqlen
=
576
image_seqlen
=
576
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-1.5-7b-hf"
)
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-1.5-7b-hf"
)
...
@@ -251,6 +265,7 @@ def test_llava_plugin():
...
@@ -251,6 +265,7 @@ def test_llava_plugin():
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_llava_next_plugin
():
def
test_llava_next_plugin
():
image_seqlen
=
1176
image_seqlen
=
1176
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-v1.6-vicuna-7b-hf"
)
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-v1.6-vicuna-7b-hf"
)
...
@@ -264,6 +279,7 @@ def test_llava_next_plugin():
...
@@ -264,6 +279,7 @@ def test_llava_next_plugin():
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_llava_next_video_plugin
():
def
test_llava_next_video_plugin
():
image_seqlen
=
1176
image_seqlen
=
1176
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/LLaVA-NeXT-Video-7B-hf"
)
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/LLaVA-NeXT-Video-7B-hf"
)
...
@@ -277,6 +293,7 @@ def test_llava_next_video_plugin():
...
@@ -277,6 +293,7 @@ def test_llava_next_video_plugin():
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
def
test_paligemma_plugin
():
def
test_paligemma_plugin
():
image_seqlen
=
256
image_seqlen
=
256
...
@@ -296,6 +313,7 @@ def test_paligemma_plugin():
...
@@ -296,6 +313,7 @@ def test_paligemma_plugin():
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.50.0"
),
reason
=
"Requires transformers>=4.50.0"
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.50.0"
),
reason
=
"Requires transformers>=4.50.0"
)
def
test_pixtral_plugin
():
def
test_pixtral_plugin
():
image_slice_height
,
image_slice_width
=
2
,
2
image_slice_height
,
image_slice_width
=
2
,
2
...
@@ -318,12 +336,20 @@ def test_pixtral_plugin():
...
@@ -318,12 +336,20 @@ def test_pixtral_plugin():
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.52.0"
),
reason
=
"Requires transformers>=4.52.0"
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.52.0"
),
reason
=
"Requires transformers>=4.52.0"
)
def
test_qwen2_omni_plugin
():
def
test_qwen2_omni_plugin
():
image_seqlen
,
audio_seqlen
=
4
,
2
image_seqlen
,
audio_seqlen
=
4
,
2
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"Qwen/Qwen2.5-Omni-7B"
)
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"Qwen/Qwen2.5-Omni-7B"
)
qwen2_omni_plugin
=
get_mm_plugin
(
qwen2_omni_plugin
=
get_mm_plugin
(
name
=
"qwen2_omni"
,
audio_token
=
"<|AUDIO|>"
,
image_token
=
"<|IMAGE|>"
,
video_token
=
"<|VIDEO|>"
name
=
"qwen2_omni"
,
image_token
=
"<|IMAGE|>"
,
video_token
=
"<|VIDEO|>"
,
audio_token
=
"<|AUDIO|>"
,
vision_bos_token
=
"<|vision_bos|>"
,
vision_eos_token
=
"<|vision_eos|>"
,
audio_bos_token
=
"<|audio_bos|>"
,
audio_eos_token
=
"<|audio_eos|>"
,
)
)
check_inputs
=
{
"plugin"
:
qwen2_omni_plugin
,
**
tokenizer_module
}
check_inputs
=
{
"plugin"
:
qwen2_omni_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
check_inputs
[
"expected_mm_messages"
]
=
[
...
@@ -341,6 +367,7 @@ def test_qwen2_omni_plugin():
...
@@ -341,6 +367,7 @@ def test_qwen2_omni_plugin():
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_qwen2_vl_plugin
():
def
test_qwen2_vl_plugin
():
image_seqlen
=
4
image_seqlen
=
4
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"Qwen/Qwen2-VL-7B-Instruct"
)
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"Qwen/Qwen2-VL-7B-Instruct"
)
...
@@ -357,6 +384,29 @@ def test_qwen2_vl_plugin():
...
@@ -357,6 +384,29 @@ def test_qwen2_vl_plugin():
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.57.0"
),
reason
=
"Requires transformers>=4.57.0"
)
def
test_qwen3_vl_plugin
():
frame_seqlen
=
1
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"Qwen/Qwen3-VL-30B-A3B-Instruct"
)
qwen3_vl_plugin
=
get_mm_plugin
(
name
=
"qwen3_vl"
,
video_token
=
"<|video_pad|>"
)
check_inputs
=
{
"plugin"
:
qwen3_vl_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<video>"
,
# little different with original processor for default `fps=2` in our repo
"<0.2 seconds><|vision_start|>{}<|vision_end|><1.2 seconds><|vision_start|>{}<|vision_end|>"
.
format
(
"<|video_pad|>"
*
frame_seqlen
,
"<|video_pad|>"
*
frame_seqlen
),
)
for
key
,
value
in
message
.
items
()
}
for
message
in
VIDEO_MESSAGES
]
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.47.0"
),
reason
=
"Requires transformers>=4.47.0"
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.47.0"
),
reason
=
"Requires transformers>=4.47.0"
)
def
test_video_llava_plugin
():
def
test_video_llava_plugin
():
image_seqlen
=
256
image_seqlen
=
256
...
...
tests/data/test_template.py
View file @
ca625f43
...
@@ -89,6 +89,7 @@ def _check_template(
...
@@ -89,6 +89,7 @@ def _check_template(
_check_tokenization
(
tokenizer
,
(
prompt_ids
,
answer_ids
),
(
prompt_str
,
answer_str
))
_check_tokenization
(
tokenizer
,
(
prompt_ids
,
answer_ids
),
(
prompt_str
,
answer_str
))
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_encode_oneturn
(
use_fast
:
bool
):
def
test_encode_oneturn
(
use_fast
:
bool
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
,
use_fast
=
use_fast
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
,
use_fast
=
use_fast
)
...
@@ -104,6 +105,7 @@ def test_encode_oneturn(use_fast: bool):
...
@@ -104,6 +105,7 @@ def test_encode_oneturn(use_fast: bool):
_check_tokenization
(
tokenizer
,
(
prompt_ids
,
answer_ids
),
(
prompt_str
,
answer_str
))
_check_tokenization
(
tokenizer
,
(
prompt_ids
,
answer_ids
),
(
prompt_str
,
answer_str
))
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_encode_multiturn
(
use_fast
:
bool
):
def
test_encode_multiturn
(
use_fast
:
bool
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
,
use_fast
=
use_fast
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
,
use_fast
=
use_fast
)
...
@@ -125,6 +127,7 @@ def test_encode_multiturn(use_fast: bool):
...
@@ -125,6 +127,7 @@ def test_encode_multiturn(use_fast: bool):
)
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"cot_messages"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"cot_messages"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enable_thinking"
,
[
True
,
False
,
None
])
@
pytest
.
mark
.
parametrize
(
"enable_thinking"
,
[
True
,
False
,
None
])
...
@@ -151,6 +154,7 @@ def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thi
...
@@ -151,6 +154,7 @@ def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thi
_check_tokenization
(
tokenizer
,
(
prompt_ids
,
answer_ids
),
(
prompt_str
,
answer_str
))
_check_tokenization
(
tokenizer
,
(
prompt_ids
,
answer_ids
),
(
prompt_str
,
answer_str
))
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"cot_messages"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"cot_messages"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enable_thinking"
,
[
True
,
False
,
None
])
@
pytest
.
mark
.
parametrize
(
"enable_thinking"
,
[
True
,
False
,
None
])
...
@@ -180,6 +184,7 @@ def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_t
...
@@ -180,6 +184,7 @@ def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_t
)
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_jinja_template
(
use_fast
:
bool
):
def
test_jinja_template
(
use_fast
:
bool
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
,
use_fast
=
use_fast
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
,
use_fast
=
use_fast
)
...
@@ -190,6 +195,7 @@ def test_jinja_template(use_fast: bool):
...
@@ -190,6 +195,7 @@ def test_jinja_template(use_fast: bool):
assert
tokenizer
.
apply_chat_template
(
MESSAGES
)
==
ref_tokenizer
.
apply_chat_template
(
MESSAGES
)
assert
tokenizer
.
apply_chat_template
(
MESSAGES
)
==
ref_tokenizer
.
apply_chat_template
(
MESSAGES
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_ollama_modelfile
():
def
test_ollama_modelfile
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
DataArguments
(
template
=
"llama3"
))
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
DataArguments
(
template
=
"llama3"
))
...
@@ -207,12 +213,14 @@ def test_ollama_modelfile():
...
@@ -207,12 +213,14 @@ def test_ollama_modelfile():
)
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_get_stop_token_ids
():
def
test_get_stop_token_ids
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
DataArguments
(
template
=
"llama3"
))
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
DataArguments
(
template
=
"llama3"
))
assert
set
(
template
.
get_stop_token_ids
(
tokenizer
))
==
{
128008
,
128009
}
assert
set
(
template
.
get_stop_token_ids
(
tokenizer
))
==
{
128008
,
128009
}
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_gemma_template
(
use_fast
:
bool
):
def
test_gemma_template
(
use_fast
:
bool
):
...
@@ -226,6 +234,7 @@ def test_gemma_template(use_fast: bool):
...
@@ -226,6 +234,7 @@ def test_gemma_template(use_fast: bool):
_check_template
(
"google/gemma-3-4b-it"
,
"gemma"
,
prompt_str
,
answer_str
,
use_fast
)
_check_template
(
"google/gemma-3-4b-it"
,
"gemma"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_gemma2_template
(
use_fast
:
bool
):
def
test_gemma2_template
(
use_fast
:
bool
):
...
@@ -239,6 +248,7 @@ def test_gemma2_template(use_fast: bool):
...
@@ -239,6 +248,7 @@ def test_gemma2_template(use_fast: bool):
_check_template
(
"google/gemma-2-2b-it"
,
"gemma2"
,
prompt_str
,
answer_str
,
use_fast
)
_check_template
(
"google/gemma-2-2b-it"
,
"gemma2"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_llama3_template
(
use_fast
:
bool
):
def
test_llama3_template
(
use_fast
:
bool
):
...
@@ -252,6 +262,7 @@ def test_llama3_template(use_fast: bool):
...
@@ -252,6 +262,7 @@ def test_llama3_template(use_fast: bool):
_check_template
(
"meta-llama/Meta-Llama-3-8B-Instruct"
,
"llama3"
,
prompt_str
,
answer_str
,
use_fast
)
_check_template
(
"meta-llama/Meta-Llama-3-8B-Instruct"
,
"llama3"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
pytest
.
param
(
False
,
marks
=
pytest
.
mark
.
xfail
(
reason
=
"Llama 4 has no slow tokenizer."
))]
"use_fast"
,
[
True
,
pytest
.
param
(
False
,
marks
=
pytest
.
mark
.
xfail
(
reason
=
"Llama 4 has no slow tokenizer."
))]
)
)
...
@@ -273,6 +284,7 @@ def test_llama4_template(use_fast: bool):
...
@@ -273,6 +284,7 @@ def test_llama4_template(use_fast: bool):
pytest
.
param
(
False
,
marks
=
pytest
.
mark
.
xfail
(
reason
=
"Phi-4 slow tokenizer is broken."
)),
pytest
.
param
(
False
,
marks
=
pytest
.
mark
.
xfail
(
reason
=
"Phi-4 slow tokenizer is broken."
)),
],
],
)
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_phi4_template
(
use_fast
:
bool
):
def
test_phi4_template
(
use_fast
:
bool
):
prompt_str
=
(
prompt_str
=
(
f
"<|im_start|>user<|im_sep|>
{
MESSAGES
[
0
][
'content'
]
}
<|im_end|>"
f
"<|im_start|>user<|im_sep|>
{
MESSAGES
[
0
][
'content'
]
}
<|im_end|>"
...
@@ -284,6 +296,7 @@ def test_phi4_template(use_fast: bool):
...
@@ -284,6 +296,7 @@ def test_phi4_template(use_fast: bool):
_check_template
(
"microsoft/phi-4"
,
"phi4"
,
prompt_str
,
answer_str
,
use_fast
)
_check_template
(
"microsoft/phi-4"
,
"phi4"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
xfail
(
not
HF_TOKEN
,
reason
=
"Authorization."
)
@
pytest
.
mark
.
xfail
(
not
HF_TOKEN
,
reason
=
"Authorization."
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_qwen2_5_template
(
use_fast
:
bool
):
def
test_qwen2_5_template
(
use_fast
:
bool
):
...
@@ -298,6 +311,7 @@ def test_qwen2_5_template(use_fast: bool):
...
@@ -298,6 +311,7 @@ def test_qwen2_5_template(use_fast: bool):
_check_template
(
"Qwen/Qwen2.5-7B-Instruct"
,
"qwen"
,
prompt_str
,
answer_str
,
use_fast
)
_check_template
(
"Qwen/Qwen2.5-7B-Instruct"
,
"qwen"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"cot_messages"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"cot_messages"
,
[
True
,
False
])
def
test_qwen3_template
(
use_fast
:
bool
,
cot_messages
:
bool
):
def
test_qwen3_template
(
use_fast
:
bool
,
cot_messages
:
bool
):
...
@@ -317,6 +331,7 @@ def test_qwen3_template(use_fast: bool, cot_messages: bool):
...
@@ -317,6 +331,7 @@ def test_qwen3_template(use_fast: bool, cot_messages: bool):
_check_template
(
"Qwen/Qwen3-8B"
,
"qwen3"
,
prompt_str
,
answer_str
,
use_fast
,
messages
=
messages
)
_check_template
(
"Qwen/Qwen3-8B"
,
"qwen3"
,
prompt_str
,
answer_str
,
use_fast
,
messages
=
messages
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_parse_llama3_template
():
def
test_parse_llama3_template
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
,
token
=
HF_TOKEN
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
,
token
=
HF_TOKEN
)
template
=
parse_template
(
tokenizer
)
template
=
parse_template
(
tokenizer
)
...
@@ -330,6 +345,7 @@ def test_parse_llama3_template():
...
@@ -330,6 +345,7 @@ def test_parse_llama3_template():
assert
template
.
default_system
==
""
assert
template
.
default_system
==
""
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
xfail
(
not
HF_TOKEN
,
reason
=
"Authorization."
)
@
pytest
.
mark
.
xfail
(
not
HF_TOKEN
,
reason
=
"Authorization."
)
def
test_parse_qwen_template
():
def
test_parse_qwen_template
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen2.5-7B-Instruct"
,
token
=
HF_TOKEN
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen2.5-7B-Instruct"
,
token
=
HF_TOKEN
)
...
@@ -342,6 +358,7 @@ def test_parse_qwen_template():
...
@@ -342,6 +358,7 @@ def test_parse_qwen_template():
assert
template
.
default_system
==
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
assert
template
.
default_system
==
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
xfail
(
not
HF_TOKEN
,
reason
=
"Authorization."
)
@
pytest
.
mark
.
xfail
(
not
HF_TOKEN
,
reason
=
"Authorization."
)
def
test_parse_qwen3_template
():
def
test_parse_qwen3_template
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen3-8B"
,
token
=
HF_TOKEN
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen3-8B"
,
token
=
HF_TOKEN
)
...
...
tests/e2e/test_chat.py
View file @
ca625f43
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
import
os
import
os
import
pytest
from
llamafactory.chat
import
ChatModel
from
llamafactory.chat
import
ChatModel
...
@@ -35,11 +37,13 @@ MESSAGES = [
...
@@ -35,11 +37,13 @@ MESSAGES = [
EXPECTED_RESPONSE
=
"_rho"
EXPECTED_RESPONSE
=
"_rho"
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_chat
():
def
test_chat
():
chat_model
=
ChatModel
(
INFER_ARGS
)
chat_model
=
ChatModel
(
INFER_ARGS
)
assert
chat_model
.
chat
(
MESSAGES
)[
0
].
response_text
==
EXPECTED_RESPONSE
assert
chat_model
.
chat
(
MESSAGES
)[
0
].
response_text
==
EXPECTED_RESPONSE
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_stream_chat
():
def
test_stream_chat
():
chat_model
=
ChatModel
(
INFER_ARGS
)
chat_model
=
ChatModel
(
INFER_ARGS
)
response
=
""
response
=
""
...
...
tests/e2e/test_sglang.py
View file @
ca625f43
...
@@ -39,6 +39,7 @@ MESSAGES = [
...
@@ -39,6 +39,7 @@ MESSAGES = [
]
]
@
pytest
.
mark
.
runs_on
([
"cuda"
])
@
pytest
.
mark
.
skipif
(
not
is_sglang_available
(),
reason
=
"SGLang is not installed"
)
@
pytest
.
mark
.
skipif
(
not
is_sglang_available
(),
reason
=
"SGLang is not installed"
)
def
test_chat
():
def
test_chat
():
r
"""Test the SGLang engine's basic chat functionality."""
r
"""Test the SGLang engine's basic chat functionality."""
...
@@ -48,6 +49,7 @@ def test_chat():
...
@@ -48,6 +49,7 @@ def test_chat():
print
(
response
.
response_text
)
print
(
response
.
response_text
)
@
pytest
.
mark
.
runs_on
([
"cuda"
])
@
pytest
.
mark
.
skipif
(
not
is_sglang_available
(),
reason
=
"SGLang is not installed"
)
@
pytest
.
mark
.
skipif
(
not
is_sglang_available
(),
reason
=
"SGLang is not installed"
)
def
test_stream_chat
():
def
test_stream_chat
():
r
"""Test the SGLang engine's streaming chat functionality."""
r
"""Test the SGLang engine's streaming chat functionality."""
...
...
tests/e2e/test_train.py
View file @
ca625f43
...
@@ -49,6 +49,7 @@ INFER_ARGS = {
...
@@ -49,6 +49,7 @@ INFER_ARGS = {
OS_NAME
=
os
.
getenv
(
"OS_NAME"
,
""
)
OS_NAME
=
os
.
getenv
(
"OS_NAME"
,
""
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"stage,dataset"
,
"stage,dataset"
,
[
[
...
@@ -65,6 +66,7 @@ def test_run_exp(stage: str, dataset: str):
...
@@ -65,6 +66,7 @@ def test_run_exp(stage: str, dataset: str):
assert
os
.
path
.
exists
(
output_dir
)
assert
os
.
path
.
exists
(
output_dir
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_export
():
def
test_export
():
export_dir
=
os
.
path
.
join
(
"output"
,
"llama3_export"
)
export_dir
=
os
.
path
.
join
(
"output"
,
"llama3_export"
)
export_model
({
"export_dir"
:
export_dir
,
**
INFER_ARGS
})
export_model
({
"export_dir"
:
export_dir
,
**
INFER_ARGS
})
...
...
tests/eval/test_eval_template.py
View file @
ca625f43
...
@@ -12,9 +12,12 @@
...
@@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
pytest
from
llamafactory.eval.template
import
get_eval_template
from
llamafactory.eval.template
import
get_eval_template
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_eval_template_en
():
def
test_eval_template_en
():
support_set
=
[
support_set
=
[
{
{
...
@@ -53,6 +56,7 @@ def test_eval_template_en():
...
@@ -53,6 +56,7 @@ def test_eval_template_en():
]
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_eval_template_zh
():
def
test_eval_template_zh
():
support_set
=
[
support_set
=
[
{
{
...
...
tests/model/model_utils/test_attention.py
View file @
ca625f43
...
@@ -15,7 +15,17 @@
...
@@ -15,7 +15,17 @@
import
os
import
os
import
pytest
import
pytest
from
transformers.utils
import
is_flash_attn_2_available
,
is_torch_sdpa_available
from
transformers.utils
import
is_flash_attn_2_available
# Compatible with Transformers v4 and Transformers v5
try
:
from
transformers.utils
import
is_torch_sdpa_available
except
ImportError
:
def
is_torch_sdpa_available
():
return
True
from
llamafactory.extras.packages
import
is_transformers_version_greater_than
from
llamafactory.extras.packages
import
is_transformers_version_greater_than
from
llamafactory.train.test_utils
import
load_infer_model
from
llamafactory.train.test_utils
import
load_infer_model
...
...
tests/model/test_base.py
View file @
ca625f43
...
@@ -16,7 +16,7 @@ import os
...
@@ -16,7 +16,7 @@ import os
import
pytest
import
pytest
from
llamafactory.train.test_utils
import
compare_model
,
load_infer_model
,
load_reference_model
,
patch_valuehead_model
from
llamafactory.train.test_utils
import
compare_model
,
load_infer_model
,
load_reference_model
TINY_LLAMA3
=
os
.
getenv
(
"TINY_LLAMA3"
,
"llamafactory/tiny-random-Llama-3"
)
TINY_LLAMA3
=
os
.
getenv
(
"TINY_LLAMA3"
,
"llamafactory/tiny-random-Llama-3"
)
...
@@ -30,11 +30,6 @@ INFER_ARGS = {
...
@@ -30,11 +30,6 @@ INFER_ARGS = {
}
}
@
pytest
.
fixture
def
fix_valuehead_cpu_loading
():
patch_valuehead_model
()
def
test_base
():
def
test_base
():
model
=
load_infer_model
(
**
INFER_ARGS
)
model
=
load_infer_model
(
**
INFER_ARGS
)
ref_model
=
load_reference_model
(
TINY_LLAMA3
)
ref_model
=
load_reference_model
(
TINY_LLAMA3
)
...
...
tests/model/test_lora.py
View file @
ca625f43
...
@@ -23,7 +23,6 @@ from llamafactory.train.test_utils import (
...
@@ -23,7 +23,6 @@ from llamafactory.train.test_utils import (
load_infer_model
,
load_infer_model
,
load_reference_model
,
load_reference_model
,
load_train_model
,
load_train_model
,
patch_valuehead_model
,
)
)
...
@@ -56,11 +55,6 @@ INFER_ARGS = {
...
@@ -56,11 +55,6 @@ INFER_ARGS = {
}
}
@
pytest
.
fixture
def
fix_valuehead_cpu_loading
():
patch_valuehead_model
()
def
test_lora_train_qv_modules
():
def
test_lora_train_qv_modules
():
model
=
load_train_model
(
lora_target
=
"q_proj,v_proj"
,
**
TRAIN_ARGS
)
model
=
load_train_model
(
lora_target
=
"q_proj,v_proj"
,
**
TRAIN_ARGS
)
linear_modules
,
_
=
check_lora_model
(
model
)
linear_modules
,
_
=
check_lora_model
(
model
)
...
...
tests/version.txt
View file @
ca625f43
# change if test fails or cache is outdated
# change if test fails or cache is outdated
0.9.4.10
0
0.9.4.10
5
tests_v1/accelerator/test_interface.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
pytest
import
torch.multiprocessing
as
mp
from
llamafactory.v1.accelerator.helper
import
ReduceOp
from
llamafactory.v1.accelerator.interface
import
DistributedInterface
from
llamafactory.v1.utils.env
import
find_available_port
from
llamafactory.v1.utils.pytest
import
dist_env
def
_all_reduce_tests
(
local_rank
:
int
,
world_size
:
int
,
master_port
:
int
):
with
dist_env
(
local_rank
,
world_size
,
master_port
):
rank
=
DistributedInterface
().
get_rank
()
world_size
=
DistributedInterface
().
get_world_size
()
assert
world_size
==
2
y_sum
=
DistributedInterface
().
all_reduce
(
rank
+
1.0
,
op
=
ReduceOp
.
SUM
)
assert
y_sum
==
pytest
.
approx
(
3.0
)
y_mean
=
DistributedInterface
().
all_reduce
(
rank
+
1.0
,
op
=
ReduceOp
.
MEAN
)
assert
y_mean
==
pytest
.
approx
(
1.5
)
y_max
=
DistributedInterface
().
all_reduce
(
rank
+
1.0
,
op
=
ReduceOp
.
MAX
)
assert
y_max
==
pytest
.
approx
(
2.0
)
z
=
DistributedInterface
().
all_gather
(
rank
+
1.0
)
assert
z
==
pytest
.
approx
([
1.0
,
2.0
])
z
=
DistributedInterface
().
broadcast
(
rank
+
1.0
)
assert
z
==
pytest
.
approx
(
1.0
)
def
test_all_device
():
assert
DistributedInterface
().
get_rank
()
==
int
(
os
.
getenv
(
"RANK"
,
"0"
))
assert
DistributedInterface
().
get_world_size
()
==
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
assert
DistributedInterface
().
get_local_rank
()
==
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
assert
DistributedInterface
().
get_local_world_size
()
==
int
(
os
.
getenv
(
"LOCAL_WORLD_SIZE"
,
"1"
))
@
pytest
.
mark
.
runs_on
([
"cuda"
,
"npu"
])
@
pytest
.
mark
.
require_distributed
(
2
)
def
test_multi_device
():
master_port
=
find_available_port
()
mp
.
spawn
(
_all_reduce_tests
,
args
=
(
2
,
master_port
),
nprocs
=
2
)
tests_v1/config/test_args_parser.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pathlib
import
sys
from
unittest.mock
import
patch
from
llamafactory.v1.config.arg_parser
import
get_args
def
test_get_args_from_yaml
(
tmp_path
:
pathlib
.
Path
):
config_yaml
=
"""
### model
model: "llamafactory/tiny-random-qwen2.5"
trust_remote_code: true
use_fast_processor: true
model_class: "llm"
kernel_config:
name: "auto"
include_kernels: "auto" # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
peft_config:
name: "lora"
lora_rank: 0.8
quant_config: null
### data
dataset: "llamafactory/tiny-supervised-dataset"
cutoff_len: 2048
### training
output_dir: "outputs/test_run"
micro_batch_size: 1
global_batch_size: 1
learning_rate: 1.0e-4
bf16: false
dist_config: null
### sample
sample_backend: "hf"
max_new_tokens: 128
"""
config_file
=
tmp_path
/
"config.yaml"
config_file
.
write_text
(
config_yaml
,
encoding
=
"utf-8"
)
test_argv
=
[
"test_args_parser.py"
,
str
(
config_file
)]
with
patch
.
object
(
sys
,
"argv"
,
test_argv
):
data_args
,
model_args
,
training_args
,
sample_args
=
get_args
()
assert
training_args
.
output_dir
==
"outputs/test_run"
assert
training_args
.
micro_batch_size
==
1
assert
training_args
.
global_batch_size
==
1
assert
training_args
.
learning_rate
==
1.0e-4
assert
training_args
.
bf16
is
False
assert
training_args
.
dist_config
is
None
assert
model_args
.
model
==
"llamafactory/tiny-random-qwen2.5"
assert
model_args
.
kernel_config
.
name
==
"auto"
assert
model_args
.
kernel_config
.
get
(
"include_kernels"
)
==
"auto"
assert
model_args
.
peft_config
.
name
==
"lora"
assert
model_args
.
peft_config
.
get
(
"lora_rank"
)
==
0.8
tests_v1/conftest.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LLaMA-Factory test configuration.
Contains shared fixtures, pytest configuration, and custom markers.
"""
import
os
import
sys
import
pytest
import
torch
from
pytest
import
Config
,
FixtureRequest
,
Item
,
MonkeyPatch
from
llamafactory.v1.accelerator.helper
import
get_current_accelerator
,
get_device_count
from
llamafactory.v1.utils.env
import
is_env_enabled
from
llamafactory.v1.utils.packages
import
is_transformers_version_greater_than
CURRENT_DEVICE
=
get_current_accelerator
().
type
def
pytest_configure
(
config
:
Config
):
"""Register custom pytest markers."""
config
.
addinivalue_line
(
"markers"
,
"slow: marks tests as slow (deselect with '-m
\"
not slow
\"
' or set RUN_SLOW=1 to run)"
,
)
config
.
addinivalue_line
(
"markers"
,
"runs_on: test requires specific device type, e.g., @pytest.mark.runs_on(['cuda'])"
,
)
config
.
addinivalue_line
(
"markers"
,
"require_distributed(num_devices): allow multi-device execution (default: 2)"
,
)
def
_handle_runs_on
(
items
:
list
[
Item
]):
"""Skip tests on specified device TYPES (cpu/cuda/npu)."""
for
item
in
items
:
marker
=
item
.
get_closest_marker
(
"runs_on"
)
if
not
marker
:
continue
devices
=
marker
.
args
[
0
]
if
isinstance
(
devices
,
str
):
devices
=
[
devices
]
if
CURRENT_DEVICE
not
in
devices
:
item
.
add_marker
(
pytest
.
mark
.
skip
(
reason
=
f
"test requires one of
{
devices
}
(current:
{
CURRENT_DEVICE
}
)"
))
def
_handle_slow_tests
(
items
:
list
[
Item
]):
"""Skip slow tests unless RUN_SLOW is enabled."""
if
not
is_env_enabled
(
"RUN_SLOW"
):
skip_slow
=
pytest
.
mark
.
skip
(
reason
=
"slow test (set RUN_SLOW=1 to run)"
)
for
item
in
items
:
if
"slow"
in
item
.
keywords
:
item
.
add_marker
(
skip_slow
)
def
_get_visible_devices_env
()
->
str
|
None
:
"""Return device visibility env var name."""
if
CURRENT_DEVICE
==
"cuda"
:
return
"CUDA_VISIBLE_DEVICES"
elif
CURRENT_DEVICE
==
"npu"
:
return
"ASCEND_RT_VISIBLE_DEVICES"
else
:
return
None
def
_handle_device_visibility
(
items
:
list
[
Item
]):
"""Handle device visibility based on test markers."""
env_key
=
_get_visible_devices_env
()
if
env_key
is
None
or
CURRENT_DEVICE
in
(
"cpu"
,
"mps"
):
return
# Parse visible devices
visible_devices_env
=
os
.
environ
.
get
(
env_key
)
if
visible_devices_env
is
None
:
available
=
get_device_count
()
else
:
visible_devices
=
[
v
for
v
in
visible_devices_env
.
split
(
","
)
if
v
!=
""
]
available
=
len
(
visible_devices
)
for
item
in
items
:
marker
=
item
.
get_closest_marker
(
"require_distributed"
)
if
not
marker
:
continue
required
=
marker
.
args
[
0
]
if
marker
.
args
else
2
if
available
<
required
:
item
.
add_marker
(
pytest
.
mark
.
skip
(
reason
=
f
"test requires
{
required
}
devices, but only
{
available
}
visible"
))
def
pytest_collection_modifyitems
(
config
:
Config
,
items
:
list
[
Item
]):
"""Modify test collection based on markers and environment."""
# Handle version compatibility (from HEAD)
if
not
is_transformers_version_greater_than
(
"4.57.0"
):
skip_bc
=
pytest
.
mark
.
skip
(
reason
=
"Skip backward compatibility tests"
)
for
item
in
items
:
if
"tests_v1"
in
str
(
item
.
fspath
):
item
.
add_marker
(
skip_bc
)
_handle_slow_tests
(
items
)
_handle_runs_on
(
items
)
_handle_device_visibility
(
items
)
@
pytest
.
fixture
(
autouse
=
True
)
def
_manage_distributed_env
(
request
:
FixtureRequest
,
monkeypatch
:
MonkeyPatch
)
->
None
:
"""Set environment variables for distributed tests if specific devices are requested."""
env_key
=
_get_visible_devices_env
()
if
not
env_key
:
return
# Save old environment for logic checks, monkeypatch handles restoration
old_value
=
os
.
environ
.
get
(
env_key
)
marker
=
request
.
node
.
get_closest_marker
(
"require_distributed"
)
if
marker
:
# distributed test
required
=
marker
.
args
[
0
]
if
marker
.
args
else
2
specific_devices
=
marker
.
args
[
1
]
if
len
(
marker
.
args
)
>
1
else
None
if
specific_devices
:
devices_str
=
","
.
join
(
map
(
str
,
specific_devices
))
else
:
devices_str
=
","
.
join
(
str
(
i
)
for
i
in
range
(
required
))
monkeypatch
.
setenv
(
env_key
,
devices_str
)
# add project root dir to path for mp run
project_root
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
))
if
project_root
not
in
sys
.
path
:
sys
.
path
.
insert
(
0
,
project_root
)
os
.
environ
[
"PYTHONPATH"
]
=
project_root
+
os
.
pathsep
+
os
.
environ
.
get
(
"PYTHONPATH"
,
""
)
else
:
# non-distributed test
if
old_value
:
visible_devices
=
[
v
for
v
in
old_value
.
split
(
","
)
if
v
!=
""
]
monkeypatch
.
setenv
(
env_key
,
visible_devices
[
0
]
if
visible_devices
else
"0"
)
else
:
monkeypatch
.
setenv
(
env_key
,
"0"
)
if
CURRENT_DEVICE
==
"cuda"
:
monkeypatch
.
setattr
(
torch
.
cuda
,
"device_count"
,
lambda
:
1
)
elif
CURRENT_DEVICE
==
"npu"
:
monkeypatch
.
setattr
(
torch
.
npu
,
"device_count"
,
lambda
:
1
)
tests_v1/core/test_data_engine.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
random
import
pytest
from
datasets
import
load_dataset
from
llamafactory.v1.config.data_args
import
DataArguments
from
llamafactory.v1.core.data_engine
import
DataEngine
@
pytest
.
mark
.
parametrize
(
"num_samples"
,
[
16
])
def
test_map_dataset
(
num_samples
:
int
):
data_args
=
DataArguments
(
dataset
=
"llamafactory/v1-sft-demo"
)
data_engine
=
DataEngine
(
data_args
)
original_data
=
load_dataset
(
"llamafactory/v1-sft-demo"
,
split
=
"train"
)
indexes
=
random
.
choices
(
range
(
len
(
data_engine
)),
k
=
num_samples
)
for
index
in
indexes
:
print
(
data_engine
[
index
])
assert
data_engine
[
index
]
==
{
"_dataset_name"
:
"default"
,
**
original_data
[
index
]}
if
__name__
==
"__main__"
:
test_map_dataset
(
1
)
tests_v1/core/test_data_loader.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
Tests the 4 scenarios:
a) non pack + non dynamic.
b) non pack + dynamic.
c) pack + non dynamic.
d) pack + dynamic.
"""
import
torch
from
torch.utils.data
import
DataLoader
as
TorchDataLoader
from
torch.utils.data
import
Dataset
from
transformers
import
AutoTokenizer
from
llamafactory.v1.config.data_args
import
DataArguments
from
llamafactory.v1.core.data_engine
import
DataEngine
from
llamafactory.v1.core.trainer_utils.data_collator
import
(
DefaultCollator
,
)
from
llamafactory.v1.core.trainer_utils.data_loader
import
DataLoader
from
llamafactory.v1.plugins.data_plugins.template
import
QwenTemplate
from
llamafactory.v1.utils.batching_queue
import
TextBatchingQueue
class
TensorDataset
(
Dataset
):
"""Wrapper dataset that converts DataEngine samples to tensor format."""
def
__init__
(
self
,
data_engine
:
DataEngine
,
processor
,
template
,
max_samples
:
int
=
None
):
self
.
data_engine
=
data_engine
self
.
processor
=
processor
self
.
template
=
template
self
.
max_samples
=
max_samples
or
len
(
data_engine
)
self
.
tokenizer
=
processor
.
tokenizer
if
hasattr
(
processor
,
"tokenizer"
)
else
processor
def
__len__
(
self
):
return
min
(
self
.
max_samples
,
len
(
self
.
data_engine
))
def
__getitem__
(
self
,
idx
):
# Get sample from DataEngine
sample
=
self
.
data_engine
[
idx
]
# Extract messages from sample
# DataEngine returns samples with format like {"messages": [...], ...}
# For llamafactory/v1-sft-demo, the format should have "messages" field
messages
=
None
if
"messages"
in
sample
:
messages
=
sample
[
"messages"
]
elif
"conversations"
in
sample
:
messages
=
sample
[
"conversations"
]
elif
"conversation"
in
sample
:
messages
=
sample
[
"conversation"
]
else
:
# Try to find message-like fields (skip _dataset_name)
for
key
,
value
in
sample
.
items
():
if
key
.
startswith
(
"_"
):
continue
if
isinstance
(
value
,
list
)
and
len
(
value
)
>
0
:
# Check if it looks like a message list
if
isinstance
(
value
[
0
],
dict
)
and
"role"
in
value
[
0
]:
messages
=
value
break
if
messages
is
None
:
raise
ValueError
(
f
"Could not find messages in sample:
{
list
(
sample
.
keys
())
}
"
)
# Encode messages using template
encoded
=
self
.
template
.
encode_messages
(
self
.
tokenizer
,
messages
)
# Convert to tensors
return
{
"input_ids"
:
torch
.
tensor
(
encoded
[
"input_ids"
],
dtype
=
torch
.
long
),
"attention_mask"
:
torch
.
tensor
(
encoded
[
"attention_mask"
],
dtype
=
torch
.
long
),
"labels"
:
torch
.
tensor
(
encoded
[
"labels"
],
dtype
=
torch
.
long
),
}
def
create_real_dataset
(
max_samples
:
int
=
20
,
batch_size
:
int
=
4
):
"""Create a real dataset using DataEngine."""
data_args
=
DataArguments
(
dataset
=
"llamafactory/v1-sft-demo"
)
data_engine
=
DataEngine
(
data_args
)
# Create processor and template
processor
=
AutoTokenizer
.
from_pretrained
(
"llamafactory/tiny-random-qwen2.5"
)
template
=
QwenTemplate
()
# Create tensor dataset
raw_data_dataset
=
TensorDataset
(
data_engine
,
processor
,
template
,
max_samples
=
max_samples
)
# Create torch DataLoader
torch_dataloader
=
TorchDataLoader
(
raw_data_dataset
,
batch_size
=
batch_size
,
shuffle
=
False
,
collate_fn
=
lambda
x
:
x
,
)
return
torch_dataloader
,
processor
,
template
class
TestDataLoaderNonPackNonDynamic
:
"""Test case a) non pack + non dynamic."""
def
test_basic_functionality
(
self
):
"""Test DataLoader without packing and without dynamic batching."""
# Create real dataset
torch_dataloader
,
processor
,
template
=
create_real_dataset
(
max_samples
=
80
,
batch_size
=
8
)
# Create collator (non-packing)
collator
=
DefaultCollator
(
processor
=
processor
,
template
=
template
)
# Create DataLoader without batching_queue (non-dynamic)
data_loader
=
DataLoader
(
dataloader
=
torch_dataloader
,
collate_fn
=
collator
,
num_micro_batch
=
1
,
batching_queue
=
None
,
)
# Iterate and check results
batches
=
list
(
iter
(
data_loader
))
assert
len
(
batches
)
>
0
# Check first batch
one_batch
=
batches
[
0
]
micro_batches
=
one_batch
[
0
]
assert
"input_ids"
in
micro_batches
assert
"attention_mask"
in
micro_batches
assert
"labels"
in
micro_batches
assert
micro_batches
[
"input_ids"
].
shape
[
0
]
==
1
# batch_size=1
assert
micro_batches
[
"input_ids"
].
ndim
==
2
# [batch_size, seq_len]
class
TestDataLoaderNonPackDynamic
:
"""Test case b) non pack + dynamic."""
def
test_basic_functionality
(
self
):
"""Test DataLoader without packing but with dynamic batching."""
# Create real dataset
torch_dataloader
,
processor
,
template
=
create_real_dataset
(
max_samples
=
80
,
batch_size
=
8
)
collator
=
DefaultCollator
(
processor
=
processor
,
template
=
template
)
# Create batching queue for dynamic batching
batching_queue
=
TextBatchingQueue
(
token_micro_bsz
=
120
,
buffer_size
=
8
,
)
data_loader
=
DataLoader
(
dataloader
=
torch_dataloader
,
collate_fn
=
collator
,
num_micro_batch
=
4
,
batching_queue
=
batching_queue
,
)
# Iterate and check
batches
=
list
(
iter
(
data_loader
))
micro_batch_tokens_first
=
[
micro_batch
[
"attention_mask"
].
sum
()
for
micro_batch
in
batches
[
0
]]
assert
all
(
num_tokens
<=
120
for
num_tokens
in
micro_batch_tokens_first
)
assert
len
(
batches
)
>
0
tests_v1/core/test_model_loader.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
llamafactory.v1.config.model_args
import
ModelArguments
,
PluginConfig
from
llamafactory.v1.core.model_loader
import
ModelLoader
def
test_tiny_qwen
():
from
transformers
import
Qwen2Config
,
Qwen2ForCausalLM
,
Qwen2TokenizerFast
model_args
=
ModelArguments
(
model
=
"llamafactory/tiny-random-qwen2.5"
)
model_loader
=
ModelLoader
(
model_args
)
assert
isinstance
(
model_loader
.
processor
,
Qwen2TokenizerFast
)
assert
isinstance
(
model_loader
.
model
.
config
,
Qwen2Config
)
assert
isinstance
(
model_loader
.
model
,
Qwen2ForCausalLM
)
assert
model_loader
.
model
.
dtype
==
torch
.
bfloat16
def
test_tiny_qwen_with_kernel_plugin
():
from
transformers
import
Qwen2ForCausalLM
from
llamafactory.v1.plugins.model_plugins.kernels.ops.rms_norm.npu_rms_norm
import
npu_rms_norm_forward
model_args
=
ModelArguments
(
model
=
"llamafactory/tiny-random-qwen2.5"
,
kernel_config
=
PluginConfig
(
name
=
"auto"
,
include_kernels
=
"auto"
)
)
model_loader
=
ModelLoader
(
model_args
)
# test enable apply kernel plugin
if
hasattr
(
torch
,
"npu"
):
assert
model_loader
.
model
.
model
.
layers
[
0
].
input_layernorm
.
forward
.
__code__
==
npu_rms_norm_forward
.
__code__
else
:
assert
model_loader
.
model
.
model
.
layers
[
0
].
input_layernorm
.
forward
.
__code__
!=
npu_rms_norm_forward
.
__code__
assert
isinstance
(
model_loader
.
model
,
Qwen2ForCausalLM
)
if
__name__
==
"__main__"
:
test_tiny_qwen
()
test_tiny_qwen_with_kernel_plugin
()
tests_v1/core/utils/test_batching.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
llamafactory.v1.config
import
DataArguments
,
ModelArguments
,
TrainingArguments
from
llamafactory.v1.core.data_engine
import
DataEngine
from
llamafactory.v1.core.model_engine
import
ModelEngine
from
llamafactory.v1.core.utils.batching
import
BatchGenerator
def
test_normal_batching
():
data_args
=
DataArguments
(
train_dataset
=
"llamafactory/v1-sft-demo"
)
data_engine
=
DataEngine
(
data_args
.
train_dataset
)
model_args
=
ModelArguments
(
model
=
"llamafactory/tiny-random-qwen3"
)
model_engine
=
ModelEngine
(
model_args
=
model_args
)
training_args
=
TrainingArguments
(
micro_batch_size
=
4
,
global_batch_size
=
8
,
cutoff_len
=
10
,
batching_workers
=
0
,
batching_strategy
=
"normal"
,
)
batch_generator
=
BatchGenerator
(
data_engine
,
model_engine
.
renderer
,
micro_batch_size
=
training_args
.
micro_batch_size
,
global_batch_size
=
training_args
.
global_batch_size
,
cutoff_len
=
training_args
.
cutoff_len
,
batching_workers
=
training_args
.
batching_workers
,
batching_strategy
=
training_args
.
batching_strategy
,
)
assert
len
(
batch_generator
)
==
len
(
data_engine
)
//
training_args
.
global_batch_size
batch
=
next
(
iter
(
batch_generator
))
assert
len
(
batch
)
==
2
assert
batch
[
0
][
"input_ids"
].
shape
==
(
4
,
10
)
if
__name__
==
"__main__"
:
"""
python -m tests_v1.core.utils.test_batching
"""
test_normal_batching
()
tests_v1/core/utils/test_rendering.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
pytest
from
transformers
import
AutoTokenizer
from
llamafactory.v1.config
import
DataArguments
from
llamafactory.v1.core.data_engine
import
DataEngine
from
llamafactory.v1.core.utils.rendering
import
Renderer
from
llamafactory.v1.utils.types
import
Processor
def
_get_input_ids
(
inputs
:
list
|
dict
)
->
list
:
if
not
isinstance
(
inputs
,
list
):
return
inputs
[
"input_ids"
]
else
:
return
inputs
HF_MESSAGES
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant."
},
{
"role"
:
"user"
,
"content"
:
"What is LLM?"
},
{
"role"
:
"assistant"
,
"content"
:
"LLM stands for Large Language Model."
},
]
V1_MESSAGES
=
[
{
"role"
:
"system"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
"You are a helpful assistant."
}]},
{
"role"
:
"user"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
"What is LLM?"
}]},
{
"role"
:
"assistant"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
"LLM stands for Large Language Model."
}]},
]
HF_MESSAGES_WITH_TOOLS
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant."
},
{
"role"
:
"user"
,
"content"
:
"What is 6*8?"
},
{
"role"
:
"assistant"
,
"tool_calls"
:
[{
"type"
:
"function"
,
"function"
:
{
"name"
:
"multiply"
,
"arguments"
:
{
"a"
:
6
,
"b"
:
8
}}}],
},
{
"role"
:
"tool"
,
"content"
:
"48."
},
{
"role"
:
"assistant"
,
"content"
:
"The result of 6*8 is 48."
},
]
V1_MESSAGES_WITH_TOOLS
=
[
{
"role"
:
"system"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
"You are a helpful assistant."
}]},
{
"role"
:
"user"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
"What is 6*8?"
}]},
{
"role"
:
"assistant"
,
"content"
:
[{
"type"
:
"tool_call"
,
"value"
:
json
.
dumps
({
"name"
:
"multiply"
,
"arguments"
:
{
"a"
:
6
,
"b"
:
8
}})}],
"loss_weight"
:
0.0
,
},
{
"role"
:
"tool"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
"48."
}]},
{
"role"
:
"assistant"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
"The result of 6*8 is 48."
}]},
]
V1_TOOLS
=
[
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"multiply"
,
"description"
:
"A function that multiplies two numbers"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"a"
:
{
"type"
:
"number"
,
"description"
:
"The first number to multiply"
},
"b"
:
{
"type"
:
"number"
,
"description"
:
"The second number to multiply"
},
},
"required"
:
[
"a"
,
"b"
],
},
},
}
]
def
test_chatml_rendering
():
tokenizer
:
Processor
=
AutoTokenizer
.
from_pretrained
(
"llamafactory/tiny-random-qwen3"
)
renderer
=
Renderer
(
template
=
"chatml"
,
processor
=
tokenizer
)
hf_inputs
=
_get_input_ids
(
tokenizer
.
apply_chat_template
(
HF_MESSAGES
[:
-
1
],
add_generation_prompt
=
True
))
v1_inputs
=
renderer
.
render_messages
(
V1_MESSAGES
[:
-
1
],
is_generate
=
True
)
assert
v1_inputs
[
"input_ids"
]
==
hf_inputs
assert
v1_inputs
[
"attention_mask"
]
==
[
1
]
*
len
(
hf_inputs
)
assert
v1_inputs
[
"labels"
]
==
[
-
100
]
*
len
(
hf_inputs
)
assert
v1_inputs
[
"loss_weights"
]
==
[
0.0
]
*
len
(
hf_inputs
)
hf_inputs_part
=
_get_input_ids
(
tokenizer
.
apply_chat_template
(
HF_MESSAGES
[:
-
1
],
add_generation_prompt
=
False
))
hf_inputs_full
=
_get_input_ids
(
tokenizer
.
apply_chat_template
(
HF_MESSAGES
,
add_generation_prompt
=
False
))
v1_inputs_full
=
renderer
.
render_messages
(
V1_MESSAGES
,
is_generate
=
False
)
assert
v1_inputs_full
[
"input_ids"
]
==
hf_inputs_full
assert
v1_inputs_full
[
"attention_mask"
]
==
[
1
]
*
len
(
hf_inputs_full
)
assert
v1_inputs_full
[
"labels"
]
==
[
-
100
]
*
len
(
hf_inputs_part
)
+
hf_inputs_full
[
len
(
hf_inputs_part
)
:]
assert
v1_inputs_full
[
"loss_weights"
]
==
[
0.0
]
*
len
(
hf_inputs_part
)
+
[
1.0
]
*
(
len
(
hf_inputs_full
)
-
len
(
hf_inputs_part
)
)
def
test_chatml_parse
():
tokenizer
:
Processor
=
AutoTokenizer
.
from_pretrained
(
"llamafactory/tiny-random-qwen3"
)
renderer
=
Renderer
(
template
=
"chatml"
,
processor
=
tokenizer
)
generated_text
=
"LLM stands for Large Language Model."
parsed_message
=
renderer
.
parse_message
(
generated_text
)
assert
parsed_message
==
V1_MESSAGES
[
-
1
]
@
pytest
.
mark
.
parametrize
(
"num_samples"
,
[
16
])
def
test_chatml_rendering_remote
(
num_samples
:
int
):
tokenizer
:
Processor
=
AutoTokenizer
.
from_pretrained
(
"llamafactory/tiny-random-qwen3"
)
renderer
=
Renderer
(
template
=
"chatml"
,
processor
=
tokenizer
)
data_args
=
DataArguments
(
train_dataset
=
"llamafactory/v1-sft-demo"
)
data_engine
=
DataEngine
(
data_args
.
train_dataset
)
for
index
in
range
(
num_samples
):
v1_inputs
=
renderer
.
render_messages
(
data_engine
[
index
][
"messages"
],
is_generate
=
True
)
prefix
=
tokenizer
.
encode
(
"<|im_start|>user
\n
"
,
add_special_tokens
=
False
)
print
(
tokenizer
.
decode
(
v1_inputs
[
"input_ids"
][:
len
(
prefix
)]))
assert
v1_inputs
[
"input_ids"
][:
len
(
prefix
)]
==
prefix
def
test_qwen3_nothink_rendering
():
tokenizer
:
Processor
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen3-4B-Instruct-2507"
)
renderer
=
Renderer
(
template
=
"qwen3_nothink"
,
processor
=
tokenizer
)
hf_inputs
=
_get_input_ids
(
tokenizer
.
apply_chat_template
(
HF_MESSAGES_WITH_TOOLS
[:
-
1
],
tools
=
V1_TOOLS
,
add_generation_prompt
=
True
)
)
v1_inputs
=
renderer
.
render_messages
(
V1_MESSAGES_WITH_TOOLS
[:
-
1
],
tools
=
json
.
dumps
(
V1_TOOLS
),
is_generate
=
True
)
assert
v1_inputs
[
"input_ids"
]
==
hf_inputs
assert
v1_inputs
[
"attention_mask"
]
==
[
1
]
*
len
(
hf_inputs
)
assert
v1_inputs
[
"labels"
]
==
[
-
100
]
*
len
(
hf_inputs
)
assert
v1_inputs
[
"loss_weights"
]
==
[
0.0
]
*
len
(
hf_inputs
)
hf_inputs_part
=
_get_input_ids
(
tokenizer
.
apply_chat_template
(
HF_MESSAGES_WITH_TOOLS
[:
-
1
],
tools
=
V1_TOOLS
,
add_generation_prompt
=
False
)
)
hf_inputs_full
=
_get_input_ids
(
tokenizer
.
apply_chat_template
(
HF_MESSAGES_WITH_TOOLS
,
tools
=
V1_TOOLS
,
add_generation_prompt
=
False
)
)
v1_inputs_full
=
renderer
.
render_messages
(
V1_MESSAGES_WITH_TOOLS
,
tools
=
json
.
dumps
(
V1_TOOLS
),
is_generate
=
False
)
assert
v1_inputs_full
[
"input_ids"
]
==
hf_inputs_full
assert
v1_inputs_full
[
"attention_mask"
]
==
[
1
]
*
len
(
hf_inputs_full
)
assert
v1_inputs_full
[
"labels"
]
==
[
-
100
]
*
len
(
hf_inputs_part
)
+
hf_inputs_full
[
len
(
hf_inputs_part
)
:]
assert
v1_inputs_full
[
"loss_weights"
]
==
[
0.0
]
*
len
(
hf_inputs_part
)
+
[
1.0
]
*
(
len
(
hf_inputs_full
)
-
len
(
hf_inputs_part
)
)
def
test_qwen3_nothink_parse
():
tokenizer
:
Processor
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen3-4B-Instruct-2507"
)
renderer
=
Renderer
(
template
=
"qwen3_nothink"
,
processor
=
tokenizer
)
generated_text
=
(
"<thinking>I need to use the multiply function to calculate 6*8.</thinking>"
"Let me call the multiply function."
'<tool_call>{"name": "multiply", "arguments": {"a": 6, "b": 8}}</tool_call>'
)
parsed_message
=
renderer
.
parse_message
(
generated_text
)
assert
parsed_message
==
{
"role"
:
"assistant"
,
"content"
:
[
{
"type"
:
"reasoning"
,
"value"
:
"I need to use the multiply function to calculate 6*8."
},
{
"type"
:
"text"
,
"value"
:
"Let me call the multiply function."
},
{
"type"
:
"tool_call"
,
"value"
:
json
.
dumps
({
"name"
:
"multiply"
,
"arguments"
:
{
"a"
:
6
,
"b"
:
8
}})},
],
}
@
pytest
.
mark
.
parametrize
(
"num_samples"
,
[
8
])
def
test_qwen3_nothink_rendering_remote
(
num_samples
:
int
):
tokenizer
:
Processor
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen3-4B-Instruct-2507"
)
renderer
=
Renderer
(
template
=
"qwen3_nothink"
,
processor
=
tokenizer
)
data_args
=
DataArguments
(
train_dataset
=
"llamafactory/reason-tool-use-demo-1500"
)
data_engine
=
DataEngine
(
data_args
.
train_dataset
)
for
index
in
range
(
num_samples
):
v1_inputs
=
renderer
.
render_messages
(
data_engine
[
index
][
"messages"
],
tools
=
data_engine
[
index
][
"tools"
])
prefix_text
=
(
"<|im_start|>system
\n
You are a methodical and expert assistant. "
"Your primary goal is to solve user requests by leveraging a set of available tools. "
"You must reason for the best course of action in a structured manner before responding.
\n\n
"
"# Tools
\n\n
You may call one or more functions to assist with the user query.
\n\n
"
"You are provided with function signatures within <tools></tools> XML tags:
\n
<tools>
\n
"
'{"type": "function", "function": {"name":'
)
prefix
=
tokenizer
.
encode
(
prefix_text
,
add_special_tokens
=
False
)
print
(
tokenizer
.
decode
(
v1_inputs
[
"input_ids"
][:
len
(
prefix
)]))
assert
v1_inputs
[
"input_ids"
][:
len
(
prefix
)]
==
prefix
def
test_process_sft_samples
():
tokenizer
:
Processor
=
AutoTokenizer
.
from_pretrained
(
"llamafactory/tiny-random-qwen3"
)
renderer
=
Renderer
(
template
=
"chatml"
,
processor
=
tokenizer
)
hf_inputs
=
_get_input_ids
(
tokenizer
.
apply_chat_template
(
HF_MESSAGES
))
samples
=
[{
"messages"
:
V1_MESSAGES
,
"extra_info"
:
"test"
,
"_dataset_name"
:
"default"
}]
model_inputs
=
renderer
.
process_samples
(
samples
)
assert
len
(
model_inputs
)
==
1
assert
model_inputs
[
0
][
"input_ids"
]
==
hf_inputs
assert
model_inputs
[
0
][
"extra_info"
]
==
"test"
assert
model_inputs
[
0
][
"_dataset_name"
]
==
"default"
def
test_process_dpo_samples
():
tokenizer
:
Processor
=
AutoTokenizer
.
from_pretrained
(
"llamafactory/tiny-random-qwen3"
)
renderer
=
Renderer
(
template
=
"chatml"
,
processor
=
tokenizer
)
hf_inputs
=
_get_input_ids
(
tokenizer
.
apply_chat_template
(
HF_MESSAGES
))
samples
=
[
{
"chosen_messages"
:
V1_MESSAGES
,
"rejected_messages"
:
V1_MESSAGES
,
"extra_info"
:
"test"
,
"_dataset_name"
:
"default"
,
}
]
model_inputs
=
renderer
.
process_samples
(
samples
)
assert
len
(
model_inputs
)
==
1
assert
model_inputs
[
0
][
"input_ids"
]
==
hf_inputs
*
2
assert
model_inputs
[
0
][
"token_type_ids"
]
==
[
1
]
*
len
(
hf_inputs
)
+
[
2
]
*
len
(
hf_inputs
)
assert
model_inputs
[
0
][
"extra_info"
]
==
"test"
assert
model_inputs
[
0
][
"_dataset_name"
]
==
"default"
if
__name__
==
"__main__"
:
"""
python -m tests_v1.core.utils.test_rendering
"""
test_chatml_rendering
()
test_chatml_parse
()
test_chatml_rendering_remote
(
16
)
test_qwen3_nothink_rendering
()
test_qwen3_nothink_parse
()
test_qwen3_nothink_rendering_remote
(
16
)
test_process_sft_samples
()
test_process_dpo_samples
()
tests_v1/plugins/data_plugins/test_converter.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
random
import
pytest
from
datasets
import
load_dataset
from
llamafactory.v1.config.data_args
import
DataArguments
from
llamafactory.v1.core.data_engine
import
DataEngine
from
llamafactory.v1.plugins.data_plugins.converter
import
DataConverterPlugin
@
pytest
.
mark
.
parametrize
(
"num_samples"
,
[
16
])
def
test_alpaca_converter
(
num_samples
:
int
):
data_args
=
DataArguments
(
dataset
=
"llamafactory/v1-dataset-info/tiny-supervised-dataset.yaml"
)
data_engine
=
DataEngine
(
data_args
)
original_data
=
load_dataset
(
"llamafactory/tiny-supervised-dataset"
,
split
=
"train"
)
indexes
=
random
.
choices
(
range
(
len
(
data_engine
)),
k
=
num_samples
)
for
index
in
indexes
:
print
(
data_engine
[
index
])
expected_data
=
{
"messages"
:
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"value"
:
original_data
[
index
][
"instruction"
]
+
original_data
[
index
][
"input"
]}
],
"loss_weight"
:
0.0
,
},
{
"role"
:
"assistant"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
original_data
[
index
][
"output"
]}],
"loss_weight"
:
1.0
,
},
]
}
assert
data_engine
[
index
]
==
{
"_dataset_name"
:
"tiny_dataset"
,
**
expected_data
}
def
test_sharegpt_converter
():
example
=
{
"conversations"
:
[
{
"from"
:
"system"
,
"value"
:
"System"
},
{
"from"
:
"human"
,
"value"
:
"User"
},
{
"from"
:
"function_call"
,
"value"
:
"Tool"
},
{
"from"
:
"observation"
,
"value"
:
"Observation"
},
{
"from"
:
"gpt"
,
"value"
:
"Assistant"
},
]
}
expected_data
=
{
"messages"
:
[
{
"content"
:
[{
"type"
:
"text"
,
"value"
:
"System"
}],
"loss_weight"
:
0.0
,
"role"
:
"system"
},
{
"content"
:
[{
"type"
:
"text"
,
"value"
:
"User"
}],
"loss_weight"
:
0.0
,
"role"
:
"user"
},
{
"content"
:
[{
"type"
:
"tool_calls"
,
"value"
:
"Tool"
}],
"loss_weight"
:
1.0
,
"role"
:
"assistant"
},
{
"content"
:
[{
"type"
:
"text"
,
"value"
:
"Observation"
}],
"loss_weight"
:
0.0
,
"role"
:
"tool"
},
{
"content"
:
[{
"type"
:
"text"
,
"value"
:
"Assistant"
}],
"loss_weight"
:
1.0
,
"role"
:
"assistant"
},
]
}
assert
DataConverterPlugin
(
"sharegpt"
)(
example
)
==
expected_data
@
pytest
.
mark
.
parametrize
(
"num_samples"
,
[
16
])
def
test_pair_converter
(
num_samples
:
int
):
data_args
=
DataArguments
(
dataset
=
"llamafactory/v1-dataset-info/orca-dpo-pairs.yaml"
)
data_engine
=
DataEngine
(
data_args
)
original_data
=
load_dataset
(
"HuggingFaceH4/orca_dpo_pairs"
,
split
=
"train_prefs"
)
indexes
=
random
.
choices
(
range
(
len
(
data_engine
)),
k
=
num_samples
)
for
index
in
indexes
:
print
(
data_engine
[
index
])
print
(
original_data
[
index
])
expected_data
=
{
"chosen_messages"
:
[
{
"role"
:
"system"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
original_data
[
index
][
"chosen"
][
0
][
"content"
]}],
"loss_weight"
:
0.0
,
},
{
"role"
:
"user"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
original_data
[
index
][
"chosen"
][
1
][
"content"
]}],
"loss_weight"
:
0.0
,
},
{
"role"
:
"assistant"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
original_data
[
index
][
"chosen"
][
2
][
"content"
]}],
"loss_weight"
:
1.0
,
},
],
"rejected_messages"
:
[
{
"role"
:
"system"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
original_data
[
index
][
"rejected"
][
0
][
"content"
]}],
"loss_weight"
:
0.0
,
},
{
"role"
:
"user"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
original_data
[
index
][
"rejected"
][
1
][
"content"
]}],
"loss_weight"
:
0.0
,
},
{
"role"
:
"assistant"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
original_data
[
index
][
"rejected"
][
2
][
"content"
]}],
"loss_weight"
:
1.0
,
},
],
}
assert
data_engine
[
index
]
==
{
"_dataset_name"
:
"tiny_dataset"
,
**
expected_data
}
if
__name__
==
"__main__"
:
test_alpaca_converter
(
1
)
test_sharegpt_converter
()
test_pair_converter
(
1
)
tests_v1/plugins/model_plugins/test_init_plugin.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
llamafactory.v1.accelerator.interface
import
DistributedInterface
from
llamafactory.v1.config.arg_parser
import
get_args
from
llamafactory.v1.core.model_engine
import
ModelEngine
def
test_init_on_meta
():
model_args
,
*
_
=
get_args
(
dict
(
model
=
"llamafactory/tiny-random-qwen3"
,
init_config
=
{
"name"
:
"init_on_meta"
},
)
)
model_engine
=
ModelEngine
(
model_args
=
model_args
)
assert
model_engine
.
model
.
device
.
type
==
"meta"
def
test_init_on_rank0
():
model_args
,
*
_
=
get_args
(
dict
(
model
=
"llamafactory/tiny-random-qwen3"
,
init_config
=
{
"name"
:
"init_on_rank0"
},
)
)
model_engine
=
ModelEngine
(
model_args
=
model_args
)
if
DistributedInterface
().
get_rank
()
==
0
:
assert
model_engine
.
model
.
device
.
type
==
"cpu"
else
:
assert
model_engine
.
model
.
device
.
type
==
"meta"
def
test_init_on_default
():
model_args
,
*
_
=
get_args
(
dict
(
model
=
"llamafactory/tiny-random-qwen3"
,
init_config
=
{
"name"
:
"init_on_default"
},
)
)
model_engine
=
ModelEngine
(
model_args
=
model_args
)
assert
model_engine
.
model
.
device
==
DistributedInterface
().
current_device
if
__name__
==
"__main__"
:
"""
python tests_v1/plugins/model_plugins/test_init_plugin.py
"""
test_init_on_meta
()
test_init_on_rank0
()
test_init_on_default
()
Prev
1
…
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment