Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
c7d1b209
Commit
c7d1b209
authored
Apr 29, 2025
by
chenych
Browse files
Update 0429
parent
c8d12c06
Changes
65
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
183 additions
and
7 deletions
+183
-7
tests/data/test_mm_plugin.py
tests/data/test_mm_plugin.py
+89
-0
tests/data/test_template.py
tests/data/test_template.py
+42
-5
tests/model/model_utils/test_add_tokens.py
tests/model/model_utils/test_add_tokens.py
+46
-0
tests/train/test_sft_trainer.py
tests/train/test_sft_trainer.py
+4
-0
tests/version.txt
tests/version.txt
+2
-2
No files found.
tests/data/test_mm_plugin.py
View file @
c7d1b209
...
...
@@ -15,11 +15,13 @@
import
os
from
typing
import
TYPE_CHECKING
,
Any
import
numpy
as
np
import
pytest
import
torch
from
PIL
import
Image
from
llamafactory.data.mm_plugin
import
get_mm_plugin
from
llamafactory.extras.packages
import
is_transformers_version_greater_than
from
llamafactory.hparams
import
get_infer_args
from
llamafactory.model
import
load_tokenizer
...
...
@@ -42,11 +44,20 @@ MM_MESSAGES = [
{
"role"
:
"assistant"
,
"content"
:
"A cat."
},
]
OMNI_MESSAGES
=
[
{
"role"
:
"user"
,
"content"
:
"<image>What is in this image?"
},
{
"role"
:
"assistant"
,
"content"
:
"A cat."
},
{
"role"
:
"user"
,
"content"
:
"<audio>What is in this audio?"
},
{
"role"
:
"assistant"
,
"content"
:
"Nothing."
},
]
TEXT_MESSAGES
=
[
{
"role"
:
"user"
,
"content"
:
"How are you"
},
{
"role"
:
"assistant"
,
"content"
:
"I am fine!"
},
]
AUDIOS
=
[
np
.
zeros
(
1600
)]
IMAGES
=
[
Image
.
new
(
"RGB"
,
(
32
,
32
),
(
255
,
255
,
255
))]
NO_IMAGES
=
[]
...
...
@@ -57,6 +68,8 @@ NO_AUDIOS = []
IMGLENS
=
[
1
]
AUDLENS
=
[
1
]
NO_IMGLENS
=
[
0
]
NO_VIDLENS
=
[
0
]
...
...
@@ -75,6 +88,25 @@ def _get_mm_inputs(processor: "ProcessorMixin") -> dict[str, "torch.Tensor"]:
return
image_processor
(
images
=
IMAGES
,
return_tensors
=
"pt"
)
def
_get_omni_inputs
(
processor
:
"ProcessorMixin"
)
->
dict
[
str
,
"torch.Tensor"
]:
mm_inputs
=
{}
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
feature_extractor
=
getattr
(
processor
,
"feature_extractor"
,
None
)
mm_inputs
.
update
(
image_processor
(
IMAGES
,
return_tensors
=
"pt"
))
mm_inputs
.
update
(
feature_extractor
(
AUDIOS
,
sampling_rate
=
getattr
(
processor
,
"audio_sampling_rate"
,
16000
),
return_attention_mask
=
True
,
padding
=
"max_length"
,
return_tensors
=
"pt"
,
)
)
mm_inputs
[
"feature_attention_mask"
]
=
mm_inputs
.
pop
(
"attention_mask"
)
return
mm_inputs
def
_is_close
(
batch_a
:
dict
[
str
,
Any
],
batch_b
:
dict
[
str
,
Any
])
->
None
:
assert
batch_a
.
keys
()
==
batch_b
.
keys
()
for
key
in
batch_a
.
keys
():
...
...
@@ -103,6 +135,17 @@ def _check_plugin(
expected_mm_inputs
:
dict
[
str
,
Any
]
=
{},
expected_no_mm_inputs
:
dict
[
str
,
Any
]
=
{},
)
->
None
:
# test omni_messages
if
plugin
.
__class__
.
__name__
==
"Qwen2OmniPlugin"
:
assert
plugin
.
process_messages
(
OMNI_MESSAGES
,
IMAGES
,
NO_VIDEOS
,
AUDIOS
,
processor
)
==
expected_mm_messages
assert
plugin
.
process_token_ids
(
INPUT_IDS
,
LABELS
,
IMAGES
,
NO_VIDEOS
,
AUDIOS
,
tokenizer
,
processor
)
==
(
expected_input_ids
,
expected_labels
,
)
_is_close
(
plugin
.
get_mm_inputs
(
IMAGES
,
NO_VIDEOS
,
AUDIOS
,
IMGLENS
,
NO_VIDLENS
,
AUDLENS
,
BATCH_IDS
,
processor
),
expected_mm_inputs
,
)
# test mm_messages
if
plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
assert
plugin
.
process_messages
(
MM_MESSAGES
,
IMAGES
,
NO_VIDEOS
,
NO_AUDIOS
,
processor
)
==
expected_mm_messages
...
...
@@ -137,6 +180,7 @@ def test_base_plugin():
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.50.0"
),
reason
=
"Requires transformers>=4.50.0"
)
def
test_gemma3_plugin
():
image_seqlen
=
256
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"google/gemma-3-4b-it"
)
...
...
@@ -157,6 +201,24 @@ def test_gemma3_plugin():
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
xfail
(
reason
=
"Unknown error."
)
def
test_internvl_plugin
():
image_seqlen
=
256
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"OpenGVLab/InternVL3-1B-hf"
)
internvl_plugin
=
get_mm_plugin
(
"intern_vl"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
)
check_inputs
=
{
"plugin"
:
internvl_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
f
"<img>
{
'<IMG_CONTEXT>'
*
image_seqlen
*
1
}
</img>"
)
for
key
,
value
in
message
.
items
()
}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"processor"
])
check_inputs
[
"expected_mm_inputs"
].
pop
(
"num_patches"
,
None
)
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
xfail
(
reason
=
"Unknown error."
)
def
test_llama4_plugin
():
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
TINY_LLAMA4
)
...
...
@@ -178,6 +240,7 @@ def test_llama4_plugin():
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.47.0"
),
reason
=
"Requires transformers>=4.47.0"
)
def
test_llava_plugin
():
image_seqlen
=
576
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-1.5-7b-hf"
)
...
...
@@ -236,6 +299,7 @@ def test_paligemma_plugin():
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.50.0"
),
reason
=
"Requires transformers>=4.50.0"
)
def
test_pixtral_plugin
():
image_slice_height
,
image_slice_width
=
2
,
2
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"mistral-community/pixtral-12b"
)
...
...
@@ -257,6 +321,30 @@ def test_pixtral_plugin():
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
xfail
(
reason
=
"Unknown error."
)
def
test_qwen2_omni_plugin
():
image_seqlen
=
4
audio_seqlen
=
2
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"Qwen/Qwen2.5-Omni-7B"
)
qwen2_omni_plugin
=
get_mm_plugin
(
name
=
"qwen2_omni"
,
audio_token
=
"<|AUDIO|>"
,
image_token
=
"<|IMAGE|>"
,
video_token
=
"<|VIDEO|>"
)
check_inputs
=
{
"plugin"
:
qwen2_omni_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
(
value
.
replace
(
"<image>"
,
f
"<|vision_bos|>
{
'<|IMAGE|>'
*
image_seqlen
}
<|vision_eos|>"
).
replace
(
"<audio>"
,
f
"<|audio_bos|>
{
'<|AUDIO|>'
*
audio_seqlen
}
<|audio_eos|>"
)
)
for
key
,
value
in
message
.
items
()
}
for
message
in
OMNI_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_omni_inputs
(
tokenizer_module
[
"processor"
])
_check_plugin
(
**
check_inputs
)
def
test_qwen2_vl_plugin
():
image_seqlen
=
4
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"Qwen/Qwen2-VL-7B-Instruct"
)
...
...
@@ -273,6 +361,7 @@ def test_qwen2_vl_plugin():
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.47.0"
),
reason
=
"Requires transformers>=4.47.0"
)
def
test_video_llava_plugin
():
image_seqlen
=
256
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"LanguageBind/Video-LLaVA-7B-hf"
)
...
...
tests/data/test_template.py
View file @
c7d1b209
...
...
@@ -39,6 +39,13 @@ MESSAGES = [
{
"role"
:
"assistant"
,
"content"
:
"很高兴认识你!"
},
]
MESSAGES_WITH_THOUGHT
=
[
{
"role"
:
"user"
,
"content"
:
"How are you"
},
{
"role"
:
"assistant"
,
"content"
:
"<think>
\n
Model thought here
\n
</think>
\n\n
I am fine!"
},
{
"role"
:
"user"
,
"content"
:
"你好"
},
{
"role"
:
"assistant"
,
"content"
:
"<think>
\n
模型思考内容
\n
</think>
\n\n
很高兴认识你!"
},
]
def
_check_tokenization
(
tokenizer
:
"PreTrainedTokenizer"
,
batch_input_ids
:
list
[
list
[
int
]],
batch_text
:
list
[
str
]
...
...
@@ -53,7 +60,14 @@ def _check_tokenization(
assert
tokenizer
.
decode
(
input_ids
)
==
text
def
_check_template
(
model_id
:
str
,
template_name
:
str
,
prompt_str
:
str
,
answer_str
:
str
,
use_fast
:
bool
)
->
None
:
def
_check_template
(
model_id
:
str
,
template_name
:
str
,
prompt_str
:
str
,
answer_str
:
str
,
use_fast
:
bool
,
messages
:
list
[
dict
[
str
,
str
]]
=
MESSAGES
,
)
->
None
:
r
"""Check template.
Args:
...
...
@@ -62,13 +76,14 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s
prompt_str: the string corresponding to the prompt part.
answer_str: the string corresponding to the answer part.
use_fast: whether to use fast tokenizer.
messages: the list of messages.
"""
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
use_fast
=
use_fast
,
token
=
HF_TOKEN
)
content_str
=
tokenizer
.
apply_chat_template
(
MESSAGES
,
tokenize
=
False
)
content_ids
=
tokenizer
.
apply_chat_template
(
MESSAGES
,
tokenize
=
True
)
content_str
=
tokenizer
.
apply_chat_template
(
messages
,
tokenize
=
False
)
content_ids
=
tokenizer
.
apply_chat_template
(
messages
,
tokenize
=
True
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
DataArguments
(
template
=
template_name
))
prompt_ids
,
answer_ids
=
template
.
encode_oneturn
(
tokenizer
,
MESSAGES
)
prompt_ids
,
answer_ids
=
template
.
encode_oneturn
(
tokenizer
,
messages
)
assert
content_str
==
prompt_str
+
answer_str
assert
content_ids
==
prompt_ids
+
answer_ids
_check_tokenization
(
tokenizer
,
(
prompt_ids
,
answer_ids
),
(
prompt_str
,
answer_str
))
...
...
@@ -198,7 +213,7 @@ def test_phi4_template(use_fast: bool):
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_qwen_template
(
use_fast
:
bool
):
def
test_qwen
2_5
_template
(
use_fast
:
bool
):
prompt_str
=
(
"<|im_start|>system
\n
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
\n
"
"<|im_start|>user
\n
How are you<|im_end|>
\n
"
...
...
@@ -210,6 +225,18 @@ def test_qwen_template(use_fast: bool):
_check_template
(
"Qwen/Qwen2.5-7B-Instruct"
,
"qwen"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_qwen3_template
(
use_fast
:
bool
):
prompt_str
=
(
"<|im_start|>user
\n
How are you<|im_end|>
\n
"
"<|im_start|>assistant
\n
I am fine!<|im_end|>
\n
"
"<|im_start|>user
\n
你好<|im_end|>
\n
"
"<|im_start|>assistant
\n
"
)
answer_str
=
"<think>
\n
模型思考内容
\n
</think>
\n\n
很高兴认识你!<|im_end|>
\n
"
_check_template
(
"Qwen/Qwen3-8B"
,
"qwen3"
,
prompt_str
,
answer_str
,
use_fast
,
messages
=
MESSAGES_WITH_THOUGHT
)
def
test_parse_llama3_template
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
,
token
=
HF_TOKEN
)
template
=
parse_template
(
tokenizer
)
...
...
@@ -231,3 +258,13 @@ def test_parse_qwen_template():
assert
template
.
format_system
.
slots
==
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]
assert
template
.
format_prefix
.
slots
==
[]
assert
template
.
default_system
==
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
def
test_parse_qwen3_template
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen3-8B"
,
token
=
HF_TOKEN
)
template
=
parse_template
(
tokenizer
)
assert
template
.
format_user
.
slots
==
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]
assert
template
.
format_assistant
.
slots
==
[
"{{content}}<|im_end|>
\n
"
]
assert
template
.
format_system
.
slots
==
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]
assert
template
.
format_prefix
.
slots
==
[]
assert
template
.
default_system
==
""
tests/model/model_utils/test_add_tokens.py
0 → 100644
View file @
c7d1b209
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
pytest
from
llamafactory.hparams
import
ModelArguments
from
llamafactory.model
import
load_tokenizer
TINY_LLAMA3
=
os
.
getenv
(
"TINY_LLAMA3"
,
"llamafactory/tiny-random-Llama-3"
)
UNUSED_TOKEN
=
"<|UNUSED_TOKEN|>"
@
pytest
.
mark
.
parametrize
(
"special_tokens"
,
[
False
,
True
])
def
test_add_tokens
(
special_tokens
:
bool
):
if
special_tokens
:
model_args
=
ModelArguments
(
model_name_or_path
=
TINY_LLAMA3
,
add_special_tokens
=
UNUSED_TOKEN
)
else
:
model_args
=
ModelArguments
(
model_name_or_path
=
TINY_LLAMA3
,
add_tokens
=
UNUSED_TOKEN
)
tokenizer
=
load_tokenizer
(
model_args
)[
"tokenizer"
]
encoded_ids
=
tokenizer
.
encode
(
UNUSED_TOKEN
,
add_special_tokens
=
False
)
assert
len
(
encoded_ids
)
==
1
decoded_str
=
tokenizer
.
decode
(
encoded_ids
,
skip_special_tokens
=
True
)
if
special_tokens
:
assert
decoded_str
==
""
else
:
assert
decoded_str
==
UNUSED_TOKEN
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
tests/train/test_sft_trainer.py
View file @
c7d1b209
...
...
@@ -50,6 +50,10 @@ class DataCollatorWithVerbose(DataCollatorWithPadding):
verbose_list
:
list
[
dict
[
str
,
Any
]]
=
field
(
default_factory
=
list
)
def
__call__
(
self
,
features
:
list
[
dict
[
str
,
Any
]])
->
dict
[
str
,
Any
]:
features
=
[
{
k
:
v
for
k
,
v
in
feature
.
items
()
if
k
in
[
"input_ids"
,
"attention_mask"
,
"labels"
]}
for
feature
in
features
]
self
.
verbose_list
.
extend
(
features
)
batch
=
super
().
__call__
(
features
)
return
{
k
:
v
[:,
:
1
]
for
k
,
v
in
batch
.
items
()}
# truncate input length
...
...
tests/version.txt
View file @
c7d1b209
# change if test fails
0.9.3.10
2
# change if test fails
or cache is outdated
0.9.3.10
6
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment