Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
8293100a
Commit
8293100a
authored
Jan 16, 2025
by
luopl
Browse files
update to 0.9.2.dev0
parent
2778a3d0
Changes
124
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
949 additions
and
219 deletions
+949
-219
src/llamafactory/train/dpo/workflow.py
src/llamafactory/train/dpo/workflow.py
+5
-11
src/llamafactory/train/kto/trainer.py
src/llamafactory/train/kto/trainer.py
+18
-10
src/llamafactory/train/kto/workflow.py
src/llamafactory/train/kto/workflow.py
+1
-0
src/llamafactory/train/ppo/workflow.py
src/llamafactory/train/ppo/workflow.py
+1
-1
src/llamafactory/train/pt/trainer.py
src/llamafactory/train/pt/trainer.py
+24
-15
src/llamafactory/train/rm/trainer.py
src/llamafactory/train/rm/trainer.py
+15
-7
src/llamafactory/train/rm/workflow.py
src/llamafactory/train/rm/workflow.py
+3
-1
src/llamafactory/train/sft/trainer.py
src/llamafactory/train/sft/trainer.py
+47
-48
src/llamafactory/train/sft/workflow.py
src/llamafactory/train/sft/workflow.py
+12
-15
src/llamafactory/train/trainer_utils.py
src/llamafactory/train/trainer_utils.py
+186
-18
src/llamafactory/train/tuner.py
src/llamafactory/train/tuner.py
+36
-4
src/llamafactory/webui/chatter.py
src/llamafactory/webui/chatter.py
+2
-1
src/llamafactory/webui/components/export.py
src/llamafactory/webui/components/export.py
+1
-0
src/llamafactory/webui/components/train.py
src/llamafactory/webui/components/train.py
+46
-2
src/llamafactory/webui/interface.py
src/llamafactory/webui/interface.py
+3
-1
src/llamafactory/webui/locales.py
src/llamafactory/webui/locales.py
+228
-6
src/llamafactory/webui/runner.py
src/llamafactory/webui/runner.py
+24
-3
tests/data/test_collator.py
tests/data/test_collator.py
+97
-1
tests/data/test_formatter.py
tests/data/test_formatter.py
+161
-38
tests/data/test_mm_plugin.py
tests/data/test_mm_plugin.py
+39
-37
No files found.
src/llamafactory/train/dpo/workflow.py
View file @
8293100a
...
...
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, List, Optional
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.misc
import
cal
_effective_token
s
from
...extras.misc
import
cal
culate_tp
s
from
...extras.ploting
import
plot_loss
from
...hparams
import
ModelArguments
from
...model
import
load_model
,
load_tokenizer
...
...
@@ -48,6 +48,7 @@ def run_dpo(
data_collator
=
PairwiseDataCollatorWithPadding
(
template
=
template
,
model
=
model
,
pad_to_multiple_of
=
8
,
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
**
tokenizer_module
,
...
...
@@ -65,12 +66,6 @@ def run_dpo(
# Update arguments
training_args
.
remove_unused_columns
=
False
# important for multimodal and pairwise dataset
effective_token_num
=
0.0
if
finetuning_args
.
include_effective_tokens_per_second
:
for
data
in
dataset_module
[
"train_dataset"
]:
effective_token_num
+=
len
(
data
[
"chosen_input_ids"
])
effective_token_num
+=
len
(
data
[
"rejected_input_ids"
])
# Initialize our Trainer
trainer
=
CustomDPOTrainer
(
model
=
model
,
...
...
@@ -86,13 +81,12 @@ def run_dpo(
# Training
if
training_args
.
do_train
:
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
trainer
.
save_model
()
if
finetuning_args
.
include_effective_tokens_per_second
:
train_result
.
metrics
[
"effective_tokens_per_sec"
]
=
cal
_effective_token
s
(
effective_token_num
,
train_result
.
metrics
[
"epoch
"
],
train_result
.
metrics
[
"train_runtime"
]
train_result
.
metrics
[
"effective_tokens_per_sec"
]
=
cal
culate_tp
s
(
dataset_module
[
"train_dataset
"
],
train_result
.
metrics
,
stage
=
"rm"
)
trainer
.
save_model
()
trainer
.
log_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_state
()
...
...
src/llamafactory/train/kto/trainer.py
View file @
8293100a
...
...
@@ -19,7 +19,7 @@ import warnings
from
collections
import
defaultdict
from
contextlib
import
nullcontext
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Dict
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Union
import
torch
from
transformers
import
Trainer
...
...
@@ -28,9 +28,9 @@ from trl.trainer import disable_dropout_in_model
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
...extras.packages
import
is_transformers_version_equal_to_4_46
from
...extras.packages
import
is_transformers_version_equal_to_4_46
,
is_transformers_version_greater_than
from
..callbacks
import
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
,
get_batch_logps
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
,
get_batch_logps
,
nested_detach
if
TYPE_CHECKING
:
...
...
@@ -50,6 +50,9 @@ class CustomKTOTrainer(KTOTrainer):
disable_dropout
:
bool
=
True
,
**
kwargs
,
):
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
if
disable_dropout
:
disable_dropout_in_model
(
model
)
if
ref_model
is
not
None
:
...
...
@@ -77,6 +80,7 @@ class CustomKTOTrainer(KTOTrainer):
self
.
ftx_gamma
=
finetuning_args
.
pref_ftx
Trainer
.
__init__
(
self
,
model
=
model
,
**
kwargs
)
self
.
model_accepts_loss_kwargs
=
False
# overwrite trainer's default behavior
if
not
hasattr
(
self
,
"accelerator"
):
raise
AttributeError
(
"Please update `transformers`."
)
...
...
@@ -119,6 +123,9 @@ class CustomKTOTrainer(KTOTrainer):
r
"""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
"""
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
Trainer
.
_get_train_sampler
(
self
)
@
override
...
...
@@ -135,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer):
r
"""
Runs forward pass and computes the log probabilities.
"""
batch
=
{
k
:
v
.
detach
().
clone
()
for
k
,
v
in
batch
.
items
()}
# avoid error
batch
=
nested_detach
(
batch
,
clone
=
True
)
# avoid error
model_inputs
=
{
"input_ids"
:
batch
[
f
"
{
prefix
}
input_ids"
],
"attention_mask"
:
batch
[
f
"
{
prefix
}
attention_mask"
],
...
...
@@ -245,17 +252,18 @@ class CustomKTOTrainer(KTOTrainer):
return
losses
,
metrics
@
override
def
compute_loss
(
self
,
model
,
inputs
,
return_outputs
=
False
,
**
kwargs
):
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
r
"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
pop
(
"num_items_in_batch"
,
False
):
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
get
(
"num_items_in_batch"
):
if
return_outputs
:
return
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
loss
=
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
return
loss
/
self
.
args
.
gradient_accumulation_steps
loss
=
loss
/
self
.
args
.
gradient_accumulation_steps
return
loss
...
...
src/llamafactory/train/kto/workflow.py
View file @
8293100a
...
...
@@ -47,6 +47,7 @@ def run_kto(
data_collator
=
KTODataCollatorWithPadding
(
template
=
template
,
model
=
model
,
pad_to_multiple_of
=
8
,
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
**
tokenizer_module
,
...
...
src/llamafactory/train/ppo/workflow.py
View file @
8293100a
...
...
@@ -46,7 +46,7 @@ def run_ppo(
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
,
add_valuehead
=
True
)
tokenizer
.
padding_side
=
"left"
# use left-padding in generation while using right-padding in training
data_collator
=
MultiModalDataCollatorForSeq2Seq
(
template
=
template
,
**
tokenizer_module
)
data_collator
=
MultiModalDataCollatorForSeq2Seq
(
template
=
template
,
model
=
model
,
**
tokenizer_module
)
# Create reference model and reward model
ref_model
=
create_ref_model
(
model_args
,
finetuning_args
,
add_valuehead
=
True
)
...
...
src/llamafactory/train/pt/trainer.py
View file @
8293100a
...
...
@@ -13,19 +13,19 @@
# limitations under the License.
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
transformers
import
Trainer
from
typing_extensions
import
override
from
...extras.packages
import
is_transformers_version_
equal_to_4_46
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
from
...extras.packages
import
is_transformers_version_
greater_than
from
..callbacks
import
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
if
TYPE_CHECKING
:
import
torch
from
transformers
import
ProcessorMixin
from
transformers
import
PreTrainedModel
,
ProcessorMixin
from
...hparams
import
FinetuningArguments
...
...
@@ -38,15 +38,15 @@ class CustomTrainer(Trainer):
def
__init__
(
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
)
->
None
:
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
super
().
__init__
(
**
kwargs
)
self
.
finetuning_args
=
finetuning_args
if
processor
is
not
None
:
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
pissa_convert
:
self
.
add_callback
(
PissaConvertCallback
)
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
...
...
@@ -67,17 +67,26 @@ class CustomTrainer(Trainer):
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
compute_loss
(
self
,
model
,
inputs
,
return_outputs
=
False
,
**
kwargs
):
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
()
@
override
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
r
"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
,
**
kwargs
)
if
is_transformers_version_equal_to_4_46
()
and
not
getattr
(
self
,
"model_accepts_loss_kwargs"
,
False
):
# other model should not scale the loss
if
kwargs
.
get
(
"num_items_in_batch"
)
and
not
getattr
(
self
,
"model_accepts_loss_kwargs"
,
False
):
if
return_outputs
:
return
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
loss
=
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
return
loss
/
self
.
args
.
gradient_accumulation_steps
loss
=
loss
/
self
.
args
.
gradient_accumulation_steps
return
loss
src/llamafactory/train/rm/trainer.py
View file @
8293100a
...
...
@@ -25,8 +25,8 @@ from transformers import Trainer
from
typing_extensions
import
override
from
...extras
import
logging
from
...extras.packages
import
is_transformers_version_equal_to_4_46
from
..callbacks
import
FixValueHeadModelCallback
,
PissaConvertCallback
,
SaveProcessorCallback
from
...extras.packages
import
is_transformers_version_equal_to_4_46
,
is_transformers_version_greater_than
from
..callbacks
import
FixValueHeadModelCallback
,
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
...
...
@@ -48,7 +48,11 @@ class PairwiseTrainer(Trainer):
def
__init__
(
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
)
->
None
:
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
super
().
__init__
(
**
kwargs
)
self
.
model_accepts_loss_kwargs
=
False
# overwrite trainer's default behavior
self
.
finetuning_args
=
finetuning_args
self
.
can_return_loss
=
True
# override property to return eval_loss
self
.
add_callback
(
FixValueHeadModelCallback
)
...
...
@@ -56,9 +60,6 @@ class PairwiseTrainer(Trainer):
if
processor
is
not
None
:
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
pissa_convert
:
self
.
add_callback
(
PissaConvertCallback
)
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
...
...
@@ -78,6 +79,13 @@ class PairwiseTrainer(Trainer):
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
()
@
override
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
...
...
@@ -100,8 +108,8 @@ class PairwiseTrainer(Trainer):
loss
=
-
torch
.
nn
.
functional
.
logsigmoid
(
chosen_scores
.
float
()
-
rejected_scores
.
float
()).
mean
()
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
pop
(
"num_items_in_batch"
,
False
):
loss
/=
self
.
args
.
gradient_accumulation_steps
# fixes the loss value for transformers 4.46.0
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
get
(
"num_items_in_batch"
):
loss
/=
self
.
args
.
gradient_accumulation_steps
# fixes the loss value for transformers 4.46.0
-4.46.1
if
return_outputs
:
return
loss
,
(
loss
,
chosen_scores
,
rejected_scores
)
...
...
src/llamafactory/train/rm/workflow.py
View file @
8293100a
...
...
@@ -44,7 +44,9 @@ def run_rm(
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"rm"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
,
add_valuehead
=
True
)
data_collator
=
PairwiseDataCollatorWithPadding
(
template
=
template
,
pad_to_multiple_of
=
8
,
**
tokenizer_module
)
data_collator
=
PairwiseDataCollatorWithPadding
(
template
=
template
,
model
=
model
,
pad_to_multiple_of
=
8
,
**
tokenizer_module
)
# Update arguments
training_args
.
remove_unused_columns
=
False
# important for multimodal and pairwise dataset
...
...
src/llamafactory/train/sft/trainer.py
View file @
8293100a
...
...
@@ -27,14 +27,14 @@ from typing_extensions import override
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
...extras.packages
import
is_transformers_version_
equal_to_4_46
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
from
...extras.packages
import
is_transformers_version_
greater_than
from
..callbacks
import
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
if
TYPE_CHECKING
:
from
torch.utils.data
import
Dataset
from
transformers
import
ProcessorMixin
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
,
ProcessorMixin
from
transformers.trainer
import
PredictionOutput
from
...hparams
import
FinetuningArguments
...
...
@@ -51,15 +51,17 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def
__init__
(
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
)
->
None
:
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
else
:
self
.
processing_class
:
"PreTrainedTokenizer"
=
kwargs
.
get
(
"tokenizer"
)
super
().
__init__
(
**
kwargs
)
self
.
finetuning_args
=
finetuning_args
if
processor
is
not
None
:
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
pissa_convert
:
self
.
add_callback
(
PissaConvertCallback
)
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
...
...
@@ -80,18 +82,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
compute_loss
(
self
,
model
,
inputs
,
return_outputs
=
False
,
**
kwargs
):
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
()
@
override
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
r
"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
,
**
kwargs
)
if
is_transformers_version_equal_to_4_46
()
and
not
getattr
(
self
,
"model_accepts_loss_kwargs"
,
False
):
# other model should not scale the loss
if
kwargs
.
get
(
"num_items_in_batch"
)
and
not
getattr
(
self
,
"model_accepts_loss_kwargs"
,
False
):
if
return_outputs
:
return
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
loss
=
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
return
loss
/
self
.
args
.
gradient_accumulation_steps
loss
=
loss
/
self
.
args
.
gradient_accumulation_steps
return
loss
...
...
@@ -102,41 +113,30 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
inputs
:
Dict
[
str
,
Union
[
"torch.Tensor"
,
Any
]],
prediction_loss_only
:
bool
,
ignore_keys
:
Optional
[
List
[
str
]]
=
None
,
**
gen_kwargs
,
)
->
Tuple
[
Optional
[
float
],
Optional
[
"torch.Tensor"
],
Optional
[
"torch.Tensor"
]]:
r
"""
Removes the prompt part in the generated tokens.
Subclass and override to inject custom behavior.
"""
labels
=
inputs
[
"labels"
]
if
"labels"
in
inputs
else
None
if
self
.
args
.
predict_with_generate
:
assert
self
.
tokenizer
.
padding_side
==
"left"
,
"This method only accepts left-padded tensor."
labels
=
labels
.
detach
().
clone
()
if
labels
is
not
None
else
None
# backup labels
prompt_len
,
label_len
=
inputs
[
"input_ids"
].
size
(
-
1
),
inputs
[
"labels"
].
size
(
-
1
)
if
prompt_len
>
label_len
:
inputs
[
"labels"
]
=
self
.
_pad_tensors_to_target_len
(
inputs
[
"labels"
],
inputs
[
"input_ids"
])
if
label_len
>
prompt_len
:
# truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
inputs
[
"labels"
]
=
inputs
[
"labels"
][:,
:
prompt_len
]
loss
,
generated_tokens
,
_
=
super
().
prediction_step
(
# ignore the returned labels (may be truncated)
model
,
inputs
,
prediction_loss_only
=
prediction_loss_only
,
ignore_keys
=
ignore_keys
if
self
.
args
.
predict_with_generate
:
# do not pass labels to model when generate
labels
=
inputs
.
pop
(
"labels"
,
None
)
else
:
labels
=
inputs
.
get
(
"labels"
)
loss
,
generated_tokens
,
_
=
super
().
prediction_step
(
model
,
inputs
,
prediction_loss_only
=
prediction_loss_only
,
ignore_keys
=
ignore_keys
,
**
gen_kwargs
)
if
generated_tokens
is
not
None
and
self
.
args
.
predict_with_generate
:
generated_tokens
[:,
:
prompt_len
]
=
self
.
tokenizer
.
pad_token_id
generated_tokens
[:,
:
inputs
[
"input_ids"
].
size
(
-
1
)]
=
self
.
processing_class
.
pad_token_id
generated_tokens
=
generated_tokens
.
contiguous
()
return
loss
,
generated_tokens
,
labels
def
_pad_tensors_to_target_len
(
self
,
src_tensor
:
"torch.Tensor"
,
tgt_tensor
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""
Pads the tensor to the same length as the target tensor.
"""
assert
self
.
tokenizer
.
pad_token_id
is
not
None
,
"Pad token is required."
padded_tensor
=
self
.
tokenizer
.
pad_token_id
*
torch
.
ones_like
(
tgt_tensor
)
padded_tensor
[:,
-
src_tensor
.
shape
[
-
1
]
:]
=
src_tensor
# adopt left-padding
return
padded_tensor
.
contiguous
()
# in contiguous memory
def
save_predictions
(
self
,
dataset
:
"Dataset"
,
predict_results
:
"PredictionOutput"
)
->
None
:
def
save_predictions
(
self
,
dataset
:
"Dataset"
,
predict_results
:
"PredictionOutput"
,
skip_special_tokens
:
bool
=
True
)
->
None
:
r
"""
Saves model predictions to `output_dir`.
...
...
@@ -149,24 +149,23 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
logger
.
info_rank0
(
f
"Saving prediction results to
{
output_prediction_file
}
"
)
labels
=
np
.
where
(
predict_results
.
label_ids
!=
IGNORE_INDEX
,
predict_results
.
label_ids
,
self
.
tokenizer
.
pad_token_id
predict_results
.
label_ids
!=
IGNORE_INDEX
,
predict_results
.
label_ids
,
self
.
processing_class
.
pad_token_id
)
preds
=
np
.
where
(
predict_results
.
predictions
!=
IGNORE_INDEX
,
predict_results
.
predictions
,
self
.
tokenizer
.
pad_token_id
predict_results
.
predictions
!=
IGNORE_INDEX
,
predict_results
.
predictions
,
self
.
processing_class
.
pad_token_id
,
)
for
i
in
range
(
len
(
preds
)):
pad_len
=
np
.
nonzero
(
preds
[
i
]
!=
self
.
tokenizer
.
pad_token_id
)[
0
]
pad_len
=
np
.
nonzero
(
preds
[
i
]
!=
self
.
processing_class
.
pad_token_id
)[
0
]
if
len
(
pad_len
):
# move pad token to last
preds
[
i
]
=
np
.
concatenate
((
preds
[
i
][
pad_len
[
0
]
:],
preds
[
i
][:
pad_len
[
0
]]),
axis
=-
1
)
decoded_inputs
=
self
.
tokenizer
.
batch_decode
(
dataset
[
"input_ids"
],
skip_special_tokens
=
True
)
decoded_labels
=
self
.
tokenizer
.
batch_decode
(
labels
,
skip_special_tokens
=
True
)
decoded_preds
=
self
.
tokenizer
.
batch_decode
(
preds
,
skip_special_tokens
=
True
)
with
open
(
output_prediction_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
res
:
List
[
str
]
=
[]
for
text
,
label
,
pred
in
zip
(
decoded_inputs
,
decoded_labels
,
decoded_preds
):
res
.
append
(
json
.
dumps
({
"prompt"
:
text
,
"label"
:
label
,
"predict"
:
pred
},
ensure_ascii
=
False
))
decoded_inputs
=
self
.
processing_class
.
batch_decode
(
dataset
[
"input_ids"
],
skip_special_tokens
=
False
)
decoded_preds
=
self
.
processing_class
.
batch_decode
(
preds
,
skip_special_tokens
=
skip_special_tokens
)
decoded_labels
=
self
.
processing_class
.
batch_decode
(
labels
,
skip_special_tokens
=
skip_special_tokens
)
writer
.
write
(
"
\n
"
.
join
(
res
))
with
open
(
output_prediction_file
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
for
text
,
pred
,
label
in
zip
(
decoded_inputs
,
decoded_preds
,
decoded_labels
):
f
.
write
(
json
.
dumps
({
"prompt"
:
text
,
"predict"
:
pred
,
"label"
:
label
},
ensure_ascii
=
False
)
+
"
\n
"
)
src/llamafactory/train/sft/workflow.py
View file @
8293100a
...
...
@@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, List, Optional
from
...data
import
SFTDataCollatorWith4DAttentionMask
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.misc
import
cal_effective_tokens
,
get_logits_processor
from
...extras.logging
import
get_logger
from
...extras.misc
import
calculate_tps
,
get_logits_processor
from
...extras.ploting
import
plot_loss
from
...model
import
load_model
,
load_tokenizer
from
..trainer_utils
import
create_modelcard_and_push
...
...
@@ -33,6 +34,9 @@ if TYPE_CHECKING:
from
...hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
get_logger
(
__name__
)
def
run_sft
(
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
...
...
@@ -52,6 +56,7 @@ def run_sft(
data_collator
=
SFTDataCollatorWith4DAttentionMask
(
template
=
template
,
model
=
model
if
not
training_args
.
predict_with_generate
else
None
,
pad_to_multiple_of
=
8
if
training_args
.
do_train
else
None
,
# for shift short attention
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
block_diag_attn
=
model_args
.
block_diag_attn
,
...
...
@@ -65,11 +70,6 @@ def run_sft(
training_args
.
generation_num_beams
=
data_args
.
eval_num_beams
or
training_args
.
generation_num_beams
training_args
.
remove_unused_columns
=
False
# important for multimodal dataset
effective_token_num
=
0.0
if
finetuning_args
.
include_effective_tokens_per_second
:
for
data
in
dataset_module
[
"train_dataset"
]:
effective_token_num
+=
len
(
data
[
"input_ids"
])
# Metric utils
metric_module
=
{}
if
training_args
.
predict_with_generate
:
...
...
@@ -91,7 +91,7 @@ def run_sft(
)
# Keyword arguments for `model.generate`
gen_kwargs
=
generating_args
.
to_dict
()
gen_kwargs
=
generating_args
.
to_dict
(
obey_generation_config
=
True
)
gen_kwargs
[
"eos_token_id"
]
=
[
tokenizer
.
eos_token_id
]
+
tokenizer
.
additional_special_tokens_ids
gen_kwargs
[
"pad_token_id"
]
=
tokenizer
.
pad_token_id
gen_kwargs
[
"logits_processor"
]
=
get_logits_processor
()
...
...
@@ -99,12 +99,12 @@ def run_sft(
# Training
if
training_args
.
do_train
:
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
trainer
.
save_model
()
if
finetuning_args
.
include_effective_tokens_per_second
:
train_result
.
metrics
[
"effective_tokens_per_sec"
]
=
cal
_effective_token
s
(
effective_token_num
,
train_result
.
metrics
[
"epoch
"
],
train_result
.
metrics
[
"train_runtime"
]
train_result
.
metrics
[
"effective_tokens_per_sec"
]
=
cal
culate_tp
s
(
dataset_module
[
"train_dataset
"
],
train_result
.
metrics
,
stage
=
"sft"
)
trainer
.
save_model
()
trainer
.
log_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_state
()
...
...
@@ -117,19 +117,16 @@ def run_sft(
# Evaluation
if
training_args
.
do_eval
:
metrics
=
trainer
.
evaluate
(
metric_key_prefix
=
"eval"
,
**
gen_kwargs
)
if
training_args
.
predict_with_generate
:
# eval_loss will be wrong if predict_with_generate is enabled
metrics
.
pop
(
"eval_loss"
,
None
)
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
# Predict
if
training_args
.
do_predict
:
logger
.
warning_rank0_once
(
"Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead."
)
predict_results
=
trainer
.
predict
(
dataset_module
[
"eval_dataset"
],
metric_key_prefix
=
"predict"
,
**
gen_kwargs
)
if
training_args
.
predict_with_generate
:
# predict_loss will be wrong if predict_with_generate is enabled
predict_results
.
metrics
.
pop
(
"predict_loss"
,
None
)
trainer
.
log_metrics
(
"predict"
,
predict_results
.
metrics
)
trainer
.
save_metrics
(
"predict"
,
predict_results
.
metrics
)
trainer
.
save_predictions
(
dataset_module
[
"eval_dataset"
],
predict_results
)
trainer
.
save_predictions
(
dataset_module
[
"eval_dataset"
],
predict_results
,
generating_args
.
skip_special_tokens
)
# Create model card
create_modelcard_and_push
(
trainer
,
model_args
,
data_args
,
training_args
,
finetuning_args
)
src/llamafactory/train/trainer_utils.py
View file @
8293100a
...
...
@@ -17,7 +17,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
collections.abc
import
Mapping
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
transformers
import
Trainer
...
...
@@ -30,20 +32,29 @@ from typing_extensions import override
from
..extras
import
logging
from
..extras.constants
import
IGNORE_INDEX
from
..extras.packages
import
is_
galore
_available
from
..extras.packages
import
is_
apollo_available
,
is_galore_available
,
is_ray
_available
from
..hparams
import
FinetuningArguments
,
ModelArguments
from
..model
import
find_all_linear_modules
,
load_model
,
load_tokenizer
,
load_valuehead_params
if
is_galore_available
():
from
galore_torch
import
GaLoreAdafactor
,
GaLoreAdamW
,
GaLoreAdamW8bit
from
galore_torch
import
GaLoreAdafactor
,
GaLoreAdamW
,
GaLoreAdamW8bit
# type: ignore
if
is_apollo_available
():
from
apollo_torch
import
APOLLOAdamW
# type: ignore
if
is_ray_available
():
from
ray.train
import
RunConfig
,
ScalingConfig
from
ray.train.torch
import
TorchTrainer
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
,
Seq2SeqTrainingArguments
from
transformers
import
PreTrainedModel
,
TrainerCallback
from
trl
import
AutoModelForCausalLMWithValueHead
from
..hparams
import
DataArguments
from
..hparams
import
DataArguments
,
RayArguments
,
TrainingArguments
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -51,7 +62,7 @@ logger = logging.get_logger(__name__)
class
DummyOptimizer
(
torch
.
optim
.
Optimizer
):
r
"""
A dummy optimizer used for the GaLore algorithm.
A dummy optimizer used for the GaLore
or APOLLO
algorithm.
"""
def
__init__
(
...
...
@@ -74,7 +85,7 @@ def create_modelcard_and_push(
trainer
:
"Trainer"
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"
Seq2Seq
TrainingArguments"
,
training_args
:
"TrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
None
:
kwargs
=
{
...
...
@@ -187,7 +198,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
def
_create_galore_optimizer
(
model
:
"PreTrainedModel"
,
training_args
:
"
Seq2Seq
TrainingArguments"
,
training_args
:
"TrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
"torch.optim.Optimizer"
:
if
len
(
finetuning_args
.
galore_target
)
==
1
and
finetuning_args
.
galore_target
[
0
]
==
"all"
:
...
...
@@ -231,9 +242,10 @@ def _create_galore_optimizer(
elif
training_args
.
optim
==
"adafactor"
:
optim_class
=
GaLoreAdafactor
else
:
raise
NotImplementedError
(
f
"Unknow optim:
{
training_args
.
optim
}
"
)
raise
NotImplementedError
(
f
"Unknow
n
optim:
{
training_args
.
optim
}
.
"
)
if
finetuning_args
.
galore_layerwise
:
logger
.
warning_rank0
(
"The displayed gradient norm will be all zeros in layerwise GaLore."
)
if
training_args
.
gradient_accumulation_steps
!=
1
:
raise
ValueError
(
"Per-layer GaLore does not support gradient accumulation."
)
...
...
@@ -265,13 +277,100 @@ def _create_galore_optimizer(
]
optimizer
=
optim_class
(
param_groups
,
**
optim_kwargs
)
logger
.
info_rank0
(
"Using GaLore optimizer, may cause hanging at the start of training, wait patiently."
)
logger
.
info_rank0
(
f
"Using GaLore optimizer with args:
{
galore_kwargs
}
. "
"It may cause hanging at the start of training, wait patiently."
)
return
optimizer
def
_create_apollo_optimizer
(
model
:
"PreTrainedModel"
,
training_args
:
"TrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
"torch.optim.Optimizer"
:
if
len
(
finetuning_args
.
apollo_target
)
==
1
and
finetuning_args
.
apollo_target
[
0
]
==
"all"
:
apollo_targets
=
find_all_linear_modules
(
model
,
finetuning_args
.
freeze_vision_tower
)
else
:
apollo_targets
=
finetuning_args
.
apollo_target
apollo_params
:
List
[
"torch.nn.Parameter"
]
=
[]
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
torch
.
nn
.
Linear
)
and
any
(
target
in
name
for
target
in
apollo_targets
):
for
param
in
module
.
parameters
():
if
param
.
requires_grad
and
len
(
param
.
shape
)
>
1
:
apollo_params
.
append
(
param
)
apollo_kwargs
=
{
"rank"
:
finetuning_args
.
apollo_rank
,
"proj"
:
finetuning_args
.
apollo_proj
,
"proj_type"
:
finetuning_args
.
apollo_proj_type
,
"update_proj_gap"
:
finetuning_args
.
apollo_update_interval
,
"scale"
:
finetuning_args
.
apollo_scale
,
"scale_type"
:
finetuning_args
.
apollo_scale_type
,
"scale_front"
:
finetuning_args
.
apollo_scale_front
,
}
id_apollo_params
=
{
id
(
param
)
for
param
in
apollo_params
}
decay_params
,
nodecay_params
=
[],
[]
# they are non-apollo parameters
trainable_params
:
List
[
"torch.nn.Parameter"
]
=
[]
# apollo_params + decay_params + nodecay_params
decay_param_names
=
_get_decay_parameter_names
(
model
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
trainable_params
.
append
(
param
)
if
id
(
param
)
not
in
id_apollo_params
:
if
name
in
decay_param_names
:
decay_params
.
append
(
param
)
else
:
nodecay_params
.
append
(
param
)
_
,
optim_kwargs
=
Trainer
.
get_optimizer_cls_and_kwargs
(
training_args
)
if
training_args
.
optim
==
"adamw_torch"
:
optim_class
=
APOLLOAdamW
else
:
raise
NotImplementedError
(
f
"Unknown optim:
{
training_args
.
optim
}
."
)
if
finetuning_args
.
apollo_layerwise
:
logger
.
warning_rank0
(
"The displayed gradient norm will be all zeros in layerwise APOLLO."
)
if
training_args
.
gradient_accumulation_steps
!=
1
:
raise
ValueError
(
"Per-layer APOLLO does not support gradient accumulation."
)
optimizer_dict
:
Dict
[
"torch.Tensor"
,
"torch.optim.Optimizer"
]
=
{}
for
param
in
nodecay_params
:
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
0.0
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
for
param
in
decay_params
:
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
training_args
.
weight_decay
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
for
param
in
apollo_params
:
# apollo params have weight decay
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
training_args
.
weight_decay
,
**
apollo_kwargs
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
def
optimizer_hook
(
param
:
"torch.nn.Parameter"
):
if
param
.
grad
is
not
None
:
optimizer_dict
[
param
].
step
()
optimizer_dict
[
param
].
zero_grad
()
for
param
in
trainable_params
:
param
.
register_post_accumulate_grad_hook
(
optimizer_hook
)
optimizer
=
DummyOptimizer
(
lr
=
training_args
.
learning_rate
,
optimizer_dict
=
optimizer_dict
)
else
:
param_groups
=
[
dict
(
params
=
nodecay_params
,
weight_decay
=
0.0
),
dict
(
params
=
decay_params
,
weight_decay
=
training_args
.
weight_decay
),
dict
(
params
=
apollo_params
,
weight_decay
=
training_args
.
weight_decay
,
**
apollo_kwargs
),
]
optimizer
=
optim_class
(
param_groups
,
**
optim_kwargs
)
logger
.
info_rank0
(
f
"Using APOLLO optimizer with args:
{
apollo_kwargs
}
."
)
return
optimizer
def
_create_loraplus_optimizer
(
model
:
"PreTrainedModel"
,
training_args
:
"
Seq2Seq
TrainingArguments"
,
training_args
:
"TrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
"torch.optim.Optimizer"
:
default_lr
=
training_args
.
learning_rate
...
...
@@ -311,7 +410,7 @@ def _create_loraplus_optimizer(
def
_create_badam_optimizer
(
model
:
"PreTrainedModel"
,
training_args
:
"
Seq2Seq
TrainingArguments"
,
training_args
:
"TrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
"torch.optim.Optimizer"
:
decay_params
,
nodecay_params
=
[],
[]
...
...
@@ -330,7 +429,7 @@ def _create_badam_optimizer(
]
if
finetuning_args
.
badam_mode
==
"layer"
:
from
badam
import
BlockOptimizer
from
badam
import
BlockOptimizer
# type: ignore
base_optimizer
=
optim_class
(
param_groups
,
**
optim_kwargs
)
optimizer
=
BlockOptimizer
(
...
...
@@ -350,7 +449,7 @@ def _create_badam_optimizer(
)
elif
finetuning_args
.
badam_mode
==
"ratio"
:
from
badam
import
BlockOptimizerRatio
from
badam
import
BlockOptimizerRatio
# type: ignore
assert
finetuning_args
.
badam_update_ratio
>
1e-6
optimizer
=
BlockOptimizerRatio
(
...
...
@@ -372,9 +471,9 @@ def _create_badam_optimizer(
def
_create_adam_mini_optimizer
(
model
:
"PreTrainedModel"
,
training_args
:
"
Seq2Seq
TrainingArguments"
,
training_args
:
"TrainingArguments"
,
)
->
"torch.optim.Optimizer"
:
from
adam_mini
import
Adam_mini
from
adam_mini
import
Adam_mini
# type: ignore
hidden_size
=
getattr
(
model
.
config
,
"hidden_size"
,
None
)
num_q_head
=
getattr
(
model
.
config
,
"num_attention_heads"
,
None
)
...
...
@@ -397,12 +496,15 @@ def _create_adam_mini_optimizer(
def
create_custom_optimizer
(
model
:
"PreTrainedModel"
,
training_args
:
"
Seq2Seq
TrainingArguments"
,
training_args
:
"TrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
Optional
[
"torch.optim.Optimizer"
]:
if
finetuning_args
.
use_galore
:
return
_create_galore_optimizer
(
model
,
training_args
,
finetuning_args
)
if
finetuning_args
.
use_apollo
:
return
_create_apollo_optimizer
(
model
,
training_args
,
finetuning_args
)
if
finetuning_args
.
loraplus_lr_ratio
is
not
None
:
return
_create_loraplus_optimizer
(
model
,
training_args
,
finetuning_args
)
...
...
@@ -414,7 +516,7 @@ def create_custom_optimizer(
def
create_custom_scheduler
(
training_args
:
"
Seq2Seq
TrainingArguments"
,
training_args
:
"TrainingArguments"
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
,
)
->
None
:
...
...
@@ -457,3 +559,69 @@ def get_batch_logps(
labels
[
labels
==
label_pad_token_id
]
=
0
# dummy token
per_token_logps
=
torch
.
gather
(
logits
.
log_softmax
(
-
1
),
dim
=
2
,
index
=
labels
.
unsqueeze
(
2
)).
squeeze
(
2
)
return
(
per_token_logps
*
loss_mask
).
sum
(
-
1
),
loss_mask
.
sum
(
-
1
)
def
nested_detach
(
tensors
:
Union
[
"torch.Tensor"
,
List
[
"torch.Tensor"
],
Tuple
[
"torch.Tensor"
],
Dict
[
str
,
"torch.Tensor"
]],
clone
:
bool
=
False
,
):
r
"""
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
"""
if
isinstance
(
tensors
,
(
list
,
tuple
)):
return
type
(
tensors
)(
nested_detach
(
t
,
clone
=
clone
)
for
t
in
tensors
)
elif
isinstance
(
tensors
,
Mapping
):
return
type
(
tensors
)({
k
:
nested_detach
(
t
,
clone
=
clone
)
for
k
,
t
in
tensors
.
items
()})
if
isinstance
(
tensors
,
torch
.
Tensor
):
if
clone
:
return
tensors
.
detach
().
clone
()
else
:
return
tensors
.
detach
()
else
:
return
tensors
def
get_swanlab_callback
(
finetuning_args
:
"FinetuningArguments"
)
->
"TrainerCallback"
:
r
"""
Gets the callback for logging to SwanLab.
"""
import
swanlab
# type: ignore
from
swanlab.integration.transformers
import
SwanLabCallback
# type: ignore
if
finetuning_args
.
swanlab_api_key
is
not
None
:
swanlab
.
login
(
api_key
=
finetuning_args
.
swanlab_api_key
)
swanlab_callback
=
SwanLabCallback
(
project
=
finetuning_args
.
swanlab_project
,
workspace
=
finetuning_args
.
swanlab_workspace
,
experiment_name
=
finetuning_args
.
swanlab_run_name
,
mode
=
finetuning_args
.
swanlab_mode
,
config
=
{
"Framework"
:
"🦙LlamaFactory"
},
)
return
swanlab_callback
def
get_ray_trainer
(
training_function
:
Callable
,
train_loop_config
:
Dict
[
str
,
Any
],
ray_args
:
"RayArguments"
,
)
->
"TorchTrainer"
:
if
not
ray_args
.
use_ray
:
raise
ValueError
(
"Ray was not enabled. Please set `USE_RAY=1` to enable ray."
)
trainer
=
TorchTrainer
(
training_function
,
train_loop_config
=
train_loop_config
,
scaling_config
=
ScalingConfig
(
num_workers
=
ray_args
.
ray_num_workers
,
resources_per_worker
=
ray_args
.
resources_per_worker
,
placement_strategy
=
ray_args
.
placement_strategy
,
use_gpu
=
True
,
),
run_config
=
RunConfig
(
name
=
ray_args
.
ray_run_name
,
storage_path
=
Path
(
"./saves"
).
absolute
().
as_posix
(),
),
)
return
trainer
src/llamafactory/train/tuner.py
View file @
8293100a
...
...
@@ -22,15 +22,21 @@ from transformers import PreTrainedModel
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras.constants
import
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
..hparams
import
get_infer_args
,
get_train_args
from
..extras.packages
import
is_ray_available
from
..hparams
import
get_infer_args
,
get_ray_args
,
get_train_args
,
read_args
from
..model
import
load_model
,
load_tokenizer
from
.callbacks
import
LogCallback
from
.callbacks
import
LogCallback
,
PissaConvertCallback
,
ReporterCallback
from
.dpo
import
run_dpo
from
.kto
import
run_kto
from
.ppo
import
run_ppo
from
.pt
import
run_pt
from
.rm
import
run_rm
from
.sft
import
run_sft
from
.trainer_utils
import
get_ray_trainer
,
get_swanlab_callback
if
is_ray_available
():
from
ray.train.huggingface.transformers
import
RayTrainReportCallback
if
TYPE_CHECKING
:
...
...
@@ -40,10 +46,20 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
def
run_exp
(
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
callbacks
:
List
[
"TrainerCallback"
]
=
[])
->
None
:
callbacks
.
append
(
LogCallback
())
def
_training_function
(
config
:
Dict
[
str
,
Any
])
->
None
:
args
=
config
.
get
(
"args"
)
callbacks
:
List
[
Any
]
=
config
.
get
(
"callbacks"
)
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
=
get_train_args
(
args
)
callbacks
.
append
(
LogCallback
())
if
finetuning_args
.
pissa_convert
:
callbacks
.
append
(
PissaConvertCallback
())
if
finetuning_args
.
use_swanlab
:
callbacks
.
append
(
get_swanlab_callback
(
finetuning_args
))
callbacks
.
append
(
ReporterCallback
(
model_args
,
data_args
,
finetuning_args
,
generating_args
))
# add to last
if
finetuning_args
.
stage
==
"pt"
:
run_pt
(
model_args
,
data_args
,
training_args
,
finetuning_args
,
callbacks
)
elif
finetuning_args
.
stage
==
"sft"
:
...
...
@@ -60,6 +76,22 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
raise
ValueError
(
f
"Unknown task:
{
finetuning_args
.
stage
}
."
)
def
run_exp
(
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
callbacks
:
Optional
[
List
[
"TrainerCallback"
]]
=
None
)
->
None
:
args
=
read_args
(
args
)
ray_args
=
get_ray_args
(
args
)
callbacks
=
callbacks
or
[]
if
ray_args
.
use_ray
:
callbacks
.
append
(
RayTrainReportCallback
())
trainer
=
get_ray_trainer
(
training_function
=
_training_function
,
train_loop_config
=
{
"args"
:
args
,
"callbacks"
:
callbacks
},
ray_args
=
ray_args
,
)
trainer
.
fit
()
else
:
_training_function
(
config
=
{
"args"
:
args
,
"callbacks"
:
callbacks
})
def
export_model
(
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
None
:
model_args
,
data_args
,
finetuning_args
,
_
=
get_infer_args
(
args
)
...
...
src/llamafactory/webui/chatter.py
View file @
8293100a
...
...
@@ -91,6 +91,7 @@ class WebChatModel(ChatModel):
rope_scaling
=
get
(
"top.rope_scaling"
)
if
get
(
"top.rope_scaling"
)
in
[
"linear"
,
"dynamic"
]
else
None
,
infer_backend
=
get
(
"infer.infer_backend"
),
infer_dtype
=
get
(
"infer.infer_dtype"
),
trust_remote_code
=
True
,
)
if
checkpoint_path
:
...
...
@@ -157,7 +158,7 @@ class WebChatModel(ChatModel):
result
=
response
if
isinstance
(
result
,
list
):
tool_calls
=
[{
"name"
:
tool
[
0
]
,
"arguments"
:
json
.
loads
(
tool
[
1
]
)}
for
tool
in
result
]
tool_calls
=
[{
"name"
:
tool
.
name
,
"arguments"
:
json
.
loads
(
tool
.
arguments
)}
for
tool
in
result
]
tool_calls
=
json
.
dumps
(
tool_calls
,
indent
=
4
,
ensure_ascii
=
False
)
output_messages
=
messages
+
[{
"role"
:
Role
.
FUNCTION
.
value
,
"content"
:
tool_calls
}]
bot_text
=
"```json
\n
"
+
tool_calls
+
"
\n
```"
...
...
src/llamafactory/webui/components/export.py
View file @
8293100a
...
...
@@ -84,6 +84,7 @@ def save_model(
export_quantization_dataset
=
export_quantization_dataset
,
export_device
=
export_device
,
export_legacy_format
=
export_legacy_format
,
trust_remote_code
=
True
,
)
if
checkpoint_path
:
...
...
src/llamafactory/webui/components/train.py
View file @
8293100a
...
...
@@ -234,8 +234,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with
gr
.
Row
():
use_galore
=
gr
.
Checkbox
()
galore_rank
=
gr
.
Slider
(
minimum
=
1
,
maximum
=
1024
,
value
=
16
,
step
=
1
)
galore_update_interval
=
gr
.
Slider
(
minimum
=
1
,
maximum
=
1024
,
value
=
200
,
step
=
1
)
galore_scale
=
gr
.
Slider
(
minimum
=
0
,
maximum
=
1
,
value
=
0.25
,
step
=
0.
0
1
)
galore_update_interval
=
gr
.
Slider
(
minimum
=
1
,
maximum
=
2048
,
value
=
200
,
step
=
1
)
galore_scale
=
gr
.
Slider
(
minimum
=
0
,
maximum
=
1
00
,
value
=
2.0
,
step
=
0.1
)
galore_target
=
gr
.
Textbox
(
value
=
"all"
)
input_elems
.
update
({
use_galore
,
galore_rank
,
galore_update_interval
,
galore_scale
,
galore_target
})
...
...
@@ -250,6 +250,26 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
)
with
gr
.
Accordion
(
open
=
False
)
as
apollo_tab
:
with
gr
.
Row
():
use_apollo
=
gr
.
Checkbox
()
apollo_rank
=
gr
.
Slider
(
minimum
=
1
,
maximum
=
1024
,
value
=
16
,
step
=
1
)
apollo_update_interval
=
gr
.
Slider
(
minimum
=
1
,
maximum
=
2048
,
value
=
200
,
step
=
1
)
apollo_scale
=
gr
.
Slider
(
minimum
=
0
,
maximum
=
100
,
value
=
32.0
,
step
=
0.1
)
apollo_target
=
gr
.
Textbox
(
value
=
"all"
)
input_elems
.
update
({
use_apollo
,
apollo_rank
,
apollo_update_interval
,
apollo_scale
,
apollo_target
})
elem_dict
.
update
(
dict
(
apollo_tab
=
apollo_tab
,
use_apollo
=
use_apollo
,
apollo_rank
=
apollo_rank
,
apollo_update_interval
=
apollo_update_interval
,
apollo_scale
=
apollo_scale
,
apollo_target
=
apollo_target
,
)
)
with
gr
.
Accordion
(
open
=
False
)
as
badam_tab
:
with
gr
.
Row
():
use_badam
=
gr
.
Checkbox
()
...
...
@@ -270,6 +290,30 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
)
with
gr
.
Accordion
(
open
=
False
)
as
swanlab_tab
:
with
gr
.
Row
():
use_swanlab
=
gr
.
Checkbox
()
swanlab_project
=
gr
.
Textbox
(
value
=
"llamafactory"
)
swanlab_run_name
=
gr
.
Textbox
()
swanlab_workspace
=
gr
.
Textbox
()
swanlab_api_key
=
gr
.
Textbox
()
swanlab_mode
=
gr
.
Dropdown
(
choices
=
[
"cloud"
,
"local"
],
value
=
"cloud"
)
input_elems
.
update
(
{
use_swanlab
,
swanlab_project
,
swanlab_run_name
,
swanlab_workspace
,
swanlab_api_key
,
swanlab_mode
}
)
elem_dict
.
update
(
dict
(
swanlab_tab
=
swanlab_tab
,
use_swanlab
=
use_swanlab
,
swanlab_project
=
swanlab_project
,
swanlab_run_name
=
swanlab_run_name
,
swanlab_workspace
=
swanlab_workspace
,
swanlab_api_key
=
swanlab_api_key
,
swanlab_mode
=
swanlab_mode
,
)
)
with
gr
.
Row
():
cmd_preview_btn
=
gr
.
Button
()
arg_save_btn
=
gr
.
Button
()
...
...
src/llamafactory/webui/interface.py
View file @
8293100a
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
os
import
platform
from
..extras.packages
import
is_gradio_available
from
.common
import
save_config
...
...
@@ -34,8 +35,9 @@ if is_gradio_available():
def
create_ui
(
demo_mode
:
bool
=
False
)
->
"gr.Blocks"
:
engine
=
Engine
(
demo_mode
=
demo_mode
,
pure_chat
=
False
)
hostname
=
os
.
getenv
(
"HOSTNAME"
,
os
.
getenv
(
"COMPUTERNAME"
,
platform
.
node
())).
split
(
"."
)[
0
]
with
gr
.
Blocks
(
title
=
"LLaMA Board"
,
css
=
CSS
)
as
demo
:
with
gr
.
Blocks
(
title
=
f
"LLaMA Board
(
{
hostname
}
)
"
,
css
=
CSS
)
as
demo
:
if
demo_mode
:
gr
.
HTML
(
"<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>"
)
gr
.
HTML
(
...
...
src/llamafactory/webui/locales.py
View file @
8293100a
...
...
@@ -30,15 +30,19 @@ LOCALES = {
"model_name"
:
{
"en"
:
{
"label"
:
"Model name"
,
"info"
:
"Input the name prefix to search for the model."
,
},
"ru"
:
{
"label"
:
"Название модели"
,
"info"
:
"Введите префикс имени для поиска модели."
,
},
"zh"
:
{
"label"
:
"模型名称"
,
"info"
:
"输入首单词以检索模型。"
,
},
"ko"
:
{
"label"
:
"모델 이름"
,
"info"
:
"모델을 검색하기 위해 이름 접두어를 입력하세요."
,
},
},
"model_path"
:
{
...
...
@@ -464,7 +468,7 @@ LOCALES = {
"val_size"
:
{
"en"
:
{
"label"
:
"Val size"
,
"info"
:
"P
roportion of data in the dev
set."
,
"info"
:
"P
ercentage of validation set from the entire data
set."
,
},
"ru"
:
{
"label"
:
"Размер валидации"
,
...
...
@@ -1115,7 +1119,7 @@ LOCALES = {
"info"
:
"Нормализация оценок в тренировке PPO."
,
},
"zh"
:
{
"label"
:
"
奖励模型
"
,
"label"
:
"
归一化分数
"
,
"info"
:
"PPO 训练中归一化奖励分数。"
,
},
"ko"
:
{
...
...
@@ -1158,19 +1162,19 @@ LOCALES = {
"use_galore"
:
{
"en"
:
{
"label"
:
"Use GaLore"
,
"info"
:
"
Enable gradient low-Rank projection
."
,
"info"
:
"
Use GaLore optimizer
."
,
},
"ru"
:
{
"label"
:
"Использовать GaLore"
,
"info"
:
"
Включить проекцию градиента на низкоранговое пространство
."
,
"info"
:
"
Используйте оптимизатор GaLore
."
,
},
"zh"
:
{
"label"
:
"使用 GaLore"
,
"info"
:
"使用
梯度低秩投影
。"
,
"info"
:
"使用
GaLore 优化器
。"
,
},
"ko"
:
{
"label"
:
"GaLore 사용"
,
"info"
:
"
그레디언트 로우 랭크 프로젝션을 활성화합니다
."
,
"info"
:
"
GaLore 최적화를 사용하세요
."
,
},
},
"galore_rank"
:
{
...
...
@@ -1245,6 +1249,110 @@ LOCALES = {
"info"
:
"GaLore를 적용할 모듈의 이름. 모듈 간에는 쉼표(,)로 구분하십시오."
,
},
},
"apollo_tab"
:
{
"en"
:
{
"label"
:
"APOLLO configurations"
,
},
"ru"
:
{
"label"
:
"Конфигурации APOLLO"
,
},
"zh"
:
{
"label"
:
"APOLLO 参数设置"
,
},
"ko"
:
{
"label"
:
"APOLLO 구성"
,
},
},
"use_apollo"
:
{
"en"
:
{
"label"
:
"Use APOLLO"
,
"info"
:
"Use APOLLO optimizer."
,
},
"ru"
:
{
"label"
:
"Использовать APOLLO"
,
"info"
:
"Используйте оптимизатор APOLLO."
,
},
"zh"
:
{
"label"
:
"使用 APOLLO"
,
"info"
:
"使用 APOLLO 优化器。"
,
},
"ko"
:
{
"label"
:
"APOLLO 사용"
,
"info"
:
"APOLLO 최적화를 사용하세요."
,
},
},
"apollo_rank"
:
{
"en"
:
{
"label"
:
"APOLLO rank"
,
"info"
:
"The rank of APOLLO gradients."
,
},
"ru"
:
{
"label"
:
"Ранг APOLLO"
,
"info"
:
"Ранг градиентов APOLLO."
,
},
"zh"
:
{
"label"
:
"APOLLO 秩"
,
"info"
:
"APOLLO 梯度的秩大小。"
,
},
"ko"
:
{
"label"
:
"APOLLO 랭크"
,
"info"
:
"APOLLO 그레디언트의 랭크."
,
},
},
"apollo_update_interval"
:
{
"en"
:
{
"label"
:
"Update interval"
,
"info"
:
"Number of steps to update the APOLLO projection."
,
},
"ru"
:
{
"label"
:
"Интервал обновления"
,
"info"
:
"Количество шагов для обновления проекции APOLLO."
,
},
"zh"
:
{
"label"
:
"更新间隔"
,
"info"
:
"相邻两次投影更新的步数。"
,
},
"ko"
:
{
"label"
:
"업데이트 간격"
,
"info"
:
"APOLLO 프로젝션을 업데이트할 간격의 스텝 수."
,
},
},
"apollo_scale"
:
{
"en"
:
{
"label"
:
"APOLLO scale"
,
"info"
:
"APOLLO scaling coefficient."
,
},
"ru"
:
{
"label"
:
"LoRA Alpha"
,
"info"
:
"Коэффициент масштабирования APOLLO."
,
},
"zh"
:
{
"label"
:
"APOLLO 缩放系数"
,
"info"
:
"APOLLO 缩放系数大小。"
,
},
"ko"
:
{
"label"
:
"APOLLO 스케일"
,
"info"
:
"APOLLO 스케일링 계수."
,
},
},
"apollo_target"
:
{
"en"
:
{
"label"
:
"APOLLO modules"
,
"info"
:
"Name(s) of modules to apply APOLLO. Use commas to separate multiple modules."
,
},
"ru"
:
{
"label"
:
"Модули APOLLO"
,
"info"
:
"Имена модулей для применения APOLLO. Используйте запятые для разделения нескольких модулей."
,
},
"zh"
:
{
"label"
:
"APOLLO 作用模块"
,
"info"
:
"应用 APOLLO 的模块名称。使用英文逗号分隔多个名称。"
,
},
"ko"
:
{
"label"
:
"APOLLO 모듈"
,
"info"
:
"APOLLO를 적용할 모듈의 이름. 모듈 간에는 쉼표(,)로 구분하십시오."
,
},
},
"badam_tab"
:
{
"en"
:
{
"label"
:
"BAdam configurations"
,
...
...
@@ -1349,6 +1457,120 @@ LOCALES = {
"info"
:
"비율-BAdam의 업데이트 비율."
,
},
},
"swanlab_tab"
:
{
"en"
:
{
"label"
:
"SwanLab configurations"
,
},
"ru"
:
{
"label"
:
"Конфигурации SwanLab"
,
},
"zh"
:
{
"label"
:
"SwanLab 参数设置"
,
},
"ko"
:
{
"label"
:
"SwanLab 설정"
,
},
},
"use_swanlab"
:
{
"en"
:
{
"label"
:
"Use SwanLab"
,
"info"
:
"Enable SwanLab for experiment tracking and visualization."
,
},
"ru"
:
{
"label"
:
"Использовать SwanLab"
,
"info"
:
"Включить SwanLab для отслеживания и визуализации экспериментов."
,
},
"zh"
:
{
"label"
:
"使用 SwanLab"
,
"info"
:
"启用 SwanLab 进行实验跟踪和可视化。"
,
},
"ko"
:
{
"label"
:
"SwanLab 사용"
,
"info"
:
"SwanLab를 사용하여 실험을 추적하고 시각화합니다."
,
},
},
"swanlab_project"
:
{
"en"
:
{
"label"
:
"SwanLab project"
,
},
"ru"
:
{
"label"
:
"SwanLab Проект"
,
},
"zh"
:
{
"label"
:
"SwanLab 项目名"
,
},
"ko"
:
{
"label"
:
"SwanLab 프로젝트"
,
},
},
"swanlab_run_name"
:
{
"en"
:
{
"label"
:
"SwanLab experiment name (optional)"
,
},
"ru"
:
{
"label"
:
"SwanLab Имя эксперимента (опционально)"
,
},
"zh"
:
{
"label"
:
"SwanLab 实验名(非必填)"
,
},
"ko"
:
{
"label"
:
"SwanLab 실험 이름 (선택 사항)"
,
},
},
"swanlab_workspace"
:
{
"en"
:
{
"label"
:
"SwanLab workspace (optional)"
,
"info"
:
"Workspace for SwanLab. Defaults to the personal workspace."
,
},
"ru"
:
{
"label"
:
"SwanLab Рабочая область (опционально)"
,
"info"
:
"Рабочая область SwanLab, если не заполнено, то по умолчанию в личной рабочей области."
,
},
"zh"
:
{
"label"
:
"SwanLab 工作区(非必填)"
,
"info"
:
"SwanLab 的工作区,默认在个人工作区下。"
,
},
"ko"
:
{
"label"
:
"SwanLab 작업 영역 (선택 사항)"
,
"info"
:
"SwanLab 조직의 작업 영역, 비어 있으면 기본적으로 개인 작업 영역에 있습니다."
,
},
},
"swanlab_api_key"
:
{
"en"
:
{
"label"
:
"SwanLab API key (optional)"
,
"info"
:
"API key for SwanLab."
,
},
"ru"
:
{
"label"
:
"SwanLab API ключ (опционально)"
,
"info"
:
"API ключ для SwanLab."
,
},
"zh"
:
{
"label"
:
"SwanLab API密钥(非必填)"
,
"info"
:
"用于在编程环境登录 SwanLab,已登录则无需填写。"
,
},
"ko"
:
{
"label"
:
"SwanLab API 키 (선택 사항)"
,
"info"
:
"SwanLab의 API 키."
,
},
},
"swanlab_mode"
:
{
"en"
:
{
"label"
:
"SwanLab mode"
,
"info"
:
"Cloud or offline version."
,
},
"ru"
:
{
"label"
:
"SwanLab Режим"
,
"info"
:
"Версия в облаке или локальная версия."
,
},
"zh"
:
{
"label"
:
"SwanLab 模式"
,
"info"
:
"使用云端版或离线版 SwanLab。"
,
},
"ko"
:
{
"label"
:
"SwanLab 모드"
,
"info"
:
"클라우드 버전 또는 오프라인 버전."
,
},
},
"cmd_preview_btn"
:
{
"en"
:
{
"value"
:
"Preview command"
,
...
...
src/llamafactory/webui/runner.py
View file @
8293100a
...
...
@@ -19,9 +19,10 @@ from subprocess import Popen, TimeoutExpired
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Generator
,
Optional
from
transformers.trainer
import
TRAINING_ARGS_NAME
from
transformers.utils
import
is_torch_npu_available
from
..extras.constants
import
LLAMABOARD_CONFIG
,
PEFT_METHODS
,
TRAINING_STAGES
from
..extras.misc
import
is_gpu_or_npu_available
,
torch_gc
from
..extras.misc
import
is_gpu_or_npu_available
,
torch_gc
,
use_ray
from
..extras.packages
import
is_gradio_available
,
is_transformers_version_equal_to_4_46
from
.common
import
DEFAULT_CACHE_DIR
,
DEFAULT_CONFIG_DIR
,
QUANTIZATION_BITS
,
get_save_dir
,
load_config
from
.locales
import
ALERTS
,
LOCALES
...
...
@@ -146,12 +147,15 @@ class Runner:
shift_attn
=
get
(
"train.shift_attn"
),
report_to
=
"all"
if
get
(
"train.report_to"
)
else
"none"
,
use_galore
=
get
(
"train.use_galore"
),
use_apollo
=
get
(
"train.use_apollo"
),
use_badam
=
get
(
"train.use_badam"
),
use_swanlab
=
get
(
"train.use_swanlab"
),
output_dir
=
get_save_dir
(
model_name
,
finetuning_type
,
get
(
"train.output_dir"
)),
fp16
=
(
get
(
"train.compute_type"
)
==
"fp16"
),
bf16
=
(
get
(
"train.compute_type"
)
==
"bf16"
),
pure_bf16
=
(
get
(
"train.compute_type"
)
==
"pure_bf16"
),
plot_loss
=
True
,
trust_remote_code
=
True
,
ddp_timeout
=
180000000
,
include_num_input_tokens_seen
=
False
if
is_transformers_version_equal_to_4_46
()
else
True
,
# FIXME
)
...
...
@@ -170,6 +174,7 @@ class Runner:
if
get
(
"top.quantization_bit"
)
in
QUANTIZATION_BITS
:
args
[
"quantization_bit"
]
=
int
(
get
(
"top.quantization_bit"
))
args
[
"quantization_method"
]
=
get
(
"top.quantization_method"
)
args
[
"double_quantization"
]
=
not
is_torch_npu_available
()
# freeze config
if
args
[
"finetuning_type"
]
==
"freeze"
:
...
...
@@ -220,6 +225,13 @@ class Runner:
args
[
"galore_scale"
]
=
get
(
"train.galore_scale"
)
args
[
"galore_target"
]
=
get
(
"train.galore_target"
)
# apollo config
if
args
[
"use_apollo"
]:
args
[
"apollo_rank"
]
=
get
(
"train.apollo_rank"
)
args
[
"apollo_update_interval"
]
=
get
(
"train.apollo_update_interval"
)
args
[
"apollo_scale"
]
=
get
(
"train.apollo_scale"
)
args
[
"apollo_target"
]
=
get
(
"train.apollo_target"
)
# badam config
if
args
[
"use_badam"
]:
args
[
"badam_mode"
]
=
get
(
"train.badam_mode"
)
...
...
@@ -227,6 +239,14 @@ class Runner:
args
[
"badam_switch_interval"
]
=
get
(
"train.badam_switch_interval"
)
args
[
"badam_update_ratio"
]
=
get
(
"train.badam_update_ratio"
)
# swanlab config
if
get
(
"train.use_swanlab"
):
args
[
"swanlab_project"
]
=
get
(
"train.swanlab_project"
)
args
[
"swanlab_run_name"
]
=
get
(
"train.swanlab_run_name"
)
args
[
"swanlab_workspace"
]
=
get
(
"train.swanlab_workspace"
)
args
[
"swanlab_api_key"
]
=
get
(
"train.swanlab_api_key"
)
args
[
"swanlab_mode"
]
=
get
(
"train.swanlab_mode"
)
# eval config
if
get
(
"train.val_size"
)
>
1e-6
and
args
[
"stage"
]
!=
"ppo"
:
args
[
"val_size"
]
=
get
(
"train.val_size"
)
...
...
@@ -268,6 +288,7 @@ class Runner:
top_p
=
get
(
"eval.top_p"
),
temperature
=
get
(
"eval.temperature"
),
output_dir
=
get_save_dir
(
model_name
,
finetuning_type
,
get
(
"eval.output_dir"
)),
trust_remote_code
=
True
,
)
if
get
(
"eval.predict"
):
...
...
@@ -383,12 +404,12 @@ class Runner:
continue
if
self
.
do_train
:
if
os
.
path
.
exists
(
os
.
path
.
join
(
output_path
,
TRAINING_ARGS_NAME
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
output_path
,
TRAINING_ARGS_NAME
))
or
use_ray
()
:
finish_info
=
ALERTS
[
"info_finished"
][
lang
]
else
:
finish_info
=
ALERTS
[
"err_failed"
][
lang
]
else
:
if
os
.
path
.
exists
(
os
.
path
.
join
(
output_path
,
"all_results.json"
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
output_path
,
"all_results.json"
))
or
use_ray
()
:
finish_info
=
get_eval_results
(
os
.
path
.
join
(
output_path
,
"all_results.json"
))
else
:
finish_info
=
ALERTS
[
"err_failed"
][
lang
]
...
...
tests/data/test_collator.py
View file @
8293100a
...
...
@@ -12,9 +12,105 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
torch
from
PIL
import
Image
from
llamafactory.data
import
get_template_and_fix_tokenizer
from
llamafactory.data.collator
import
MultiModalDataCollatorForSeq2Seq
,
prepare_4d_attention_mask
from
llamafactory.extras.constants
import
IGNORE_INDEX
from
llamafactory.hparams
import
get_infer_args
from
llamafactory.model
import
load_tokenizer
TINY_LLAMA
=
os
.
getenv
(
"TINY_LLAMA"
,
"llamafactory/tiny-random-Llama-3"
)
from
llamafactory.data.collator
import
prepare_4d_attention_mask
def
test_base_collator
():
model_args
,
data_args
,
*
_
=
get_infer_args
({
"model_name_or_path"
:
TINY_LLAMA
,
"template"
:
"default"
})
tokenizer_module
=
load_tokenizer
(
model_args
)
template
=
get_template_and_fix_tokenizer
(
tokenizer_module
[
"tokenizer"
],
data_args
)
data_collator
=
MultiModalDataCollatorForSeq2Seq
(
template
=
template
,
pad_to_multiple_of
=
8
,
label_pad_token_id
=
IGNORE_INDEX
,
**
tokenizer_module
,
)
p
=
tokenizer_module
[
"tokenizer"
].
pad_token_id
q
=
IGNORE_INDEX
features
=
[
{
"input_ids"
:
[
0
,
1
,
2
,
3
,
4
,
5
],
"attention_mask"
:
[
1
,
1
,
1
,
1
,
1
,
1
],
"labels"
:
[
q
,
q
,
2
,
3
,
4
,
5
],
},
{
"input_ids"
:
[
6
,
7
],
"attention_mask"
:
[
1
,
1
],
"labels"
:
[
q
,
7
],
},
]
batch_input
=
data_collator
(
features
)
expected_input
=
{
"input_ids"
:
[
[
0
,
1
,
2
,
3
,
4
,
5
,
p
,
p
],
[
6
,
7
,
p
,
p
,
p
,
p
,
p
,
p
],
],
"attention_mask"
:
[
[
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
],
[
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
],
],
"labels"
:
[
[
q
,
q
,
2
,
3
,
4
,
5
,
q
,
q
],
[
q
,
7
,
q
,
q
,
q
,
q
,
q
,
q
],
],
}
for
k
in
batch_input
.
keys
():
assert
batch_input
[
k
].
eq
(
torch
.
tensor
(
expected_input
[
k
])).
all
()
def
test_multimodal_collator
():
model_args
,
data_args
,
*
_
=
get_infer_args
(
{
"model_name_or_path"
:
"Qwen/Qwen2-VL-7B-Instruct"
,
"template"
:
"qwen2_vl"
}
)
tokenizer_module
=
load_tokenizer
(
model_args
)
template
=
get_template_and_fix_tokenizer
(
tokenizer_module
[
"tokenizer"
],
data_args
)
data_collator
=
MultiModalDataCollatorForSeq2Seq
(
template
=
template
,
pad_to_multiple_of
=
4
,
label_pad_token_id
=
IGNORE_INDEX
,
**
tokenizer_module
,
)
p
=
tokenizer_module
[
"tokenizer"
].
pad_token_id
q
=
IGNORE_INDEX
s
=
tokenizer_module
[
"tokenizer"
].
convert_tokens_to_ids
(
"<|vision_start|>"
)
e
=
tokenizer_module
[
"tokenizer"
].
convert_tokens_to_ids
(
"<|vision_end|>"
)
m
=
tokenizer_module
[
"tokenizer"
].
convert_tokens_to_ids
(
"<|image_pad|>"
)
fake_image
=
Image
.
new
(
"RGB"
,
(
64
,
64
),
(
255
,
255
,
255
))
features
=
[
{
"input_ids"
:
[
0
,
1
,
2
,
3
],
"attention_mask"
:
[
1
,
1
,
1
,
1
],
"labels"
:
[
0
,
1
,
2
,
3
],
},
]
batch_input
=
data_collator
(
features
)
expected_input
=
{
"input_ids"
:
[
[
0
,
1
,
2
,
3
,
s
,
m
,
m
,
m
,
m
,
e
,
p
,
p
],
],
"attention_mask"
:
[
[
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
],
"labels"
:
[
[
0
,
1
,
2
,
3
,
q
,
q
,
q
,
q
,
q
,
q
,
q
,
q
],
],
**
tokenizer_module
[
"processor"
].
image_processor
(
fake_image
),
}
for
k
in
batch_input
.
keys
():
assert
batch_input
[
k
].
eq
(
torch
.
tensor
(
expected_input
[
k
])).
all
()
def
test_4d_attention_mask
():
...
...
tests/data/test_formatter.py
View file @
8293100a
...
...
@@ -13,10 +13,29 @@
# limitations under the License.
import
json
from
datetime
import
datetime
from
llamafactory.data.formatter
import
EmptyFormatter
,
FunctionFormatter
,
StringFormatter
,
ToolFormatter
FUNCTION
=
{
"name"
:
"tool_name"
,
"arguments"
:
{
"foo"
:
"bar"
,
"size"
:
10
}}
TOOLS
=
[
{
"name"
:
"test_tool"
,
"description"
:
"tool_desc"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"foo"
:
{
"type"
:
"string"
,
"description"
:
"foo_desc"
},
"bar"
:
{
"type"
:
"number"
,
"description"
:
"bar_desc"
},
},
"required"
:
[
"foo"
],
},
}
]
def
test_empty_formatter
():
formatter
=
EmptyFormatter
(
slots
=
[
"
\n
"
])
assert
formatter
.
apply
()
==
[
"
\n
"
]
...
...
@@ -28,39 +47,27 @@ def test_string_formatter():
def
test_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[],
tool_format
=
"default"
)
tool_calls
=
json
.
dumps
(
{
"name"
:
"tool_name"
,
"arguments"
:
{
"foo"
:
"bar"
,
"size"
:
10
}}
)
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
,
"</s>"
],
tool_format
=
"default"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""Action: tool_name
\n
Action Input: {
\"
foo
\"
:
\"
bar
\"
,
\"
size
\"
: 10}
\n
"""
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}
\n
"""
,
"</s>"
,
]
def
test_multi_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[],
tool_format
=
"default"
)
tool_calls
=
json
.
dumps
([
{
"name"
:
"tool_name"
,
"arguments"
:
{
"foo"
:
"bar"
,
"size"
:
10
}}
]
*
2
)
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
,
"</s>"
],
tool_format
=
"default"
)
tool_calls
=
json
.
dumps
([
FUNCTION
]
*
2
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""Action: tool_name
\n
Action Input: {
\"
foo
\"
:
\"
bar
\"
,
\"
size
\"
: 10}
\n
"""
,
"""Action: tool_name
\n
Action Input: {
\"
foo
\"
:
\"
bar
\"
,
\"
size
\"
: 10}
\n
"""
,
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}
\n
"""
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}
\n
"""
,
"</s>"
,
]
def
test_default_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"default"
)
tools
=
[
{
"name"
:
"test_tool"
,
"description"
:
"tool_desc"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"foo"
:
{
"type"
:
"string"
,
"description"
:
"foo_desc"
},
"bar"
:
{
"type"
:
"number"
,
"description"
:
"bar_desc"
},
},
"required"
:
[
"foo"
],
},
}
]
assert
formatter
.
apply
(
content
=
json
.
dumps
(
tools
))
==
[
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
"You have access to the following tools:
\n
"
"> Tool Name: test_tool
\n
"
"Tool Description: tool_desc
\n
"
...
...
@@ -94,26 +101,18 @@ def test_default_multi_tool_extractor():
]
def
test_glm4_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
],
tool_format
=
"glm4"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""tool_name
\n
{"foo": "bar", "size": 10}"""
]
def
test_glm4_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"glm4"
)
tools
=
[
{
"name"
:
"test_tool"
,
"description"
:
"tool_desc"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"foo"
:
{
"type"
:
"string"
,
"description"
:
"foo_desc"
},
"bar"
:
{
"type"
:
"number"
,
"description"
:
"bar_desc"
},
},
"required"
:
[
"foo"
],
},
}
]
assert
formatter
.
apply
(
content
=
json
.
dumps
(
tools
))
==
[
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具
\n\n
"
"## test_tool
\n\n
{
}
\n
在调用上述函数时,请使用 Json 格式表示调用的参数。"
.
format
(
json
.
dumps
(
tools
[
0
],
indent
=
4
))
f
"## test_tool
\n\n
{
json
.
dumps
(
TOOLS
[
0
],
indent
=
4
,
ensure_ascii
=
False
)
}
\n
在调用上述函数时,请使用 Json 格式表示调用的参数。"
]
...
...
@@ -121,3 +120,127 @@ def test_glm4_tool_extractor():
formatter
=
ToolFormatter
(
tool_format
=
"glm4"
)
result
=
"""test_tool
\n
{"foo": "bar", "size": 10}
\n
"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
def
test_llama3_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
,
"<|eot_id|>"
],
tool_format
=
"llama3"
)
tool_calls
=
json
.
dumps
({
"name"
:
"tool_name"
,
"arguments"
:
{
"foo"
:
"bar"
,
"size"
:
10
}})
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}"""
,
"<|eot_id|>"
,
]
def
test_llama3_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"llama3"
)
date
=
datetime
.
now
().
strftime
(
"%d %b %Y"
)
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
TOOLS
[
0
]}
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
f
"Cutting Knowledge Date: December 2023
\n
Today Date:
{
date
}
\n\n
"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
f
"Do not use variables.
\n\n
{
json
.
dumps
(
wrapped_tool
,
indent
=
4
,
ensure_ascii
=
False
)
}
\n\n
"
]
def
test_llama3_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"llama3"
)
result
=
"""{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}
\n
"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
def
test_mistral_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] "
,
"{{content}}"
,
"</s>"
],
tool_format
=
"mistral"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"[TOOL_CALLS] "
,
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]"""
,
"</s>"
,
]
def
test_mistral_multi_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] "
,
"{{content}}"
,
"</s>"
],
tool_format
=
"mistral"
)
tool_calls
=
json
.
dumps
([
FUNCTION
]
*
2
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"[TOOL_CALLS] "
,
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]"""
,
"</s>"
,
]
def
test_mistral_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"mistral"
)
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
TOOLS
[
0
]}
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
"[AVAILABLE_TOOLS] "
+
json
.
dumps
([
wrapped_tool
],
ensure_ascii
=
False
)
+
"[/AVAILABLE_TOOLS]"
]
def
test_mistral_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"mistral"
)
result
=
"""{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
def
test_mistral_multi_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"mistral"
)
result
=
(
"""[{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}]"""
)
assert
formatter
.
extract
(
result
)
==
[
(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
),
(
"another_tool"
,
"""{"foo": "job", "size": 2}"""
),
]
def
test_qwen_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
,
"<|im_end|>"
],
tool_format
=
"qwen"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""<tool_call>
\n
{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}
\n
</tool_call>"""
,
"<|im_end|>"
,
]
def
test_qwen_multi_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
,
"<|im_end|>"
],
tool_format
=
"qwen"
)
tool_calls
=
json
.
dumps
([
FUNCTION
]
*
2
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""<tool_call>
\n
{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}
\n
</tool_call>
\n
"""
"""<tool_call>
\n
{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}
\n
</tool_call>"""
,
"<|im_end|>"
,
]
def
test_qwen_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"qwen"
)
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
TOOLS
[
0
]}
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
"
\n\n
# Tools
\n\n
You may call one or more functions to assist with the user query.
\n\n
"
"You are provided with function signatures within <tools></tools> XML tags:
\n
<tools>"
f
"
\n
{
json
.
dumps
(
wrapped_tool
,
ensure_ascii
=
False
)
}
"
"
\n
</tools>
\n\n
For each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:
\n
<tool_call>
\n
{"name": <function-name>, """
""""arguments": <args-json-object>}
\n
</tool_call><|im_end|>
\n
"""
]
def
test_qwen_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"qwen"
)
result
=
"""<tool_call>
\n
{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}
\n
</tool_call>"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
def
test_qwen_multi_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"qwen"
)
result
=
(
"""<tool_call>
\n
{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}
\n
</tool_call>
\n
"""
"""<tool_call>
\n
{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}
\n
</tool_call>"""
)
assert
formatter
.
extract
(
result
)
==
[
(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
),
(
"another_tool"
,
"""{"foo": "job", "size": 2}"""
),
]
tests/data/test_mm_plugin.py
View file @
8293100a
...
...
@@ -13,14 +13,14 @@
# limitations under the License.
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Sequence
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Sequence
import
pytest
import
torch
from
PIL
import
Image
from
llamafactory.data.mm_plugin
import
get_mm_plugin
from
llamafactory.hparams
import
ModelArgument
s
from
llamafactory.hparams
import
get_infer_arg
s
from
llamafactory.model
import
load_tokenizer
...
...
@@ -29,6 +29,7 @@ if TYPE_CHECKING:
from
transformers.image_processing_utils
import
BaseImageProcessor
from
llamafactory.data.mm_plugin
import
BasePlugin
from
llamafactory.model.loader
import
TokenizerModule
HF_TOKEN
=
os
.
getenv
(
"HF_TOKEN"
)
...
...
@@ -82,10 +83,9 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
assert
batch_a
[
key
]
==
batch_b
[
key
]
def
_load_tokenizer_module
(
model_name_or_path
:
str
)
->
Tuple
[
"PreTrainedTokenizer"
,
"ProcessorMixin"
]:
model_args
=
ModelArguments
(
model_name_or_path
=
model_name_or_path
)
tokenizer_module
=
load_tokenizer
(
model_args
)
return
tokenizer_module
[
"tokenizer"
],
tokenizer_module
[
"processor"
]
def
_load_tokenizer_module
(
model_name_or_path
:
str
)
->
"TokenizerModule"
:
model_args
,
*
_
=
get_infer_args
({
"model_name_or_path"
:
model_name_or_path
,
"template"
:
"default"
})
return
load_tokenizer
(
model_args
)
def
_check_plugin
(
...
...
@@ -121,73 +121,75 @@ def _check_plugin(
def
test_base_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
TINY_LLAMA
)
tokenizer
_module
=
_load_tokenizer_module
(
model_name_or_path
=
TINY_LLAMA
)
base_plugin
=
get_mm_plugin
(
name
=
"base"
,
image_token
=
"<image>"
)
check_inputs
=
{
"plugin"
:
base_plugin
,
"
tokenizer
"
:
tokenizer
,
"processor"
:
processor
}
check_inputs
=
{
"plugin"
:
base_plugin
,
**
tokenizer
_module
}
_check_plugin
(
**
check_inputs
)
def
test_llava_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-1.5-7b-hf"
)
llava_plugin
=
get_mm_plugin
(
name
=
"llava"
,
image_token
=
"<image>"
)
image_seqlen
=
576
check_inputs
=
{
"plugin"
:
llava_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-1.5-7b-hf"
)
llava_plugin
=
get_mm_plugin
(
name
=
"llava"
,
image_token
=
"<image>"
)
check_inputs
=
{
"plugin"
:
llava_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"
processor
"
]
)
_check_plugin
(
**
check_inputs
)
def
test_llava_next_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-v1.6-vicuna-7b-hf"
)
llava_next_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
)
check_inputs
=
{
"plugin"
:
llava_next_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
image_seqlen
=
1176
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-v1.6-vicuna-7b-hf"
)
llava_next_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
)
check_inputs
=
{
"plugin"
:
llava_next_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"
processor
"
]
)
_check_plugin
(
**
check_inputs
)
def
test_llava_next_video_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/LLaVA-NeXT-Video-7B-hf"
)
llava_next_video_plugin
=
get_mm_plugin
(
name
=
"llava_next_video"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
)
check_inputs
=
{
"plugin"
:
llava_next_video_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
image_seqlen
=
1176
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/LLaVA-NeXT-Video-7B-hf"
)
llava_next_video_plugin
=
get_mm_plugin
(
name
=
"llava_next_video"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
)
check_inputs
=
{
"plugin"
:
llava_next_video_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"
processor
"
]
)
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
def
test_paligemma_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"google/paligemma-3b-pt-224"
)
paligemma_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
)
image_seqlen
=
256
check_inputs
=
{
"plugin"
:
paligemma_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"google/paligemma-3b-pt-224"
)
paligemma_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
)
check_inputs
=
{
"plugin"
:
paligemma_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
""
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_input_ids"
]
=
[
tokenizer
.
convert_tokens_to_ids
(
"<image>"
)]
*
image_seqlen
+
INPUT_IDS
check_inputs
[
"expected_input_ids"
]
=
[
tokenizer_module
[
"tokenizer"
].
convert_tokens_to_ids
(
paligemma_plugin
.
image_token
)
]
*
image_seqlen
+
INPUT_IDS
check_inputs
[
"expected_labels"
]
=
[
-
100
]
*
image_seqlen
+
LABELS
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"
processor
"
]
)
check_inputs
[
"expected_mm_inputs"
][
"token_type_ids"
]
=
[[
0
]
*
image_seqlen
+
[
1
]
*
(
1024
-
image_seqlen
)]
check_inputs
[
"expected_no_mm_inputs"
]
=
{
"token_type_ids"
:
[[
1
]
*
1024
]}
_check_plugin
(
**
check_inputs
)
def
test_pixtral_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"mistral-community/pixtral-12b"
)
pixtral_plugin
=
get_mm_plugin
(
name
=
"pixtral"
,
image_token
=
"[IMG]"
)
image_slice_height
,
image_slice_width
=
2
,
2
check_inputs
=
{
"plugin"
:
pixtral_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"mistral-community/pixtral-12b"
)
pixtral_plugin
=
get_mm_plugin
(
name
=
"pixtral"
,
image_token
=
"[IMG]"
)
check_inputs
=
{
"plugin"
:
pixtral_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
...
...
@@ -199,17 +201,17 @@ def test_pixtral_plugin():
}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"
processor
"
]
)
check_inputs
[
"expected_mm_inputs"
].
pop
(
"image_sizes"
)
check_inputs
[
"expected_mm_inputs"
][
"pixel_values"
]
=
check_inputs
[
"expected_mm_inputs"
][
"pixel_values"
][
0
]
_check_plugin
(
**
check_inputs
)
def
test_qwen2_vl_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"Qwen/Qwen2-VL-7B-Instruct"
)
qwen2_vl_plugin
=
get_mm_plugin
(
name
=
"qwen2_vl"
,
image_token
=
"<|image_pad|>"
)
image_seqlen
=
4
check_inputs
=
{
"plugin"
:
qwen2_vl_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"Qwen/Qwen2-VL-7B-Instruct"
)
qwen2_vl_plugin
=
get_mm_plugin
(
name
=
"qwen2_vl"
,
image_token
=
"<|image_pad|>"
)
check_inputs
=
{
"plugin"
:
qwen2_vl_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<|vision_start|>{}<|vision_end|>"
.
format
(
"<|image_pad|>"
*
image_seqlen
))
...
...
@@ -217,18 +219,18 @@ def test_qwen2_vl_plugin():
}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"
processor
"
]
)
_check_plugin
(
**
check_inputs
)
def
test_video_llava_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"LanguageBind/Video-LLaVA-7B-hf"
)
video_llava_plugin
=
get_mm_plugin
(
name
=
"video_llava"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
)
check_inputs
=
{
"plugin"
:
video_llava_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
image_seqlen
=
256
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"LanguageBind/Video-LLaVA-7B-hf"
)
video_llava_plugin
=
get_mm_plugin
(
name
=
"video_llava"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
)
check_inputs
=
{
"plugin"
:
video_llava_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"
processor
"
]
)
_check_plugin
(
**
check_inputs
)
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment