Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
8293100a
Commit
8293100a
authored
Jan 16, 2025
by
luopl
Browse files
update to 0.9.2.dev0
parent
2778a3d0
Changes
124
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
224 additions
and
46 deletions
+224
-46
tests/data/test_template.py
tests/data/test_template.py
+57
-37
tests/model/model_utils/test_checkpointing.py
tests/model/model_utils/test_checkpointing.py
+5
-9
tests/model/model_utils/test_visual.py
tests/model/model_utils/test_visual.py
+77
-0
tests/train/test_sft_trainer.py
tests/train/test_sft_trainer.py
+85
-0
No files found.
tests/data/test_template.py
View file @
8293100a
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
import
os
from
typing
import
TYPE_CHECKING
,
List
,
Sequence
from
typing
import
TYPE_CHECKING
,
Sequence
import
pytest
from
transformers
import
AutoTokenizer
...
...
@@ -42,39 +42,36 @@ MESSAGES = [
def
_check_tokenization
(
tokenizer
:
"PreTrainedTokenizer"
,
batch_input_ids
:
Sequence
[
Sequence
[
int
]],
batch_text
:
Sequence
[
str
]
)
->
None
:
r
"""
Checks token ids and texts.
encode(text) == token_ids
decode(token_ids) == text
"""
for
input_ids
,
text
in
zip
(
batch_input_ids
,
batch_text
):
assert
input_ids
==
tokenizer
.
encode
(
text
,
add_special_tokens
=
False
)
assert
tokenizer
.
encode
(
text
,
add_special_tokens
=
False
)
==
input_ids
assert
tokenizer
.
decode
(
input_ids
)
==
text
def
_check_single_template
(
model_id
:
str
,
template_name
:
str
,
prompt_str
:
str
,
answer_str
:
str
,
extra_str
:
str
,
use_fast
:
bool
)
->
List
[
str
]:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
use_fast
=
use_fast
,
token
=
HF_TOKEN
)
content_str
=
tokenizer
.
apply_chat_template
(
MESSAGES
,
tokenize
=
False
)
content_ids
=
tokenizer
.
apply_chat_template
(
MESSAGES
,
tokenize
=
True
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
DataArguments
(
template
=
template_name
))
prompt_ids
,
answer_ids
=
template
.
encode_oneturn
(
tokenizer
,
MESSAGES
)
assert
content_str
==
prompt_str
+
answer_str
+
extra_str
assert
content_ids
==
prompt_ids
+
answer_ids
+
tokenizer
.
encode
(
extra_str
,
add_special_tokens
=
False
)
_check_tokenization
(
tokenizer
,
(
prompt_ids
,
answer_ids
),
(
prompt_str
,
answer_str
))
return
content_ids
def
_check_template
(
model_id
:
str
,
template_name
:
str
,
prompt_str
:
str
,
answer_str
:
str
,
extra_str
:
str
=
""
)
->
None
:
"""
Checks template for both the slow tokenizer and the fast tokenizer.
def
_check_template
(
model_id
:
str
,
template_name
:
str
,
prompt_str
:
str
,
answer_str
:
str
,
use_fast
:
bool
)
->
None
:
r
"""
Checks template.
Args:
model_id: the model id on hugging face hub.
template_name: the template name.
prompt_str: the string corresponding to the prompt part.
answer_str: the string corresponding to the answer part.
extra_
st
r
:
t
he
extra string in the jinja template of the original
tokenizer.
use_fa
st:
w
he
ther to use fast
tokenizer.
"""
slow_ids
=
_check_single_template
(
model_id
,
template_name
,
prompt_str
,
answer_str
,
extra_str
,
use_fast
=
False
)
fast_ids
=
_check_single_template
(
model_id
,
template_name
,
prompt_str
,
answer_str
,
extra_str
,
use_fast
=
True
)
assert
slow_ids
==
fast_ids
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
use_fast
=
use_fast
,
token
=
HF_TOKEN
)
content_str
=
tokenizer
.
apply_chat_template
(
MESSAGES
,
tokenize
=
False
)
content_ids
=
tokenizer
.
apply_chat_template
(
MESSAGES
,
tokenize
=
True
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
DataArguments
(
template
=
template_name
))
prompt_ids
,
answer_ids
=
template
.
encode_oneturn
(
tokenizer
,
MESSAGES
)
assert
content_str
==
prompt_str
+
answer_str
assert
content_ids
==
prompt_ids
+
answer_ids
_check_tokenization
(
tokenizer
,
(
prompt_ids
,
answer_ids
),
(
prompt_str
,
answer_str
))
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
...
...
@@ -103,8 +100,7 @@ def test_encode_multiturn(use_fast: bool):
)
answer_str_1
=
"I am fine!<|eot_id|>"
prompt_str_2
=
(
"<|start_header_id|>user<|end_header_id|>
\n\n
你好<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
"<|start_header_id|>user<|end_header_id|>
\n\n
你好<|eot_id|><|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
answer_str_2
=
"很高兴认识你!<|eot_id|>"
_check_tokenization
(
...
...
@@ -124,20 +120,28 @@ def test_jinja_template(use_fast: bool):
assert
tokenizer
.
apply_chat_template
(
MESSAGES
)
==
ref_tokenizer
.
apply_chat_template
(
MESSAGES
)
def
test_get_stop_token_ids
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
DataArguments
(
template
=
"llama3"
))
assert
set
(
template
.
get_stop_token_ids
(
tokenizer
))
==
{
128008
,
128009
}
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
def
test_gemma_template
():
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_gemma_template
(
use_fast
:
bool
):
prompt_str
=
(
"<bos><start_of_turn>user
\n
How are you<end_of_turn>
\n
"
"<start_of_turn>model
\n
I am fine!<end_of_turn>
\n
"
"<start_of_turn>user
\n
你好<end_of_turn>
\n
"
"<start_of_turn>model
\n
"
)
answer_str
=
"很高兴认识你!"
_check_template
(
"google/gemma-2-9b-it"
,
"gemma"
,
prompt_str
,
answer_str
,
extra_str
=
"<end_of_turn>
\n
"
)
answer_str
=
"很高兴认识你!
<end_of_turn>
\n
"
_check_template
(
"google/gemma-2-9b-it"
,
"gemma"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
def
test_llama3_template
():
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_llama3_template
(
use_fast
:
bool
):
prompt_str
=
(
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>
\n\n
How are you<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
I am fine!<|eot_id|>"
...
...
@@ -145,10 +149,25 @@ def test_llama3_template():
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
answer_str
=
"很高兴认识你!<|eot_id|>"
_check_template
(
"meta-llama/Meta-Llama-3-8B-Instruct"
,
"llama3"
,
prompt_str
,
answer_str
)
_check_template
(
"meta-llama/Meta-Llama-3-8B-Instruct"
,
"llama3"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
pytest
.
param
(
False
,
marks
=
pytest
.
mark
.
xfail
(
reason
=
"Phi-4 slow tokenizer is broken."
))]
)
def
test_phi4_template
(
use_fast
:
bool
):
prompt_str
=
(
"<|im_start|>user<|im_sep|>How are you<|im_end|>"
"<|im_start|>assistant<|im_sep|>I am fine!<|im_end|>"
"<|im_start|>user<|im_sep|>你好<|im_end|>"
"<|im_start|>assistant<|im_sep|>"
)
answer_str
=
"很高兴认识你!<|im_end|>"
_check_template
(
"microsoft/phi-4"
,
"phi4"
,
prompt_str
,
answer_str
,
use_fast
)
def
test_qwen_template
():
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_qwen_template
(
use_fast
:
bool
):
prompt_str
=
(
"<|im_start|>system
\n
You are a helpful assistant.<|im_end|>
\n
"
"<|im_start|>user
\n
How are you<|im_end|>
\n
"
...
...
@@ -156,17 +175,18 @@ def test_qwen_template():
"<|im_start|>user
\n
你好<|im_end|>
\n
"
"<|im_start|>assistant
\n
"
)
answer_str
=
"很高兴认识你!<|im_end|>"
_check_template
(
"Qwen/Qwen2-7B-Instruct"
,
"qwen"
,
prompt_str
,
answer_str
,
extra_str
=
"
\n
"
)
answer_str
=
"很高兴认识你!<|im_end|>
\n
"
_check_template
(
"Qwen/Qwen2-7B-Instruct"
,
"qwen"
,
prompt_str
,
answer_str
,
use_fast
)
@
pytest
.
mark
.
xfail
(
reason
=
"The fast tokenizer of Yi model is corrupted."
)
def
test_yi_template
():
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
@
pytest
.
mark
.
xfail
(
reason
=
"Yi tokenizer is broken."
)
def
test_yi_template
(
use_fast
:
bool
):
prompt_str
=
(
"<|im_start|>user
\n
How are you<|im_end|>
\n
"
"<|im_start|>assistant
\n
I am fine!<|im_end|>
\n
"
"<|im_start|>user
\n
你好<|im_end|>
\n
"
"<|im_start|>assistant
\n
"
)
answer_str
=
"很高兴认识你!<|im_end|>"
_check_template
(
"01-ai/Yi-1.5-6B-Chat"
,
"yi"
,
prompt_str
,
answer_str
)
answer_str
=
"很高兴认识你!<|im_end|>
\n
"
_check_template
(
"01-ai/Yi-1.5-6B-Chat"
,
"yi"
,
prompt_str
,
answer_str
,
use_fast
)
tests/model/model_utils/test_checkpointing.py
View file @
8293100a
...
...
@@ -14,6 +14,7 @@
import
os
import
pytest
import
torch
from
llamafactory.extras.misc
import
get_current_device
...
...
@@ -39,16 +40,11 @@ TRAIN_ARGS = {
}
def
test_checkpointing_enable
():
model
=
load_train_model
(
disable_gradient_checkpointing
=
False
,
**
TRAIN_ARGS
)
@
pytest
.
mark
.
parametrize
(
"disable_gradient_checkpointing"
,
[
False
,
True
])
def
test_vanilla_checkpointing
(
disable_gradient_checkpointing
:
bool
):
model
=
load_train_model
(
disable_gradient_checkpointing
=
disable_gradient_checkpointing
,
**
TRAIN_ARGS
)
for
module
in
filter
(
lambda
m
:
hasattr
(
m
,
"gradient_checkpointing"
),
model
.
modules
()):
assert
getattr
(
module
,
"gradient_checkpointing"
)
is
True
def
test_checkpointing_disable
():
model
=
load_train_model
(
disable_gradient_checkpointing
=
True
,
**
TRAIN_ARGS
)
for
module
in
filter
(
lambda
m
:
hasattr
(
m
,
"gradient_checkpointing"
),
model
.
modules
()):
assert
getattr
(
module
,
"gradient_checkpointing"
)
is
False
assert
getattr
(
module
,
"gradient_checkpointing"
)
!=
disable_gradient_checkpointing
def
test_unsloth_gradient_checkpointing
():
...
...
tests/model/model_utils/test_visual.py
0 → 100644
View file @
8293100a
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pytest
import
torch
from
transformers
import
AutoConfig
,
AutoModelForVision2Seq
from
llamafactory.hparams
import
FinetuningArguments
,
ModelArguments
from
llamafactory.model.adapter
import
init_adapter
@
pytest
.
mark
.
parametrize
(
"freeze_vision_tower,freeze_multi_modal_projector,train_mm_proj_only"
,
[
(
False
,
False
,
False
),
(
False
,
True
,
False
),
(
True
,
False
,
False
),
(
True
,
True
,
False
),
(
True
,
False
,
True
),
],
)
def
test_visual_full
(
freeze_vision_tower
:
bool
,
freeze_multi_modal_projector
:
bool
,
train_mm_proj_only
:
bool
):
model_args
=
ModelArguments
(
model_name_or_path
=
"Qwen/Qwen2-VL-2B-Instruct"
)
finetuning_args
=
FinetuningArguments
(
finetuning_type
=
"full"
,
freeze_vision_tower
=
freeze_vision_tower
,
freeze_multi_modal_projector
=
freeze_multi_modal_projector
,
train_mm_proj_only
=
train_mm_proj_only
,
)
config
=
AutoConfig
.
from_pretrained
(
model_args
.
model_name_or_path
)
with
torch
.
device
(
"meta"
):
model
=
AutoModelForVision2Seq
.
from_config
(
config
)
model
=
init_adapter
(
config
,
model
,
model_args
,
finetuning_args
,
is_trainable
=
True
)
for
name
,
param
in
model
.
named_parameters
():
if
any
(
key
in
name
for
key
in
[
"visual.patch_embed"
,
"visual.blocks"
]):
assert
param
.
requires_grad
!=
freeze_vision_tower
elif
"visual.merger"
in
name
:
assert
param
.
requires_grad
!=
freeze_multi_modal_projector
else
:
assert
param
.
requires_grad
!=
train_mm_proj_only
@
pytest
.
mark
.
parametrize
(
"freeze_vision_tower"
,
[
False
,
True
])
def
test_visual_lora
(
freeze_vision_tower
:
bool
):
model_args
=
ModelArguments
(
model_name_or_path
=
"Qwen/Qwen2-VL-2B-Instruct"
)
finetuning_args
=
FinetuningArguments
(
finetuning_type
=
"lora"
,
freeze_vision_tower
=
freeze_vision_tower
)
config
=
AutoConfig
.
from_pretrained
(
model_args
.
model_name_or_path
)
with
torch
.
device
(
"meta"
):
model
=
AutoModelForVision2Seq
.
from_config
(
config
)
model
=
init_adapter
(
config
,
model
,
model_args
,
finetuning_args
,
is_trainable
=
True
)
trainable_params
,
frozen_params
=
set
(),
set
()
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
trainable_params
.
add
(
name
)
else
:
frozen_params
.
add
(
name
)
if
freeze_vision_tower
:
assert
"base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight"
not
in
trainable_params
else
:
assert
"base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight"
in
trainable_params
assert
"merger"
not
in
trainable_params
assert
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight"
in
trainable_params
tests/train/test_sft_trainer.py
0 → 100644
View file @
8293100a
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Dict
,
List
import
pytest
from
transformers
import
DataCollatorWithPadding
from
llamafactory.data
import
get_dataset
,
get_template_and_fix_tokenizer
from
llamafactory.hparams
import
get_train_args
from
llamafactory.model
import
load_model
,
load_tokenizer
from
llamafactory.train.sft.trainer
import
CustomSeq2SeqTrainer
DEMO_DATA
=
os
.
getenv
(
"DEMO_DATA"
,
"llamafactory/demo_data"
)
TINY_LLAMA
=
os
.
getenv
(
"TINY_LLAMA"
,
"llamafactory/tiny-random-Llama-3"
)
TRAIN_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA
,
"stage"
:
"sft"
,
"do_train"
:
True
,
"finetuning_type"
:
"lora"
,
"dataset"
:
"llamafactory/tiny-supervised-dataset"
,
"dataset_dir"
:
"ONLINE"
,
"template"
:
"llama3"
,
"cutoff_len"
:
1024
,
"overwrite_cache"
:
False
,
"overwrite_output_dir"
:
True
,
"per_device_train_batch_size"
:
1
,
"max_steps"
:
1
,
}
@
dataclass
class
DataCollatorWithVerbose
(
DataCollatorWithPadding
):
verbose_list
:
List
[
Dict
[
str
,
Any
]]
=
field
(
default_factory
=
list
)
def
__call__
(
self
,
features
:
List
[
Dict
[
str
,
Any
]])
->
Dict
[
str
,
Any
]:
self
.
verbose_list
.
extend
(
features
)
batch
=
super
().
__call__
(
features
)
return
{
k
:
v
[:,
:
1
]
for
k
,
v
in
batch
.
items
()}
# truncate input length
@
pytest
.
mark
.
parametrize
(
"disable_shuffling"
,
[
False
,
True
])
def
test_shuffle
(
disable_shuffling
:
bool
):
model_args
,
data_args
,
training_args
,
finetuning_args
,
_
=
get_train_args
(
{
"output_dir"
:
os
.
path
.
join
(
"output"
,
f
"shuffle
{
str
(
disable_shuffling
).
lower
()
}
"
),
"disable_shuffling"
:
disable_shuffling
,
**
TRAIN_ARGS
,
}
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"sft"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
data_collator
=
DataCollatorWithVerbose
(
tokenizer
=
tokenizer
)
trainer
=
CustomSeq2SeqTrainer
(
model
=
model
,
args
=
training_args
,
finetuning_args
=
finetuning_args
,
data_collator
=
data_collator
,
**
dataset_module
,
**
tokenizer_module
,
)
trainer
.
train
()
if
disable_shuffling
:
assert
data_collator
.
verbose_list
[
0
][
"input_ids"
]
==
dataset_module
[
"train_dataset"
][
0
][
"input_ids"
]
else
:
assert
data_collator
.
verbose_list
[
0
][
"input_ids"
]
!=
dataset_module
[
"train_dataset"
][
0
][
"input_ids"
]
Prev
1
…
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment