Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
27a7ad86
Commit
27a7ad86
authored
Oct 14, 2024
by
luopl
Browse files
update to v0.9.1
parent
731cf9b8
Changes
120
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
411 additions
and
75 deletions
+411
-75
src/llamafactory/train/sft/workflow.py
src/llamafactory/train/sft/workflow.py
+6
-4
src/llamafactory/train/test_utils.py
src/llamafactory/train/test_utils.py
+5
-4
src/llamafactory/train/trainer_utils.py
src/llamafactory/train/trainer_utils.py
+3
-0
src/llamafactory/train/tuner.py
src/llamafactory/train/tuner.py
+4
-4
src/llamafactory/webui/chatter.py
src/llamafactory/webui/chatter.py
+4
-6
src/llamafactory/webui/common.py
src/llamafactory/webui/common.py
+4
-14
src/llamafactory/webui/components/chatbot.py
src/llamafactory/webui/components/chatbot.py
+9
-4
src/llamafactory/webui/components/export.py
src/llamafactory/webui/components/export.py
+0
-3
src/llamafactory/webui/components/infer.py
src/llamafactory/webui/components/infer.py
+5
-4
src/llamafactory/webui/components/top.py
src/llamafactory/webui/components/top.py
+6
-8
src/llamafactory/webui/engine.py
src/llamafactory/webui/engine.py
+1
-1
src/llamafactory/webui/locales.py
src/llamafactory/webui/locales.py
+15
-15
src/llamafactory/webui/manager.py
src/llamafactory/webui/manager.py
+0
-1
src/llamafactory/webui/runner.py
src/llamafactory/webui/runner.py
+1
-2
tests/data/test_mm_plugin.py
tests/data/test_mm_plugin.py
+208
-0
tests/data/test_template.py
tests/data/test_template.py
+8
-5
tests/e2e/test_chat.py
tests/e2e/test_chat.py
+49
-0
tests/e2e/test_train.py
tests/e2e/test_train.py
+71
-0
tests/model/model_utils/test_checkpointing.py
tests/model/model_utils/test_checkpointing.py
+6
-0
tests/model/test_pissa.py
tests/model/test_pissa.py
+6
-0
No files found.
src/llamafactory/train/sft/workflow.py
View file @
27a7ad86
...
...
@@ -17,7 +17,7 @@
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
...data
import
SFTDataCollatorWith4DAttentionMask
,
get_dataset
from
...data
import
SFTDataCollatorWith4DAttentionMask
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.misc
import
get_logits_processor
from
...extras.ploting
import
plot_loss
...
...
@@ -43,25 +43,27 @@ def run_sft(
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
dataset_module
=
get_dataset
(
model_args
,
data_args
,
training_args
,
stage
=
"sft"
,
**
tokenizer_module
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"sft"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
if
getattr
(
model
,
"is_quantized"
,
False
)
and
not
training_args
.
do_train
:
setattr
(
model
,
"_hf_peft_config_loaded"
,
True
)
# hack here: make model compatible with prediction
data_collator
=
SFTDataCollatorWith4DAttentionMask
(
t
okenizer
=
tokenizer
,
t
emplate
=
template
,
pad_to_multiple_of
=
8
if
training_args
.
do_train
else
None
,
# for shift short attention
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
block_diag_attn
=
model_args
.
block_diag_attn
,
attn_implementation
=
getattr
(
model
.
config
,
"_attn_implementation"
,
None
),
compute_dtype
=
model_args
.
compute_dtype
,
**
tokenizer_module
,
)
# Override the decoding parameters of Seq2SeqTrainer
training_args
.
generation_max_length
=
training_args
.
generation_max_length
or
data_args
.
cutoff_len
training_args
.
generation_num_beams
=
data_args
.
eval_num_beams
or
training_args
.
generation_num_beams
training_args
.
remove_unused_columns
=
False
if
model_args
.
visual_inputs
else
training_args
.
remove_unused_columns
training_args
.
remove_unused_columns
=
False
# important for multimodal dataset
# Metric utils
metric_module
=
{}
...
...
src/llamafactory/train/test_utils.py
View file @
27a7ad86
...
...
@@ -19,7 +19,7 @@ from peft import PeftModel
from
transformers
import
AutoModelForCausalLM
from
trl
import
AutoModelForCausalLMWithValueHead
from
..data
import
get_dataset
from
..data
import
get_dataset
,
get_template_and_fix_tokenizer
from
..extras.misc
import
get_current_device
from
..hparams
import
get_infer_args
,
get_train_args
from
..model
import
load_model
,
load_tokenizer
...
...
@@ -37,9 +37,9 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
assert
set
(
state_dict_a
.
keys
())
==
set
(
state_dict_b
.
keys
())
for
name
in
state_dict_a
.
keys
():
if
any
(
key
in
name
for
key
in
diff_keys
):
assert
torch
.
allclose
(
state_dict_a
[
name
],
state_dict_b
[
name
],
rtol
=
1e-
4
,
atol
=
1e-
5
)
is
False
assert
torch
.
allclose
(
state_dict_a
[
name
],
state_dict_b
[
name
],
rtol
=
1e-
3
,
atol
=
1e-
4
)
is
False
else
:
assert
torch
.
allclose
(
state_dict_a
[
name
],
state_dict_b
[
name
],
rtol
=
1e-
4
,
atol
=
1e-
5
)
is
True
assert
torch
.
allclose
(
state_dict_a
[
name
],
state_dict_b
[
name
],
rtol
=
1e-
3
,
atol
=
1e-
4
)
is
True
def
check_lora_model
(
model
:
"LoraModel"
)
->
Tuple
[
Set
[
str
],
Set
[
str
]]:
...
...
@@ -105,7 +105,8 @@ def load_reference_model(
def
load_train_dataset
(
**
kwargs
)
->
"Dataset"
:
model_args
,
data_args
,
training_args
,
_
,
_
=
get_train_args
(
kwargs
)
tokenizer_module
=
load_tokenizer
(
model_args
)
dataset_module
=
get_dataset
(
model_args
,
data_args
,
training_args
,
stage
=
kwargs
[
"stage"
],
**
tokenizer_module
)
template
=
get_template_and_fix_tokenizer
(
tokenizer_module
[
"tokenizer"
],
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
kwargs
[
"stage"
],
**
tokenizer_module
)
return
dataset_module
[
"train_dataset"
]
...
...
src/llamafactory/train/trainer_utils.py
View file @
27a7ad86
...
...
@@ -26,6 +26,7 @@ from transformers.modeling_utils import is_fsdp_enabled
from
transformers.optimization
import
get_scheduler
from
transformers.pytorch_utils
import
ALL_LAYERNORM_LAYERS
from
transformers.trainer_pt_utils
import
get_parameter_names
from
typing_extensions
import
override
from
..extras.constants
import
IGNORE_INDEX
from
..extras.logging
import
get_logger
...
...
@@ -60,9 +61,11 @@ class DummyOptimizer(torch.optim.Optimizer):
self
.
optimizer_dict
=
optimizer_dict
super
().
__init__
([
dummy_tensor
],
{
"lr"
:
lr
})
@
override
def
zero_grad
(
self
,
set_to_none
:
bool
=
True
)
->
None
:
pass
@
override
def
step
(
self
,
closure
:
Optional
[
Callable
[[],
float
]]
=
None
)
->
Optional
[
float
]:
pass
...
...
src/llamafactory/train/tuner.py
View file @
27a7ad86
...
...
@@ -72,7 +72,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
processor
=
tokenizer_module
[
"processor"
]
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
.
template
)
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
)
# must after fixing tokenizer to resize vocab
if
getattr
(
model
,
"quantization_method"
,
None
)
is
not
None
and
model_args
.
adapter_name_or_path
is
not
None
:
...
...
@@ -132,12 +132,12 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
if
model_args
.
export_hub_model_id
is
not
None
:
tokenizer
.
push_to_hub
(
model_args
.
export_hub_model_id
,
token
=
model_args
.
hf_hub_token
)
if
model_args
.
visual_inputs
and
processor
is
not
None
:
if
processor
is
not
None
:
getattr
(
processor
,
"image_processor"
).
save_pretrained
(
model_args
.
export_dir
)
if
model_args
.
export_hub_model_id
is
not
None
:
getattr
(
processor
,
"image_processor"
).
push_to_hub
(
model_args
.
export_hub_model_id
,
token
=
model_args
.
hf_hub_token
)
except
Exception
:
logger
.
warning
(
"Cannot save tokenizer, please copy the files manually
."
)
except
Exception
as
e
:
logger
.
warning
(
"Cannot save tokenizer, please copy the files manually
: {}."
.
format
(
e
)
)
src/llamafactory/webui/chatter.py
View file @
27a7ad86
...
...
@@ -14,9 +14,7 @@
import
json
import
os
from
typing
import
TYPE_CHECKING
,
Dict
,
Generator
,
List
,
Optional
,
Sequence
,
Tuple
from
numpy.typing
import
NDArray
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Generator
,
List
,
Optional
,
Sequence
,
Tuple
from
..chat
import
ChatModel
from
..data
import
Role
...
...
@@ -90,7 +88,6 @@ class WebChatModel(ChatModel):
template
=
get
(
"top.template"
),
flash_attn
=
"fa2"
if
get
(
"top.booster"
)
==
"flashattn2"
else
"auto"
,
use_unsloth
=
(
get
(
"top.booster"
)
==
"unsloth"
),
visual_inputs
=
get
(
"top.visual_inputs"
),
rope_scaling
=
get
(
"top.rope_scaling"
)
if
get
(
"top.rope_scaling"
)
in
[
"linear"
,
"dynamic"
]
else
None
,
infer_backend
=
get
(
"infer.infer_backend"
),
infer_dtype
=
get
(
"infer.infer_dtype"
),
...
...
@@ -135,7 +132,8 @@ class WebChatModel(ChatModel):
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
str
,
tools
:
str
,
image
:
Optional
[
NDArray
],
image
:
Optional
[
Any
],
video
:
Optional
[
Any
],
max_new_tokens
:
int
,
top_p
:
float
,
temperature
:
float
,
...
...
@@ -143,7 +141,7 @@ class WebChatModel(ChatModel):
chatbot
[
-
1
][
1
]
=
""
response
=
""
for
new_text
in
self
.
stream_chat
(
messages
,
system
,
tools
,
image
,
max_new_tokens
=
max_new_tokens
,
top_p
=
top_p
,
temperature
=
temperature
messages
,
system
,
tools
,
image
,
video
,
max_new_tokens
=
max_new_tokens
,
top_p
=
top_p
,
temperature
=
temperature
):
response
+=
new_text
if
tools
:
...
...
src/llamafactory/webui/common.py
View file @
27a7ad86
...
...
@@ -115,39 +115,29 @@ def get_model_path(model_name: str) -> str:
return
model_path
def
get_prefix
(
model_name
:
str
)
->
str
:
r
"""
Gets the prefix of the model name to obtain the model family.
"""
return
model_name
.
split
(
"-"
)[
0
]
def
get_model_info
(
model_name
:
str
)
->
Tuple
[
str
,
str
,
bool
]:
def
get_model_info
(
model_name
:
str
)
->
Tuple
[
str
,
str
]:
r
"""
Gets the necessary information of this model.
Returns:
model_path (str)
template (str)
visual (bool)
"""
return
get_model_path
(
model_name
),
get_template
(
model_name
)
,
get_visual
(
model_name
)
return
get_model_path
(
model_name
),
get_template
(
model_name
)
def
get_template
(
model_name
:
str
)
->
str
:
r
"""
Gets the template name if the model is a chat model.
"""
if
model_name
and
model_name
.
endswith
(
"Chat"
)
and
get_prefix
(
model_name
)
in
DEFAULT_TEMPLATE
:
return
DEFAULT_TEMPLATE
[
get_prefix
(
model_name
)]
return
"default"
return
DEFAULT_TEMPLATE
.
get
(
model_name
,
"default"
)
def
get_visual
(
model_name
:
str
)
->
bool
:
r
"""
Judges if the model is a vision language model.
"""
return
get_prefix
(
model_name
)
in
VISION_MODELS
return
model_name
in
VISION_MODELS
def
list_checkpoints
(
model_name
:
str
,
finetuning_type
:
str
)
->
"gr.Dropdown"
:
...
...
src/llamafactory/webui/components/chatbot.py
View file @
27a7ad86
...
...
@@ -43,8 +43,12 @@ def create_chat_box(
system
=
gr
.
Textbox
(
show_label
=
False
)
tools
=
gr
.
Textbox
(
show_label
=
False
,
lines
=
3
)
with
gr
.
Column
()
as
image_box
:
image
=
gr
.
Image
(
sources
=
[
"upload"
],
type
=
"numpy"
)
with
gr
.
Column
()
as
mm_box
:
with
gr
.
Tab
(
"Image"
):
image
=
gr
.
Image
(
sources
=
[
"upload"
],
type
=
"pil"
)
with
gr
.
Tab
(
"Video"
):
video
=
gr
.
Video
(
sources
=
[
"upload"
])
query
=
gr
.
Textbox
(
show_label
=
False
,
lines
=
8
)
submit_btn
=
gr
.
Button
(
variant
=
"primary"
)
...
...
@@ -63,7 +67,7 @@ def create_chat_box(
[
chatbot
,
messages
,
query
],
).
then
(
engine
.
chatter
.
stream
,
[
chatbot
,
messages
,
system
,
tools
,
image
,
max_new_tokens
,
top_p
,
temperature
],
[
chatbot
,
messages
,
system
,
tools
,
image
,
video
,
max_new_tokens
,
top_p
,
temperature
],
[
chatbot
,
messages
],
)
clear_btn
.
click
(
lambda
:
([],
[]),
outputs
=
[
chatbot
,
messages
])
...
...
@@ -76,8 +80,9 @@ def create_chat_box(
role
=
role
,
system
=
system
,
tools
=
tools
,
image_box
=
image
_box
,
mm_box
=
mm
_box
,
image
=
image
,
video
=
video
,
query
=
query
,
submit_btn
=
submit_btn
,
max_new_tokens
=
max_new_tokens
,
...
...
src/llamafactory/webui/components/export.py
View file @
27a7ad86
...
...
@@ -46,7 +46,6 @@ def save_model(
finetuning_type
:
str
,
checkpoint_path
:
Union
[
str
,
List
[
str
]],
template
:
str
,
visual_inputs
:
bool
,
export_size
:
int
,
export_quantization_bit
:
str
,
export_quantization_dataset
:
str
,
...
...
@@ -78,7 +77,6 @@ def save_model(
model_name_or_path
=
model_path
,
finetuning_type
=
finetuning_type
,
template
=
template
,
visual_inputs
=
visual_inputs
,
export_dir
=
export_dir
,
export_hub_model_id
=
export_hub_model_id
or
None
,
export_size
=
export_size
,
...
...
@@ -129,7 +127,6 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
engine
.
manager
.
get_elem_by_id
(
"top.finetuning_type"
),
engine
.
manager
.
get_elem_by_id
(
"top.checkpoint_path"
),
engine
.
manager
.
get_elem_by_id
(
"top.template"
),
engine
.
manager
.
get_elem_by_id
(
"top.visual_inputs"
),
export_size
,
export_quantization_bit
,
export_quantization_dataset
,
...
...
src/llamafactory/webui/components/infer.py
View file @
27a7ad86
...
...
@@ -15,6 +15,7 @@
from
typing
import
TYPE_CHECKING
,
Dict
from
...extras.packages
import
is_gradio_available
from
..common
import
get_visual
from
.chatbot
import
create_chat_box
...
...
@@ -64,10 +65,10 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
lambda
:
([],
[]),
outputs
=
[
chatbot
,
messages
]
).
then
(
lambda
:
gr
.
Column
(
visible
=
engine
.
chatter
.
loaded
),
outputs
=
[
chat_elems
[
"chat_box"
]])
engine
.
manager
.
get_elem_by_id
(
"top.
visual_inputs
"
).
change
(
lambda
enabled
:
gr
.
Column
(
visible
=
enabled
),
[
engine
.
manager
.
get_elem_by_id
(
"top.
visual_inputs
"
)],
[
chat_elems
[
"
image
_box"
]],
engine
.
manager
.
get_elem_by_id
(
"top.
model_name
"
).
change
(
lambda
model_name
:
gr
.
Column
(
visible
=
get_visual
(
model_name
)
),
[
engine
.
manager
.
get_elem_by_id
(
"top.
model_name
"
)],
[
chat_elems
[
"
mm
_box"
]],
)
return
elem_dict
src/llamafactory/webui/components/top.py
View file @
27a7ad86
...
...
@@ -43,14 +43,13 @@ def create_top() -> Dict[str, "Component"]:
with
gr
.
Accordion
(
open
=
False
)
as
advanced_tab
:
with
gr
.
Row
():
quantization_bit
=
gr
.
Dropdown
(
choices
=
[
"none"
,
"8"
,
"4"
],
value
=
"none"
,
allow_custom_value
=
True
,
scale
=
1
)
quantization_method
=
gr
.
Dropdown
(
choices
=
[
"bitsandbytes"
,
"hqq"
,
"eetq"
],
value
=
"bitsandbytes"
,
scale
=
1
)
template
=
gr
.
Dropdown
(
choices
=
list
(
TEMPLATES
.
keys
()),
value
=
"default"
,
scale
=
1
)
rope_scaling
=
gr
.
Radio
(
choices
=
[
"none"
,
"linear"
,
"dynamic"
],
value
=
"none"
,
scale
=
2
)
booster
=
gr
.
Radio
(
choices
=
[
"auto"
,
"flashattn2"
,
"unsloth"
],
value
=
"auto"
,
scale
=
2
)
visual_inputs
=
gr
.
Checkbox
(
scale
=
1
)
quantization_bit
=
gr
.
Dropdown
(
choices
=
[
"none"
,
"8"
,
"4"
],
value
=
"none"
,
allow_custom_value
=
True
,
scale
=
2
)
quantization_method
=
gr
.
Dropdown
(
choices
=
[
"bitsandbytes"
,
"hqq"
,
"eetq"
],
value
=
"bitsandbytes"
,
scale
=
2
)
template
=
gr
.
Dropdown
(
choices
=
list
(
TEMPLATES
.
keys
()),
value
=
"default"
,
scale
=
2
)
rope_scaling
=
gr
.
Radio
(
choices
=
[
"none"
,
"linear"
,
"dynamic"
],
value
=
"none"
,
scale
=
3
)
booster
=
gr
.
Radio
(
choices
=
[
"auto"
,
"flashattn2"
,
"unsloth"
,
"liger_kernel"
],
value
=
"auto"
,
scale
=
5
)
model_name
.
change
(
get_model_info
,
[
model_name
],
[
model_path
,
template
,
visual_inputs
],
queue
=
False
).
then
(
model_name
.
change
(
get_model_info
,
[
model_name
],
[
model_path
,
template
],
queue
=
False
).
then
(
list_checkpoints
,
[
model_name
,
finetuning_type
],
[
checkpoint_path
],
queue
=
False
)
model_name
.
input
(
save_config
,
inputs
=
[
lang
,
model_name
],
queue
=
False
)
...
...
@@ -73,5 +72,4 @@ def create_top() -> Dict[str, "Component"]:
template
=
template
,
rope_scaling
=
rope_scaling
,
booster
=
booster
,
visual_inputs
=
visual_inputs
,
)
src/llamafactory/webui/engine.py
View file @
27a7ad86
...
...
@@ -59,7 +59,7 @@ class Engine:
init_dict
[
"train.output_dir"
]
=
{
"value"
:
"train_{}"
.
format
(
current_time
)}
init_dict
[
"train.config_path"
]
=
{
"value"
:
"{}.yaml"
.
format
(
current_time
)}
init_dict
[
"eval.output_dir"
]
=
{
"value"
:
"eval_{}"
.
format
(
current_time
)}
init_dict
[
"infer.
image
_box"
]
=
{
"visible"
:
False
}
init_dict
[
"infer.
mm
_box"
]
=
{
"visible"
:
False
}
if
user_config
.
get
(
"last_model"
,
None
):
init_dict
[
"top.model_name"
]
=
{
"value"
:
user_config
[
"last_model"
]}
...
...
src/llamafactory/webui/locales.py
View file @
27a7ad86
...
...
@@ -148,7 +148,7 @@ LOCALES = {
},
"zh"
:
{
"label"
:
"提示模板"
,
"info"
:
"构建提示词时使用的模板"
,
"info"
:
"构建提示词时使用的模板
。
"
,
},
"ko"
:
{
"label"
:
"프롬프트 템플릿"
,
...
...
@@ -183,20 +183,6 @@ LOCALES = {
"label"
:
"부스터"
,
},
},
"visual_inputs"
:
{
"en"
:
{
"label"
:
"Visual inputs"
,
},
"ru"
:
{
"label"
:
"визуальные входы"
,
},
"zh"
:
{
"label"
:
"图像输入"
,
},
"ko"
:
{
"label"
:
"시각적 입력"
,
},
},
"training_stage"
:
{
"en"
:
{
"label"
:
"Stage"
,
...
...
@@ -1705,6 +1691,20 @@ LOCALES = {
"label"
:
"이미지 (선택 사항)"
,
},
},
"video"
:
{
"en"
:
{
"label"
:
"Video (optional)"
,
},
"ru"
:
{
"label"
:
"Видео (по желанию)"
,
},
"zh"
:
{
"label"
:
"视频(非必填)"
,
},
"ko"
:
{
"label"
:
"비디오 (선택 사항)"
,
},
},
"query"
:
{
"en"
:
{
"placeholder"
:
"Input..."
,
...
...
src/llamafactory/webui/manager.py
View file @
27a7ad86
...
...
@@ -75,5 +75,4 @@ class Manager:
self
.
_id_to_elem
[
"top.template"
],
self
.
_id_to_elem
[
"top.rope_scaling"
],
self
.
_id_to_elem
[
"top.booster"
],
self
.
_id_to_elem
[
"top.visual_inputs"
],
}
src/llamafactory/webui/runner.py
View file @
27a7ad86
...
...
@@ -115,7 +115,7 @@ class Runner:
rope_scaling
=
get
(
"top.rope_scaling"
)
if
get
(
"top.rope_scaling"
)
in
[
"linear"
,
"dynamic"
]
else
None
,
flash_attn
=
"fa2"
if
get
(
"top.booster"
)
==
"flashattn2"
else
"auto"
,
use_unsloth
=
(
get
(
"top.booster"
)
==
"unsloth"
),
visual_inputs
=
get
(
"top.visual_inputs
"
),
enable_liger_kernel
=
(
get
(
"top.booster"
)
==
"liger_kernel
"
),
dataset_dir
=
get
(
"train.dataset_dir"
),
dataset
=
","
.
join
(
get
(
"train.dataset"
)),
cutoff_len
=
get
(
"train.cutoff_len"
),
...
...
@@ -251,7 +251,6 @@ class Runner:
rope_scaling
=
get
(
"top.rope_scaling"
)
if
get
(
"top.rope_scaling"
)
in
[
"linear"
,
"dynamic"
]
else
None
,
flash_attn
=
"fa2"
if
get
(
"top.booster"
)
==
"flashattn2"
else
"auto"
,
use_unsloth
=
(
get
(
"top.booster"
)
==
"unsloth"
),
visual_inputs
=
get
(
"top.visual_inputs"
),
dataset_dir
=
get
(
"eval.dataset_dir"
),
eval_dataset
=
","
.
join
(
get
(
"eval.dataset"
)),
cutoff_len
=
get
(
"eval.cutoff_len"
),
...
...
tests/data/test_mm_plugin.py
0 → 100644
View file @
27a7ad86
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Sequence
,
Tuple
import
pytest
import
torch
from
PIL
import
Image
from
llamafactory.data.mm_plugin
import
get_mm_plugin
from
llamafactory.hparams
import
ModelArguments
from
llamafactory.model
import
load_tokenizer
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
transformers.image_processing_utils
import
BaseImageProcessor
from
llamafactory.data.mm_plugin
import
BasePlugin
HF_TOKEN
=
os
.
environ
.
get
(
"HF_TOKEN"
,
None
)
TINY_LLAMA
=
os
.
environ
.
get
(
"TINY_LLAMA"
,
"llamafactory/tiny-random-Llama-3"
)
MM_MESSAGES
=
[
{
"role"
:
"user"
,
"content"
:
"<image>What is in this image?"
},
{
"role"
:
"assistant"
,
"content"
:
"A cat."
},
]
TEXT_MESSAGES
=
[
{
"role"
:
"user"
,
"content"
:
"How are you"
},
{
"role"
:
"assistant"
,
"content"
:
"I am fine!"
},
]
IMAGES
=
[
Image
.
new
(
"RGB"
,
(
32
,
32
),
(
255
,
255
,
255
))]
NO_IMAGES
=
[]
NO_VIDEOS
=
[]
IMGLENS
=
[
1
]
NO_IMGLENS
=
[
0
]
NO_VIDLENS
=
[
0
]
INPUT_IDS
=
[
0
,
1
,
2
,
3
,
4
]
LABELS
=
[
0
,
1
,
2
,
3
,
4
]
SEQLENS
=
[
1024
]
def
_get_mm_inputs
(
processor
:
"ProcessorMixin"
)
->
Dict
[
str
,
"torch.Tensor"
]:
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
return
image_processor
(
images
=
IMAGES
,
return_tensors
=
"pt"
)
def
_is_close
(
batch_a
:
Dict
[
str
,
Any
],
batch_b
:
Dict
[
str
,
Any
])
->
None
:
assert
batch_a
.
keys
()
==
batch_b
.
keys
()
for
key
in
batch_a
.
keys
():
if
isinstance
(
batch_a
[
key
],
torch
.
Tensor
):
assert
torch
.
allclose
(
batch_a
[
key
],
batch_b
[
key
],
rtol
=
1e-4
,
atol
=
1e-5
)
else
:
assert
batch_a
[
key
]
==
batch_b
[
key
]
def
_load_tokenizer_module
(
model_name_or_path
:
str
)
->
Tuple
[
"PreTrainedTokenizer"
,
"ProcessorMixin"
]:
model_args
=
ModelArguments
(
model_name_or_path
=
model_name_or_path
)
tokenizer_module
=
load_tokenizer
(
model_args
)
return
tokenizer_module
[
"tokenizer"
],
tokenizer_module
[
"processor"
]
def
_check_plugin
(
plugin
:
"BasePlugin"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
"ProcessorMixin"
,
expected_mm_messages
:
Sequence
[
Dict
[
str
,
str
]]
=
MM_MESSAGES
,
expected_input_ids
:
List
[
int
]
=
INPUT_IDS
,
expected_labels
:
List
[
int
]
=
LABELS
,
expected_mm_inputs
:
Dict
[
str
,
Any
]
=
{},
expected_no_mm_inputs
:
Dict
[
str
,
Any
]
=
{},
)
->
None
:
# test mm_messages
assert
plugin
.
process_messages
(
MM_MESSAGES
,
IMAGES
,
NO_VIDEOS
,
processor
)
==
expected_mm_messages
assert
plugin
.
process_token_ids
(
INPUT_IDS
,
LABELS
,
IMAGES
,
NO_VIDEOS
,
tokenizer
,
processor
)
==
(
expected_input_ids
,
expected_labels
,
)
_is_close
(
plugin
.
get_mm_inputs
(
IMAGES
,
NO_VIDEOS
,
IMGLENS
,
NO_VIDLENS
,
SEQLENS
,
processor
),
expected_mm_inputs
,
)
# test text_messages
assert
plugin
.
process_messages
(
TEXT_MESSAGES
,
NO_IMAGES
,
NO_VIDEOS
,
processor
)
==
TEXT_MESSAGES
assert
plugin
.
process_token_ids
(
INPUT_IDS
,
LABELS
,
NO_IMAGES
,
NO_VIDEOS
,
tokenizer
,
processor
)
==
(
INPUT_IDS
,
LABELS
,
)
_is_close
(
plugin
.
get_mm_inputs
(
NO_IMAGES
,
NO_VIDEOS
,
NO_IMGLENS
,
NO_VIDLENS
,
SEQLENS
,
processor
),
expected_no_mm_inputs
,
)
def
test_base_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
TINY_LLAMA
)
base_plugin
=
get_mm_plugin
(
name
=
"base"
,
image_token
=
"<image>"
)
check_inputs
=
{
"plugin"
:
base_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
_check_plugin
(
**
check_inputs
)
def
test_llava_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-1.5-7b-hf"
)
llava_plugin
=
get_mm_plugin
(
name
=
"llava"
,
image_token
=
"<image>"
)
image_seqlen
=
576
check_inputs
=
{
"plugin"
:
llava_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
_check_plugin
(
**
check_inputs
)
def
test_llava_next_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-v1.6-vicuna-7b-hf"
)
llava_next_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
)
check_inputs
=
{
"plugin"
:
llava_next_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
image_seqlen
=
1176
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
_check_plugin
(
**
check_inputs
)
def
test_llava_next_video_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/LLaVA-NeXT-Video-7B-hf"
)
llava_next_video_plugin
=
get_mm_plugin
(
name
=
"llava_next_video"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
)
check_inputs
=
{
"plugin"
:
llava_next_video_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
image_seqlen
=
1176
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
def
test_paligemma_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"google/paligemma-3b-pt-224"
)
paligemma_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
)
image_seqlen
=
256
check_inputs
=
{
"plugin"
:
paligemma_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
""
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_input_ids"
]
=
[
tokenizer
.
convert_tokens_to_ids
(
"<image>"
)]
*
image_seqlen
+
INPUT_IDS
check_inputs
[
"expected_labels"
]
=
[
-
100
]
*
image_seqlen
+
LABELS
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
check_inputs
[
"expected_mm_inputs"
][
"token_type_ids"
]
=
[[
0
]
*
image_seqlen
+
[
1
]
*
(
1024
-
image_seqlen
)]
check_inputs
[
"expected_no_mm_inputs"
]
=
{
"token_type_ids"
:
[[
1
]
*
1024
]}
_check_plugin
(
**
check_inputs
)
def
test_qwen2_vl_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"Qwen/Qwen2-VL-7B-Instruct"
)
qwen2_vl_plugin
=
get_mm_plugin
(
name
=
"qwen2_vl"
,
image_token
=
"<|image_pad|>"
)
image_seqlen
=
4
check_inputs
=
{
"plugin"
:
qwen2_vl_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<|vision_start|>{}<|vision_end|>"
.
format
(
"<|image_pad|>"
*
image_seqlen
))
for
key
,
value
in
message
.
items
()
}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
_check_plugin
(
**
check_inputs
)
def
test_video_llava_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"LanguageBind/Video-LLaVA-7B-hf"
)
video_llava_plugin
=
get_mm_plugin
(
name
=
"video_llava"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
)
check_inputs
=
{
"plugin"
:
video_llava_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
image_seqlen
=
256
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
_check_plugin
(
**
check_inputs
)
tests/data/test_template.py
View file @
27a7ad86
...
...
@@ -19,6 +19,8 @@ import pytest
from
transformers
import
AutoTokenizer
from
llamafactory.data
import
get_template_and_fix_tokenizer
from
llamafactory.data.template
import
_get_jinja_template
from
llamafactory.hparams
import
DataArguments
if
TYPE_CHECKING
:
...
...
@@ -51,7 +53,7 @@ def _check_single_template(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
use_fast
=
use_fast
,
token
=
HF_TOKEN
)
content_str
=
tokenizer
.
apply_chat_template
(
MESSAGES
,
tokenize
=
False
)
content_ids
=
tokenizer
.
apply_chat_template
(
MESSAGES
,
tokenize
=
True
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
nam
e
=
template_name
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
DataArguments
(
templat
e
=
template_name
)
)
prompt_ids
,
answer_ids
=
template
.
encode_oneturn
(
tokenizer
,
MESSAGES
)
assert
content_str
==
prompt_str
+
answer_str
+
extra_str
assert
content_ids
==
prompt_ids
+
answer_ids
+
tokenizer
.
encode
(
extra_str
,
add_special_tokens
=
False
)
...
...
@@ -78,7 +80,7 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_encode_oneturn
(
use_fast
:
bool
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA
,
use_fast
=
use_fast
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
nam
e
=
"llama3"
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
DataArguments
(
templat
e
=
"llama3"
)
)
prompt_ids
,
answer_ids
=
template
.
encode_oneturn
(
tokenizer
,
MESSAGES
)
prompt_str
=
(
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>
\n\n
How are you<|eot_id|>"
...
...
@@ -93,7 +95,7 @@ def test_encode_oneturn(use_fast: bool):
@
pytest
.
mark
.
parametrize
(
"use_fast"
,
[
True
,
False
])
def
test_encode_multiturn
(
use_fast
:
bool
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA
,
use_fast
=
use_fast
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
nam
e
=
"llama3"
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
DataArguments
(
templat
e
=
"llama3"
)
)
encoded_pairs
=
template
.
encode_multiturn
(
tokenizer
,
MESSAGES
)
prompt_str_1
=
(
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>
\n\n
How are you<|eot_id|>"
...
...
@@ -116,7 +118,8 @@ def test_encode_multiturn(use_fast: bool):
def
test_jinja_template
(
use_fast
:
bool
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA
,
use_fast
=
use_fast
)
ref_tokenizer
=
AutoTokenizer
.
from_pretrained
(
TINY_LLAMA
,
use_fast
=
use_fast
)
get_template_and_fix_tokenizer
(
tokenizer
,
name
=
"llama3"
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
DataArguments
(
template
=
"llama3"
))
tokenizer
.
chat_template
=
_get_jinja_template
(
template
,
tokenizer
)
# llama3 template no replace
assert
tokenizer
.
chat_template
!=
ref_tokenizer
.
chat_template
assert
tokenizer
.
apply_chat_template
(
MESSAGES
)
==
ref_tokenizer
.
apply_chat_template
(
MESSAGES
)
...
...
@@ -157,7 +160,7 @@ def test_qwen_template():
_check_template
(
"Qwen/Qwen2-7B-Instruct"
,
"qwen"
,
prompt_str
,
answer_str
,
extra_str
=
"
\n
"
)
@
pytest
.
mark
.
skip
(
reason
=
"The fast tokenizer of Yi model is corrupted."
)
@
pytest
.
mark
.
xfail
(
reason
=
"The fast tokenizer of Yi model is corrupted."
)
def
test_yi_template
():
prompt_str
=
(
"<|im_start|>user
\n
How are you<|im_end|>
\n
"
...
...
tests/e2e/test_chat.py
0 → 100644
View file @
27a7ad86
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
llamafactory.chat
import
ChatModel
TINY_LLAMA
=
os
.
environ
.
get
(
"TINY_LLAMA"
,
"llamafactory/tiny-random-Llama-3"
)
INFER_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA
,
"finetuning_type"
:
"lora"
,
"template"
:
"llama3"
,
"infer_dtype"
:
"float16"
,
"do_sample"
:
False
,
"max_new_tokens"
:
1
,
}
MESSAGES
=
[
{
"role"
:
"user"
,
"content"
:
"Hi"
},
]
EXPECTED_RESPONSE
=
"_rho"
def
test_chat
():
chat_model
=
ChatModel
(
INFER_ARGS
)
assert
chat_model
.
chat
(
MESSAGES
)[
0
].
response_text
==
EXPECTED_RESPONSE
def
test_stream_chat
():
chat_model
=
ChatModel
(
INFER_ARGS
)
response
=
""
for
token
in
chat_model
.
stream_chat
(
MESSAGES
):
response
+=
token
assert
response
==
EXPECTED_RESPONSE
tests/e2e/test_train.py
0 → 100644
View file @
27a7ad86
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
pytest
from
llamafactory.train.tuner
import
export_model
,
run_exp
DEMO_DATA
=
os
.
environ
.
get
(
"DEMO_DATA"
,
"llamafactory/demo_data"
)
TINY_LLAMA
=
os
.
environ
.
get
(
"TINY_LLAMA"
,
"llamafactory/tiny-random-Llama-3"
)
TINY_LLAMA_ADAPTER
=
os
.
environ
.
get
(
"TINY_LLAMA_ADAPTER"
,
"llamafactory/tiny-random-Llama-3-lora"
)
TRAIN_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA
,
"do_train"
:
True
,
"finetuning_type"
:
"lora"
,
"dataset_dir"
:
"REMOTE:"
+
DEMO_DATA
,
"template"
:
"llama3"
,
"cutoff_len"
:
1
,
"overwrite_cache"
:
False
,
"overwrite_output_dir"
:
True
,
"per_device_train_batch_size"
:
1
,
"max_steps"
:
1
,
}
INFER_ARGS
=
{
"model_name_or_path"
:
TINY_LLAMA
,
"adapter_name_or_path"
:
TINY_LLAMA_ADAPTER
,
"finetuning_type"
:
"lora"
,
"template"
:
"llama3"
,
"infer_dtype"
:
"float16"
,
"export_dir"
:
"llama3_export"
,
}
OS_NAME
=
os
.
environ
.
get
(
"OS_NAME"
,
""
)
@
pytest
.
mark
.
parametrize
(
"stage,dataset"
,
[
(
"pt"
,
"c4_demo"
),
(
"sft"
,
"alpaca_en_demo"
),
(
"dpo"
,
"dpo_en_demo"
),
(
"kto"
,
"kto_en_demo"
),
pytest
.
param
(
"rm"
,
"dpo_en_demo"
,
marks
=
pytest
.
mark
.
xfail
(
OS_NAME
.
startswith
(
"windows"
),
reason
=
"OS error."
)),
],
)
def
test_run_exp
(
stage
:
str
,
dataset
:
str
):
output_dir
=
"train_{}"
.
format
(
stage
)
run_exp
({
"stage"
:
stage
,
"dataset"
:
dataset
,
"output_dir"
:
output_dir
,
**
TRAIN_ARGS
})
assert
os
.
path
.
exists
(
output_dir
)
def
test_export
():
export_model
(
INFER_ARGS
)
assert
os
.
path
.
exists
(
"llama3_export"
)
tests/model/model_utils/test_checkpointing.py
View file @
27a7ad86
...
...
@@ -51,6 +51,12 @@ def test_checkpointing_disable():
assert
getattr
(
module
,
"gradient_checkpointing"
)
is
False
def
test_unsloth_gradient_checkpointing
():
model
=
load_train_model
(
use_unsloth_gc
=
True
,
**
TRAIN_ARGS
)
for
module
in
filter
(
lambda
m
:
hasattr
(
m
,
"gradient_checkpointing"
),
model
.
modules
()):
assert
module
.
_gradient_checkpointing_func
.
__self__
.
__name__
==
"UnslothGradientCheckpointing"
# classmethod
def
test_upcast_layernorm
():
model
=
load_train_model
(
upcast_layernorm
=
True
,
**
TRAIN_ARGS
)
for
name
,
param
in
model
.
named_parameters
():
...
...
tests/model/test_pissa.py
View file @
27a7ad86
...
...
@@ -14,6 +14,8 @@
import
os
import
pytest
from
llamafactory.train.test_utils
import
compare_model
,
load_infer_model
,
load_reference_model
,
load_train_model
...
...
@@ -47,13 +49,17 @@ INFER_ARGS = {
"infer_dtype"
:
"float16"
,
}
OS_NAME
=
os
.
environ
.
get
(
"OS_NAME"
,
""
)
@
pytest
.
mark
.
xfail
(
OS_NAME
.
startswith
(
"windows"
),
reason
=
"Known connection error on Windows."
)
def
test_pissa_train
():
model
=
load_train_model
(
**
TRAIN_ARGS
)
ref_model
=
load_reference_model
(
TINY_LLAMA_PISSA
,
TINY_LLAMA_PISSA
,
use_pissa
=
True
,
is_trainable
=
True
)
compare_model
(
model
,
ref_model
)
@
pytest
.
mark
.
xfail
(
OS_NAME
.
startswith
(
"windows"
),
reason
=
"Known connection error on Windows."
)
def
test_pissa_inference
():
model
=
load_infer_model
(
**
INFER_ARGS
)
ref_model
=
load_reference_model
(
TINY_LLAMA_PISSA
,
TINY_LLAMA_PISSA
,
use_pissa
=
True
,
is_trainable
=
False
)
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment