Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
TS-MODELS-OPT
training
Autonomous-Driving-models
Commits
5ed76316
Commit
5ed76316
authored
Apr 08, 2026
by
雍大凯
Browse files
models add
parent
b2379236
Changes
290
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2181 additions
and
0 deletions
+2181
-0
docker-hub/qwen2.5-vl/llama-factory/tests/data/processor/test_processor_utils.py
...lama-factory/tests/data/processor/test_processor_utils.py
+34
-0
docker-hub/qwen2.5-vl/llama-factory/tests/data/processor/test_supervised.py
...-vl/llama-factory/tests/data/processor/test_supervised.py
+105
-0
docker-hub/qwen2.5-vl/llama-factory/tests/data/processor/test_unsupervised.py
...l/llama-factory/tests/data/processor/test_unsupervised.py
+60
-0
docker-hub/qwen2.5-vl/llama-factory/tests/data/test_collator.py
...-hub/qwen2.5-vl/llama-factory/tests/data/test_collator.py
+169
-0
docker-hub/qwen2.5-vl/llama-factory/tests/data/test_converter.py
...hub/qwen2.5-vl/llama-factory/tests/data/test_converter.py
+60
-0
docker-hub/qwen2.5-vl/llama-factory/tests/data/test_formatter.py
...hub/qwen2.5-vl/llama-factory/tests/data/test_formatter.py
+267
-0
docker-hub/qwen2.5-vl/llama-factory/tests/data/test_loader.py
...er-hub/qwen2.5-vl/llama-factory/tests/data/test_loader.py
+56
-0
docker-hub/qwen2.5-vl/llama-factory/tests/data/test_mm_plugin.py
...hub/qwen2.5-vl/llama-factory/tests/data/test_mm_plugin.py
+371
-0
docker-hub/qwen2.5-vl/llama-factory/tests/data/test_template.py
...-hub/qwen2.5-vl/llama-factory/tests/data/test_template.py
+354
-0
docker-hub/qwen2.5-vl/llama-factory/tests/e2e/test_chat.py
docker-hub/qwen2.5-vl/llama-factory/tests/e2e/test_chat.py
+49
-0
docker-hub/qwen2.5-vl/llama-factory/tests/e2e/test_sglang.py
docker-hub/qwen2.5-vl/llama-factory/tests/e2e/test_sglang.py
+71
-0
docker-hub/qwen2.5-vl/llama-factory/tests/e2e/test_train.py
docker-hub/qwen2.5-vl/llama-factory/tests/e2e/test_train.py
+71
-0
docker-hub/qwen2.5-vl/llama-factory/tests/eval/test_eval_template.py
...qwen2.5-vl/llama-factory/tests/eval/test_eval_template.py
+91
-0
docker-hub/qwen2.5-vl/llama-factory/tests/model/model_utils/test_add_tokens.py
.../llama-factory/tests/model/model_utils/test_add_tokens.py
+46
-0
docker-hub/qwen2.5-vl/llama-factory/tests/model/model_utils/test_attention.py
...l/llama-factory/tests/model/model_utils/test_attention.py
+50
-0
docker-hub/qwen2.5-vl/llama-factory/tests/model/model_utils/test_checkpointing.py
...ama-factory/tests/model/model_utils/test_checkpointing.py
+66
-0
docker-hub/qwen2.5-vl/llama-factory/tests/model/model_utils/test_misc.py
...2.5-vl/llama-factory/tests/model/model_utils/test_misc.py
+43
-0
docker-hub/qwen2.5-vl/llama-factory/tests/model/model_utils/test_packing.py
...-vl/llama-factory/tests/model/model_utils/test_packing.py
+68
-0
docker-hub/qwen2.5-vl/llama-factory/tests/model/model_utils/test_visual.py
...5-vl/llama-factory/tests/model/model_utils/test_visual.py
+102
-0
docker-hub/qwen2.5-vl/llama-factory/tests/model/test_base.py
docker-hub/qwen2.5-vl/llama-factory/tests/model/test_base.py
+48
-0
No files found.
docker-hub/qwen2.5-vl/llama-factory/tests/data/processor/test_processor_utils.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pytest
from
llamafactory.data.processor.processor_utils
import
infer_seqlen
@
pytest
.
mark
.
parametrize
(
"test_input,test_output"
,
[
((
3000
,
2000
,
1000
),
(
600
,
400
)),
((
2000
,
3000
,
1000
),
(
400
,
600
)),
((
1000
,
100
,
1000
),
(
900
,
100
)),
((
100
,
1000
,
1000
),
(
100
,
900
)),
((
100
,
500
,
1000
),
(
100
,
500
)),
((
500
,
100
,
1000
),
(
500
,
100
)),
((
10
,
10
,
1000
),
(
10
,
10
)),
],
)
def
test_infer_seqlen
(
test_input
:
tuple
[
int
,
int
,
int
],
test_output
:
tuple
[
int
,
int
]):
assert
test_output
==
infer_seqlen
(
*
test_input
)
docker-hub/qwen2.5-vl/llama-factory/tests/data/processor/test_supervised.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
random
import
pytest
from
datasets
import
load_dataset
from
transformers
import
AutoTokenizer
from
llamafactory.extras.constants
import
IGNORE_INDEX
from
llamafactory.train.test_utils
import
load_dataset_module
DEMO_DATA
=
os
.
getenv
(
"DEMO_DATA"
,
"llamafactory/demo_data"
)
TINY_LLAMA3
=
os
.
getenv
(
"TINY_LLAMA3"
,
"llamafactory/tiny-random-Llama-3"
)
TINY_DATA
=
os
.
getenv
(
"TINY_DATA"
,
"llamafactory/tiny-supervised-dataset"
)
TRAIN_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA3
,
"stage"
:
"sft"
,
"do_train"
:
True
,
"finetuning_type"
:
"full"
,
"template"
:
"llama3"
,
"cutoff_len"
:
8192
,
"output_dir"
:
"dummy_dir"
,
"overwrite_output_dir"
:
True
,
"fp16"
:
True
,
}
@
pytest
.
mark
.
parametrize
(
"num_samples"
,
[
16
])
def
test_supervised_single_turn
(
num_samples
:
int
):
train_dataset
=
load_dataset_module
(
dataset_dir
=
"ONLINE"
,
dataset
=
TINY_DATA
,
**
TRAIN_ARGS
)[
"train_dataset"
]
ref_tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
)
original_data
=
load_dataset
(
TINY_DATA
,
split
=
"train"
)
indexes
=
random
.
choices
(
range
(
len
(
original_data
)),
k
=
num_samples
)
for
index
in
indexes
:
prompt
=
original_data
[
"instruction"
][
index
]
if
original_data
[
"input"
][
index
]:
prompt
+=
"
\n
"
+
original_data
[
"input"
][
index
]
messages
=
[
{
"role"
:
"user"
,
"content"
:
prompt
},
{
"role"
:
"assistant"
,
"content"
:
original_data
[
"output"
][
index
]},
]
ref_input_ids
=
ref_tokenizer
.
apply_chat_template
(
messages
)
assert
train_dataset
[
"input_ids"
][
index
]
==
ref_input_ids
@
pytest
.
mark
.
parametrize
(
"num_samples"
,
[
8
])
def
test_supervised_multi_turn
(
num_samples
:
int
):
train_dataset
=
load_dataset_module
(
dataset_dir
=
"REMOTE:"
+
DEMO_DATA
,
dataset
=
"system_chat"
,
**
TRAIN_ARGS
)[
"train_dataset"
]
ref_tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
)
original_data
=
load_dataset
(
DEMO_DATA
,
name
=
"system_chat"
,
split
=
"train"
)
indexes
=
random
.
choices
(
range
(
len
(
original_data
)),
k
=
num_samples
)
for
index
in
indexes
:
ref_input_ids
=
ref_tokenizer
.
apply_chat_template
(
original_data
[
"messages"
][
index
])
assert
train_dataset
[
"input_ids"
][
index
]
==
ref_input_ids
@
pytest
.
mark
.
parametrize
(
"num_samples"
,
[
4
])
def
test_supervised_train_on_prompt
(
num_samples
:
int
):
train_dataset
=
load_dataset_module
(
dataset_dir
=
"REMOTE:"
+
DEMO_DATA
,
dataset
=
"system_chat"
,
train_on_prompt
=
True
,
**
TRAIN_ARGS
)[
"train_dataset"
]
ref_tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
)
original_data
=
load_dataset
(
DEMO_DATA
,
name
=
"system_chat"
,
split
=
"train"
)
indexes
=
random
.
choices
(
range
(
len
(
original_data
)),
k
=
num_samples
)
for
index
in
indexes
:
ref_ids
=
ref_tokenizer
.
apply_chat_template
(
original_data
[
"messages"
][
index
])
assert
train_dataset
[
"input_ids"
][
index
]
==
ref_ids
assert
train_dataset
[
"labels"
][
index
]
==
ref_ids
@
pytest
.
mark
.
parametrize
(
"num_samples"
,
[
4
])
def
test_supervised_mask_history
(
num_samples
:
int
):
train_dataset
=
load_dataset_module
(
dataset_dir
=
"REMOTE:"
+
DEMO_DATA
,
dataset
=
"system_chat"
,
mask_history
=
True
,
**
TRAIN_ARGS
)[
"train_dataset"
]
ref_tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
)
original_data
=
load_dataset
(
DEMO_DATA
,
name
=
"system_chat"
,
split
=
"train"
)
indexes
=
random
.
choices
(
range
(
len
(
original_data
)),
k
=
num_samples
)
for
index
in
indexes
:
messages
=
original_data
[
"messages"
][
index
]
ref_input_ids
=
ref_tokenizer
.
apply_chat_template
(
messages
)
prompt_len
=
len
(
ref_tokenizer
.
apply_chat_template
(
messages
[:
-
1
],
add_generation_prompt
=
True
))
ref_label_ids
=
[
IGNORE_INDEX
]
*
prompt_len
+
ref_input_ids
[
prompt_len
:]
assert
train_dataset
[
"input_ids"
][
index
]
==
ref_input_ids
assert
train_dataset
[
"labels"
][
index
]
==
ref_label_ids
docker-hub/qwen2.5-vl/llama-factory/tests/data/processor/test_unsupervised.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
random
import
pytest
from
datasets
import
load_dataset
from
transformers
import
AutoTokenizer
from
llamafactory.train.test_utils
import
load_dataset_module
DEMO_DATA
=
os
.
getenv
(
"DEMO_DATA"
,
"llamafactory/demo_data"
)
TINY_LLAMA3
=
os
.
getenv
(
"TINY_LLAMA3"
,
"llamafactory/tiny-random-Llama-3"
)
TINY_DATA
=
os
.
getenv
(
"TINY_DATA"
,
"llamafactory/tiny-supervised-dataset"
)
TRAIN_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA3
,
"stage"
:
"ppo"
,
"do_train"
:
True
,
"finetuning_type"
:
"full"
,
"reward_model"
:
""
,
"reward_model_type"
:
"full"
,
"dataset"
:
"system_chat"
,
"dataset_dir"
:
"REMOTE:"
+
DEMO_DATA
,
"template"
:
"llama3"
,
"cutoff_len"
:
8192
,
"output_dir"
:
"dummy_dir"
,
"overwrite_output_dir"
:
True
,
"fp16"
:
True
,
}
@
pytest
.
mark
.
parametrize
(
"num_samples"
,
[
16
])
def
test_unsupervised_data
(
num_samples
:
int
):
train_dataset
=
load_dataset_module
(
**
TRAIN_ARGS
)[
"train_dataset"
]
ref_tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
)
original_data
=
load_dataset
(
DEMO_DATA
,
name
=
"system_chat"
,
split
=
"train"
)
indexes
=
random
.
choices
(
range
(
len
(
original_data
)),
k
=
num_samples
)
for
index
in
indexes
:
messages
=
original_data
[
"messages"
][
index
]
ref_ids
=
ref_tokenizer
.
apply_chat_template
(
messages
)
ref_input_ids
=
ref_tokenizer
.
apply_chat_template
(
messages
[:
-
1
],
add_generation_prompt
=
True
)
ref_labels
=
ref_ids
[
len
(
ref_input_ids
)
:]
assert
train_dataset
[
"input_ids"
][
index
]
==
ref_input_ids
assert
train_dataset
[
"labels"
][
index
]
==
ref_labels
docker-hub/qwen2.5-vl/llama-factory/tests/data/test_collator.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
torch
from
PIL
import
Image
from
transformers
import
AutoConfig
,
AutoModelForVision2Seq
from
llamafactory.data
import
get_template_and_fix_tokenizer
from
llamafactory.data.collator
import
MultiModalDataCollatorForSeq2Seq
,
prepare_4d_attention_mask
from
llamafactory.extras.constants
import
IGNORE_INDEX
from
llamafactory.hparams
import
get_infer_args
from
llamafactory.model
import
load_tokenizer
TINY_LLAMA3
=
os
.
getenv
(
"TINY_LLAMA3"
,
"llamafactory/tiny-random-Llama-3"
)
def
test_base_collator
():
model_args
,
data_args
,
*
_
=
get_infer_args
({
"model_name_or_path"
:
TINY_LLAMA3
,
"template"
:
"default"
})
tokenizer_module
=
load_tokenizer
(
model_args
)
template
=
get_template_and_fix_tokenizer
(
tokenizer_module
[
"tokenizer"
],
data_args
)
data_collator
=
MultiModalDataCollatorForSeq2Seq
(
template
=
template
,
pad_to_multiple_of
=
8
,
label_pad_token_id
=
IGNORE_INDEX
,
**
tokenizer_module
,
)
p
=
tokenizer_module
[
"tokenizer"
].
pad_token_id
q
=
IGNORE_INDEX
features
=
[
{
"input_ids"
:
[
0
,
1
,
2
,
3
,
4
,
5
],
"attention_mask"
:
[
1
,
1
,
1
,
1
,
1
,
1
],
"labels"
:
[
q
,
q
,
2
,
3
,
4
,
5
],
},
{
"input_ids"
:
[
6
,
7
],
"attention_mask"
:
[
1
,
1
],
"labels"
:
[
q
,
7
],
},
]
batch_input
=
data_collator
(
features
)
expected_input
=
{
"input_ids"
:
[
[
0
,
1
,
2
,
3
,
4
,
5
,
p
,
p
],
[
6
,
7
,
p
,
p
,
p
,
p
,
p
,
p
],
],
"attention_mask"
:
[
[
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
],
[
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
],
],
"labels"
:
[
[
q
,
q
,
2
,
3
,
4
,
5
,
q
,
q
],
[
q
,
7
,
q
,
q
,
q
,
q
,
q
,
q
],
],
}
for
k
in
batch_input
.
keys
():
assert
batch_input
[
k
].
eq
(
torch
.
tensor
(
expected_input
[
k
])).
all
()
def
test_multimodal_collator
():
model_args
,
data_args
,
*
_
=
get_infer_args
(
{
"model_name_or_path"
:
"Qwen/Qwen2-VL-2B-Instruct"
,
"template"
:
"qwen2_vl"
}
)
tokenizer_module
=
load_tokenizer
(
model_args
)
template
=
get_template_and_fix_tokenizer
(
tokenizer_module
[
"tokenizer"
],
data_args
)
config
=
AutoConfig
.
from_pretrained
(
model_args
.
model_name_or_path
)
with
torch
.
device
(
"meta"
):
model
=
AutoModelForVision2Seq
.
from_config
(
config
)
data_collator
=
MultiModalDataCollatorForSeq2Seq
(
template
=
template
,
model
=
model
,
pad_to_multiple_of
=
4
,
label_pad_token_id
=
IGNORE_INDEX
,
**
tokenizer_module
,
)
p
=
tokenizer_module
[
"tokenizer"
].
pad_token_id
q
=
IGNORE_INDEX
s
=
tokenizer_module
[
"tokenizer"
].
convert_tokens_to_ids
(
"<|vision_start|>"
)
e
=
tokenizer_module
[
"tokenizer"
].
convert_tokens_to_ids
(
"<|vision_end|>"
)
m
=
tokenizer_module
[
"tokenizer"
].
convert_tokens_to_ids
(
"<|image_pad|>"
)
fake_image
=
Image
.
new
(
"RGB"
,
(
64
,
64
),
(
255
,
255
,
255
))
features
=
[
{
"input_ids"
:
[
0
,
1
,
2
,
3
],
"attention_mask"
:
[
1
,
1
,
1
,
1
],
"labels"
:
[
0
,
1
,
2
,
3
],
},
]
batch_input
=
data_collator
(
features
)
expected_input
=
{
"input_ids"
:
[
[
0
,
1
,
2
,
3
,
s
,
m
,
m
,
m
,
m
,
e
,
p
,
p
],
],
"attention_mask"
:
[
[
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
],
"labels"
:
[
[
0
,
1
,
2
,
3
,
q
,
q
,
q
,
q
,
q
,
q
,
q
,
q
],
],
"position_ids"
:
[
[[
0
,
1
,
2
,
3
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
]],
[[
0
,
1
,
2
,
3
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
]],
[[
0
,
1
,
2
,
3
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
]],
],
"rope_deltas"
:
[[
-
8
]],
**
tokenizer_module
[
"processor"
].
image_processor
(
fake_image
),
}
assert
batch_input
.
keys
()
==
expected_input
.
keys
()
for
k
in
batch_input
.
keys
():
assert
batch_input
[
k
].
eq
(
torch
.
tensor
(
expected_input
[
k
])).
all
()
def
test_4d_attention_mask
():
o
=
0.0
x
=
torch
.
finfo
(
torch
.
float16
).
min
attention_mask_with_indices
=
torch
.
tensor
(
[
[
1
,
1
,
2
,
2
,
2
,
0
],
[
1
,
2
,
2
,
3
,
3
,
3
],
]
)
attention_mask_computed
=
prepare_4d_attention_mask
(
attention_mask_with_indices
,
torch
.
float16
)
attention_mask_expected
=
torch
.
tensor
(
[
[
[
[
o
,
x
,
x
,
x
,
x
,
x
],
[
o
,
o
,
x
,
x
,
x
,
x
],
[
x
,
x
,
o
,
x
,
x
,
x
],
[
x
,
x
,
o
,
o
,
x
,
x
],
[
x
,
x
,
o
,
o
,
o
,
x
],
[
x
,
x
,
x
,
x
,
x
,
x
],
]
],
[
[
[
o
,
x
,
x
,
x
,
x
,
x
],
[
x
,
o
,
x
,
x
,
x
,
x
],
[
x
,
o
,
o
,
x
,
x
,
x
],
[
x
,
x
,
x
,
o
,
x
,
x
],
[
x
,
x
,
x
,
o
,
o
,
x
],
[
x
,
x
,
x
,
o
,
o
,
o
],
]
],
],
dtype
=
torch
.
float16
,
)
assert
list
(
attention_mask_computed
.
size
())
==
[
2
,
1
,
6
,
6
]
assert
torch
.
all
(
attention_mask_computed
==
attention_mask_expected
)
if
__name__
==
"__main__"
:
test_multimodal_collator
()
docker-hub/qwen2.5-vl/llama-factory/tests/data/test_converter.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
llamafactory.data
import
Role
from
llamafactory.data.converter
import
get_dataset_converter
from
llamafactory.data.parser
import
DatasetAttr
from
llamafactory.hparams
import
DataArguments
def
test_alpaca_converter
():
dataset_attr
=
DatasetAttr
(
"hf_hub"
,
"llamafactory/tiny-supervised-dataset"
)
data_args
=
DataArguments
()
example
=
{
"instruction"
:
"Solve the math problem."
,
"input"
:
"3 + 4"
,
"output"
:
"The answer is 7."
,
}
dataset_converter
=
get_dataset_converter
(
"alpaca"
,
dataset_attr
,
data_args
)
assert
dataset_converter
(
example
)
==
{
"_prompt"
:
[{
"role"
:
Role
.
USER
.
value
,
"content"
:
"Solve the math problem.
\n
3 + 4"
}],
"_response"
:
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
"The answer is 7."
}],
"_system"
:
""
,
"_tools"
:
""
,
"_images"
:
None
,
"_videos"
:
None
,
"_audios"
:
None
,
}
def
test_sharegpt_converter
():
dataset_attr
=
DatasetAttr
(
"hf_hub"
,
"llamafactory/tiny-supervised-dataset"
)
data_args
=
DataArguments
()
example
=
{
"conversations"
:
[
{
"from"
:
"system"
,
"value"
:
"You are a helpful assistant."
},
{
"from"
:
"human"
,
"value"
:
"Solve the math problem.
\n
3 + 4"
},
{
"from"
:
"gpt"
,
"value"
:
"The answer is 7."
},
]
}
dataset_converter
=
get_dataset_converter
(
"sharegpt"
,
dataset_attr
,
data_args
)
assert
dataset_converter
(
example
)
==
{
"_prompt"
:
[{
"role"
:
Role
.
USER
.
value
,
"content"
:
"Solve the math problem.
\n
3 + 4"
}],
"_response"
:
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
"The answer is 7."
}],
"_system"
:
"You are a helpful assistant."
,
"_tools"
:
""
,
"_images"
:
None
,
"_videos"
:
None
,
"_audios"
:
None
,
}
docker-hub/qwen2.5-vl/llama-factory/tests/data/test_formatter.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
from
datetime
import
datetime
from
llamafactory.data.formatter
import
EmptyFormatter
,
FunctionFormatter
,
StringFormatter
,
ToolFormatter
FUNCTION
=
{
"name"
:
"tool_name"
,
"arguments"
:
{
"foo"
:
"bar"
,
"size"
:
10
}}
TOOLS
=
[
{
"name"
:
"test_tool"
,
"description"
:
"tool_desc"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"foo"
:
{
"type"
:
"string"
,
"description"
:
"foo_desc"
},
"bar"
:
{
"type"
:
"number"
,
"description"
:
"bar_desc"
},
},
"required"
:
[
"foo"
],
},
}
]
def
test_empty_formatter
():
formatter
=
EmptyFormatter
(
slots
=
[
"
\n
"
])
assert
formatter
.
apply
()
==
[
"
\n
"
]
def
test_string_formatter
():
formatter
=
StringFormatter
(
slots
=
[
"<s>"
,
"Human: {{content}}
\n
Assistant:"
])
assert
formatter
.
apply
(
content
=
"Hi"
)
==
[
"<s>"
,
"Human: Hi
\n
Assistant:"
]
def
test_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
,
"</s>"
],
tool_format
=
"default"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}"""
,
"</s>"
,
]
def
test_multi_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
,
"</s>"
],
tool_format
=
"default"
)
tool_calls
=
json
.
dumps
([
FUNCTION
]
*
2
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}
\n
"""
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}"""
,
"</s>"
,
]
def
test_default_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"default"
)
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
"You have access to the following tools:
\n
"
"> Tool Name: test_tool
\n
"
"Tool Description: tool_desc
\n
"
"Tool Args:
\n
"
" - foo (string, required): foo_desc
\n
"
" - bar (number): bar_desc
\n\n
"
"Use the following format if using a tool:
\n
"
"```
\n
"
"Action: tool name (one of [test_tool])
\n
"
"Action Input: the input to the tool, in a JSON format representing the kwargs "
"""(e.g. ```{"input": "hello world", "num_beams": 5}```)
\n
"""
"```
\n
"
]
def
test_default_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"default"
)
result
=
"""Action: test_tool
\n
Action Input: {"foo": "bar", "size": 10}"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
def
test_default_multi_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"default"
)
result
=
(
"""Action: test_tool
\n
Action Input: {"foo": "bar", "size": 10}
\n
"""
"""Action: another_tool
\n
Action Input: {"foo": "job", "size": 2}"""
)
assert
formatter
.
extract
(
result
)
==
[
(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
),
(
"another_tool"
,
"""{"foo": "job", "size": 2}"""
),
]
def
test_glm4_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
],
tool_format
=
"glm4"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""tool_name
\n
{"foo": "bar", "size": 10}"""
]
def
test_glm4_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"glm4"
)
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱 AI 公司训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。
\n\n
# 可用工具
\n\n
"
f
"## test_tool
\n\n
{
json
.
dumps
(
TOOLS
[
0
],
indent
=
4
,
ensure_ascii
=
False
)
}
\n
"
"在调用上述函数时,请使用 Json 格式表示调用的参数。"
]
def
test_glm4_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"glm4"
)
result
=
"""test_tool
\n
{"foo": "bar", "size": 10}
\n
"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
def
test_llama3_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
],
tool_format
=
"llama3"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>"""
]
def
test_llama3_multi_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
],
tool_format
=
"llama3"
)
tool_calls
=
json
.
dumps
([
FUNCTION
]
*
2
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""[{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}]"""
"""<|eot_id|>"""
]
def
test_llama3_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"llama3"
)
date
=
datetime
.
now
().
strftime
(
"%d %b %Y"
)
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
TOOLS
[
0
]}
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
f
"Cutting Knowledge Date: December 2023
\n
Today Date:
{
date
}
\n\n
"
"You have access to the following functions. "
"To call a function, please respond with JSON for a function call. "
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
f
"Do not use variables.
\n\n
{
json
.
dumps
(
wrapped_tool
,
indent
=
4
,
ensure_ascii
=
False
)
}
\n\n
"
]
def
test_llama3_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"llama3"
)
result
=
"""{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}
\n
"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
def
test_llama3_multi_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"llama3"
)
result
=
(
"""[{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}, """
"""{"name": "another_tool", "parameters": {"foo": "job", "size": 2}}]"""
)
assert
formatter
.
extract
(
result
)
==
[
(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
),
(
"another_tool"
,
"""{"foo": "job", "size": 2}"""
),
]
def
test_mistral_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] {{content}}"
,
"</s>"
],
tool_format
=
"mistral"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"[TOOL_CALLS] "
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]"""
,
"</s>"
,
]
def
test_mistral_multi_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] {{content}}"
,
"</s>"
],
tool_format
=
"mistral"
)
tool_calls
=
json
.
dumps
([
FUNCTION
]
*
2
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"[TOOL_CALLS] "
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]"""
,
"</s>"
,
]
def
test_mistral_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"mistral"
)
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
TOOLS
[
0
]}
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
"[AVAILABLE_TOOLS] "
+
json
.
dumps
([
wrapped_tool
],
ensure_ascii
=
False
)
+
"[/AVAILABLE_TOOLS]"
]
def
test_mistral_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"mistral"
)
result
=
"""{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
def
test_mistral_multi_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"mistral"
)
result
=
(
"""[{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}]"""
)
assert
formatter
.
extract
(
result
)
==
[
(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
),
(
"another_tool"
,
"""{"foo": "job", "size": 2}"""
),
]
def
test_qwen_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""<tool_call>
\n
{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}
\n
</tool_call><|im_end|>
\n
"""
]
def
test_qwen_multi_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
)
tool_calls
=
json
.
dumps
([
FUNCTION
]
*
2
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""<tool_call>
\n
{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}
\n
</tool_call>
\n
"""
"""<tool_call>
\n
{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}
\n
</tool_call>"""
"<|im_end|>
\n
"
]
def
test_qwen_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"qwen"
)
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
TOOLS
[
0
]}
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
"
\n\n
# Tools
\n\n
You may call one or more functions to assist with the user query.
\n\n
"
"You are provided with function signatures within <tools></tools> XML tags:
\n
<tools>"
f
"
\n
{
json
.
dumps
(
wrapped_tool
,
ensure_ascii
=
False
)
}
"
"
\n
</tools>
\n\n
For each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:
\n
<tool_call>
\n
{"name": <function-name>, """
""""arguments": <args-json-object>}
\n
</tool_call>"""
]
def
test_qwen_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"qwen"
)
result
=
"""<tool_call>
\n
{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}
\n
</tool_call>"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
def
test_qwen_multi_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"qwen"
)
result
=
(
"""<tool_call>
\n
{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}
\n
</tool_call>
\n
"""
"""<tool_call>
\n
{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}
\n
</tool_call>"""
)
assert
formatter
.
extract
(
result
)
==
[
(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
),
(
"another_tool"
,
"""{"foo": "job", "size": 2}"""
),
]
docker-hub/qwen2.5-vl/llama-factory/tests/data/test_loader.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
llamafactory.train.test_utils
import
load_dataset_module
DEMO_DATA
=
os
.
getenv
(
"DEMO_DATA"
,
"llamafactory/demo_data"
)
TINY_LLAMA3
=
os
.
getenv
(
"TINY_LLAMA3"
,
"llamafactory/tiny-random-Llama-3"
)
TINY_DATA
=
os
.
getenv
(
"TINY_DATA"
,
"llamafactory/tiny-supervised-dataset"
)
TRAIN_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA3
,
"stage"
:
"sft"
,
"do_train"
:
True
,
"finetuning_type"
:
"full"
,
"template"
:
"llama3"
,
"dataset"
:
TINY_DATA
,
"dataset_dir"
:
"ONLINE"
,
"cutoff_len"
:
8192
,
"output_dir"
:
"dummy_dir"
,
"overwrite_output_dir"
:
True
,
"fp16"
:
True
,
}
def
test_load_train_only
():
dataset_module
=
load_dataset_module
(
**
TRAIN_ARGS
)
assert
dataset_module
.
get
(
"train_dataset"
)
is
not
None
assert
dataset_module
.
get
(
"eval_dataset"
)
is
None
def
test_load_val_size
():
dataset_module
=
load_dataset_module
(
val_size
=
0.1
,
**
TRAIN_ARGS
)
assert
dataset_module
.
get
(
"train_dataset"
)
is
not
None
assert
dataset_module
.
get
(
"eval_dataset"
)
is
not
None
def
test_load_eval_data
():
dataset_module
=
load_dataset_module
(
eval_dataset
=
TINY_DATA
,
**
TRAIN_ARGS
)
assert
dataset_module
.
get
(
"train_dataset"
)
is
not
None
assert
dataset_module
.
get
(
"eval_dataset"
)
is
not
None
docker-hub/qwen2.5-vl/llama-factory/tests/data/test_mm_plugin.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
TYPE_CHECKING
,
Any
import
numpy
as
np
import
pytest
import
torch
from
PIL
import
Image
from
llamafactory.data.mm_plugin
import
get_mm_plugin
from
llamafactory.extras.packages
import
is_transformers_version_greater_than
from
llamafactory.hparams
import
get_infer_args
from
llamafactory.model
import
load_tokenizer
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
transformers.image_processing_utils
import
BaseImageProcessor
from
llamafactory.data.mm_plugin
import
BasePlugin
from
llamafactory.model.loader
import
TokenizerModule
HF_TOKEN
=
os
.
getenv
(
"HF_TOKEN"
)
TINY_LLAMA3
=
os
.
getenv
(
"TINY_LLAMA3"
,
"llamafactory/tiny-random-Llama-3"
)
TINY_LLAMA4
=
os
.
getenv
(
"TINY_LLAMA4"
,
"llamafactory/tiny-random-Llama-4"
)
MM_MESSAGES
=
[
{
"role"
:
"user"
,
"content"
:
"<image>What is in this image?"
},
{
"role"
:
"assistant"
,
"content"
:
"A cat."
},
]
OMNI_MESSAGES
=
[
{
"role"
:
"user"
,
"content"
:
"<image>What is in this image?"
},
{
"role"
:
"assistant"
,
"content"
:
"A cat."
},
{
"role"
:
"user"
,
"content"
:
"<audio>What is in this audio?"
},
{
"role"
:
"assistant"
,
"content"
:
"Nothing."
},
]
TEXT_MESSAGES
=
[
{
"role"
:
"user"
,
"content"
:
"How are you"
},
{
"role"
:
"assistant"
,
"content"
:
"I am fine!"
},
]
AUDIOS
=
[
np
.
zeros
(
1600
)]
IMAGES
=
[
Image
.
new
(
"RGB"
,
(
32
,
32
),
(
255
,
255
,
255
))]
NO_IMAGES
=
[]
NO_VIDEOS
=
[]
NO_AUDIOS
=
[]
IMGLENS
=
[
1
]
AUDLENS
=
[
1
]
NO_IMGLENS
=
[
0
]
NO_VIDLENS
=
[
0
]
NO_AUDLENS
=
[
0
]
INPUT_IDS
=
[
0
,
1
,
2
,
3
,
4
]
LABELS
=
[
0
,
1
,
2
,
3
,
4
]
BATCH_IDS
=
[[
1
]
*
1024
]
def
_get_mm_inputs
(
processor
:
"ProcessorMixin"
)
->
dict
[
str
,
"torch.Tensor"
]:
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
return
image_processor
(
images
=
IMAGES
,
return_tensors
=
"pt"
)
def
_get_omni_inputs
(
processor
:
"ProcessorMixin"
)
->
dict
[
str
,
"torch.Tensor"
]:
mm_inputs
=
{}
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
feature_extractor
=
getattr
(
processor
,
"feature_extractor"
,
None
)
mm_inputs
.
update
(
image_processor
(
IMAGES
,
return_tensors
=
"pt"
))
mm_inputs
.
update
(
feature_extractor
(
AUDIOS
,
sampling_rate
=
getattr
(
processor
,
"audio_sampling_rate"
,
16000
),
return_attention_mask
=
True
,
padding
=
"max_length"
,
return_tensors
=
"pt"
,
)
)
mm_inputs
[
"feature_attention_mask"
]
=
mm_inputs
.
pop
(
"attention_mask"
)
return
mm_inputs
def
_is_close
(
batch_a
:
dict
[
str
,
Any
],
batch_b
:
dict
[
str
,
Any
])
->
None
:
assert
batch_a
.
keys
()
==
batch_b
.
keys
()
for
key
in
batch_a
.
keys
():
if
isinstance
(
batch_a
[
key
],
torch
.
Tensor
):
assert
torch
.
allclose
(
batch_a
[
key
],
batch_b
[
key
],
rtol
=
1e-4
,
atol
=
1e-5
)
elif
isinstance
(
batch_a
[
key
],
list
)
and
all
(
isinstance
(
item
,
torch
.
Tensor
)
for
item
in
batch_a
[
key
]):
assert
len
(
batch_a
[
key
])
==
len
(
batch_b
[
key
])
for
tensor_a
,
tensor_b
in
zip
(
batch_a
[
key
],
batch_b
[
key
]):
assert
torch
.
allclose
(
tensor_a
,
tensor_b
,
rtol
=
1e-4
,
atol
=
1e-5
)
else
:
assert
batch_a
[
key
]
==
batch_b
[
key
]
def
_load_tokenizer_module
(
model_name_or_path
:
str
)
->
"TokenizerModule"
:
model_args
,
*
_
=
get_infer_args
({
"model_name_or_path"
:
model_name_or_path
,
"template"
:
"default"
})
return
load_tokenizer
(
model_args
)
def
_check_plugin
(
plugin
:
"BasePlugin"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
"ProcessorMixin"
,
expected_mm_messages
:
list
[
dict
[
str
,
str
]]
=
MM_MESSAGES
,
expected_input_ids
:
list
[
int
]
=
INPUT_IDS
,
expected_labels
:
list
[
int
]
=
LABELS
,
expected_mm_inputs
:
dict
[
str
,
Any
]
=
{},
expected_no_mm_inputs
:
dict
[
str
,
Any
]
=
{},
)
->
None
:
if
plugin
.
__class__
.
__name__
==
"Qwen2OmniPlugin"
:
# test omni_messages
assert
plugin
.
process_messages
(
OMNI_MESSAGES
,
IMAGES
,
NO_VIDEOS
,
AUDIOS
,
processor
)
==
expected_mm_messages
assert
plugin
.
process_token_ids
(
INPUT_IDS
,
LABELS
,
IMAGES
,
NO_VIDEOS
,
AUDIOS
,
tokenizer
,
processor
)
==
(
expected_input_ids
,
expected_labels
,
)
_is_close
(
plugin
.
get_mm_inputs
(
IMAGES
,
NO_VIDEOS
,
AUDIOS
,
IMGLENS
,
NO_VIDLENS
,
AUDLENS
,
BATCH_IDS
,
processor
),
expected_mm_inputs
,
)
elif
plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
# test mm_messages
assert
plugin
.
process_messages
(
MM_MESSAGES
,
IMAGES
,
NO_VIDEOS
,
NO_AUDIOS
,
processor
)
==
expected_mm_messages
assert
plugin
.
process_token_ids
(
INPUT_IDS
,
LABELS
,
IMAGES
,
NO_VIDEOS
,
NO_AUDIOS
,
tokenizer
,
processor
)
==
(
expected_input_ids
,
expected_labels
,
)
_is_close
(
plugin
.
get_mm_inputs
(
IMAGES
,
NO_VIDEOS
,
NO_AUDIOS
,
IMGLENS
,
NO_VIDLENS
,
NO_AUDLENS
,
BATCH_IDS
,
processor
),
expected_mm_inputs
,
)
# test text_messages
assert
plugin
.
process_messages
(
TEXT_MESSAGES
,
NO_IMAGES
,
NO_VIDEOS
,
NO_AUDIOS
,
processor
)
==
TEXT_MESSAGES
assert
plugin
.
process_token_ids
(
INPUT_IDS
,
LABELS
,
NO_IMAGES
,
NO_VIDEOS
,
NO_AUDIOS
,
tokenizer
,
processor
)
==
(
INPUT_IDS
,
LABELS
,
)
_is_close
(
plugin
.
get_mm_inputs
(
NO_IMAGES
,
NO_VIDEOS
,
NO_AUDIOS
,
NO_IMGLENS
,
NO_VIDLENS
,
NO_AUDLENS
,
BATCH_IDS
,
processor
),
expected_no_mm_inputs
,
)
def
test_base_plugin
():
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
TINY_LLAMA3
)
base_plugin
=
get_mm_plugin
(
name
=
"base"
)
check_inputs
=
{
"plugin"
:
base_plugin
,
**
tokenizer_module
}
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.50.0"
),
reason
=
"Requires transformers>=4.50.0"
)
def
test_gemma3_plugin
():
image_seqlen
=
256
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"google/gemma-3-4b-it"
)
gemma3_plugin
=
get_mm_plugin
(
name
=
"gemma3"
,
image_token
=
"<image_soft_token>"
)
image_tokens_expanded
=
"<image_soft_token>"
*
image_seqlen
check_inputs
=
{
"plugin"
:
gemma3_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
f
"
\n\n
<start_of_image>
{
image_tokens_expanded
}
<end_of_image>
\n\n
"
)
for
key
,
value
in
message
.
items
()
}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"processor"
])
check_inputs
[
"expected_mm_inputs"
].
pop
(
"num_crops"
)
check_inputs
[
"expected_mm_inputs"
][
"token_type_ids"
]
=
[[
0
]
*
1024
]
check_inputs
[
"expected_no_mm_inputs"
]
=
{
"token_type_ids"
:
[[
0
]
*
1024
]}
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.52.0"
),
reason
=
"Requires transformers>=4.52.0"
)
def
test_internvl_plugin
():
image_seqlen
=
256
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"OpenGVLab/InternVL3-1B-hf"
)
internvl_plugin
=
get_mm_plugin
(
"intern_vl"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
)
check_inputs
=
{
"plugin"
:
internvl_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
f
"<img>
{
'<IMG_CONTEXT>'
*
image_seqlen
*
1
}
</img>"
)
for
key
,
value
in
message
.
items
()
}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"processor"
])
check_inputs
[
"expected_mm_inputs"
].
pop
(
"num_patches"
,
None
)
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.51.0"
),
reason
=
"Requires transformers>=4.51.0"
)
def
test_llama4_plugin
():
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
TINY_LLAMA4
)
processor
=
tokenizer_module
[
"processor"
]
llama4_plugin
=
get_mm_plugin
(
name
=
"llama4"
,
image_token
=
"<|image|>"
)
check_inputs
=
{
"plugin"
:
llama4_plugin
,
**
tokenizer_module
}
mm_inputs
=
_get_mm_inputs
(
tokenizer_module
[
"processor"
])
image_height
,
image_width
=
mm_inputs
[
"pixel_values"
][
0
].
shape
[
-
2
:]
num_patches_per_chunk
=
int
(
(
image_height
//
processor
.
patch_size
)
*
(
image_width
//
processor
.
patch_size
)
//
processor
.
downsample_ratio
)
aspect_ratios
=
mm_inputs
.
pop
(
"aspect_ratios"
)
tokens_for_this_image
=
processor
.
_prompt_split_image
(
aspect_ratios
[
0
],
num_patches_per_chunk
)
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
tokens_for_this_image
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
mm_inputs
_check_plugin
(
**
check_inputs
)
def
test_llava_plugin
():
image_seqlen
=
576
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-1.5-7b-hf"
)
llava_plugin
=
get_mm_plugin
(
name
=
"llava"
,
image_token
=
"<image>"
)
check_inputs
=
{
"plugin"
:
llava_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"processor"
])
_check_plugin
(
**
check_inputs
)
def
test_llava_next_plugin
():
image_seqlen
=
1176
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-v1.6-vicuna-7b-hf"
)
llava_next_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
)
check_inputs
=
{
"plugin"
:
llava_next_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"processor"
])
_check_plugin
(
**
check_inputs
)
def
test_llava_next_video_plugin
():
image_seqlen
=
1176
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/LLaVA-NeXT-Video-7B-hf"
)
llava_next_video_plugin
=
get_mm_plugin
(
name
=
"llava_next_video"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
)
check_inputs
=
{
"plugin"
:
llava_next_video_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"processor"
])
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
def
test_paligemma_plugin
():
image_seqlen
=
256
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"google/paligemma-3b-pt-224"
)
paligemma_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
)
check_inputs
=
{
"plugin"
:
paligemma_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
""
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_input_ids"
]
=
[
tokenizer_module
[
"tokenizer"
].
convert_tokens_to_ids
(
paligemma_plugin
.
image_token
)
]
*
image_seqlen
+
INPUT_IDS
check_inputs
[
"expected_labels"
]
=
[
-
100
]
*
image_seqlen
+
LABELS
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"processor"
])
check_inputs
[
"expected_mm_inputs"
][
"token_type_ids"
]
=
[[
0
]
*
image_seqlen
+
[
1
]
*
(
1024
-
image_seqlen
)]
check_inputs
[
"expected_no_mm_inputs"
]
=
{
"token_type_ids"
:
[[
1
]
*
1024
]}
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.50.0"
),
reason
=
"Requires transformers>=4.50.0"
)
def
test_pixtral_plugin
():
image_slice_height
,
image_slice_width
=
2
,
2
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"mistral-community/pixtral-12b"
)
pixtral_plugin
=
get_mm_plugin
(
name
=
"pixtral"
,
image_token
=
"[IMG]"
)
check_inputs
=
{
"plugin"
:
pixtral_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
(
"{}[IMG_BREAK]"
.
format
(
"[IMG]"
*
image_slice_width
)
*
image_slice_height
).
rsplit
(
"[IMG_BREAK]"
,
1
)[
0
]
+
"[IMG_END]"
,
)
for
key
,
value
in
message
.
items
()
}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"processor"
])
check_inputs
[
"expected_mm_inputs"
][
"pixel_values"
]
=
check_inputs
[
"expected_mm_inputs"
][
"pixel_values"
][
0
]
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.52.0"
),
reason
=
"Requires transformers>=4.52.0"
)
def
test_qwen2_omni_plugin
():
image_seqlen
,
audio_seqlen
=
4
,
2
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"Qwen/Qwen2.5-Omni-7B"
)
qwen2_omni_plugin
=
get_mm_plugin
(
name
=
"qwen2_omni"
,
audio_token
=
"<|AUDIO|>"
,
image_token
=
"<|IMAGE|>"
,
video_token
=
"<|VIDEO|>"
)
check_inputs
=
{
"plugin"
:
qwen2_omni_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
(
value
.
replace
(
"<image>"
,
f
"<|vision_bos|>
{
'<|IMAGE|>'
*
image_seqlen
}
<|vision_eos|>"
).
replace
(
"<audio>"
,
f
"<|audio_bos|>
{
'<|AUDIO|>'
*
audio_seqlen
}
<|audio_eos|>"
)
)
for
key
,
value
in
message
.
items
()
}
for
message
in
OMNI_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_omni_inputs
(
tokenizer_module
[
"processor"
])
_check_plugin
(
**
check_inputs
)
def
test_qwen2_vl_plugin
():
image_seqlen
=
4
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"Qwen/Qwen2-VL-7B-Instruct"
)
qwen2_vl_plugin
=
get_mm_plugin
(
name
=
"qwen2_vl"
,
image_token
=
"<|image_pad|>"
)
check_inputs
=
{
"plugin"
:
qwen2_vl_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<|vision_start|>{}<|vision_end|>"
.
format
(
"<|image_pad|>"
*
image_seqlen
))
for
key
,
value
in
message
.
items
()
}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"processor"
])
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
skipif
(
not
is_transformers_version_greater_than
(
"4.47.0"
),
reason
=
"Requires transformers>=4.47.0"
)
def
test_video_llava_plugin
():
image_seqlen
=
256
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"LanguageBind/Video-LLaVA-7B-hf"
)
video_llava_plugin
=
get_mm_plugin
(
name
=
"video_llava"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
)
check_inputs
=
{
"plugin"
:
video_llava_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"processor"
])
_check_plugin
(
**
check_inputs
)
docker-hub/qwen2.5-vl/llama-factory/tests/data/test_template.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
TYPE_CHECKING
import
pytest
from
transformers
import
AutoTokenizer
from
llamafactory.data
import
get_template_and_fix_tokenizer
from
llamafactory.data.template
import
parse_template
from
llamafactory.hparams
import
DataArguments
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
HF_TOKEN
=
os
.
getenv
(
"HF_TOKEN"
)
TINY_LLAMA3
=
os
.
getenv
(
"TINY_LLAMA3"
,
"llamafactory/tiny-random-Llama-3"
)
TINY_LLAMA4
=
os
.
getenv
(
"TINY_LLAMA4"
,
"llamafactory/tiny-random-Llama-4"
)
MESSAGES
=
[
{
"role"
:
"user"
,
"content"
:
"How are you"
},
{
"role"
:
"assistant"
,
"content"
:
"I am fine!"
},
{
"role"
:
"user"
,
"content"
:
"你好"
},
{
"role"
:
"assistant"
,
"content"
:
"很高兴认识你!"
},
]
MESSAGES_WITH_THOUGHT
=
[
{
"role"
:
"user"
,
"content"
:
"How are you"
},
{
"role"
:
"assistant"
,
"content"
:
"<think>
\n
Model thought here
\n
</think>
\n\n
I am fine!"
},
{
"role"
:
"user"
,
"content"
:
"你好"
},
{
"role"
:
"assistant"
,
"content"
:
"<think>
\n
模型思考内容
\n
</think>
\n\n
很高兴认识你!"
},
]
def
_check_tokenization
(
tokenizer
:
"PreTrainedTokenizer"
,
batch_input_ids
:
list
[
list
[
int
]],
batch_text
:
list
[
str
]
)
->
None
:
r
"""Check token ids and texts.
encode(text) == token_ids
decode(token_ids) == text
"""
for
input_ids
,
text
in
zip
(
batch_input_ids
,
batch_text
):
assert
tokenizer
.
encode
(
text
,
add_special_tokens
=
False
)
==
input_ids
assert
tokenizer
.
decode
(
input_ids
)
==
text
def
_check_template
(
model_id
:
str
,
template_name
:
str
,
prompt_str
:
str
,
answer_str
:
str
,
use_fast
:
bool
,
messages
:
list
[
dict
[
str
,
str
]]
=
MESSAGES
,
)
->
None
:
r
"""Check template.
Args:
model_id: the model id on hugging face hub.
template_name: the template name.
prompt_str: the string corresponding to the prompt part.
answer_str: the string corresponding to the answer part.
use_fast: whether to use fast tokenizer.
messages: the list of messages.
"""
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
use_fast
=
use_fast
,
token
=
HF_TOKEN
)
content_str
=
tokenizer
.
apply_chat_template
(
messages
,
tokenize
=
False
)
content_ids
=
tokenizer
.
apply_chat_template
(
messages
,
tokenize
=
True
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
DataArguments
(
template
=
template_name
))
prompt_ids
,
answer_ids
=
template
.
encode_oneturn
(
tokenizer
,
messages
)
assert
content_str
==
prompt_str
+
answer_str
assert
content_ids
==
prompt_ids
+
answer_ids
_check_tokenization
(
tokenizer
,
(
prompt_ids
,
answer_ids
),
(
prompt_str
,
answer_str
))
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_encode_oneturn
(
use_fast
:
bool
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
,
use_fast
=
use_fast
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
DataArguments
(
template
=
"llama3"
))
prompt_ids
,
answer_ids
=
template
.
encode_oneturn
(
tokenizer
,
MESSAGES
)
prompt_str
=
(
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>
\n\n
How are you<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
I am fine!<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>
\n\n
你好<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
answer_str
=
"很高兴认识你!<|eot_id|>"
_check_tokenization
(
tokenizer
,
(
prompt_ids
,
answer_ids
),
(
prompt_str
,
answer_str
))
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_encode_multiturn
(
use_fast
:
bool
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
,
use_fast
=
use_fast
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
DataArguments
(
template
=
"llama3"
))
encoded_pairs
=
template
.
encode_multiturn
(
tokenizer
,
MESSAGES
)
prompt_str_1
=
(
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>
\n\n
How are you<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
answer_str_1
=
"I am fine!<|eot_id|>"
prompt_str_2
=
(
"<|start_header_id|>user<|end_header_id|>
\n\n
你好<|eot_id|><|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
answer_str_2
=
"很高兴认识你!<|eot_id|>"
_check_tokenization
(
tokenizer
,
(
encoded_pairs
[
0
][
0
],
encoded_pairs
[
0
][
1
],
encoded_pairs
[
1
][
0
],
encoded_pairs
[
1
][
1
]),
(
prompt_str_1
,
answer_str_1
,
prompt_str_2
,
answer_str_2
),
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"cot_messages"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enable_thinking"
,
[
True
,
False
,
None
])
def
test_reasoning_encode_oneturn
(
use_fast
:
bool
,
cot_messages
:
bool
,
enable_thinking
:
bool
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen3-8B"
,
use_fast
=
use_fast
)
data_args
=
DataArguments
(
template
=
"qwen3"
,
enable_thinking
=
enable_thinking
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
prompt_ids
,
answer_ids
=
template
.
encode_oneturn
(
tokenizer
,
MESSAGES_WITH_THOUGHT
if
cot_messages
else
MESSAGES
)
prompt_str
=
(
f
"<|im_start|>user
\n
{
MESSAGES
[
0
][
'content'
]
}
<|im_end|>
\n
<|im_start|>assistant
\n
"
f
"
{
MESSAGES
[
1
][
'content'
]
}
<|im_end|>
\n
"
f
"<|im_start|>user
\n
{
MESSAGES
[
2
][
'content'
]
}
<|im_end|>
\n
<|im_start|>assistant
\n
"
)
if
not
cot_messages
or
enable_thinking
is
False
:
answer_str
=
f
"
{
MESSAGES
[
3
][
'content'
]
}
<|im_end|>
\n
"
if
enable_thinking
:
answer_str
=
"<think>
\n\n
</think>
\n\n
"
+
answer_str
else
:
prompt_str
=
prompt_str
+
"<think>
\n\n
</think>
\n\n
"
else
:
answer_str
=
f
"
{
MESSAGES_WITH_THOUGHT
[
3
][
'content'
]
}
<|im_end|>
\n
"
_check_tokenization
(
tokenizer
,
(
prompt_ids
,
answer_ids
),
(
prompt_str
,
answer_str
))
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"cot_messages"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enable_thinking"
,
[
True
,
False
,
None
])
def
test_reasoning_encode_multiturn
(
use_fast
:
bool
,
cot_messages
:
bool
,
enable_thinking
:
bool
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen3-8B"
,
use_fast
=
use_fast
)
data_args
=
DataArguments
(
template
=
"qwen3"
,
enable_thinking
=
enable_thinking
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
encoded_pairs
=
template
.
encode_multiturn
(
tokenizer
,
MESSAGES_WITH_THOUGHT
if
cot_messages
else
MESSAGES
)
messages
=
MESSAGES
if
not
cot_messages
or
enable_thinking
is
False
else
MESSAGES_WITH_THOUGHT
prompt_str_1
=
f
"<|im_start|>user
\n
{
MESSAGES
[
0
][
'content'
]
}
<|im_end|>
\n
<|im_start|>assistant
\n
"
answer_str_1
=
f
"
{
messages
[
1
][
'content'
]
}
<|im_end|>
\n
"
prompt_str_2
=
f
"<|im_start|>user
\n
{
MESSAGES
[
2
][
'content'
]
}
<|im_end|>
\n
<|im_start|>assistant
\n
"
answer_str_2
=
f
"
{
messages
[
3
][
'content'
]
}
<|im_end|>
\n
"
if
not
cot_messages
or
enable_thinking
is
False
:
if
enable_thinking
:
answer_str_1
=
"<think>
\n\n
</think>
\n\n
"
+
answer_str_1
answer_str_2
=
"<think>
\n\n
</think>
\n\n
"
+
answer_str_2
else
:
prompt_str_1
=
prompt_str_1
+
"<think>
\n\n
</think>
\n\n
"
prompt_str_2
=
prompt_str_2
+
"<think>
\n\n
</think>
\n\n
"
_check_tokenization
(
tokenizer
,
(
encoded_pairs
[
0
][
0
],
encoded_pairs
[
0
][
1
],
encoded_pairs
[
1
][
0
],
encoded_pairs
[
1
][
1
]),
(
prompt_str_1
,
answer_str_1
,
prompt_str_2
,
answer_str_2
),
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_jinja_template
(
use_fast
:
bool
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
,
use_fast
=
use_fast
)
ref_tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
,
use_fast
=
use_fast
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
DataArguments
(
template
=
"llama3"
))
tokenizer
.
chat_template
=
template
.
_get_jinja_template
(
tokenizer
)
# llama3 template no replace
assert
tokenizer
.
chat_template
!=
ref_tokenizer
.
chat_template
assert
tokenizer
.
apply_chat_template
(
MESSAGES
)
==
ref_tokenizer
.
apply_chat_template
(
MESSAGES
)
def
test_ollama_modelfile
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
DataArguments
(
template
=
"llama3"
))
assert
template
.
get_ollama_modelfile
(
tokenizer
)
==
(
"# ollama modelfile auto-generated by llamafactory
\n\n
"
"FROM .
\n\n
"
'TEMPLATE """<|begin_of_text|>'
"{{ if .System }}<|start_header_id|>system<|end_header_id|>
\n\n
{{ .System }}<|eot_id|>{{ end }}"
'{{ range .Messages }}{{ if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>
\n\n
{{ .Content }}'
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>
\n\n
"
'{{ else if eq .Role "assistant" }}{{ .Content }}<|eot_id|>{{ end }}{{ end }}"""
\n\n
'
'PARAMETER stop "<|eom_id|>"
\n
'
'PARAMETER stop "<|eot_id|>"
\n
'
"PARAMETER num_ctx 4096
\n
"
)
def
test_get_stop_token_ids
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
DataArguments
(
template
=
"llama3"
))
assert
set
(
template
.
get_stop_token_ids
(
tokenizer
))
==
{
128008
,
128009
}
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_gemma_template
(
use_fast
:
bool
):
prompt_str
=
(
f
"<bos><start_of_turn>user
\n
{
MESSAGES
[
0
][
'content'
]
}
<end_of_turn>
\n
"
f
"<start_of_turn>model
\n
{
MESSAGES
[
1
][
'content'
]
}
<end_of_turn>
\n
"
f
"<start_of_turn>user
\n
{
MESSAGES
[
2
][
'content'
]
}
<end_of_turn>
\n
"
"<start_of_turn>model
\n
"
)
answer_str
=
f
"
{
MESSAGES
[
3
][
'content'
]
}
<end_of_turn>
\n
"
_check_template
(
"google/gemma-3-4b-it"
,
"gemma"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_gemma2_template
(
use_fast
:
bool
):
prompt_str
=
(
f
"<bos><start_of_turn>user
\n
{
MESSAGES
[
0
][
'content'
]
}
<end_of_turn>
\n
"
f
"<start_of_turn>model
\n
{
MESSAGES
[
1
][
'content'
]
}
<end_of_turn>
\n
"
f
"<start_of_turn>user
\n
{
MESSAGES
[
2
][
'content'
]
}
<end_of_turn>
\n
"
"<start_of_turn>model
\n
"
)
answer_str
=
f
"
{
MESSAGES
[
3
][
'content'
]
}
<end_of_turn>
\n
"
_check_template
(
"google/gemma-2-2b-it"
,
"gemma2"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_llama3_template
(
use_fast
:
bool
):
prompt_str
=
(
f
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>
\n\n
{
MESSAGES
[
0
][
'content'
]
}
<|eot_id|>"
f
"<|start_header_id|>assistant<|end_header_id|>
\n\n
{
MESSAGES
[
1
][
'content'
]
}
<|eot_id|>"
f
"<|start_header_id|>user<|end_header_id|>
\n\n
{
MESSAGES
[
2
][
'content'
]
}
<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
answer_str
=
f
"
{
MESSAGES
[
3
][
'content'
]
}
<|eot_id|>"
_check_template
(
"meta-llama/Meta-Llama-3-8B-Instruct"
,
"llama3"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
pytest
.
param
(
False
,
marks
=
pytest
.
mark
.
xfail
(
reason
=
"Llama 4 has no slow tokenizer."
))]
)
def
test_llama4_template
(
use_fast
:
bool
):
prompt_str
=
(
f
"<|begin_of_text|><|header_start|>user<|header_end|>
\n\n
{
MESSAGES
[
0
][
'content'
]
}
<|eot|>"
f
"<|header_start|>assistant<|header_end|>
\n\n
{
MESSAGES
[
1
][
'content'
]
}
<|eot|>"
f
"<|header_start|>user<|header_end|>
\n\n
{
MESSAGES
[
2
][
'content'
]
}
<|eot|>"
"<|header_start|>assistant<|header_end|>
\n\n
"
)
answer_str
=
f
"
{
MESSAGES
[
3
][
'content'
]
}
<|eot|>"
_check_template
(
TINY_LLAMA4
,
"llama4"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
pytest
.
param
(
True
,
marks
=
pytest
.
mark
.
xfail
(
not
HF_TOKEN
,
reason
=
"Authorization."
)),
pytest
.
param
(
False
,
marks
=
pytest
.
mark
.
xfail
(
reason
=
"Phi-4 slow tokenizer is broken."
)),
],
)
def
test_phi4_template
(
use_fast
:
bool
):
prompt_str
=
(
f
"<|im_start|>user<|im_sep|>
{
MESSAGES
[
0
][
'content'
]
}
<|im_end|>"
f
"<|im_start|>assistant<|im_sep|>
{
MESSAGES
[
1
][
'content'
]
}
<|im_end|>"
f
"<|im_start|>user<|im_sep|>
{
MESSAGES
[
2
][
'content'
]
}
<|im_end|>"
"<|im_start|>assistant<|im_sep|>"
)
answer_str
=
f
"
{
MESSAGES
[
3
][
'content'
]
}
<|im_end|>"
_check_template
(
"microsoft/phi-4"
,
"phi4"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
xfail
(
not
HF_TOKEN
,
reason
=
"Authorization."
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_qwen2_5_template
(
use_fast
:
bool
):
prompt_str
=
(
"<|im_start|>system
\n
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
\n
"
f
"<|im_start|>user
\n
{
MESSAGES
[
0
][
'content'
]
}
<|im_end|>
\n
"
f
"<|im_start|>assistant
\n
{
MESSAGES
[
1
][
'content'
]
}
<|im_end|>
\n
"
f
"<|im_start|>user
\n
{
MESSAGES
[
2
][
'content'
]
}
<|im_end|>
\n
"
"<|im_start|>assistant
\n
"
)
answer_str
=
f
"
{
MESSAGES
[
3
][
'content'
]
}
<|im_end|>
\n
"
_check_template
(
"Qwen/Qwen2.5-7B-Instruct"
,
"qwen"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"cot_messages"
,
[
True
,
False
])
def
test_qwen3_template
(
use_fast
:
bool
,
cot_messages
:
bool
):
prompt_str
=
(
f
"<|im_start|>user
\n
{
MESSAGES
[
0
][
'content'
]
}
<|im_end|>
\n
"
f
"<|im_start|>assistant
\n
{
MESSAGES
[
1
][
'content'
]
}
<|im_end|>
\n
"
f
"<|im_start|>user
\n
{
MESSAGES
[
2
][
'content'
]
}
<|im_end|>
\n
"
"<|im_start|>assistant
\n
"
)
if
not
cot_messages
:
answer_str
=
f
"<think>
\n\n
</think>
\n\n
{
MESSAGES
[
3
][
'content'
]
}
<|im_end|>
\n
"
messages
=
MESSAGES
else
:
answer_str
=
f
"
{
MESSAGES_WITH_THOUGHT
[
3
][
'content'
]
}
<|im_end|>
\n
"
messages
=
MESSAGES_WITH_THOUGHT
_check_template
(
"Qwen/Qwen3-8B"
,
"qwen3"
,
prompt_str
,
answer_str
,
use_fast
,
messages
=
messages
)
def
test_parse_llama3_template
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA3
,
token
=
HF_TOKEN
)
template
=
parse_template
(
tokenizer
)
assert
template
.
format_user
.
slots
==
[
"<|start_header_id|>user<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
]
assert
template
.
format_assistant
.
slots
==
[
"{{content}}<|eot_id|>"
]
assert
template
.
format_system
.
slots
==
[
"<|start_header_id|>system<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
]
assert
template
.
format_prefix
.
slots
==
[
"<|begin_of_text|>"
]
assert
template
.
default_system
==
""
@
pytest
.
mark
.
xfail
(
not
HF_TOKEN
,
reason
=
"Authorization."
)
def
test_parse_qwen_template
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen2.5-7B-Instruct"
,
token
=
HF_TOKEN
)
template
=
parse_template
(
tokenizer
)
assert
template
.
__class__
.
__name__
==
"Template"
assert
template
.
format_user
.
slots
==
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]
assert
template
.
format_assistant
.
slots
==
[
"{{content}}<|im_end|>
\n
"
]
assert
template
.
format_system
.
slots
==
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]
assert
template
.
format_prefix
.
slots
==
[]
assert
template
.
default_system
==
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
@
pytest
.
mark
.
xfail
(
not
HF_TOKEN
,
reason
=
"Authorization."
)
def
test_parse_qwen3_template
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen3-8B"
,
token
=
HF_TOKEN
)
template
=
parse_template
(
tokenizer
)
assert
template
.
__class__
.
__name__
==
"ReasoningTemplate"
assert
template
.
format_user
.
slots
==
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]
assert
template
.
format_assistant
.
slots
==
[
"{{content}}<|im_end|>
\n
"
]
assert
template
.
format_system
.
slots
==
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]
assert
template
.
format_prefix
.
slots
==
[]
assert
template
.
default_system
==
""
docker-hub/qwen2.5-vl/llama-factory/tests/e2e/test_chat.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
llamafactory.chat
import
ChatModel
TINY_LLAMA3
=
os
.
getenv
(
"TINY_LLAMA3"
,
"llamafactory/tiny-random-Llama-3"
)
INFER_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA3
,
"finetuning_type"
:
"lora"
,
"template"
:
"llama3"
,
"infer_dtype"
:
"float16"
,
"do_sample"
:
False
,
"max_new_tokens"
:
1
,
}
MESSAGES
=
[
{
"role"
:
"user"
,
"content"
:
"Hi"
},
]
EXPECTED_RESPONSE
=
"_rho"
def
test_chat
():
chat_model
=
ChatModel
(
INFER_ARGS
)
assert
chat_model
.
chat
(
MESSAGES
)[
0
].
response_text
==
EXPECTED_RESPONSE
def
test_stream_chat
():
chat_model
=
ChatModel
(
INFER_ARGS
)
response
=
""
for
token
in
chat_model
.
stream_chat
(
MESSAGES
):
response
+=
token
assert
response
==
EXPECTED_RESPONSE
docker-hub/qwen2.5-vl/llama-factory/tests/e2e/test_sglang.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
pytest
from
llamafactory.chat
import
ChatModel
from
llamafactory.extras.packages
import
is_sglang_available
MODEL_NAME
=
"Qwen/Qwen2.5-0.5B"
INFER_ARGS
=
{
"model_name_or_path"
:
MODEL_NAME
,
"finetuning_type"
:
"lora"
,
"template"
:
"llama3"
,
"infer_dtype"
:
"float16"
,
"infer_backend"
:
"sglang"
,
"do_sample"
:
False
,
"max_new_tokens"
:
1
,
}
MESSAGES
=
[
{
"role"
:
"user"
,
"content"
:
"Hi"
},
]
@
pytest
.
mark
.
skipif
(
not
is_sglang_available
(),
reason
=
"SGLang is not installed"
)
def
test_chat
():
r
"""Test the SGLang engine's basic chat functionality."""
chat_model
=
ChatModel
(
INFER_ARGS
)
response
=
chat_model
.
chat
(
MESSAGES
)[
0
]
# TODO: Change to EXPECTED_RESPONSE
print
(
response
.
response_text
)
@
pytest
.
mark
.
skipif
(
not
is_sglang_available
(),
reason
=
"SGLang is not installed"
)
def
test_stream_chat
():
r
"""Test the SGLang engine's streaming chat functionality."""
chat_model
=
ChatModel
(
INFER_ARGS
)
response
=
""
for
token
in
chat_model
.
stream_chat
(
MESSAGES
):
response
+=
token
print
(
"Complete response:"
,
response
)
assert
response
,
"Should receive a non-empty response"
# Run tests if executed directly
if
__name__
==
"__main__"
:
if
not
is_sglang_available
():
print
(
"SGLang is not available. Please install it."
)
sys
.
exit
(
1
)
test_chat
()
test_stream_chat
()
docker-hub/qwen2.5-vl/llama-factory/tests/e2e/test_train.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
pytest
from
llamafactory.train.tuner
import
export_model
,
run_exp
DEMO_DATA
=
os
.
getenv
(
"DEMO_DATA"
,
"llamafactory/demo_data"
)
TINY_LLAMA3
=
os
.
getenv
(
"TINY_LLAMA3"
,
"llamafactory/tiny-random-Llama-3"
)
TINY_LLAMA_ADAPTER
=
os
.
getenv
(
"TINY_LLAMA_ADAPTER"
,
"llamafactory/tiny-random-Llama-3-lora"
)
TRAIN_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA3
,
"do_train"
:
True
,
"finetuning_type"
:
"lora"
,
"dataset_dir"
:
"REMOTE:"
+
DEMO_DATA
,
"template"
:
"llama3"
,
"cutoff_len"
:
1
,
"overwrite_output_dir"
:
True
,
"per_device_train_batch_size"
:
1
,
"max_steps"
:
1
,
"report_to"
:
"none"
,
}
INFER_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA3
,
"adapter_name_or_path"
:
TINY_LLAMA_ADAPTER
,
"finetuning_type"
:
"lora"
,
"template"
:
"llama3"
,
"infer_dtype"
:
"float16"
,
}
OS_NAME
=
os
.
getenv
(
"OS_NAME"
,
""
)
@
pytest
.
mark
.
parametrize
(
"stage,dataset"
,
[
(
"pt"
,
"c4_demo"
),
(
"sft"
,
"alpaca_en_demo"
),
(
"dpo"
,
"dpo_en_demo"
),
(
"kto"
,
"kto_en_demo"
),
pytest
.
param
(
"rm"
,
"dpo_en_demo"
,
marks
=
pytest
.
mark
.
xfail
(
OS_NAME
.
startswith
(
"windows"
),
reason
=
"OS error."
)),
],
)
def
test_run_exp
(
stage
:
str
,
dataset
:
str
):
output_dir
=
os
.
path
.
join
(
"output"
,
f
"train_
{
stage
}
"
)
run_exp
({
"stage"
:
stage
,
"dataset"
:
dataset
,
"output_dir"
:
output_dir
,
**
TRAIN_ARGS
})
assert
os
.
path
.
exists
(
output_dir
)
def
test_export
():
export_dir
=
os
.
path
.
join
(
"output"
,
"llama3_export"
)
export_model
({
"export_dir"
:
export_dir
,
**
INFER_ARGS
})
assert
os
.
path
.
exists
(
export_dir
)
docker-hub/qwen2.5-vl/llama-factory/tests/eval/test_eval_template.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
llamafactory.eval.template
import
get_eval_template
def
test_eval_template_en
():
support_set
=
[
{
"question"
:
"Fewshot question"
,
"A"
:
"Fewshot1"
,
"B"
:
"Fewshot2"
,
"C"
:
"Fewshot3"
,
"D"
:
"Fewshot4"
,
"answer"
:
"B"
,
}
]
example
=
{
"question"
:
"Target question"
,
"A"
:
"Target1"
,
"B"
:
"Target2"
,
"C"
:
"Target3"
,
"D"
:
"Target4"
,
"answer"
:
"C"
,
}
template
=
get_eval_template
(
name
=
"en"
)
messages
=
template
.
format_example
(
example
,
support_set
=
support_set
,
subject_name
=
"SubName"
)
assert
messages
==
[
{
"role"
:
"user"
,
"content"
:
(
"The following are multiple choice questions (with answers) about SubName.
\n\n
"
"Fewshot question
\n
A. Fewshot1
\n
B. Fewshot2
\n
C. Fewshot3
\n
D. Fewshot4
\n
Answer:"
),
},
{
"role"
:
"assistant"
,
"content"
:
"B"
},
{
"role"
:
"user"
,
"content"
:
"Target question
\n
A. Target1
\n
B. Target2
\n
C. Target3
\n
D. Target4
\n
Answer:"
,
},
{
"role"
:
"assistant"
,
"content"
:
"C"
},
]
def
test_eval_template_zh
():
support_set
=
[
{
"question"
:
"示例问题"
,
"A"
:
"示例答案1"
,
"B"
:
"示例答案2"
,
"C"
:
"示例答案3"
,
"D"
:
"示例答案4"
,
"answer"
:
"B"
,
}
]
example
=
{
"question"
:
"目标问题"
,
"A"
:
"目标答案1"
,
"B"
:
"目标答案2"
,
"C"
:
"目标答案3"
,
"D"
:
"目标答案4"
,
"answer"
:
"C"
,
}
template
=
get_eval_template
(
name
=
"zh"
)
messages
=
template
.
format_example
(
example
,
support_set
=
support_set
,
subject_name
=
"主题"
)
assert
messages
==
[
{
"role"
:
"user"
,
"content"
:
(
"以下是中国关于主题考试的单项选择题,请选出其中的正确答案。
\n\n
"
"示例问题
\n
A. 示例答案1
\n
B. 示例答案2
\n
C. 示例答案3
\n
D. 示例答案4
\n
答案:"
),
},
{
"role"
:
"assistant"
,
"content"
:
"B"
},
{
"role"
:
"user"
,
"content"
:
"目标问题
\n
A. 目标答案1
\n
B. 目标答案2
\n
C. 目标答案3
\n
D. 目标答案4
\n
答案:"
,
},
{
"role"
:
"assistant"
,
"content"
:
"C"
},
]
docker-hub/qwen2.5-vl/llama-factory/tests/model/model_utils/test_add_tokens.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
pytest
from
llamafactory.hparams
import
ModelArguments
from
llamafactory.model
import
load_tokenizer
TINY_LLAMA3
=
os
.
getenv
(
"TINY_LLAMA3"
,
"llamafactory/tiny-random-Llama-3"
)
UNUSED_TOKEN
=
"<|UNUSED_TOKEN|>"
@
pytest
.
mark
.
parametrize
(
"special_tokens"
,
[
False
,
True
])
def
test_add_tokens
(
special_tokens
:
bool
):
if
special_tokens
:
model_args
=
ModelArguments
(
model_name_or_path
=
TINY_LLAMA3
,
add_special_tokens
=
UNUSED_TOKEN
)
else
:
model_args
=
ModelArguments
(
model_name_or_path
=
TINY_LLAMA3
,
add_tokens
=
UNUSED_TOKEN
)
tokenizer
=
load_tokenizer
(
model_args
)[
"tokenizer"
]
encoded_ids
=
tokenizer
.
encode
(
UNUSED_TOKEN
,
add_special_tokens
=
False
)
assert
len
(
encoded_ids
)
==
1
decoded_str
=
tokenizer
.
decode
(
encoded_ids
,
skip_special_tokens
=
True
)
if
special_tokens
:
assert
decoded_str
==
""
else
:
assert
decoded_str
==
UNUSED_TOKEN
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
docker-hub/qwen2.5-vl/llama-factory/tests/model/model_utils/test_attention.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
pytest
from
transformers.utils
import
is_flash_attn_2_available
,
is_torch_sdpa_available
from
llamafactory.extras.packages
import
is_transformers_version_greater_than
from
llamafactory.train.test_utils
import
load_infer_model
TINY_LLAMA3
=
os
.
getenv
(
"TINY_LLAMA3"
,
"llamafactory/tiny-random-Llama-3"
)
INFER_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA3
,
"template"
:
"llama3"
,
}
@
pytest
.
mark
.
xfail
(
is_transformers_version_greater_than
(
"4.48"
),
reason
=
"Attention refactor."
)
def
test_attention
():
attention_available
=
[
"disabled"
]
if
is_torch_sdpa_available
():
attention_available
.
append
(
"sdpa"
)
if
is_flash_attn_2_available
():
attention_available
.
append
(
"fa2"
)
llama_attention_classes
=
{
"disabled"
:
"LlamaAttention"
,
"sdpa"
:
"LlamaSdpaAttention"
,
"fa2"
:
"LlamaFlashAttention2"
,
}
for
requested_attention
in
attention_available
:
model
=
load_infer_model
(
flash_attn
=
requested_attention
,
**
INFER_ARGS
)
for
module
in
model
.
modules
():
if
"Attention"
in
module
.
__class__
.
__name__
:
assert
module
.
__class__
.
__name__
==
llama_attention_classes
[
requested_attention
]
docker-hub/qwen2.5-vl/llama-factory/tests/model/model_utils/test_checkpointing.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
pytest
import
torch
from
llamafactory.extras.misc
import
get_current_device
from
llamafactory.train.test_utils
import
load_train_model
TINY_LLAMA3
=
os
.
getenv
(
"TINY_LLAMA3"
,
"llamafactory/tiny-random-Llama-3"
)
TRAIN_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA3
,
"stage"
:
"sft"
,
"do_train"
:
True
,
"finetuning_type"
:
"lora"
,
"lora_target"
:
"all"
,
"dataset"
:
"llamafactory/tiny-supervised-dataset"
,
"dataset_dir"
:
"ONLINE"
,
"template"
:
"llama3"
,
"cutoff_len"
:
1024
,
"output_dir"
:
"dummy_dir"
,
"overwrite_output_dir"
:
True
,
"fp16"
:
True
,
}
@
pytest
.
mark
.
parametrize
(
"disable_gradient_checkpointing"
,
[
False
,
True
])
def
test_vanilla_checkpointing
(
disable_gradient_checkpointing
:
bool
):
model
=
load_train_model
(
disable_gradient_checkpointing
=
disable_gradient_checkpointing
,
**
TRAIN_ARGS
)
for
module
in
filter
(
lambda
m
:
hasattr
(
m
,
"gradient_checkpointing"
),
model
.
modules
()):
assert
getattr
(
module
,
"gradient_checkpointing"
)
!=
disable_gradient_checkpointing
def
test_unsloth_gradient_checkpointing
():
model
=
load_train_model
(
use_unsloth_gc
=
True
,
**
TRAIN_ARGS
)
for
module
in
filter
(
lambda
m
:
hasattr
(
m
,
"gradient_checkpointing"
),
model
.
modules
()):
assert
module
.
_gradient_checkpointing_func
.
__self__
.
__name__
==
"UnslothGradientCheckpointing"
def
test_upcast_layernorm
():
model
=
load_train_model
(
upcast_layernorm
=
True
,
**
TRAIN_ARGS
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
ndim
==
1
and
"norm"
in
name
:
assert
param
.
dtype
==
torch
.
float32
def
test_upcast_lmhead_output
():
model
=
load_train_model
(
upcast_lmhead_output
=
True
,
**
TRAIN_ARGS
)
inputs
=
torch
.
randn
((
1
,
16
),
dtype
=
torch
.
float16
,
device
=
get_current_device
())
outputs
:
torch
.
Tensor
=
model
.
get_output_embeddings
()(
inputs
)
assert
outputs
.
dtype
==
torch
.
float32
docker-hub/qwen2.5-vl/llama-factory/tests/model/model_utils/test_misc.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
pytest
import
torch
from
transformers
import
AutoConfig
,
AutoModelForCausalLM
from
llamafactory.model.model_utils.misc
import
find_expanded_modules
HF_TOKEN
=
os
.
getenv
(
"HF_TOKEN"
)
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
def
test_expanded_modules
():
config
=
AutoConfig
.
from_pretrained
(
"meta-llama/Meta-Llama-3-8B-Instruct"
)
with
torch
.
device
(
"meta"
):
model
=
AutoModelForCausalLM
.
from_config
(
config
)
expanded_modules
=
find_expanded_modules
(
model
,
[
"q_proj"
,
"v_proj"
],
num_layer_trainable
=
4
)
assert
expanded_modules
==
[
"model.layers.7.self_attn.q_proj"
,
"model.layers.7.self_attn.v_proj"
,
"model.layers.15.self_attn.q_proj"
,
"model.layers.15.self_attn.v_proj"
,
"model.layers.23.self_attn.q_proj"
,
"model.layers.23.self_attn.v_proj"
,
"model.layers.31.self_attn.q_proj"
,
"model.layers.31.self_attn.v_proj"
,
]
docker-hub/qwen2.5-vl/llama-factory/tests/model/model_utils/test_packing.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pytest
import
torch
from
llamafactory.model.model_utils.packing
import
get_seqlens_in_batch
,
get_unpad_data
@
pytest
.
mark
.
parametrize
(
"attention_mask,golden_seq_lens"
,
[
(
[
[
1
,
1
,
2
,
2
,
2
,
0
],
[
1
,
2
,
2
,
3
,
3
,
3
],
],
[
2
,
3
,
1
,
2
,
3
],
),
(
[[
1
]],
[
1
],
),
],
)
def
test_get_seqlens_in_batch
(
attention_mask
,
golden_seq_lens
):
attention_mask_with_indices
=
torch
.
tensor
(
attention_mask
)
seqlens_in_batch
=
get_seqlens_in_batch
(
attention_mask_with_indices
)
assert
torch
.
all
(
seqlens_in_batch
==
torch
.
tensor
(
golden_seq_lens
))
@
pytest
.
mark
.
parametrize
(
"attention_mask,golden_indices,golden_cu_seqlens,golden_max_seqlen"
,
[
(
[
[
1
,
1
,
2
,
2
,
2
,
0
],
[
1
,
2
,
2
,
3
,
3
,
3
],
],
[
0
,
1
,
2
,
3
,
4
,
6
,
7
,
8
,
9
,
10
,
11
],
[
0
,
2
,
5
,
6
,
8
,
11
],
3
,
),
(
[[
1
]],
[
0
],
[
0
,
1
],
1
,
),
],
)
def
test_get_unpad_data
(
attention_mask
,
golden_indices
,
golden_cu_seqlens
,
golden_max_seqlen
):
attention_mask_with_indices
=
torch
.
tensor
(
attention_mask
)
indices
,
cu_seqlens
,
max_seqlen_in_batch
=
get_unpad_data
(
attention_mask_with_indices
)
assert
torch
.
all
(
indices
==
torch
.
tensor
(
golden_indices
))
assert
torch
.
all
(
cu_seqlens
==
torch
.
tensor
(
golden_cu_seqlens
,
dtype
=
torch
.
int32
))
assert
max_seqlen_in_batch
==
golden_max_seqlen
docker-hub/qwen2.5-vl/llama-factory/tests/model/model_utils/test_visual.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
pytest
import
torch
from
transformers
import
AutoConfig
,
AutoModelForVision2Seq
from
llamafactory.extras.packages
import
is_transformers_version_greater_than
from
llamafactory.hparams
import
FinetuningArguments
,
ModelArguments
from
llamafactory.model.adapter
import
init_adapter
@
pytest
.
mark
.
parametrize
(
"freeze_vision_tower"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"freeze_multi_modal_projector"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"freeze_language_model"
,
(
False
,
True
))
def
test_visual_full
(
freeze_vision_tower
:
bool
,
freeze_multi_modal_projector
:
bool
,
freeze_language_model
:
bool
):
model_args
=
ModelArguments
(
model_name_or_path
=
"Qwen/Qwen2-VL-2B-Instruct"
)
finetuning_args
=
FinetuningArguments
(
finetuning_type
=
"full"
,
freeze_vision_tower
=
freeze_vision_tower
,
freeze_multi_modal_projector
=
freeze_multi_modal_projector
,
freeze_language_model
=
freeze_language_model
,
)
config
=
AutoConfig
.
from_pretrained
(
model_args
.
model_name_or_path
)
with
torch
.
device
(
"meta"
):
model
=
AutoModelForVision2Seq
.
from_config
(
config
)
model
=
init_adapter
(
config
,
model
,
model_args
,
finetuning_args
,
is_trainable
=
True
)
for
name
,
param
in
model
.
named_parameters
():
if
any
(
key
in
name
for
key
in
[
"visual.patch_embed"
,
"visual.blocks"
]):
assert
param
.
requires_grad
!=
freeze_vision_tower
elif
"visual.merger"
in
name
:
assert
param
.
requires_grad
!=
freeze_multi_modal_projector
else
:
assert
param
.
requires_grad
!=
freeze_language_model
@
pytest
.
mark
.
parametrize
(
"freeze_vision_tower,freeze_language_model"
,
((
False
,
False
),
(
False
,
True
),
(
True
,
False
)))
def
test_visual_lora
(
freeze_vision_tower
:
bool
,
freeze_language_model
:
bool
):
model_args
=
ModelArguments
(
model_name_or_path
=
"Qwen/Qwen2-VL-2B-Instruct"
)
finetuning_args
=
FinetuningArguments
(
finetuning_type
=
"lora"
,
freeze_vision_tower
=
freeze_vision_tower
,
freeze_language_model
=
freeze_language_model
)
config
=
AutoConfig
.
from_pretrained
(
model_args
.
model_name_or_path
)
with
torch
.
device
(
"meta"
):
model
=
AutoModelForVision2Seq
.
from_config
(
config
)
model
=
init_adapter
(
config
,
model
,
model_args
,
finetuning_args
,
is_trainable
=
True
)
trainable_params
,
frozen_params
=
set
(),
set
()
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
trainable_params
.
add
(
name
)
else
:
frozen_params
.
add
(
name
)
if
is_transformers_version_greater_than
(
"4.52.0"
):
visual_param_name
=
"base_model.model.model.visual.blocks.0.attn.qkv.lora_A.default.weight"
language_param_name
=
"base_model.model.model.language_model.layers.0.self_attn.q_proj.lora_A.default.weight"
merger_param_name
=
"base_model.model.model.visual.merger.lora_A.default.weight"
else
:
visual_param_name
=
"base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight"
language_param_name
=
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight"
merger_param_name
=
"base_model.model.visual.merger.lora_A.default.weight"
assert
(
visual_param_name
in
trainable_params
)
!=
freeze_vision_tower
assert
(
language_param_name
in
trainable_params
)
!=
freeze_language_model
assert
(
merger_param_name
in
trainable_params
)
is
False
def
test_visual_model_save_load
():
# check VLM's state dict: https://github.com/huggingface/transformers/pull/38385
model_args
=
ModelArguments
(
model_name_or_path
=
"Qwen/Qwen2-VL-2B-Instruct"
)
finetuning_args
=
FinetuningArguments
(
finetuning_type
=
"full"
)
config
=
AutoConfig
.
from_pretrained
(
model_args
.
model_name_or_path
)
with
torch
.
device
(
"meta"
):
model
=
AutoModelForVision2Seq
.
from_config
(
config
)
model
=
init_adapter
(
config
,
model
,
model_args
,
finetuning_args
,
is_trainable
=
False
)
loaded_model_weight
=
dict
(
model
.
named_parameters
())
model
.
save_pretrained
(
os
.
path
.
join
(
"output"
,
"qwen2_vl"
),
max_shard_size
=
"10GB"
,
safe_serialization
=
False
)
saved_model_weight
=
torch
.
load
(
os
.
path
.
join
(
"output"
,
"qwen2_vl"
,
"pytorch_model.bin"
),
weights_only
=
False
)
if
is_transformers_version_greater_than
(
"4.52.0"
):
assert
"model.language_model.layers.0.self_attn.q_proj.weight"
in
loaded_model_weight
else
:
assert
"model.layers.0.self_attn.q_proj.weight"
in
loaded_model_weight
assert
"model.layers.0.self_attn.q_proj.weight"
in
saved_model_weight
docker-hub/qwen2.5-vl/llama-factory/tests/model/test_base.py
0 → 100644
View file @
5ed76316
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
pytest
from
llamafactory.train.test_utils
import
compare_model
,
load_infer_model
,
load_reference_model
,
patch_valuehead_model
TINY_LLAMA3
=
os
.
getenv
(
"TINY_LLAMA3"
,
"llamafactory/tiny-random-Llama-3"
)
TINY_LLAMA_VALUEHEAD
=
os
.
getenv
(
"TINY_LLAMA_VALUEHEAD"
,
"llamafactory/tiny-random-Llama-3-valuehead"
)
INFER_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA3
,
"template"
:
"llama3"
,
"infer_dtype"
:
"float16"
,
}
@
pytest
.
fixture
def
fix_valuehead_cpu_loading
():
patch_valuehead_model
()
def
test_base
():
model
=
load_infer_model
(
**
INFER_ARGS
)
ref_model
=
load_reference_model
(
TINY_LLAMA3
)
compare_model
(
model
,
ref_model
)
@
pytest
.
mark
.
usefixtures
(
"fix_valuehead_cpu_loading"
)
def
test_valuehead
():
model
=
load_infer_model
(
add_valuehead
=
True
,
**
INFER_ARGS
)
ref_model
=
load_reference_model
(
TINY_LLAMA_VALUEHEAD
,
add_valuehead
=
True
)
compare_model
(
model
,
ref_model
)
Prev
1
…
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment