Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
8293100a
Commit
8293100a
authored
Jan 16, 2025
by
luopl
Browse files
update to 0.9.2.dev0
parent
2778a3d0
Changes
124
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
949 additions
and
219 deletions
+949
-219
src/llamafactory/train/dpo/workflow.py
src/llamafactory/train/dpo/workflow.py
+5
-11
src/llamafactory/train/kto/trainer.py
src/llamafactory/train/kto/trainer.py
+18
-10
src/llamafactory/train/kto/workflow.py
src/llamafactory/train/kto/workflow.py
+1
-0
src/llamafactory/train/ppo/workflow.py
src/llamafactory/train/ppo/workflow.py
+1
-1
src/llamafactory/train/pt/trainer.py
src/llamafactory/train/pt/trainer.py
+24
-15
src/llamafactory/train/rm/trainer.py
src/llamafactory/train/rm/trainer.py
+15
-7
src/llamafactory/train/rm/workflow.py
src/llamafactory/train/rm/workflow.py
+3
-1
src/llamafactory/train/sft/trainer.py
src/llamafactory/train/sft/trainer.py
+47
-48
src/llamafactory/train/sft/workflow.py
src/llamafactory/train/sft/workflow.py
+12
-15
src/llamafactory/train/trainer_utils.py
src/llamafactory/train/trainer_utils.py
+186
-18
src/llamafactory/train/tuner.py
src/llamafactory/train/tuner.py
+36
-4
src/llamafactory/webui/chatter.py
src/llamafactory/webui/chatter.py
+2
-1
src/llamafactory/webui/components/export.py
src/llamafactory/webui/components/export.py
+1
-0
src/llamafactory/webui/components/train.py
src/llamafactory/webui/components/train.py
+46
-2
src/llamafactory/webui/interface.py
src/llamafactory/webui/interface.py
+3
-1
src/llamafactory/webui/locales.py
src/llamafactory/webui/locales.py
+228
-6
src/llamafactory/webui/runner.py
src/llamafactory/webui/runner.py
+24
-3
tests/data/test_collator.py
tests/data/test_collator.py
+97
-1
tests/data/test_formatter.py
tests/data/test_formatter.py
+161
-38
tests/data/test_mm_plugin.py
tests/data/test_mm_plugin.py
+39
-37
No files found.
src/llamafactory/train/dpo/workflow.py
View file @
8293100a
...
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, List, Optional
...
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, List, Optional
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
from
...extras.misc
import
cal
_effective_token
s
from
...extras.misc
import
cal
culate_tp
s
from
...extras.ploting
import
plot_loss
from
...extras.ploting
import
plot_loss
from
...hparams
import
ModelArguments
from
...hparams
import
ModelArguments
from
...model
import
load_model
,
load_tokenizer
from
...model
import
load_model
,
load_tokenizer
...
@@ -48,6 +48,7 @@ def run_dpo(
...
@@ -48,6 +48,7 @@ def run_dpo(
data_collator
=
PairwiseDataCollatorWithPadding
(
data_collator
=
PairwiseDataCollatorWithPadding
(
template
=
template
,
template
=
template
,
model
=
model
,
pad_to_multiple_of
=
8
,
pad_to_multiple_of
=
8
,
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
**
tokenizer_module
,
**
tokenizer_module
,
...
@@ -65,12 +66,6 @@ def run_dpo(
...
@@ -65,12 +66,6 @@ def run_dpo(
# Update arguments
# Update arguments
training_args
.
remove_unused_columns
=
False
# important for multimodal and pairwise dataset
training_args
.
remove_unused_columns
=
False
# important for multimodal and pairwise dataset
effective_token_num
=
0.0
if
finetuning_args
.
include_effective_tokens_per_second
:
for
data
in
dataset_module
[
"train_dataset"
]:
effective_token_num
+=
len
(
data
[
"chosen_input_ids"
])
effective_token_num
+=
len
(
data
[
"rejected_input_ids"
])
# Initialize our Trainer
# Initialize our Trainer
trainer
=
CustomDPOTrainer
(
trainer
=
CustomDPOTrainer
(
model
=
model
,
model
=
model
,
...
@@ -86,13 +81,12 @@ def run_dpo(
...
@@ -86,13 +81,12 @@ def run_dpo(
# Training
# Training
if
training_args
.
do_train
:
if
training_args
.
do_train
:
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
trainer
.
save_model
()
if
finetuning_args
.
include_effective_tokens_per_second
:
if
finetuning_args
.
include_effective_tokens_per_second
:
train_result
.
metrics
[
"effective_tokens_per_sec"
]
=
cal
_effective_token
s
(
train_result
.
metrics
[
"effective_tokens_per_sec"
]
=
cal
culate_tp
s
(
effective_token_num
,
train_result
.
metrics
[
"epoch
"
],
train_result
.
metrics
[
"train_runtime"
]
dataset_module
[
"train_dataset
"
],
train_result
.
metrics
,
stage
=
"rm"
)
)
trainer
.
save_model
()
trainer
.
log_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
log_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_state
()
trainer
.
save_state
()
...
...
src/llamafactory/train/kto/trainer.py
View file @
8293100a
...
@@ -19,7 +19,7 @@ import warnings
...
@@ -19,7 +19,7 @@ import warnings
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
types
import
MethodType
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Dict
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
transformers
import
Trainer
from
transformers
import
Trainer
...
@@ -28,9 +28,9 @@ from trl.trainer import disable_dropout_in_model
...
@@ -28,9 +28,9 @@ from trl.trainer import disable_dropout_in_model
from
typing_extensions
import
override
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
from
...extras.packages
import
is_transformers_version_equal_to_4_46
from
...extras.packages
import
is_transformers_version_equal_to_4_46
,
is_transformers_version_greater_than
from
..callbacks
import
SaveProcessorCallback
from
..callbacks
import
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
,
get_batch_logps
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
,
get_batch_logps
,
nested_detach
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -50,6 +50,9 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -50,6 +50,9 @@ class CustomKTOTrainer(KTOTrainer):
disable_dropout
:
bool
=
True
,
disable_dropout
:
bool
=
True
,
**
kwargs
,
**
kwargs
,
):
):
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
if
disable_dropout
:
if
disable_dropout
:
disable_dropout_in_model
(
model
)
disable_dropout_in_model
(
model
)
if
ref_model
is
not
None
:
if
ref_model
is
not
None
:
...
@@ -77,6 +80,7 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -77,6 +80,7 @@ class CustomKTOTrainer(KTOTrainer):
self
.
ftx_gamma
=
finetuning_args
.
pref_ftx
self
.
ftx_gamma
=
finetuning_args
.
pref_ftx
Trainer
.
__init__
(
self
,
model
=
model
,
**
kwargs
)
Trainer
.
__init__
(
self
,
model
=
model
,
**
kwargs
)
self
.
model_accepts_loss_kwargs
=
False
# overwrite trainer's default behavior
if
not
hasattr
(
self
,
"accelerator"
):
if
not
hasattr
(
self
,
"accelerator"
):
raise
AttributeError
(
"Please update `transformers`."
)
raise
AttributeError
(
"Please update `transformers`."
)
...
@@ -119,6 +123,9 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -119,6 +123,9 @@ class CustomKTOTrainer(KTOTrainer):
r
"""
r
"""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
"""
"""
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
Trainer
.
_get_train_sampler
(
self
)
return
Trainer
.
_get_train_sampler
(
self
)
@
override
@
override
...
@@ -135,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -135,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer):
r
"""
r
"""
Runs forward pass and computes the log probabilities.
Runs forward pass and computes the log probabilities.
"""
"""
batch
=
{
k
:
v
.
detach
().
clone
()
for
k
,
v
in
batch
.
items
()}
# avoid error
batch
=
nested_detach
(
batch
,
clone
=
True
)
# avoid error
model_inputs
=
{
model_inputs
=
{
"input_ids"
:
batch
[
f
"
{
prefix
}
input_ids"
],
"input_ids"
:
batch
[
f
"
{
prefix
}
input_ids"
],
"attention_mask"
:
batch
[
f
"
{
prefix
}
attention_mask"
],
"attention_mask"
:
batch
[
f
"
{
prefix
}
attention_mask"
],
...
@@ -245,17 +252,18 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -245,17 +252,18 @@ class CustomKTOTrainer(KTOTrainer):
return
losses
,
metrics
return
losses
,
metrics
@
override
@
override
def
compute_loss
(
self
,
model
,
inputs
,
return_outputs
=
False
,
**
kwargs
):
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
r
"""
r
"""
Fixes the loss value for transformers 4.46.0.
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
pop
(
"num_items_in_batch"
,
False
):
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
get
(
"num_items_in_batch"
):
if
return_outputs
:
if
return_outputs
:
return
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
loss
=
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
else
:
return
loss
/
self
.
args
.
gradient_accumulation_steps
loss
=
loss
/
self
.
args
.
gradient_accumulation_steps
return
loss
return
loss
...
...
src/llamafactory/train/kto/workflow.py
View file @
8293100a
...
@@ -47,6 +47,7 @@ def run_kto(
...
@@ -47,6 +47,7 @@ def run_kto(
data_collator
=
KTODataCollatorWithPadding
(
data_collator
=
KTODataCollatorWithPadding
(
template
=
template
,
template
=
template
,
model
=
model
,
pad_to_multiple_of
=
8
,
pad_to_multiple_of
=
8
,
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
**
tokenizer_module
,
**
tokenizer_module
,
...
...
src/llamafactory/train/ppo/workflow.py
View file @
8293100a
...
@@ -46,7 +46,7 @@ def run_ppo(
...
@@ -46,7 +46,7 @@ def run_ppo(
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
,
add_valuehead
=
True
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
,
add_valuehead
=
True
)
tokenizer
.
padding_side
=
"left"
# use left-padding in generation while using right-padding in training
tokenizer
.
padding_side
=
"left"
# use left-padding in generation while using right-padding in training
data_collator
=
MultiModalDataCollatorForSeq2Seq
(
template
=
template
,
**
tokenizer_module
)
data_collator
=
MultiModalDataCollatorForSeq2Seq
(
template
=
template
,
model
=
model
,
**
tokenizer_module
)
# Create reference model and reward model
# Create reference model and reward model
ref_model
=
create_ref_model
(
model_args
,
finetuning_args
,
add_valuehead
=
True
)
ref_model
=
create_ref_model
(
model_args
,
finetuning_args
,
add_valuehead
=
True
)
...
...
src/llamafactory/train/pt/trainer.py
View file @
8293100a
...
@@ -13,19 +13,19 @@
...
@@ -13,19 +13,19 @@
# limitations under the License.
# limitations under the License.
from
types
import
MethodType
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
transformers
import
Trainer
from
transformers
import
Trainer
from
typing_extensions
import
override
from
typing_extensions
import
override
from
...extras.packages
import
is_transformers_version_
equal_to_4_46
from
...extras.packages
import
is_transformers_version_
greater_than
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
from
..callbacks
import
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
import
torch
from
transformers
import
PreTrainedModel
,
ProcessorMixin
from
transformers
import
ProcessorMixin
from
...hparams
import
FinetuningArguments
from
...hparams
import
FinetuningArguments
...
@@ -38,15 +38,15 @@ class CustomTrainer(Trainer):
...
@@ -38,15 +38,15 @@ class CustomTrainer(Trainer):
def
__init__
(
def
__init__
(
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
)
->
None
:
)
->
None
:
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
finetuning_args
=
finetuning_args
self
.
finetuning_args
=
finetuning_args
if
processor
is
not
None
:
if
processor
is
not
None
:
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
pissa_convert
:
self
.
add_callback
(
PissaConvertCallback
)
if
finetuning_args
.
use_badam
:
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
...
@@ -67,17 +67,26 @@ class CustomTrainer(Trainer):
...
@@ -67,17 +67,26 @@ class CustomTrainer(Trainer):
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
@
override
def
compute_loss
(
self
,
model
,
inputs
,
return_outputs
=
False
,
**
kwargs
):
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
()
@
override
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
r
"""
r
"""
Fixes the loss value for transformers 4.46.0.
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
"""
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
,
**
kwargs
)
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
,
**
kwargs
)
if
is_transformers_version_equal_to_4_46
()
and
not
getattr
(
self
,
"model_accepts_loss_kwargs"
,
False
):
if
kwargs
.
get
(
"num_items_in_batch"
)
and
not
getattr
(
self
,
"model_accepts_loss_kwargs"
,
False
):
# other model should not scale the loss
if
return_outputs
:
if
return_outputs
:
return
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
loss
=
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
else
:
return
loss
/
self
.
args
.
gradient_accumulation_steps
loss
=
loss
/
self
.
args
.
gradient_accumulation_steps
return
loss
return
loss
src/llamafactory/train/rm/trainer.py
View file @
8293100a
...
@@ -25,8 +25,8 @@ from transformers import Trainer
...
@@ -25,8 +25,8 @@ from transformers import Trainer
from
typing_extensions
import
override
from
typing_extensions
import
override
from
...extras
import
logging
from
...extras
import
logging
from
...extras.packages
import
is_transformers_version_equal_to_4_46
from
...extras.packages
import
is_transformers_version_equal_to_4_46
,
is_transformers_version_greater_than
from
..callbacks
import
FixValueHeadModelCallback
,
PissaConvertCallback
,
SaveProcessorCallback
from
..callbacks
import
FixValueHeadModelCallback
,
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
...
@@ -48,7 +48,11 @@ class PairwiseTrainer(Trainer):
...
@@ -48,7 +48,11 @@ class PairwiseTrainer(Trainer):
def
__init__
(
def
__init__
(
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
)
->
None
:
)
->
None
:
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
model_accepts_loss_kwargs
=
False
# overwrite trainer's default behavior
self
.
finetuning_args
=
finetuning_args
self
.
finetuning_args
=
finetuning_args
self
.
can_return_loss
=
True
# override property to return eval_loss
self
.
can_return_loss
=
True
# override property to return eval_loss
self
.
add_callback
(
FixValueHeadModelCallback
)
self
.
add_callback
(
FixValueHeadModelCallback
)
...
@@ -56,9 +60,6 @@ class PairwiseTrainer(Trainer):
...
@@ -56,9 +60,6 @@ class PairwiseTrainer(Trainer):
if
processor
is
not
None
:
if
processor
is
not
None
:
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
pissa_convert
:
self
.
add_callback
(
PissaConvertCallback
)
if
finetuning_args
.
use_badam
:
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
...
@@ -78,6 +79,13 @@ class PairwiseTrainer(Trainer):
...
@@ -78,6 +79,13 @@ class PairwiseTrainer(Trainer):
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
()
@
override
@
override
def
compute_loss
(
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
...
@@ -100,8 +108,8 @@ class PairwiseTrainer(Trainer):
...
@@ -100,8 +108,8 @@ class PairwiseTrainer(Trainer):
loss
=
-
torch
.
nn
.
functional
.
logsigmoid
(
chosen_scores
.
float
()
-
rejected_scores
.
float
()).
mean
()
loss
=
-
torch
.
nn
.
functional
.
logsigmoid
(
chosen_scores
.
float
()
-
rejected_scores
.
float
()).
mean
()
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
pop
(
"num_items_in_batch"
,
False
):
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
get
(
"num_items_in_batch"
):
loss
/=
self
.
args
.
gradient_accumulation_steps
# fixes the loss value for transformers 4.46.0
loss
/=
self
.
args
.
gradient_accumulation_steps
# fixes the loss value for transformers 4.46.0
-4.46.1
if
return_outputs
:
if
return_outputs
:
return
loss
,
(
loss
,
chosen_scores
,
rejected_scores
)
return
loss
,
(
loss
,
chosen_scores
,
rejected_scores
)
...
...
src/llamafactory/train/rm/workflow.py
View file @
8293100a
...
@@ -44,7 +44,9 @@ def run_rm(
...
@@ -44,7 +44,9 @@ def run_rm(
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"rm"
,
**
tokenizer_module
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"rm"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
,
add_valuehead
=
True
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
,
add_valuehead
=
True
)
data_collator
=
PairwiseDataCollatorWithPadding
(
template
=
template
,
pad_to_multiple_of
=
8
,
**
tokenizer_module
)
data_collator
=
PairwiseDataCollatorWithPadding
(
template
=
template
,
model
=
model
,
pad_to_multiple_of
=
8
,
**
tokenizer_module
)
# Update arguments
# Update arguments
training_args
.
remove_unused_columns
=
False
# important for multimodal and pairwise dataset
training_args
.
remove_unused_columns
=
False
# important for multimodal and pairwise dataset
...
...
src/llamafactory/train/sft/trainer.py
View file @
8293100a
...
@@ -27,14 +27,14 @@ from typing_extensions import override
...
@@ -27,14 +27,14 @@ from typing_extensions import override
from
...extras
import
logging
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
from
...extras.packages
import
is_transformers_version_
equal_to_4_46
from
...extras.packages
import
is_transformers_version_
greater_than
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
from
..callbacks
import
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
transformers
import
ProcessorMixin
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
,
ProcessorMixin
from
transformers.trainer
import
PredictionOutput
from
transformers.trainer
import
PredictionOutput
from
...hparams
import
FinetuningArguments
from
...hparams
import
FinetuningArguments
...
@@ -51,15 +51,17 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
...
@@ -51,15 +51,17 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def
__init__
(
def
__init__
(
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
)
->
None
:
)
->
None
:
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
else
:
self
.
processing_class
:
"PreTrainedTokenizer"
=
kwargs
.
get
(
"tokenizer"
)
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
finetuning_args
=
finetuning_args
self
.
finetuning_args
=
finetuning_args
if
processor
is
not
None
:
if
processor
is
not
None
:
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
pissa_convert
:
self
.
add_callback
(
PissaConvertCallback
)
if
finetuning_args
.
use_badam
:
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
...
@@ -80,18 +82,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
...
@@ -80,18 +82,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
@
override
def
compute_loss
(
self
,
model
,
inputs
,
return_outputs
=
False
,
**
kwargs
):
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
()
@
override
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
r
"""
r
"""
Fixes the loss value for transformers 4.46.0.
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
"""
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
,
**
kwargs
)
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
,
**
kwargs
)
if
is_transformers_version_equal_to_4_46
()
and
not
getattr
(
self
,
"model_accepts_loss_kwargs"
,
False
):
if
kwargs
.
get
(
"num_items_in_batch"
)
and
not
getattr
(
self
,
"model_accepts_loss_kwargs"
,
False
):
# other model should not scale the loss
if
return_outputs
:
if
return_outputs
:
return
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
loss
=
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
else
:
return
loss
/
self
.
args
.
gradient_accumulation_steps
loss
=
loss
/
self
.
args
.
gradient_accumulation_steps
return
loss
return
loss
...
@@ -102,41 +113,30 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
...
@@ -102,41 +113,30 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
inputs
:
Dict
[
str
,
Union
[
"torch.Tensor"
,
Any
]],
inputs
:
Dict
[
str
,
Union
[
"torch.Tensor"
,
Any
]],
prediction_loss_only
:
bool
,
prediction_loss_only
:
bool
,
ignore_keys
:
Optional
[
List
[
str
]]
=
None
,
ignore_keys
:
Optional
[
List
[
str
]]
=
None
,
**
gen_kwargs
,
)
->
Tuple
[
Optional
[
float
],
Optional
[
"torch.Tensor"
],
Optional
[
"torch.Tensor"
]]:
)
->
Tuple
[
Optional
[
float
],
Optional
[
"torch.Tensor"
],
Optional
[
"torch.Tensor"
]]:
r
"""
r
"""
Removes the prompt part in the generated tokens.
Removes the prompt part in the generated tokens.
Subclass and override to inject custom behavior.
Subclass and override to inject custom behavior.
"""
"""
labels
=
inputs
[
"labels"
]
if
"labels"
in
inputs
else
None
if
self
.
args
.
predict_with_generate
:
# do not pass labels to model when generate
if
self
.
args
.
predict_with_generate
:
labels
=
inputs
.
pop
(
"labels"
,
None
)
assert
self
.
tokenizer
.
padding_side
==
"left"
,
"This method only accepts left-padded tensor."
else
:
labels
=
labels
.
detach
().
clone
()
if
labels
is
not
None
else
None
# backup labels
labels
=
inputs
.
get
(
"labels"
)
prompt_len
,
label_len
=
inputs
[
"input_ids"
].
size
(
-
1
),
inputs
[
"labels"
].
size
(
-
1
)
if
prompt_len
>
label_len
:
loss
,
generated_tokens
,
_
=
super
().
prediction_step
(
inputs
[
"labels"
]
=
self
.
_pad_tensors_to_target_len
(
inputs
[
"labels"
],
inputs
[
"input_ids"
])
model
,
inputs
,
prediction_loss_only
=
prediction_loss_only
,
ignore_keys
=
ignore_keys
,
**
gen_kwargs
if
label_len
>
prompt_len
:
# truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
inputs
[
"labels"
]
=
inputs
[
"labels"
][:,
:
prompt_len
]
loss
,
generated_tokens
,
_
=
super
().
prediction_step
(
# ignore the returned labels (may be truncated)
model
,
inputs
,
prediction_loss_only
=
prediction_loss_only
,
ignore_keys
=
ignore_keys
)
)
if
generated_tokens
is
not
None
and
self
.
args
.
predict_with_generate
:
if
generated_tokens
is
not
None
and
self
.
args
.
predict_with_generate
:
generated_tokens
[:,
:
prompt_len
]
=
self
.
tokenizer
.
pad_token_id
generated_tokens
[:,
:
inputs
[
"input_ids"
].
size
(
-
1
)]
=
self
.
processing_class
.
pad_token_id
generated_tokens
=
generated_tokens
.
contiguous
()
generated_tokens
=
generated_tokens
.
contiguous
()
return
loss
,
generated_tokens
,
labels
return
loss
,
generated_tokens
,
labels
def
_pad_tensors_to_target_len
(
self
,
src_tensor
:
"torch.Tensor"
,
tgt_tensor
:
"torch.Tensor"
)
->
"torch.Tensor"
:
def
save_predictions
(
r
"""
self
,
dataset
:
"Dataset"
,
predict_results
:
"PredictionOutput"
,
skip_special_tokens
:
bool
=
True
Pads the tensor to the same length as the target tensor.
)
->
None
:
"""
assert
self
.
tokenizer
.
pad_token_id
is
not
None
,
"Pad token is required."
padded_tensor
=
self
.
tokenizer
.
pad_token_id
*
torch
.
ones_like
(
tgt_tensor
)
padded_tensor
[:,
-
src_tensor
.
shape
[
-
1
]
:]
=
src_tensor
# adopt left-padding
return
padded_tensor
.
contiguous
()
# in contiguous memory
def
save_predictions
(
self
,
dataset
:
"Dataset"
,
predict_results
:
"PredictionOutput"
)
->
None
:
r
"""
r
"""
Saves model predictions to `output_dir`.
Saves model predictions to `output_dir`.
...
@@ -149,24 +149,23 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
...
@@ -149,24 +149,23 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
logger
.
info_rank0
(
f
"Saving prediction results to
{
output_prediction_file
}
"
)
logger
.
info_rank0
(
f
"Saving prediction results to
{
output_prediction_file
}
"
)
labels
=
np
.
where
(
labels
=
np
.
where
(
predict_results
.
label_ids
!=
IGNORE_INDEX
,
predict_results
.
label_ids
,
self
.
tokenizer
.
pad_token_id
predict_results
.
label_ids
!=
IGNORE_INDEX
,
predict_results
.
label_ids
,
self
.
processing_class
.
pad_token_id
)
)
preds
=
np
.
where
(
preds
=
np
.
where
(
predict_results
.
predictions
!=
IGNORE_INDEX
,
predict_results
.
predictions
,
self
.
tokenizer
.
pad_token_id
predict_results
.
predictions
!=
IGNORE_INDEX
,
predict_results
.
predictions
,
self
.
processing_class
.
pad_token_id
,
)
)
for
i
in
range
(
len
(
preds
)):
for
i
in
range
(
len
(
preds
)):
pad_len
=
np
.
nonzero
(
preds
[
i
]
!=
self
.
tokenizer
.
pad_token_id
)[
0
]
pad_len
=
np
.
nonzero
(
preds
[
i
]
!=
self
.
processing_class
.
pad_token_id
)[
0
]
if
len
(
pad_len
):
# move pad token to last
if
len
(
pad_len
):
# move pad token to last
preds
[
i
]
=
np
.
concatenate
((
preds
[
i
][
pad_len
[
0
]
:],
preds
[
i
][:
pad_len
[
0
]]),
axis
=-
1
)
preds
[
i
]
=
np
.
concatenate
((
preds
[
i
][
pad_len
[
0
]
:],
preds
[
i
][:
pad_len
[
0
]]),
axis
=-
1
)
decoded_inputs
=
self
.
tokenizer
.
batch_decode
(
dataset
[
"input_ids"
],
skip_special_tokens
=
True
)
decoded_inputs
=
self
.
processing_class
.
batch_decode
(
dataset
[
"input_ids"
],
skip_special_tokens
=
False
)
decoded_labels
=
self
.
tokenizer
.
batch_decode
(
labels
,
skip_special_tokens
=
True
)
decoded_preds
=
self
.
processing_class
.
batch_decode
(
preds
,
skip_special_tokens
=
skip_special_tokens
)
decoded_preds
=
self
.
tokenizer
.
batch_decode
(
preds
,
skip_special_tokens
=
True
)
decoded_labels
=
self
.
processing_class
.
batch_decode
(
labels
,
skip_special_tokens
=
skip_special_tokens
)
with
open
(
output_prediction_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
res
:
List
[
str
]
=
[]
for
text
,
label
,
pred
in
zip
(
decoded_inputs
,
decoded_labels
,
decoded_preds
):
res
.
append
(
json
.
dumps
({
"prompt"
:
text
,
"label"
:
label
,
"predict"
:
pred
},
ensure_ascii
=
False
))
writer
.
write
(
"
\n
"
.
join
(
res
))
with
open
(
output_prediction_file
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
for
text
,
pred
,
label
in
zip
(
decoded_inputs
,
decoded_preds
,
decoded_labels
):
f
.
write
(
json
.
dumps
({
"prompt"
:
text
,
"predict"
:
pred
,
"label"
:
label
},
ensure_ascii
=
False
)
+
"
\n
"
)
src/llamafactory/train/sft/workflow.py
View file @
8293100a
...
@@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, List, Optional
...
@@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, List, Optional
from
...data
import
SFTDataCollatorWith4DAttentionMask
,
get_dataset
,
get_template_and_fix_tokenizer
from
...data
import
SFTDataCollatorWith4DAttentionMask
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
from
...extras.misc
import
cal_effective_tokens
,
get_logits_processor
from
...extras.logging
import
get_logger
from
...extras.misc
import
calculate_tps
,
get_logits_processor
from
...extras.ploting
import
plot_loss
from
...extras.ploting
import
plot_loss
from
...model
import
load_model
,
load_tokenizer
from
...model
import
load_model
,
load_tokenizer
from
..trainer_utils
import
create_modelcard_and_push
from
..trainer_utils
import
create_modelcard_and_push
...
@@ -33,6 +34,9 @@ if TYPE_CHECKING:
...
@@ -33,6 +34,9 @@ if TYPE_CHECKING:
from
...hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
from
...hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
get_logger
(
__name__
)
def
run_sft
(
def
run_sft
(
model_args
:
"ModelArguments"
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
...
@@ -52,6 +56,7 @@ def run_sft(
...
@@ -52,6 +56,7 @@ def run_sft(
data_collator
=
SFTDataCollatorWith4DAttentionMask
(
data_collator
=
SFTDataCollatorWith4DAttentionMask
(
template
=
template
,
template
=
template
,
model
=
model
if
not
training_args
.
predict_with_generate
else
None
,
pad_to_multiple_of
=
8
if
training_args
.
do_train
else
None
,
# for shift short attention
pad_to_multiple_of
=
8
if
training_args
.
do_train
else
None
,
# for shift short attention
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
block_diag_attn
=
model_args
.
block_diag_attn
,
block_diag_attn
=
model_args
.
block_diag_attn
,
...
@@ -65,11 +70,6 @@ def run_sft(
...
@@ -65,11 +70,6 @@ def run_sft(
training_args
.
generation_num_beams
=
data_args
.
eval_num_beams
or
training_args
.
generation_num_beams
training_args
.
generation_num_beams
=
data_args
.
eval_num_beams
or
training_args
.
generation_num_beams
training_args
.
remove_unused_columns
=
False
# important for multimodal dataset
training_args
.
remove_unused_columns
=
False
# important for multimodal dataset
effective_token_num
=
0.0
if
finetuning_args
.
include_effective_tokens_per_second
:
for
data
in
dataset_module
[
"train_dataset"
]:
effective_token_num
+=
len
(
data
[
"input_ids"
])
# Metric utils
# Metric utils
metric_module
=
{}
metric_module
=
{}
if
training_args
.
predict_with_generate
:
if
training_args
.
predict_with_generate
:
...
@@ -91,7 +91,7 @@ def run_sft(
...
@@ -91,7 +91,7 @@ def run_sft(
)
)
# Keyword arguments for `model.generate`
# Keyword arguments for `model.generate`
gen_kwargs
=
generating_args
.
to_dict
()
gen_kwargs
=
generating_args
.
to_dict
(
obey_generation_config
=
True
)
gen_kwargs
[
"eos_token_id"
]
=
[
tokenizer
.
eos_token_id
]
+
tokenizer
.
additional_special_tokens_ids
gen_kwargs
[
"eos_token_id"
]
=
[
tokenizer
.
eos_token_id
]
+
tokenizer
.
additional_special_tokens_ids
gen_kwargs
[
"pad_token_id"
]
=
tokenizer
.
pad_token_id
gen_kwargs
[
"pad_token_id"
]
=
tokenizer
.
pad_token_id
gen_kwargs
[
"logits_processor"
]
=
get_logits_processor
()
gen_kwargs
[
"logits_processor"
]
=
get_logits_processor
()
...
@@ -99,12 +99,12 @@ def run_sft(
...
@@ -99,12 +99,12 @@ def run_sft(
# Training
# Training
if
training_args
.
do_train
:
if
training_args
.
do_train
:
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
trainer
.
save_model
()
if
finetuning_args
.
include_effective_tokens_per_second
:
if
finetuning_args
.
include_effective_tokens_per_second
:
train_result
.
metrics
[
"effective_tokens_per_sec"
]
=
cal
_effective_token
s
(
train_result
.
metrics
[
"effective_tokens_per_sec"
]
=
cal
culate_tp
s
(
effective_token_num
,
train_result
.
metrics
[
"epoch
"
],
train_result
.
metrics
[
"train_runtime"
]
dataset_module
[
"train_dataset
"
],
train_result
.
metrics
,
stage
=
"sft"
)
)
trainer
.
save_model
()
trainer
.
log_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
log_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_state
()
trainer
.
save_state
()
...
@@ -117,19 +117,16 @@ def run_sft(
...
@@ -117,19 +117,16 @@ def run_sft(
# Evaluation
# Evaluation
if
training_args
.
do_eval
:
if
training_args
.
do_eval
:
metrics
=
trainer
.
evaluate
(
metric_key_prefix
=
"eval"
,
**
gen_kwargs
)
metrics
=
trainer
.
evaluate
(
metric_key_prefix
=
"eval"
,
**
gen_kwargs
)
if
training_args
.
predict_with_generate
:
# eval_loss will be wrong if predict_with_generate is enabled
metrics
.
pop
(
"eval_loss"
,
None
)
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
# Predict
# Predict
if
training_args
.
do_predict
:
if
training_args
.
do_predict
:
logger
.
warning_rank0_once
(
"Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead."
)
predict_results
=
trainer
.
predict
(
dataset_module
[
"eval_dataset"
],
metric_key_prefix
=
"predict"
,
**
gen_kwargs
)
predict_results
=
trainer
.
predict
(
dataset_module
[
"eval_dataset"
],
metric_key_prefix
=
"predict"
,
**
gen_kwargs
)
if
training_args
.
predict_with_generate
:
# predict_loss will be wrong if predict_with_generate is enabled
predict_results
.
metrics
.
pop
(
"predict_loss"
,
None
)
trainer
.
log_metrics
(
"predict"
,
predict_results
.
metrics
)
trainer
.
log_metrics
(
"predict"
,
predict_results
.
metrics
)
trainer
.
save_metrics
(
"predict"
,
predict_results
.
metrics
)
trainer
.
save_metrics
(
"predict"
,
predict_results
.
metrics
)
trainer
.
save_predictions
(
dataset_module
[
"eval_dataset"
],
predict_results
)
trainer
.
save_predictions
(
dataset_module
[
"eval_dataset"
],
predict_results
,
generating_args
.
skip_special_tokens
)
# Create model card
# Create model card
create_modelcard_and_push
(
trainer
,
model_args
,
data_args
,
training_args
,
finetuning_args
)
create_modelcard_and_push
(
trainer
,
model_args
,
data_args
,
training_args
,
finetuning_args
)
src/llamafactory/train/trainer_utils.py
View file @
8293100a
...
@@ -17,7 +17,9 @@
...
@@ -17,7 +17,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
collections.abc
import
Mapping
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
transformers
import
Trainer
from
transformers
import
Trainer
...
@@ -30,20 +32,29 @@ from typing_extensions import override
...
@@ -30,20 +32,29 @@ from typing_extensions import override
from
..extras
import
logging
from
..extras
import
logging
from
..extras.constants
import
IGNORE_INDEX
from
..extras.constants
import
IGNORE_INDEX
from
..extras.packages
import
is_
galore
_available
from
..extras.packages
import
is_
apollo_available
,
is_galore_available
,
is_ray
_available
from
..hparams
import
FinetuningArguments
,
ModelArguments
from
..hparams
import
FinetuningArguments
,
ModelArguments
from
..model
import
find_all_linear_modules
,
load_model
,
load_tokenizer
,
load_valuehead_params
from
..model
import
find_all_linear_modules
,
load_model
,
load_tokenizer
,
load_valuehead_params
if
is_galore_available
():
if
is_galore_available
():
from
galore_torch
import
GaLoreAdafactor
,
GaLoreAdamW
,
GaLoreAdamW8bit
from
galore_torch
import
GaLoreAdafactor
,
GaLoreAdamW
,
GaLoreAdamW8bit
# type: ignore
if
is_apollo_available
():
from
apollo_torch
import
APOLLOAdamW
# type: ignore
if
is_ray_available
():
from
ray.train
import
RunConfig
,
ScalingConfig
from
ray.train.torch
import
TorchTrainer
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
,
Seq2SeqTrainingArguments
from
transformers
import
PreTrainedModel
,
TrainerCallback
from
trl
import
AutoModelForCausalLMWithValueHead
from
trl
import
AutoModelForCausalLMWithValueHead
from
..hparams
import
DataArguments
from
..hparams
import
DataArguments
,
RayArguments
,
TrainingArguments
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
...
@@ -51,7 +62,7 @@ logger = logging.get_logger(__name__)
...
@@ -51,7 +62,7 @@ logger = logging.get_logger(__name__)
class
DummyOptimizer
(
torch
.
optim
.
Optimizer
):
class
DummyOptimizer
(
torch
.
optim
.
Optimizer
):
r
"""
r
"""
A dummy optimizer used for the GaLore algorithm.
A dummy optimizer used for the GaLore
or APOLLO
algorithm.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -74,7 +85,7 @@ def create_modelcard_and_push(
...
@@ -74,7 +85,7 @@ def create_modelcard_and_push(
trainer
:
"Trainer"
,
trainer
:
"Trainer"
,
model_args
:
"ModelArguments"
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"
Seq2Seq
TrainingArguments"
,
training_args
:
"TrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
None
:
)
->
None
:
kwargs
=
{
kwargs
=
{
...
@@ -187,7 +198,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
...
@@ -187,7 +198,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
def
_create_galore_optimizer
(
def
_create_galore_optimizer
(
model
:
"PreTrainedModel"
,
model
:
"PreTrainedModel"
,
training_args
:
"
Seq2Seq
TrainingArguments"
,
training_args
:
"TrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
"torch.optim.Optimizer"
:
)
->
"torch.optim.Optimizer"
:
if
len
(
finetuning_args
.
galore_target
)
==
1
and
finetuning_args
.
galore_target
[
0
]
==
"all"
:
if
len
(
finetuning_args
.
galore_target
)
==
1
and
finetuning_args
.
galore_target
[
0
]
==
"all"
:
...
@@ -231,9 +242,10 @@ def _create_galore_optimizer(
...
@@ -231,9 +242,10 @@ def _create_galore_optimizer(
elif
training_args
.
optim
==
"adafactor"
:
elif
training_args
.
optim
==
"adafactor"
:
optim_class
=
GaLoreAdafactor
optim_class
=
GaLoreAdafactor
else
:
else
:
raise
NotImplementedError
(
f
"Unknow optim:
{
training_args
.
optim
}
"
)
raise
NotImplementedError
(
f
"Unknow
n
optim:
{
training_args
.
optim
}
.
"
)
if
finetuning_args
.
galore_layerwise
:
if
finetuning_args
.
galore_layerwise
:
logger
.
warning_rank0
(
"The displayed gradient norm will be all zeros in layerwise GaLore."
)
if
training_args
.
gradient_accumulation_steps
!=
1
:
if
training_args
.
gradient_accumulation_steps
!=
1
:
raise
ValueError
(
"Per-layer GaLore does not support gradient accumulation."
)
raise
ValueError
(
"Per-layer GaLore does not support gradient accumulation."
)
...
@@ -265,13 +277,100 @@ def _create_galore_optimizer(
...
@@ -265,13 +277,100 @@ def _create_galore_optimizer(
]
]
optimizer
=
optim_class
(
param_groups
,
**
optim_kwargs
)
optimizer
=
optim_class
(
param_groups
,
**
optim_kwargs
)
logger
.
info_rank0
(
"Using GaLore optimizer, may cause hanging at the start of training, wait patiently."
)
logger
.
info_rank0
(
f
"Using GaLore optimizer with args:
{
galore_kwargs
}
. "
"It may cause hanging at the start of training, wait patiently."
)
return
optimizer
def
_create_apollo_optimizer
(
model
:
"PreTrainedModel"
,
training_args
:
"TrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
"torch.optim.Optimizer"
:
if
len
(
finetuning_args
.
apollo_target
)
==
1
and
finetuning_args
.
apollo_target
[
0
]
==
"all"
:
apollo_targets
=
find_all_linear_modules
(
model
,
finetuning_args
.
freeze_vision_tower
)
else
:
apollo_targets
=
finetuning_args
.
apollo_target
apollo_params
:
List
[
"torch.nn.Parameter"
]
=
[]
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
torch
.
nn
.
Linear
)
and
any
(
target
in
name
for
target
in
apollo_targets
):
for
param
in
module
.
parameters
():
if
param
.
requires_grad
and
len
(
param
.
shape
)
>
1
:
apollo_params
.
append
(
param
)
apollo_kwargs
=
{
"rank"
:
finetuning_args
.
apollo_rank
,
"proj"
:
finetuning_args
.
apollo_proj
,
"proj_type"
:
finetuning_args
.
apollo_proj_type
,
"update_proj_gap"
:
finetuning_args
.
apollo_update_interval
,
"scale"
:
finetuning_args
.
apollo_scale
,
"scale_type"
:
finetuning_args
.
apollo_scale_type
,
"scale_front"
:
finetuning_args
.
apollo_scale_front
,
}
id_apollo_params
=
{
id
(
param
)
for
param
in
apollo_params
}
decay_params
,
nodecay_params
=
[],
[]
# they are non-apollo parameters
trainable_params
:
List
[
"torch.nn.Parameter"
]
=
[]
# apollo_params + decay_params + nodecay_params
decay_param_names
=
_get_decay_parameter_names
(
model
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
trainable_params
.
append
(
param
)
if
id
(
param
)
not
in
id_apollo_params
:
if
name
in
decay_param_names
:
decay_params
.
append
(
param
)
else
:
nodecay_params
.
append
(
param
)
_
,
optim_kwargs
=
Trainer
.
get_optimizer_cls_and_kwargs
(
training_args
)
if
training_args
.
optim
==
"adamw_torch"
:
optim_class
=
APOLLOAdamW
else
:
raise
NotImplementedError
(
f
"Unknown optim:
{
training_args
.
optim
}
."
)
if
finetuning_args
.
apollo_layerwise
:
logger
.
warning_rank0
(
"The displayed gradient norm will be all zeros in layerwise APOLLO."
)
if
training_args
.
gradient_accumulation_steps
!=
1
:
raise
ValueError
(
"Per-layer APOLLO does not support gradient accumulation."
)
optimizer_dict
:
Dict
[
"torch.Tensor"
,
"torch.optim.Optimizer"
]
=
{}
for
param
in
nodecay_params
:
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
0.0
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
for
param
in
decay_params
:
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
training_args
.
weight_decay
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
for
param
in
apollo_params
:
# apollo params have weight decay
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
training_args
.
weight_decay
,
**
apollo_kwargs
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
def
optimizer_hook
(
param
:
"torch.nn.Parameter"
):
if
param
.
grad
is
not
None
:
optimizer_dict
[
param
].
step
()
optimizer_dict
[
param
].
zero_grad
()
for
param
in
trainable_params
:
param
.
register_post_accumulate_grad_hook
(
optimizer_hook
)
optimizer
=
DummyOptimizer
(
lr
=
training_args
.
learning_rate
,
optimizer_dict
=
optimizer_dict
)
else
:
param_groups
=
[
dict
(
params
=
nodecay_params
,
weight_decay
=
0.0
),
dict
(
params
=
decay_params
,
weight_decay
=
training_args
.
weight_decay
),
dict
(
params
=
apollo_params
,
weight_decay
=
training_args
.
weight_decay
,
**
apollo_kwargs
),
]
optimizer
=
optim_class
(
param_groups
,
**
optim_kwargs
)
logger
.
info_rank0
(
f
"Using APOLLO optimizer with args:
{
apollo_kwargs
}
."
)
return
optimizer
return
optimizer
def
_create_loraplus_optimizer
(
def
_create_loraplus_optimizer
(
model
:
"PreTrainedModel"
,
model
:
"PreTrainedModel"
,
training_args
:
"
Seq2Seq
TrainingArguments"
,
training_args
:
"TrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
"torch.optim.Optimizer"
:
)
->
"torch.optim.Optimizer"
:
default_lr
=
training_args
.
learning_rate
default_lr
=
training_args
.
learning_rate
...
@@ -311,7 +410,7 @@ def _create_loraplus_optimizer(
...
@@ -311,7 +410,7 @@ def _create_loraplus_optimizer(
def
_create_badam_optimizer
(
def
_create_badam_optimizer
(
model
:
"PreTrainedModel"
,
model
:
"PreTrainedModel"
,
training_args
:
"
Seq2Seq
TrainingArguments"
,
training_args
:
"TrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
"torch.optim.Optimizer"
:
)
->
"torch.optim.Optimizer"
:
decay_params
,
nodecay_params
=
[],
[]
decay_params
,
nodecay_params
=
[],
[]
...
@@ -330,7 +429,7 @@ def _create_badam_optimizer(
...
@@ -330,7 +429,7 @@ def _create_badam_optimizer(
]
]
if
finetuning_args
.
badam_mode
==
"layer"
:
if
finetuning_args
.
badam_mode
==
"layer"
:
from
badam
import
BlockOptimizer
from
badam
import
BlockOptimizer
# type: ignore
base_optimizer
=
optim_class
(
param_groups
,
**
optim_kwargs
)
base_optimizer
=
optim_class
(
param_groups
,
**
optim_kwargs
)
optimizer
=
BlockOptimizer
(
optimizer
=
BlockOptimizer
(
...
@@ -350,7 +449,7 @@ def _create_badam_optimizer(
...
@@ -350,7 +449,7 @@ def _create_badam_optimizer(
)
)
elif
finetuning_args
.
badam_mode
==
"ratio"
:
elif
finetuning_args
.
badam_mode
==
"ratio"
:
from
badam
import
BlockOptimizerRatio
from
badam
import
BlockOptimizerRatio
# type: ignore
assert
finetuning_args
.
badam_update_ratio
>
1e-6
assert
finetuning_args
.
badam_update_ratio
>
1e-6
optimizer
=
BlockOptimizerRatio
(
optimizer
=
BlockOptimizerRatio
(
...
@@ -372,9 +471,9 @@ def _create_badam_optimizer(
...
@@ -372,9 +471,9 @@ def _create_badam_optimizer(
def
_create_adam_mini_optimizer
(
def
_create_adam_mini_optimizer
(
model
:
"PreTrainedModel"
,
model
:
"PreTrainedModel"
,
training_args
:
"
Seq2Seq
TrainingArguments"
,
training_args
:
"TrainingArguments"
,
)
->
"torch.optim.Optimizer"
:
)
->
"torch.optim.Optimizer"
:
from
adam_mini
import
Adam_mini
from
adam_mini
import
Adam_mini
# type: ignore
hidden_size
=
getattr
(
model
.
config
,
"hidden_size"
,
None
)
hidden_size
=
getattr
(
model
.
config
,
"hidden_size"
,
None
)
num_q_head
=
getattr
(
model
.
config
,
"num_attention_heads"
,
None
)
num_q_head
=
getattr
(
model
.
config
,
"num_attention_heads"
,
None
)
...
@@ -397,12 +496,15 @@ def _create_adam_mini_optimizer(
...
@@ -397,12 +496,15 @@ def _create_adam_mini_optimizer(
def
create_custom_optimizer
(
def
create_custom_optimizer
(
model
:
"PreTrainedModel"
,
model
:
"PreTrainedModel"
,
training_args
:
"
Seq2Seq
TrainingArguments"
,
training_args
:
"TrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
Optional
[
"torch.optim.Optimizer"
]:
)
->
Optional
[
"torch.optim.Optimizer"
]:
if
finetuning_args
.
use_galore
:
if
finetuning_args
.
use_galore
:
return
_create_galore_optimizer
(
model
,
training_args
,
finetuning_args
)
return
_create_galore_optimizer
(
model
,
training_args
,
finetuning_args
)
if
finetuning_args
.
use_apollo
:
return
_create_apollo_optimizer
(
model
,
training_args
,
finetuning_args
)
if
finetuning_args
.
loraplus_lr_ratio
is
not
None
:
if
finetuning_args
.
loraplus_lr_ratio
is
not
None
:
return
_create_loraplus_optimizer
(
model
,
training_args
,
finetuning_args
)
return
_create_loraplus_optimizer
(
model
,
training_args
,
finetuning_args
)
...
@@ -414,7 +516,7 @@ def create_custom_optimizer(
...
@@ -414,7 +516,7 @@ def create_custom_optimizer(
def
create_custom_scheduler
(
def
create_custom_scheduler
(
training_args
:
"
Seq2Seq
TrainingArguments"
,
training_args
:
"TrainingArguments"
,
num_training_steps
:
int
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
,
)
->
None
:
)
->
None
:
...
@@ -457,3 +559,69 @@ def get_batch_logps(
...
@@ -457,3 +559,69 @@ def get_batch_logps(
labels
[
labels
==
label_pad_token_id
]
=
0
# dummy token
labels
[
labels
==
label_pad_token_id
]
=
0
# dummy token
per_token_logps
=
torch
.
gather
(
logits
.
log_softmax
(
-
1
),
dim
=
2
,
index
=
labels
.
unsqueeze
(
2
)).
squeeze
(
2
)
per_token_logps
=
torch
.
gather
(
logits
.
log_softmax
(
-
1
),
dim
=
2
,
index
=
labels
.
unsqueeze
(
2
)).
squeeze
(
2
)
return
(
per_token_logps
*
loss_mask
).
sum
(
-
1
),
loss_mask
.
sum
(
-
1
)
return
(
per_token_logps
*
loss_mask
).
sum
(
-
1
),
loss_mask
.
sum
(
-
1
)
def
nested_detach
(
tensors
:
Union
[
"torch.Tensor"
,
List
[
"torch.Tensor"
],
Tuple
[
"torch.Tensor"
],
Dict
[
str
,
"torch.Tensor"
]],
clone
:
bool
=
False
,
):
r
"""
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
"""
if
isinstance
(
tensors
,
(
list
,
tuple
)):
return
type
(
tensors
)(
nested_detach
(
t
,
clone
=
clone
)
for
t
in
tensors
)
elif
isinstance
(
tensors
,
Mapping
):
return
type
(
tensors
)({
k
:
nested_detach
(
t
,
clone
=
clone
)
for
k
,
t
in
tensors
.
items
()})
if
isinstance
(
tensors
,
torch
.
Tensor
):
if
clone
:
return
tensors
.
detach
().
clone
()
else
:
return
tensors
.
detach
()
else
:
return
tensors
def
get_swanlab_callback
(
finetuning_args
:
"FinetuningArguments"
)
->
"TrainerCallback"
:
r
"""
Gets the callback for logging to SwanLab.
"""
import
swanlab
# type: ignore
from
swanlab.integration.transformers
import
SwanLabCallback
# type: ignore
if
finetuning_args
.
swanlab_api_key
is
not
None
:
swanlab
.
login
(
api_key
=
finetuning_args
.
swanlab_api_key
)
swanlab_callback
=
SwanLabCallback
(
project
=
finetuning_args
.
swanlab_project
,
workspace
=
finetuning_args
.
swanlab_workspace
,
experiment_name
=
finetuning_args
.
swanlab_run_name
,
mode
=
finetuning_args
.
swanlab_mode
,
config
=
{
"Framework"
:
"🦙LlamaFactory"
},
)
return
swanlab_callback
def
get_ray_trainer
(
training_function
:
Callable
,
train_loop_config
:
Dict
[
str
,
Any
],
ray_args
:
"RayArguments"
,
)
->
"TorchTrainer"
:
if
not
ray_args
.
use_ray
:
raise
ValueError
(
"Ray was not enabled. Please set `USE_RAY=1` to enable ray."
)
trainer
=
TorchTrainer
(
training_function
,
train_loop_config
=
train_loop_config
,
scaling_config
=
ScalingConfig
(
num_workers
=
ray_args
.
ray_num_workers
,
resources_per_worker
=
ray_args
.
resources_per_worker
,
placement_strategy
=
ray_args
.
placement_strategy
,
use_gpu
=
True
,
),
run_config
=
RunConfig
(
name
=
ray_args
.
ray_run_name
,
storage_path
=
Path
(
"./saves"
).
absolute
().
as_posix
(),
),
)
return
trainer
src/llamafactory/train/tuner.py
View file @
8293100a
...
@@ -22,15 +22,21 @@ from transformers import PreTrainedModel
...
@@ -22,15 +22,21 @@ from transformers import PreTrainedModel
from
..data
import
get_template_and_fix_tokenizer
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras
import
logging
from
..extras.constants
import
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
..extras.constants
import
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
..hparams
import
get_infer_args
,
get_train_args
from
..extras.packages
import
is_ray_available
from
..hparams
import
get_infer_args
,
get_ray_args
,
get_train_args
,
read_args
from
..model
import
load_model
,
load_tokenizer
from
..model
import
load_model
,
load_tokenizer
from
.callbacks
import
LogCallback
from
.callbacks
import
LogCallback
,
PissaConvertCallback
,
ReporterCallback
from
.dpo
import
run_dpo
from
.dpo
import
run_dpo
from
.kto
import
run_kto
from
.kto
import
run_kto
from
.ppo
import
run_ppo
from
.ppo
import
run_ppo
from
.pt
import
run_pt
from
.pt
import
run_pt
from
.rm
import
run_rm
from
.rm
import
run_rm
from
.sft
import
run_sft
from
.sft
import
run_sft
from
.trainer_utils
import
get_ray_trainer
,
get_swanlab_callback
if
is_ray_available
():
from
ray.train.huggingface.transformers
import
RayTrainReportCallback
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -40,10 +46,20 @@ if TYPE_CHECKING:
...
@@ -40,10 +46,20 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
run_exp
(
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
callbacks
:
List
[
"TrainerCallback"
]
=
[])
->
None
:
def
_training_function
(
config
:
Dict
[
str
,
Any
])
->
None
:
callbacks
.
append
(
LogCallback
())
args
=
config
.
get
(
"args"
)
callbacks
:
List
[
Any
]
=
config
.
get
(
"callbacks"
)
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
=
get_train_args
(
args
)
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
=
get_train_args
(
args
)
callbacks
.
append
(
LogCallback
())
if
finetuning_args
.
pissa_convert
:
callbacks
.
append
(
PissaConvertCallback
())
if
finetuning_args
.
use_swanlab
:
callbacks
.
append
(
get_swanlab_callback
(
finetuning_args
))
callbacks
.
append
(
ReporterCallback
(
model_args
,
data_args
,
finetuning_args
,
generating_args
))
# add to last
if
finetuning_args
.
stage
==
"pt"
:
if
finetuning_args
.
stage
==
"pt"
:
run_pt
(
model_args
,
data_args
,
training_args
,
finetuning_args
,
callbacks
)
run_pt
(
model_args
,
data_args
,
training_args
,
finetuning_args
,
callbacks
)
elif
finetuning_args
.
stage
==
"sft"
:
elif
finetuning_args
.
stage
==
"sft"
:
...
@@ -60,6 +76,22 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
...
@@ -60,6 +76,22 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
raise
ValueError
(
f
"Unknown task:
{
finetuning_args
.
stage
}
."
)
raise
ValueError
(
f
"Unknown task:
{
finetuning_args
.
stage
}
."
)
def
run_exp
(
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
callbacks
:
Optional
[
List
[
"TrainerCallback"
]]
=
None
)
->
None
:
args
=
read_args
(
args
)
ray_args
=
get_ray_args
(
args
)
callbacks
=
callbacks
or
[]
if
ray_args
.
use_ray
:
callbacks
.
append
(
RayTrainReportCallback
())
trainer
=
get_ray_trainer
(
training_function
=
_training_function
,
train_loop_config
=
{
"args"
:
args
,
"callbacks"
:
callbacks
},
ray_args
=
ray_args
,
)
trainer
.
fit
()
else
:
_training_function
(
config
=
{
"args"
:
args
,
"callbacks"
:
callbacks
})
def
export_model
(
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
None
:
def
export_model
(
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
None
:
model_args
,
data_args
,
finetuning_args
,
_
=
get_infer_args
(
args
)
model_args
,
data_args
,
finetuning_args
,
_
=
get_infer_args
(
args
)
...
...
src/llamafactory/webui/chatter.py
View file @
8293100a
...
@@ -91,6 +91,7 @@ class WebChatModel(ChatModel):
...
@@ -91,6 +91,7 @@ class WebChatModel(ChatModel):
rope_scaling
=
get
(
"top.rope_scaling"
)
if
get
(
"top.rope_scaling"
)
in
[
"linear"
,
"dynamic"
]
else
None
,
rope_scaling
=
get
(
"top.rope_scaling"
)
if
get
(
"top.rope_scaling"
)
in
[
"linear"
,
"dynamic"
]
else
None
,
infer_backend
=
get
(
"infer.infer_backend"
),
infer_backend
=
get
(
"infer.infer_backend"
),
infer_dtype
=
get
(
"infer.infer_dtype"
),
infer_dtype
=
get
(
"infer.infer_dtype"
),
trust_remote_code
=
True
,
)
)
if
checkpoint_path
:
if
checkpoint_path
:
...
@@ -157,7 +158,7 @@ class WebChatModel(ChatModel):
...
@@ -157,7 +158,7 @@ class WebChatModel(ChatModel):
result
=
response
result
=
response
if
isinstance
(
result
,
list
):
if
isinstance
(
result
,
list
):
tool_calls
=
[{
"name"
:
tool
[
0
]
,
"arguments"
:
json
.
loads
(
tool
[
1
]
)}
for
tool
in
result
]
tool_calls
=
[{
"name"
:
tool
.
name
,
"arguments"
:
json
.
loads
(
tool
.
arguments
)}
for
tool
in
result
]
tool_calls
=
json
.
dumps
(
tool_calls
,
indent
=
4
,
ensure_ascii
=
False
)
tool_calls
=
json
.
dumps
(
tool_calls
,
indent
=
4
,
ensure_ascii
=
False
)
output_messages
=
messages
+
[{
"role"
:
Role
.
FUNCTION
.
value
,
"content"
:
tool_calls
}]
output_messages
=
messages
+
[{
"role"
:
Role
.
FUNCTION
.
value
,
"content"
:
tool_calls
}]
bot_text
=
"```json
\n
"
+
tool_calls
+
"
\n
```"
bot_text
=
"```json
\n
"
+
tool_calls
+
"
\n
```"
...
...
src/llamafactory/webui/components/export.py
View file @
8293100a
...
@@ -84,6 +84,7 @@ def save_model(
...
@@ -84,6 +84,7 @@ def save_model(
export_quantization_dataset
=
export_quantization_dataset
,
export_quantization_dataset
=
export_quantization_dataset
,
export_device
=
export_device
,
export_device
=
export_device
,
export_legacy_format
=
export_legacy_format
,
export_legacy_format
=
export_legacy_format
,
trust_remote_code
=
True
,
)
)
if
checkpoint_path
:
if
checkpoint_path
:
...
...
src/llamafactory/webui/components/train.py
View file @
8293100a
...
@@ -234,8 +234,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
...
@@ -234,8 +234,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with
gr
.
Row
():
with
gr
.
Row
():
use_galore
=
gr
.
Checkbox
()
use_galore
=
gr
.
Checkbox
()
galore_rank
=
gr
.
Slider
(
minimum
=
1
,
maximum
=
1024
,
value
=
16
,
step
=
1
)
galore_rank
=
gr
.
Slider
(
minimum
=
1
,
maximum
=
1024
,
value
=
16
,
step
=
1
)
galore_update_interval
=
gr
.
Slider
(
minimum
=
1
,
maximum
=
1024
,
value
=
200
,
step
=
1
)
galore_update_interval
=
gr
.
Slider
(
minimum
=
1
,
maximum
=
2048
,
value
=
200
,
step
=
1
)
galore_scale
=
gr
.
Slider
(
minimum
=
0
,
maximum
=
1
,
value
=
0.25
,
step
=
0.
0
1
)
galore_scale
=
gr
.
Slider
(
minimum
=
0
,
maximum
=
1
00
,
value
=
2.0
,
step
=
0.1
)
galore_target
=
gr
.
Textbox
(
value
=
"all"
)
galore_target
=
gr
.
Textbox
(
value
=
"all"
)
input_elems
.
update
({
use_galore
,
galore_rank
,
galore_update_interval
,
galore_scale
,
galore_target
})
input_elems
.
update
({
use_galore
,
galore_rank
,
galore_update_interval
,
galore_scale
,
galore_target
})
...
@@ -250,6 +250,26 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
...
@@ -250,6 +250,26 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
)
)
)
with
gr
.
Accordion
(
open
=
False
)
as
apollo_tab
:
with
gr
.
Row
():
use_apollo
=
gr
.
Checkbox
()
apollo_rank
=
gr
.
Slider
(
minimum
=
1
,
maximum
=
1024
,
value
=
16
,
step
=
1
)
apollo_update_interval
=
gr
.
Slider
(
minimum
=
1
,
maximum
=
2048
,
value
=
200
,
step
=
1
)
apollo_scale
=
gr
.
Slider
(
minimum
=
0
,
maximum
=
100
,
value
=
32.0
,
step
=
0.1
)
apollo_target
=
gr
.
Textbox
(
value
=
"all"
)
input_elems
.
update
({
use_apollo
,
apollo_rank
,
apollo_update_interval
,
apollo_scale
,
apollo_target
})
elem_dict
.
update
(
dict
(
apollo_tab
=
apollo_tab
,
use_apollo
=
use_apollo
,
apollo_rank
=
apollo_rank
,
apollo_update_interval
=
apollo_update_interval
,
apollo_scale
=
apollo_scale
,
apollo_target
=
apollo_target
,
)
)
with
gr
.
Accordion
(
open
=
False
)
as
badam_tab
:
with
gr
.
Accordion
(
open
=
False
)
as
badam_tab
:
with
gr
.
Row
():
with
gr
.
Row
():
use_badam
=
gr
.
Checkbox
()
use_badam
=
gr
.
Checkbox
()
...
@@ -270,6 +290,30 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
...
@@ -270,6 +290,30 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
)
)
)
with
gr
.
Accordion
(
open
=
False
)
as
swanlab_tab
:
with
gr
.
Row
():
use_swanlab
=
gr
.
Checkbox
()
swanlab_project
=
gr
.
Textbox
(
value
=
"llamafactory"
)
swanlab_run_name
=
gr
.
Textbox
()
swanlab_workspace
=
gr
.
Textbox
()
swanlab_api_key
=
gr
.
Textbox
()
swanlab_mode
=
gr
.
Dropdown
(
choices
=
[
"cloud"
,
"local"
],
value
=
"cloud"
)
input_elems
.
update
(
{
use_swanlab
,
swanlab_project
,
swanlab_run_name
,
swanlab_workspace
,
swanlab_api_key
,
swanlab_mode
}
)
elem_dict
.
update
(
dict
(
swanlab_tab
=
swanlab_tab
,
use_swanlab
=
use_swanlab
,
swanlab_project
=
swanlab_project
,
swanlab_run_name
=
swanlab_run_name
,
swanlab_workspace
=
swanlab_workspace
,
swanlab_api_key
=
swanlab_api_key
,
swanlab_mode
=
swanlab_mode
,
)
)
with
gr
.
Row
():
with
gr
.
Row
():
cmd_preview_btn
=
gr
.
Button
()
cmd_preview_btn
=
gr
.
Button
()
arg_save_btn
=
gr
.
Button
()
arg_save_btn
=
gr
.
Button
()
...
...
src/llamafactory/webui/interface.py
View file @
8293100a
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
platform
from
..extras.packages
import
is_gradio_available
from
..extras.packages
import
is_gradio_available
from
.common
import
save_config
from
.common
import
save_config
...
@@ -34,8 +35,9 @@ if is_gradio_available():
...
@@ -34,8 +35,9 @@ if is_gradio_available():
def
create_ui
(
demo_mode
:
bool
=
False
)
->
"gr.Blocks"
:
def
create_ui
(
demo_mode
:
bool
=
False
)
->
"gr.Blocks"
:
engine
=
Engine
(
demo_mode
=
demo_mode
,
pure_chat
=
False
)
engine
=
Engine
(
demo_mode
=
demo_mode
,
pure_chat
=
False
)
hostname
=
os
.
getenv
(
"HOSTNAME"
,
os
.
getenv
(
"COMPUTERNAME"
,
platform
.
node
())).
split
(
"."
)[
0
]
with
gr
.
Blocks
(
title
=
"LLaMA Board"
,
css
=
CSS
)
as
demo
:
with
gr
.
Blocks
(
title
=
f
"LLaMA Board
(
{
hostname
}
)
"
,
css
=
CSS
)
as
demo
:
if
demo_mode
:
if
demo_mode
:
gr
.
HTML
(
"<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>"
)
gr
.
HTML
(
"<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>"
)
gr
.
HTML
(
gr
.
HTML
(
...
...
src/llamafactory/webui/locales.py
View file @
8293100a
...
@@ -30,15 +30,19 @@ LOCALES = {
...
@@ -30,15 +30,19 @@ LOCALES = {
"model_name"
:
{
"model_name"
:
{
"en"
:
{
"en"
:
{
"label"
:
"Model name"
,
"label"
:
"Model name"
,
"info"
:
"Input the name prefix to search for the model."
,
},
},
"ru"
:
{
"ru"
:
{
"label"
:
"Название модели"
,
"label"
:
"Название модели"
,
"info"
:
"Введите префикс имени для поиска модели."
,
},
},
"zh"
:
{
"zh"
:
{
"label"
:
"模型名称"
,
"label"
:
"模型名称"
,
"info"
:
"输入首单词以检索模型。"
,
},
},
"ko"
:
{
"ko"
:
{
"label"
:
"모델 이름"
,
"label"
:
"모델 이름"
,
"info"
:
"모델을 검색하기 위해 이름 접두어를 입력하세요."
,
},
},
},
},
"model_path"
:
{
"model_path"
:
{
...
@@ -464,7 +468,7 @@ LOCALES = {
...
@@ -464,7 +468,7 @@ LOCALES = {
"val_size"
:
{
"val_size"
:
{
"en"
:
{
"en"
:
{
"label"
:
"Val size"
,
"label"
:
"Val size"
,
"info"
:
"P
roportion of data in the dev
set."
,
"info"
:
"P
ercentage of validation set from the entire data
set."
,
},
},
"ru"
:
{
"ru"
:
{
"label"
:
"Размер валидации"
,
"label"
:
"Размер валидации"
,
...
@@ -1115,7 +1119,7 @@ LOCALES = {
...
@@ -1115,7 +1119,7 @@ LOCALES = {
"info"
:
"Нормализация оценок в тренировке PPO."
,
"info"
:
"Нормализация оценок в тренировке PPO."
,
},
},
"zh"
:
{
"zh"
:
{
"label"
:
"
奖励模型
"
,
"label"
:
"
归一化分数
"
,
"info"
:
"PPO 训练中归一化奖励分数。"
,
"info"
:
"PPO 训练中归一化奖励分数。"
,
},
},
"ko"
:
{
"ko"
:
{
...
@@ -1158,19 +1162,19 @@ LOCALES = {
...
@@ -1158,19 +1162,19 @@ LOCALES = {
"use_galore"
:
{
"use_galore"
:
{
"en"
:
{
"en"
:
{
"label"
:
"Use GaLore"
,
"label"
:
"Use GaLore"
,
"info"
:
"
Enable gradient low-Rank projection
."
,
"info"
:
"
Use GaLore optimizer
."
,
},
},
"ru"
:
{
"ru"
:
{
"label"
:
"Использовать GaLore"
,
"label"
:
"Использовать GaLore"
,
"info"
:
"
Включить проекцию градиента на низкоранговое пространство
."
,
"info"
:
"
Используйте оптимизатор GaLore
."
,
},
},
"zh"
:
{
"zh"
:
{
"label"
:
"使用 GaLore"
,
"label"
:
"使用 GaLore"
,
"info"
:
"使用
梯度低秩投影
。"
,
"info"
:
"使用
GaLore 优化器
。"
,
},
},
"ko"
:
{
"ko"
:
{
"label"
:
"GaLore 사용"
,
"label"
:
"GaLore 사용"
,
"info"
:
"
그레디언트 로우 랭크 프로젝션을 활성화합니다
."
,
"info"
:
"
GaLore 최적화를 사용하세요
."
,
},
},
},
},
"galore_rank"
:
{
"galore_rank"
:
{
...
@@ -1245,6 +1249,110 @@ LOCALES = {
...
@@ -1245,6 +1249,110 @@ LOCALES = {
"info"
:
"GaLore를 적용할 모듈의 이름. 모듈 간에는 쉼표(,)로 구분하십시오."
,
"info"
:
"GaLore를 적용할 모듈의 이름. 모듈 간에는 쉼표(,)로 구분하십시오."
,
},
},
},
},
"apollo_tab"
:
{
"en"
:
{
"label"
:
"APOLLO configurations"
,
},
"ru"
:
{
"label"
:
"Конфигурации APOLLO"
,
},
"zh"
:
{
"label"
:
"APOLLO 参数设置"
,
},
"ko"
:
{
"label"
:
"APOLLO 구성"
,
},
},
"use_apollo"
:
{
"en"
:
{
"label"
:
"Use APOLLO"
,
"info"
:
"Use APOLLO optimizer."
,
},
"ru"
:
{
"label"
:
"Использовать APOLLO"
,
"info"
:
"Используйте оптимизатор APOLLO."
,
},
"zh"
:
{
"label"
:
"使用 APOLLO"
,
"info"
:
"使用 APOLLO 优化器。"
,
},
"ko"
:
{
"label"
:
"APOLLO 사용"
,
"info"
:
"APOLLO 최적화를 사용하세요."
,
},
},
"apollo_rank"
:
{
"en"
:
{
"label"
:
"APOLLO rank"
,
"info"
:
"The rank of APOLLO gradients."
,
},
"ru"
:
{
"label"
:
"Ранг APOLLO"
,
"info"
:
"Ранг градиентов APOLLO."
,
},
"zh"
:
{
"label"
:
"APOLLO 秩"
,
"info"
:
"APOLLO 梯度的秩大小。"
,
},
"ko"
:
{
"label"
:
"APOLLO 랭크"
,
"info"
:
"APOLLO 그레디언트의 랭크."
,
},
},
"apollo_update_interval"
:
{
"en"
:
{
"label"
:
"Update interval"
,
"info"
:
"Number of steps to update the APOLLO projection."
,
},
"ru"
:
{
"label"
:
"Интервал обновления"
,
"info"
:
"Количество шагов для обновления проекции APOLLO."
,
},
"zh"
:
{
"label"
:
"更新间隔"
,
"info"
:
"相邻两次投影更新的步数。"
,
},
"ko"
:
{
"label"
:
"업데이트 간격"
,
"info"
:
"APOLLO 프로젝션을 업데이트할 간격의 스텝 수."
,
},
},
"apollo_scale"
:
{
"en"
:
{
"label"
:
"APOLLO scale"
,
"info"
:
"APOLLO scaling coefficient."
,
},
"ru"
:
{
"label"
:
"LoRA Alpha"
,
"info"
:
"Коэффициент масштабирования APOLLO."
,
},
"zh"
:
{
"label"
:
"APOLLO 缩放系数"
,
"info"
:
"APOLLO 缩放系数大小。"
,
},
"ko"
:
{
"label"
:
"APOLLO 스케일"
,
"info"
:
"APOLLO 스케일링 계수."
,
},
},
"apollo_target"
:
{
"en"
:
{
"label"
:
"APOLLO modules"
,
"info"
:
"Name(s) of modules to apply APOLLO. Use commas to separate multiple modules."
,
},
"ru"
:
{
"label"
:
"Модули APOLLO"
,
"info"
:
"Имена модулей для применения APOLLO. Используйте запятые для разделения нескольких модулей."
,
},
"zh"
:
{
"label"
:
"APOLLO 作用模块"
,
"info"
:
"应用 APOLLO 的模块名称。使用英文逗号分隔多个名称。"
,
},
"ko"
:
{
"label"
:
"APOLLO 모듈"
,
"info"
:
"APOLLO를 적용할 모듈의 이름. 모듈 간에는 쉼표(,)로 구분하십시오."
,
},
},
"badam_tab"
:
{
"badam_tab"
:
{
"en"
:
{
"en"
:
{
"label"
:
"BAdam configurations"
,
"label"
:
"BAdam configurations"
,
...
@@ -1349,6 +1457,120 @@ LOCALES = {
...
@@ -1349,6 +1457,120 @@ LOCALES = {
"info"
:
"비율-BAdam의 업데이트 비율."
,
"info"
:
"비율-BAdam의 업데이트 비율."
,
},
},
},
},
"swanlab_tab"
:
{
"en"
:
{
"label"
:
"SwanLab configurations"
,
},
"ru"
:
{
"label"
:
"Конфигурации SwanLab"
,
},
"zh"
:
{
"label"
:
"SwanLab 参数设置"
,
},
"ko"
:
{
"label"
:
"SwanLab 설정"
,
},
},
"use_swanlab"
:
{
"en"
:
{
"label"
:
"Use SwanLab"
,
"info"
:
"Enable SwanLab for experiment tracking and visualization."
,
},
"ru"
:
{
"label"
:
"Использовать SwanLab"
,
"info"
:
"Включить SwanLab для отслеживания и визуализации экспериментов."
,
},
"zh"
:
{
"label"
:
"使用 SwanLab"
,
"info"
:
"启用 SwanLab 进行实验跟踪和可视化。"
,
},
"ko"
:
{
"label"
:
"SwanLab 사용"
,
"info"
:
"SwanLab를 사용하여 실험을 추적하고 시각화합니다."
,
},
},
"swanlab_project"
:
{
"en"
:
{
"label"
:
"SwanLab project"
,
},
"ru"
:
{
"label"
:
"SwanLab Проект"
,
},
"zh"
:
{
"label"
:
"SwanLab 项目名"
,
},
"ko"
:
{
"label"
:
"SwanLab 프로젝트"
,
},
},
"swanlab_run_name"
:
{
"en"
:
{
"label"
:
"SwanLab experiment name (optional)"
,
},
"ru"
:
{
"label"
:
"SwanLab Имя эксперимента (опционально)"
,
},
"zh"
:
{
"label"
:
"SwanLab 实验名(非必填)"
,
},
"ko"
:
{
"label"
:
"SwanLab 실험 이름 (선택 사항)"
,
},
},
"swanlab_workspace"
:
{
"en"
:
{
"label"
:
"SwanLab workspace (optional)"
,
"info"
:
"Workspace for SwanLab. Defaults to the personal workspace."
,
},
"ru"
:
{
"label"
:
"SwanLab Рабочая область (опционально)"
,
"info"
:
"Рабочая область SwanLab, если не заполнено, то по умолчанию в личной рабочей области."
,
},
"zh"
:
{
"label"
:
"SwanLab 工作区(非必填)"
,
"info"
:
"SwanLab 的工作区,默认在个人工作区下。"
,
},
"ko"
:
{
"label"
:
"SwanLab 작업 영역 (선택 사항)"
,
"info"
:
"SwanLab 조직의 작업 영역, 비어 있으면 기본적으로 개인 작업 영역에 있습니다."
,
},
},
"swanlab_api_key"
:
{
"en"
:
{
"label"
:
"SwanLab API key (optional)"
,
"info"
:
"API key for SwanLab."
,
},
"ru"
:
{
"label"
:
"SwanLab API ключ (опционально)"
,
"info"
:
"API ключ для SwanLab."
,
},
"zh"
:
{
"label"
:
"SwanLab API密钥(非必填)"
,
"info"
:
"用于在编程环境登录 SwanLab,已登录则无需填写。"
,
},
"ko"
:
{
"label"
:
"SwanLab API 키 (선택 사항)"
,
"info"
:
"SwanLab의 API 키."
,
},
},
"swanlab_mode"
:
{
"en"
:
{
"label"
:
"SwanLab mode"
,
"info"
:
"Cloud or offline version."
,
},
"ru"
:
{
"label"
:
"SwanLab Режим"
,
"info"
:
"Версия в облаке или локальная версия."
,
},
"zh"
:
{
"label"
:
"SwanLab 模式"
,
"info"
:
"使用云端版或离线版 SwanLab。"
,
},
"ko"
:
{
"label"
:
"SwanLab 모드"
,
"info"
:
"클라우드 버전 또는 오프라인 버전."
,
},
},
"cmd_preview_btn"
:
{
"cmd_preview_btn"
:
{
"en"
:
{
"en"
:
{
"value"
:
"Preview command"
,
"value"
:
"Preview command"
,
...
...
src/llamafactory/webui/runner.py
View file @
8293100a
...
@@ -19,9 +19,10 @@ from subprocess import Popen, TimeoutExpired
...
@@ -19,9 +19,10 @@ from subprocess import Popen, TimeoutExpired
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Generator
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Generator
,
Optional
from
transformers.trainer
import
TRAINING_ARGS_NAME
from
transformers.trainer
import
TRAINING_ARGS_NAME
from
transformers.utils
import
is_torch_npu_available
from
..extras.constants
import
LLAMABOARD_CONFIG
,
PEFT_METHODS
,
TRAINING_STAGES
from
..extras.constants
import
LLAMABOARD_CONFIG
,
PEFT_METHODS
,
TRAINING_STAGES
from
..extras.misc
import
is_gpu_or_npu_available
,
torch_gc
from
..extras.misc
import
is_gpu_or_npu_available
,
torch_gc
,
use_ray
from
..extras.packages
import
is_gradio_available
,
is_transformers_version_equal_to_4_46
from
..extras.packages
import
is_gradio_available
,
is_transformers_version_equal_to_4_46
from
.common
import
DEFAULT_CACHE_DIR
,
DEFAULT_CONFIG_DIR
,
QUANTIZATION_BITS
,
get_save_dir
,
load_config
from
.common
import
DEFAULT_CACHE_DIR
,
DEFAULT_CONFIG_DIR
,
QUANTIZATION_BITS
,
get_save_dir
,
load_config
from
.locales
import
ALERTS
,
LOCALES
from
.locales
import
ALERTS
,
LOCALES
...
@@ -146,12 +147,15 @@ class Runner:
...
@@ -146,12 +147,15 @@ class Runner:
shift_attn
=
get
(
"train.shift_attn"
),
shift_attn
=
get
(
"train.shift_attn"
),
report_to
=
"all"
if
get
(
"train.report_to"
)
else
"none"
,
report_to
=
"all"
if
get
(
"train.report_to"
)
else
"none"
,
use_galore
=
get
(
"train.use_galore"
),
use_galore
=
get
(
"train.use_galore"
),
use_apollo
=
get
(
"train.use_apollo"
),
use_badam
=
get
(
"train.use_badam"
),
use_badam
=
get
(
"train.use_badam"
),
use_swanlab
=
get
(
"train.use_swanlab"
),
output_dir
=
get_save_dir
(
model_name
,
finetuning_type
,
get
(
"train.output_dir"
)),
output_dir
=
get_save_dir
(
model_name
,
finetuning_type
,
get
(
"train.output_dir"
)),
fp16
=
(
get
(
"train.compute_type"
)
==
"fp16"
),
fp16
=
(
get
(
"train.compute_type"
)
==
"fp16"
),
bf16
=
(
get
(
"train.compute_type"
)
==
"bf16"
),
bf16
=
(
get
(
"train.compute_type"
)
==
"bf16"
),
pure_bf16
=
(
get
(
"train.compute_type"
)
==
"pure_bf16"
),
pure_bf16
=
(
get
(
"train.compute_type"
)
==
"pure_bf16"
),
plot_loss
=
True
,
plot_loss
=
True
,
trust_remote_code
=
True
,
ddp_timeout
=
180000000
,
ddp_timeout
=
180000000
,
include_num_input_tokens_seen
=
False
if
is_transformers_version_equal_to_4_46
()
else
True
,
# FIXME
include_num_input_tokens_seen
=
False
if
is_transformers_version_equal_to_4_46
()
else
True
,
# FIXME
)
)
...
@@ -170,6 +174,7 @@ class Runner:
...
@@ -170,6 +174,7 @@ class Runner:
if
get
(
"top.quantization_bit"
)
in
QUANTIZATION_BITS
:
if
get
(
"top.quantization_bit"
)
in
QUANTIZATION_BITS
:
args
[
"quantization_bit"
]
=
int
(
get
(
"top.quantization_bit"
))
args
[
"quantization_bit"
]
=
int
(
get
(
"top.quantization_bit"
))
args
[
"quantization_method"
]
=
get
(
"top.quantization_method"
)
args
[
"quantization_method"
]
=
get
(
"top.quantization_method"
)
args
[
"double_quantization"
]
=
not
is_torch_npu_available
()
# freeze config
# freeze config
if
args
[
"finetuning_type"
]
==
"freeze"
:
if
args
[
"finetuning_type"
]
==
"freeze"
:
...
@@ -220,6 +225,13 @@ class Runner:
...
@@ -220,6 +225,13 @@ class Runner:
args
[
"galore_scale"
]
=
get
(
"train.galore_scale"
)
args
[
"galore_scale"
]
=
get
(
"train.galore_scale"
)
args
[
"galore_target"
]
=
get
(
"train.galore_target"
)
args
[
"galore_target"
]
=
get
(
"train.galore_target"
)
# apollo config
if
args
[
"use_apollo"
]:
args
[
"apollo_rank"
]
=
get
(
"train.apollo_rank"
)
args
[
"apollo_update_interval"
]
=
get
(
"train.apollo_update_interval"
)
args
[
"apollo_scale"
]
=
get
(
"train.apollo_scale"
)
args
[
"apollo_target"
]
=
get
(
"train.apollo_target"
)
# badam config
# badam config
if
args
[
"use_badam"
]:
if
args
[
"use_badam"
]:
args
[
"badam_mode"
]
=
get
(
"train.badam_mode"
)
args
[
"badam_mode"
]
=
get
(
"train.badam_mode"
)
...
@@ -227,6 +239,14 @@ class Runner:
...
@@ -227,6 +239,14 @@ class Runner:
args
[
"badam_switch_interval"
]
=
get
(
"train.badam_switch_interval"
)
args
[
"badam_switch_interval"
]
=
get
(
"train.badam_switch_interval"
)
args
[
"badam_update_ratio"
]
=
get
(
"train.badam_update_ratio"
)
args
[
"badam_update_ratio"
]
=
get
(
"train.badam_update_ratio"
)
# swanlab config
if
get
(
"train.use_swanlab"
):
args
[
"swanlab_project"
]
=
get
(
"train.swanlab_project"
)
args
[
"swanlab_run_name"
]
=
get
(
"train.swanlab_run_name"
)
args
[
"swanlab_workspace"
]
=
get
(
"train.swanlab_workspace"
)
args
[
"swanlab_api_key"
]
=
get
(
"train.swanlab_api_key"
)
args
[
"swanlab_mode"
]
=
get
(
"train.swanlab_mode"
)
# eval config
# eval config
if
get
(
"train.val_size"
)
>
1e-6
and
args
[
"stage"
]
!=
"ppo"
:
if
get
(
"train.val_size"
)
>
1e-6
and
args
[
"stage"
]
!=
"ppo"
:
args
[
"val_size"
]
=
get
(
"train.val_size"
)
args
[
"val_size"
]
=
get
(
"train.val_size"
)
...
@@ -268,6 +288,7 @@ class Runner:
...
@@ -268,6 +288,7 @@ class Runner:
top_p
=
get
(
"eval.top_p"
),
top_p
=
get
(
"eval.top_p"
),
temperature
=
get
(
"eval.temperature"
),
temperature
=
get
(
"eval.temperature"
),
output_dir
=
get_save_dir
(
model_name
,
finetuning_type
,
get
(
"eval.output_dir"
)),
output_dir
=
get_save_dir
(
model_name
,
finetuning_type
,
get
(
"eval.output_dir"
)),
trust_remote_code
=
True
,
)
)
if
get
(
"eval.predict"
):
if
get
(
"eval.predict"
):
...
@@ -383,12 +404,12 @@ class Runner:
...
@@ -383,12 +404,12 @@ class Runner:
continue
continue
if
self
.
do_train
:
if
self
.
do_train
:
if
os
.
path
.
exists
(
os
.
path
.
join
(
output_path
,
TRAINING_ARGS_NAME
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
output_path
,
TRAINING_ARGS_NAME
))
or
use_ray
()
:
finish_info
=
ALERTS
[
"info_finished"
][
lang
]
finish_info
=
ALERTS
[
"info_finished"
][
lang
]
else
:
else
:
finish_info
=
ALERTS
[
"err_failed"
][
lang
]
finish_info
=
ALERTS
[
"err_failed"
][
lang
]
else
:
else
:
if
os
.
path
.
exists
(
os
.
path
.
join
(
output_path
,
"all_results.json"
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
output_path
,
"all_results.json"
))
or
use_ray
()
:
finish_info
=
get_eval_results
(
os
.
path
.
join
(
output_path
,
"all_results.json"
))
finish_info
=
get_eval_results
(
os
.
path
.
join
(
output_path
,
"all_results.json"
))
else
:
else
:
finish_info
=
ALERTS
[
"err_failed"
][
lang
]
finish_info
=
ALERTS
[
"err_failed"
][
lang
]
...
...
tests/data/test_collator.py
View file @
8293100a
...
@@ -12,9 +12,105 @@
...
@@ -12,9 +12,105 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
torch
import
torch
from
PIL
import
Image
from
llamafactory.data
import
get_template_and_fix_tokenizer
from
llamafactory.data.collator
import
MultiModalDataCollatorForSeq2Seq
,
prepare_4d_attention_mask
from
llamafactory.extras.constants
import
IGNORE_INDEX
from
llamafactory.hparams
import
get_infer_args
from
llamafactory.model
import
load_tokenizer
TINY_LLAMA
=
os
.
getenv
(
"TINY_LLAMA"
,
"llamafactory/tiny-random-Llama-3"
)
from
llamafactory.data.collator
import
prepare_4d_attention_mask
def
test_base_collator
():
model_args
,
data_args
,
*
_
=
get_infer_args
({
"model_name_or_path"
:
TINY_LLAMA
,
"template"
:
"default"
})
tokenizer_module
=
load_tokenizer
(
model_args
)
template
=
get_template_and_fix_tokenizer
(
tokenizer_module
[
"tokenizer"
],
data_args
)
data_collator
=
MultiModalDataCollatorForSeq2Seq
(
template
=
template
,
pad_to_multiple_of
=
8
,
label_pad_token_id
=
IGNORE_INDEX
,
**
tokenizer_module
,
)
p
=
tokenizer_module
[
"tokenizer"
].
pad_token_id
q
=
IGNORE_INDEX
features
=
[
{
"input_ids"
:
[
0
,
1
,
2
,
3
,
4
,
5
],
"attention_mask"
:
[
1
,
1
,
1
,
1
,
1
,
1
],
"labels"
:
[
q
,
q
,
2
,
3
,
4
,
5
],
},
{
"input_ids"
:
[
6
,
7
],
"attention_mask"
:
[
1
,
1
],
"labels"
:
[
q
,
7
],
},
]
batch_input
=
data_collator
(
features
)
expected_input
=
{
"input_ids"
:
[
[
0
,
1
,
2
,
3
,
4
,
5
,
p
,
p
],
[
6
,
7
,
p
,
p
,
p
,
p
,
p
,
p
],
],
"attention_mask"
:
[
[
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
],
[
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
],
],
"labels"
:
[
[
q
,
q
,
2
,
3
,
4
,
5
,
q
,
q
],
[
q
,
7
,
q
,
q
,
q
,
q
,
q
,
q
],
],
}
for
k
in
batch_input
.
keys
():
assert
batch_input
[
k
].
eq
(
torch
.
tensor
(
expected_input
[
k
])).
all
()
def
test_multimodal_collator
():
model_args
,
data_args
,
*
_
=
get_infer_args
(
{
"model_name_or_path"
:
"Qwen/Qwen2-VL-7B-Instruct"
,
"template"
:
"qwen2_vl"
}
)
tokenizer_module
=
load_tokenizer
(
model_args
)
template
=
get_template_and_fix_tokenizer
(
tokenizer_module
[
"tokenizer"
],
data_args
)
data_collator
=
MultiModalDataCollatorForSeq2Seq
(
template
=
template
,
pad_to_multiple_of
=
4
,
label_pad_token_id
=
IGNORE_INDEX
,
**
tokenizer_module
,
)
p
=
tokenizer_module
[
"tokenizer"
].
pad_token_id
q
=
IGNORE_INDEX
s
=
tokenizer_module
[
"tokenizer"
].
convert_tokens_to_ids
(
"<|vision_start|>"
)
e
=
tokenizer_module
[
"tokenizer"
].
convert_tokens_to_ids
(
"<|vision_end|>"
)
m
=
tokenizer_module
[
"tokenizer"
].
convert_tokens_to_ids
(
"<|image_pad|>"
)
fake_image
=
Image
.
new
(
"RGB"
,
(
64
,
64
),
(
255
,
255
,
255
))
features
=
[
{
"input_ids"
:
[
0
,
1
,
2
,
3
],
"attention_mask"
:
[
1
,
1
,
1
,
1
],
"labels"
:
[
0
,
1
,
2
,
3
],
},
]
batch_input
=
data_collator
(
features
)
expected_input
=
{
"input_ids"
:
[
[
0
,
1
,
2
,
3
,
s
,
m
,
m
,
m
,
m
,
e
,
p
,
p
],
],
"attention_mask"
:
[
[
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
],
"labels"
:
[
[
0
,
1
,
2
,
3
,
q
,
q
,
q
,
q
,
q
,
q
,
q
,
q
],
],
**
tokenizer_module
[
"processor"
].
image_processor
(
fake_image
),
}
for
k
in
batch_input
.
keys
():
assert
batch_input
[
k
].
eq
(
torch
.
tensor
(
expected_input
[
k
])).
all
()
def
test_4d_attention_mask
():
def
test_4d_attention_mask
():
...
...
tests/data/test_formatter.py
View file @
8293100a
...
@@ -13,10 +13,29 @@
...
@@ -13,10 +13,29 @@
# limitations under the License.
# limitations under the License.
import
json
import
json
from
datetime
import
datetime
from
llamafactory.data.formatter
import
EmptyFormatter
,
FunctionFormatter
,
StringFormatter
,
ToolFormatter
from
llamafactory.data.formatter
import
EmptyFormatter
,
FunctionFormatter
,
StringFormatter
,
ToolFormatter
FUNCTION
=
{
"name"
:
"tool_name"
,
"arguments"
:
{
"foo"
:
"bar"
,
"size"
:
10
}}
TOOLS
=
[
{
"name"
:
"test_tool"
,
"description"
:
"tool_desc"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"foo"
:
{
"type"
:
"string"
,
"description"
:
"foo_desc"
},
"bar"
:
{
"type"
:
"number"
,
"description"
:
"bar_desc"
},
},
"required"
:
[
"foo"
],
},
}
]
def
test_empty_formatter
():
def
test_empty_formatter
():
formatter
=
EmptyFormatter
(
slots
=
[
"
\n
"
])
formatter
=
EmptyFormatter
(
slots
=
[
"
\n
"
])
assert
formatter
.
apply
()
==
[
"
\n
"
]
assert
formatter
.
apply
()
==
[
"
\n
"
]
...
@@ -28,39 +47,27 @@ def test_string_formatter():
...
@@ -28,39 +47,27 @@ def test_string_formatter():
def
test_function_formatter
():
def
test_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[],
tool_format
=
"default"
)
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
,
"</s>"
],
tool_format
=
"default"
)
tool_calls
=
json
.
dumps
(
{
"name"
:
"tool_name"
,
"arguments"
:
{
"foo"
:
"bar"
,
"size"
:
10
}}
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""Action: tool_name
\n
Action Input: {
\"
foo
\"
:
\"
bar
\"
,
\"
size
\"
: 10}
\n
"""
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}
\n
"""
,
"</s>"
,
]
]
def
test_multi_function_formatter
():
def
test_multi_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[],
tool_format
=
"default"
)
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
,
"</s>"
],
tool_format
=
"default"
)
tool_calls
=
json
.
dumps
([
{
"name"
:
"tool_name"
,
"arguments"
:
{
"foo"
:
"bar"
,
"size"
:
10
}}
]
*
2
)
tool_calls
=
json
.
dumps
([
FUNCTION
]
*
2
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""Action: tool_name
\n
Action Input: {
\"
foo
\"
:
\"
bar
\"
,
\"
size
\"
: 10}
\n
"""
,
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}
\n
"""
"""Action: tool_name
\n
Action Input: {
\"
foo
\"
:
\"
bar
\"
,
\"
size
\"
: 10}
\n
"""
,
"""Action: tool_name
\n
Action Input: {"foo": "bar", "size": 10}
\n
"""
,
"</s>"
,
]
]
def
test_default_tool_formatter
():
def
test_default_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"default"
)
formatter
=
ToolFormatter
(
tool_format
=
"default"
)
tools
=
[
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
{
"name"
:
"test_tool"
,
"description"
:
"tool_desc"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"foo"
:
{
"type"
:
"string"
,
"description"
:
"foo_desc"
},
"bar"
:
{
"type"
:
"number"
,
"description"
:
"bar_desc"
},
},
"required"
:
[
"foo"
],
},
}
]
assert
formatter
.
apply
(
content
=
json
.
dumps
(
tools
))
==
[
"You have access to the following tools:
\n
"
"You have access to the following tools:
\n
"
"> Tool Name: test_tool
\n
"
"> Tool Name: test_tool
\n
"
"Tool Description: tool_desc
\n
"
"Tool Description: tool_desc
\n
"
...
@@ -94,26 +101,18 @@ def test_default_multi_tool_extractor():
...
@@ -94,26 +101,18 @@ def test_default_multi_tool_extractor():
]
]
def
test_glm4_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
],
tool_format
=
"glm4"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""tool_name
\n
{"foo": "bar", "size": 10}"""
]
def
test_glm4_tool_formatter
():
def
test_glm4_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"glm4"
)
formatter
=
ToolFormatter
(
tool_format
=
"glm4"
)
tools
=
[
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
{
"name"
:
"test_tool"
,
"description"
:
"tool_desc"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"foo"
:
{
"type"
:
"string"
,
"description"
:
"foo_desc"
},
"bar"
:
{
"type"
:
"number"
,
"description"
:
"bar_desc"
},
},
"required"
:
[
"foo"
],
},
}
]
assert
formatter
.
apply
(
content
=
json
.
dumps
(
tools
))
==
[
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具
\n\n
"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具
\n\n
"
"## test_tool
\n\n
{
}
\n
在调用上述函数时,请使用 Json 格式表示调用的参数。"
.
format
(
json
.
dumps
(
tools
[
0
],
indent
=
4
))
f
"## test_tool
\n\n
{
json
.
dumps
(
TOOLS
[
0
],
indent
=
4
,
ensure_ascii
=
False
)
}
\n
在调用上述函数时,请使用 Json 格式表示调用的参数。"
]
]
...
@@ -121,3 +120,127 @@ def test_glm4_tool_extractor():
...
@@ -121,3 +120,127 @@ def test_glm4_tool_extractor():
formatter
=
ToolFormatter
(
tool_format
=
"glm4"
)
formatter
=
ToolFormatter
(
tool_format
=
"glm4"
)
result
=
"""test_tool
\n
{"foo": "bar", "size": 10}
\n
"""
result
=
"""test_tool
\n
{"foo": "bar", "size": 10}
\n
"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
def
test_llama3_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
,
"<|eot_id|>"
],
tool_format
=
"llama3"
)
tool_calls
=
json
.
dumps
({
"name"
:
"tool_name"
,
"arguments"
:
{
"foo"
:
"bar"
,
"size"
:
10
}})
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}"""
,
"<|eot_id|>"
,
]
def
test_llama3_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"llama3"
)
date
=
datetime
.
now
().
strftime
(
"%d %b %Y"
)
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
TOOLS
[
0
]}
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
f
"Cutting Knowledge Date: December 2023
\n
Today Date:
{
date
}
\n\n
"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
f
"Do not use variables.
\n\n
{
json
.
dumps
(
wrapped_tool
,
indent
=
4
,
ensure_ascii
=
False
)
}
\n\n
"
]
def
test_llama3_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"llama3"
)
result
=
"""{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}
\n
"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
def
test_mistral_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] "
,
"{{content}}"
,
"</s>"
],
tool_format
=
"mistral"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"[TOOL_CALLS] "
,
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]"""
,
"</s>"
,
]
def
test_mistral_multi_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] "
,
"{{content}}"
,
"</s>"
],
tool_format
=
"mistral"
)
tool_calls
=
json
.
dumps
([
FUNCTION
]
*
2
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"[TOOL_CALLS] "
,
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]"""
,
"</s>"
,
]
def
test_mistral_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"mistral"
)
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
TOOLS
[
0
]}
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
"[AVAILABLE_TOOLS] "
+
json
.
dumps
([
wrapped_tool
],
ensure_ascii
=
False
)
+
"[/AVAILABLE_TOOLS]"
]
def
test_mistral_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"mistral"
)
result
=
"""{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
def
test_mistral_multi_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"mistral"
)
result
=
(
"""[{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}]"""
)
assert
formatter
.
extract
(
result
)
==
[
(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
),
(
"another_tool"
,
"""{"foo": "job", "size": 2}"""
),
]
def
test_qwen_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
,
"<|im_end|>"
],
tool_format
=
"qwen"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""<tool_call>
\n
{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}
\n
</tool_call>"""
,
"<|im_end|>"
,
]
def
test_qwen_multi_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
,
"<|im_end|>"
],
tool_format
=
"qwen"
)
tool_calls
=
json
.
dumps
([
FUNCTION
]
*
2
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""<tool_call>
\n
{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}
\n
</tool_call>
\n
"""
"""<tool_call>
\n
{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}
\n
</tool_call>"""
,
"<|im_end|>"
,
]
def
test_qwen_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"qwen"
)
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
TOOLS
[
0
]}
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
"
\n\n
# Tools
\n\n
You may call one or more functions to assist with the user query.
\n\n
"
"You are provided with function signatures within <tools></tools> XML tags:
\n
<tools>"
f
"
\n
{
json
.
dumps
(
wrapped_tool
,
ensure_ascii
=
False
)
}
"
"
\n
</tools>
\n\n
For each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:
\n
<tool_call>
\n
{"name": <function-name>, """
""""arguments": <args-json-object>}
\n
</tool_call><|im_end|>
\n
"""
]
def
test_qwen_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"qwen"
)
result
=
"""<tool_call>
\n
{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}
\n
</tool_call>"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
def
test_qwen_multi_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"qwen"
)
result
=
(
"""<tool_call>
\n
{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}
\n
</tool_call>
\n
"""
"""<tool_call>
\n
{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}
\n
</tool_call>"""
)
assert
formatter
.
extract
(
result
)
==
[
(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
),
(
"another_tool"
,
"""{"foo": "job", "size": 2}"""
),
]
tests/data/test_mm_plugin.py
View file @
8293100a
...
@@ -13,14 +13,14 @@
...
@@ -13,14 +13,14 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Sequence
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Sequence
import
pytest
import
pytest
import
torch
import
torch
from
PIL
import
Image
from
PIL
import
Image
from
llamafactory.data.mm_plugin
import
get_mm_plugin
from
llamafactory.data.mm_plugin
import
get_mm_plugin
from
llamafactory.hparams
import
ModelArgument
s
from
llamafactory.hparams
import
get_infer_arg
s
from
llamafactory.model
import
load_tokenizer
from
llamafactory.model
import
load_tokenizer
...
@@ -29,6 +29,7 @@ if TYPE_CHECKING:
...
@@ -29,6 +29,7 @@ if TYPE_CHECKING:
from
transformers.image_processing_utils
import
BaseImageProcessor
from
transformers.image_processing_utils
import
BaseImageProcessor
from
llamafactory.data.mm_plugin
import
BasePlugin
from
llamafactory.data.mm_plugin
import
BasePlugin
from
llamafactory.model.loader
import
TokenizerModule
HF_TOKEN
=
os
.
getenv
(
"HF_TOKEN"
)
HF_TOKEN
=
os
.
getenv
(
"HF_TOKEN"
)
...
@@ -82,10 +83,9 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
...
@@ -82,10 +83,9 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
assert
batch_a
[
key
]
==
batch_b
[
key
]
assert
batch_a
[
key
]
==
batch_b
[
key
]
def
_load_tokenizer_module
(
model_name_or_path
:
str
)
->
Tuple
[
"PreTrainedTokenizer"
,
"ProcessorMixin"
]:
def
_load_tokenizer_module
(
model_name_or_path
:
str
)
->
"TokenizerModule"
:
model_args
=
ModelArguments
(
model_name_or_path
=
model_name_or_path
)
model_args
,
*
_
=
get_infer_args
({
"model_name_or_path"
:
model_name_or_path
,
"template"
:
"default"
})
tokenizer_module
=
load_tokenizer
(
model_args
)
return
load_tokenizer
(
model_args
)
return
tokenizer_module
[
"tokenizer"
],
tokenizer_module
[
"processor"
]
def
_check_plugin
(
def
_check_plugin
(
...
@@ -121,73 +121,75 @@ def _check_plugin(
...
@@ -121,73 +121,75 @@ def _check_plugin(
def
test_base_plugin
():
def
test_base_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
TINY_LLAMA
)
tokenizer
_module
=
_load_tokenizer_module
(
model_name_or_path
=
TINY_LLAMA
)
base_plugin
=
get_mm_plugin
(
name
=
"base"
,
image_token
=
"<image>"
)
base_plugin
=
get_mm_plugin
(
name
=
"base"
,
image_token
=
"<image>"
)
check_inputs
=
{
"plugin"
:
base_plugin
,
"
tokenizer
"
:
tokenizer
,
"processor"
:
processor
}
check_inputs
=
{
"plugin"
:
base_plugin
,
**
tokenizer
_module
}
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
def
test_llava_plugin
():
def
test_llava_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-1.5-7b-hf"
)
llava_plugin
=
get_mm_plugin
(
name
=
"llava"
,
image_token
=
"<image>"
)
image_seqlen
=
576
image_seqlen
=
576
check_inputs
=
{
"plugin"
:
llava_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-1.5-7b-hf"
)
llava_plugin
=
get_mm_plugin
(
name
=
"llava"
,
image_token
=
"<image>"
)
check_inputs
=
{
"plugin"
:
llava_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
for
message
in
MM_MESSAGES
]
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"
processor
"
]
)
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
def
test_llava_next_plugin
():
def
test_llava_next_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-v1.6-vicuna-7b-hf"
)
llava_next_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
)
check_inputs
=
{
"plugin"
:
llava_next_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
image_seqlen
=
1176
image_seqlen
=
1176
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-v1.6-vicuna-7b-hf"
)
llava_next_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
)
check_inputs
=
{
"plugin"
:
llava_next_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
for
message
in
MM_MESSAGES
]
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"
processor
"
]
)
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
def
test_llava_next_video_plugin
():
def
test_llava_next_video_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/LLaVA-NeXT-Video-7B-hf"
)
llava_next_video_plugin
=
get_mm_plugin
(
name
=
"llava_next_video"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
)
check_inputs
=
{
"plugin"
:
llava_next_video_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
image_seqlen
=
1176
image_seqlen
=
1176
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/LLaVA-NeXT-Video-7B-hf"
)
llava_next_video_plugin
=
get_mm_plugin
(
name
=
"llava_next_video"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
)
check_inputs
=
{
"plugin"
:
llava_next_video_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
for
message
in
MM_MESSAGES
]
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"
processor
"
]
)
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
,
reason
=
"Gated model."
)
def
test_paligemma_plugin
():
def
test_paligemma_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"google/paligemma-3b-pt-224"
)
paligemma_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
)
image_seqlen
=
256
image_seqlen
=
256
check_inputs
=
{
"plugin"
:
paligemma_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"google/paligemma-3b-pt-224"
)
paligemma_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
)
check_inputs
=
{
"plugin"
:
paligemma_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
""
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
{
key
:
value
.
replace
(
"<image>"
,
""
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
]
]
check_inputs
[
"expected_input_ids"
]
=
[
tokenizer
.
convert_tokens_to_ids
(
"<image>"
)]
*
image_seqlen
+
INPUT_IDS
check_inputs
[
"expected_input_ids"
]
=
[
tokenizer_module
[
"tokenizer"
].
convert_tokens_to_ids
(
paligemma_plugin
.
image_token
)
]
*
image_seqlen
+
INPUT_IDS
check_inputs
[
"expected_labels"
]
=
[
-
100
]
*
image_seqlen
+
LABELS
check_inputs
[
"expected_labels"
]
=
[
-
100
]
*
image_seqlen
+
LABELS
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"
processor
"
]
)
check_inputs
[
"expected_mm_inputs"
][
"token_type_ids"
]
=
[[
0
]
*
image_seqlen
+
[
1
]
*
(
1024
-
image_seqlen
)]
check_inputs
[
"expected_mm_inputs"
][
"token_type_ids"
]
=
[[
0
]
*
image_seqlen
+
[
1
]
*
(
1024
-
image_seqlen
)]
check_inputs
[
"expected_no_mm_inputs"
]
=
{
"token_type_ids"
:
[[
1
]
*
1024
]}
check_inputs
[
"expected_no_mm_inputs"
]
=
{
"token_type_ids"
:
[[
1
]
*
1024
]}
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
def
test_pixtral_plugin
():
def
test_pixtral_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"mistral-community/pixtral-12b"
)
pixtral_plugin
=
get_mm_plugin
(
name
=
"pixtral"
,
image_token
=
"[IMG]"
)
image_slice_height
,
image_slice_width
=
2
,
2
image_slice_height
,
image_slice_width
=
2
,
2
check_inputs
=
{
"plugin"
:
pixtral_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"mistral-community/pixtral-12b"
)
pixtral_plugin
=
get_mm_plugin
(
name
=
"pixtral"
,
image_token
=
"[IMG]"
)
check_inputs
=
{
"plugin"
:
pixtral_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
check_inputs
[
"expected_mm_messages"
]
=
[
{
{
key
:
value
.
replace
(
key
:
value
.
replace
(
...
@@ -199,17 +201,17 @@ def test_pixtral_plugin():
...
@@ -199,17 +201,17 @@ def test_pixtral_plugin():
}
}
for
message
in
MM_MESSAGES
for
message
in
MM_MESSAGES
]
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"
processor
"
]
)
check_inputs
[
"expected_mm_inputs"
].
pop
(
"image_sizes"
)
check_inputs
[
"expected_mm_inputs"
].
pop
(
"image_sizes"
)
check_inputs
[
"expected_mm_inputs"
][
"pixel_values"
]
=
check_inputs
[
"expected_mm_inputs"
][
"pixel_values"
][
0
]
check_inputs
[
"expected_mm_inputs"
][
"pixel_values"
]
=
check_inputs
[
"expected_mm_inputs"
][
"pixel_values"
][
0
]
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
def
test_qwen2_vl_plugin
():
def
test_qwen2_vl_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"Qwen/Qwen2-VL-7B-Instruct"
)
qwen2_vl_plugin
=
get_mm_plugin
(
name
=
"qwen2_vl"
,
image_token
=
"<|image_pad|>"
)
image_seqlen
=
4
image_seqlen
=
4
check_inputs
=
{
"plugin"
:
qwen2_vl_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"Qwen/Qwen2-VL-7B-Instruct"
)
qwen2_vl_plugin
=
get_mm_plugin
(
name
=
"qwen2_vl"
,
image_token
=
"<|image_pad|>"
)
check_inputs
=
{
"plugin"
:
qwen2_vl_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
check_inputs
[
"expected_mm_messages"
]
=
[
{
{
key
:
value
.
replace
(
"<image>"
,
"<|vision_start|>{}<|vision_end|>"
.
format
(
"<|image_pad|>"
*
image_seqlen
))
key
:
value
.
replace
(
"<image>"
,
"<|vision_start|>{}<|vision_end|>"
.
format
(
"<|image_pad|>"
*
image_seqlen
))
...
@@ -217,18 +219,18 @@ def test_qwen2_vl_plugin():
...
@@ -217,18 +219,18 @@ def test_qwen2_vl_plugin():
}
}
for
message
in
MM_MESSAGES
for
message
in
MM_MESSAGES
]
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"
processor
"
]
)
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
def
test_video_llava_plugin
():
def
test_video_llava_plugin
():
tokenizer
,
processor
=
_load_tokenizer_module
(
model_name_or_path
=
"LanguageBind/Video-LLaVA-7B-hf"
)
video_llava_plugin
=
get_mm_plugin
(
name
=
"video_llava"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
)
check_inputs
=
{
"plugin"
:
video_llava_plugin
,
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
image_seqlen
=
256
image_seqlen
=
256
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"LanguageBind/Video-LLaVA-7B-hf"
)
video_llava_plugin
=
get_mm_plugin
(
name
=
"video_llava"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
)
check_inputs
=
{
"plugin"
:
video_llava_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
{
key
:
value
.
replace
(
"<image>"
,
"<image>"
*
image_seqlen
)
for
key
,
value
in
message
.
items
()}
for
message
in
MM_MESSAGES
for
message
in
MM_MESSAGES
]
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
processor
)
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"
processor
"
]
)
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment