Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
7ea81099
Commit
7ea81099
authored
Apr 07, 2025
by
chenych
Browse files
update llama4
parent
84987715
Changes
139
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
207 additions
and
246 deletions
+207
-246
src/llamafactory/train/ppo/trainer.py
src/llamafactory/train/ppo/trainer.py
+27
-36
src/llamafactory/train/ppo/workflow.py
src/llamafactory/train/ppo/workflow.py
+4
-4
src/llamafactory/train/pt/trainer.py
src/llamafactory/train/pt/trainer.py
+5
-3
src/llamafactory/train/pt/workflow.py
src/llamafactory/train/pt/workflow.py
+3
-3
src/llamafactory/train/rm/metric.py
src/llamafactory/train/rm/metric.py
+4
-6
src/llamafactory/train/rm/trainer.py
src/llamafactory/train/rm/trainer.py
+8
-12
src/llamafactory/train/rm/workflow.py
src/llamafactory/train/rm/workflow.py
+3
-3
src/llamafactory/train/sft/metric.py
src/llamafactory/train/sft/metric.py
+11
-16
src/llamafactory/train/sft/trainer.py
src/llamafactory/train/sft/trainer.py
+17
-14
src/llamafactory/train/sft/workflow.py
src/llamafactory/train/sft/workflow.py
+3
-3
src/llamafactory/train/test_utils.py
src/llamafactory/train/test_utils.py
+5
-5
src/llamafactory/train/trainer_utils.py
src/llamafactory/train/trainer_utils.py
+48
-34
src/llamafactory/train/tuner.py
src/llamafactory/train/tuner.py
+9
-5
src/llamafactory/webui/chatter.py
src/llamafactory/webui/chatter.py
+15
-18
src/llamafactory/webui/common.py
src/llamafactory/webui/common.py
+26
-60
src/llamafactory/webui/components/chatbot.py
src/llamafactory/webui/components/chatbot.py
+3
-5
src/llamafactory/webui/components/data.py
src/llamafactory/webui/components/data.py
+6
-10
src/llamafactory/webui/components/eval.py
src/llamafactory/webui/components/eval.py
+2
-2
src/llamafactory/webui/components/export.py
src/llamafactory/webui/components/export.py
+5
-4
src/llamafactory/webui/components/infer.py
src/llamafactory/webui/components/infer.py
+3
-3
No files found.
src/llamafactory/train/ppo/trainer.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's TRL library.
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py
...
@@ -20,7 +20,7 @@ import os
...
@@ -20,7 +20,7 @@ import os
import
sys
import
sys
import
warnings
import
warnings
from
types
import
MethodType
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
import
torch
from
accelerate.utils
import
DistributedDataParallelKwargs
from
accelerate.utils
import
DistributedDataParallelKwargs
...
@@ -62,9 +62,7 @@ logger = logging.get_logger(__name__)
...
@@ -62,9 +62,7 @@ logger = logging.get_logger(__name__)
class
CustomPPOTrainer
(
PPOTrainer
,
Trainer
):
class
CustomPPOTrainer
(
PPOTrainer
,
Trainer
):
r
"""
r
"""Inherit PPOTrainer."""
Inherits PPOTrainer.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -72,7 +70,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -72,7 +70,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
training_args
:
"Seq2SeqTrainingArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
generating_args
:
"GeneratingArguments"
,
callbacks
:
Optional
[
L
ist
[
"TrainerCallback"
]],
callbacks
:
Optional
[
l
ist
[
"TrainerCallback"
]],
model
:
"AutoModelForCausalLMWithValueHead"
,
model
:
"AutoModelForCausalLMWithValueHead"
,
reward_model
:
Optional
[
"AutoModelForCausalLMWithValueHead"
],
reward_model
:
Optional
[
"AutoModelForCausalLMWithValueHead"
],
ref_model
:
Optional
[
"AutoModelForCausalLMWithValueHead"
],
ref_model
:
Optional
[
"AutoModelForCausalLMWithValueHead"
],
...
@@ -187,9 +185,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -187,9 +185,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self
.
add_callback
(
BAdamCallback
)
self
.
add_callback
(
BAdamCallback
)
def
ppo_train
(
self
,
resume_from_checkpoint
:
Optional
[
str
]
=
None
)
->
None
:
def
ppo_train
(
self
,
resume_from_checkpoint
:
Optional
[
str
]
=
None
)
->
None
:
r
"""
r
"""Implement training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer."""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
"""
if
resume_from_checkpoint
is
not
None
:
if
resume_from_checkpoint
is
not
None
:
raise
ValueError
(
"`resume_from_checkpoint` will be supported in the future version."
)
raise
ValueError
(
"`resume_from_checkpoint` will be supported in the future version."
)
...
@@ -221,9 +217,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -221,9 +217,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
logger
.
info_rank0
(
f
" Num Epochs =
{
num_train_epochs
:,
}
"
)
logger
.
info_rank0
(
f
" Num Epochs =
{
num_train_epochs
:,
}
"
)
logger
.
info_rank0
(
f
" Instantaneous batch size per device =
{
self
.
args
.
per_device_train_batch_size
:,
}
"
)
logger
.
info_rank0
(
f
" Instantaneous batch size per device =
{
self
.
args
.
per_device_train_batch_size
:,
}
"
)
logger
.
info_rank0
(
logger
.
info_rank0
(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}"
.
format
(
f
" Total train batch size (w. parallel, buffer, distributed & accumulation) =
{
total_train_batch_size
:,
}
"
total_train_batch_size
)
)
)
logger
.
info_rank0
(
f
" Gradient Accumulation steps =
{
self
.
args
.
gradient_accumulation_steps
:,
}
"
)
logger
.
info_rank0
(
f
" Gradient Accumulation steps =
{
self
.
args
.
gradient_accumulation_steps
:,
}
"
)
logger
.
info_rank0
(
f
" Num optimization epochs per batch =
{
self
.
finetuning_args
.
ppo_epochs
:,
}
"
)
logger
.
info_rank0
(
f
" Num optimization epochs per batch =
{
self
.
finetuning_args
.
ppo_epochs
:,
}
"
)
...
@@ -247,9 +241,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -247,9 +241,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self
.
tokenizer
.
padding_side
=
"right"
# change padding side
self
.
tokenizer
.
padding_side
=
"right"
# change padding side
queries
,
responses
,
rewards
=
[],
[],
[]
queries
,
responses
,
rewards
=
[],
[],
[]
for
idx
in
range
(
0
,
self
.
config
.
batch_size
,
self
.
config
.
mini_batch_size
):
for
idx
in
range
(
0
,
self
.
config
.
batch_size
,
self
.
config
.
mini_batch_size
):
mini_batch_queries
,
mini_batch_responses
=
self
.
get_inputs
(
mini_batch
=
{
batch
[
idx
:
idx
+
self
.
config
.
mini_batch_size
]
"input_ids"
:
batch
[
"input_ids"
][
idx
:
idx
+
self
.
config
.
mini_batch_size
],
)
"attention_mask"
:
batch
[
"attention_mask"
][
idx
:
idx
+
self
.
config
.
mini_batch_size
],
}
mini_batch_queries
,
mini_batch_responses
=
self
.
get_inputs
(
mini_batch
)
mini_batch_rewards
=
self
.
get_rewards
(
mini_batch_queries
,
mini_batch_responses
)
mini_batch_rewards
=
self
.
get_rewards
(
mini_batch_queries
,
mini_batch_responses
)
queries
.
extend
(
mini_batch_queries
)
queries
.
extend
(
mini_batch_queries
)
responses
.
extend
(
mini_batch_responses
)
responses
.
extend
(
mini_batch_responses
)
...
@@ -339,21 +335,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -339,21 +335,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
return
lr_scheduler
return
lr_scheduler
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
get_inputs
(
self
,
batch
:
Dict
[
str
,
"torch.Tensor"
])
->
Tuple
[
List
[
"torch.Tensor"
],
List
[
"torch.Tensor"
]]:
def
get_inputs
(
self
,
batch
:
dict
[
str
,
"torch.Tensor"
])
->
tuple
[
list
[
"torch.Tensor"
],
list
[
"torch.Tensor"
]]:
r
"""
r
"""Generate model's responses given queries."""
Generates model's responses given queries.
"""
if
batch
[
"input_ids"
].
size
(
0
)
==
1
:
# handle llama2 ppo with gradient accumulation > 1
if
batch
[
"input_ids"
].
size
(
0
)
==
1
:
# handle llama2 ppo with gradient accumulation > 1
start_index
=
(
batch
[
"input_ids"
][
0
]
!=
self
.
tokenizer
.
pad_token_id
).
nonzero
()[
0
].
item
()
start_index
=
(
batch
[
"input_ids"
][
0
]
!=
self
.
tokenizer
.
pad_token_id
).
nonzero
()[
0
].
item
()
for
k
,
v
in
batch
.
items
():
for
k
,
v
in
batch
.
items
():
batch
[
k
]
=
v
[:,
start_index
:]
batch
[
k
]
=
v
[:,
start_index
:]
with
unwrap_model_for_generation
(
self
.
model
,
self
.
accelerator
)
as
unwrapped_model
:
with
unwrap_model_for_generation
(
self
.
model
,
self
.
accelerator
)
as
unwrapped_model
:
unwrapped_model
:
"
AutoModelForCausalLMWithValueHead
"
=
self
.
accelerator
.
unwrap_model
(
self
.
model
)
unwrapped_model
:
AutoModelForCausalLMWithValueHead
=
self
.
accelerator
.
unwrap_model
(
self
.
model
)
if
self
.
model_args
.
upcast_layernorm
:
if
self
.
model_args
.
upcast_layernorm
:
layernorm_params
=
dump_layernorm
(
unwrapped_model
)
layernorm_params
=
dump_layernorm
(
unwrapped_model
)
generate_output
:
"
torch.Tensor
"
=
unwrapped_model
.
generate
(
generate_output
:
torch
.
Tensor
=
unwrapped_model
.
generate
(
generation_config
=
self
.
generation_config
,
logits_processor
=
get_logits_processor
(),
**
batch
generation_config
=
self
.
generation_config
,
logits_processor
=
get_logits_processor
(),
**
batch
)
)
if
self
.
model_args
.
upcast_layernorm
:
if
self
.
model_args
.
upcast_layernorm
:
...
@@ -381,11 +375,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -381,11 +375,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
get_rewards
(
def
get_rewards
(
self
,
self
,
queries
:
List
[
"torch.Tensor"
],
queries
:
list
[
"torch.Tensor"
],
responses
:
List
[
"torch.Tensor"
],
responses
:
list
[
"torch.Tensor"
],
)
->
List
[
"torch.Tensor"
]:
)
->
list
[
"torch.Tensor"
]:
r
"""
r
"""Compute scores using given reward model.
Computes scores using given reward model.
Both inputs and outputs are put on CPU.
Both inputs and outputs are put on CPU.
"""
"""
...
@@ -394,8 +387,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -394,8 +387,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
messages
=
self
.
tokenizer
.
batch_decode
(
token_ids
,
skip_special_tokens
=
False
)
messages
=
self
.
tokenizer
.
batch_decode
(
token_ids
,
skip_special_tokens
=
False
)
return
get_rewards_from_server
(
self
.
reward_model
,
messages
)
return
get_rewards_from_server
(
self
.
reward_model
,
messages
)
batch
:
D
ict
[
str
,
"
torch.Tensor
"
]
=
self
.
prepare_model_inputs
(
queries
,
responses
)
batch
:
d
ict
[
str
,
torch
.
Tensor
]
=
self
.
prepare_model_inputs
(
queries
,
responses
)
unwrapped_model
:
"
AutoModelForCausalLMWithValueHead
"
=
self
.
accelerator
.
unwrap_model
(
self
.
model
)
unwrapped_model
:
AutoModelForCausalLMWithValueHead
=
self
.
accelerator
.
unwrap_model
(
self
.
model
)
if
self
.
finetuning_args
.
reward_model_type
==
"lora"
:
if
self
.
finetuning_args
.
reward_model_type
==
"lora"
:
replace_model
(
unwrapped_model
,
target
=
"reward"
)
replace_model
(
unwrapped_model
,
target
=
"reward"
)
...
@@ -404,7 +397,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -404,7 +397,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
reward_model
=
self
.
reward_model
reward_model
=
self
.
reward_model
with
unwrap_model_for_generation
(
reward_model
,
self
.
accelerator
),
self
.
amp_context
:
# support bf16
with
unwrap_model_for_generation
(
reward_model
,
self
.
accelerator
),
self
.
amp_context
:
# support bf16
values
:
"
torch.Tensor
"
=
reward_model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
)[
-
1
]
values
:
torch
.
Tensor
=
reward_model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
)[
-
1
]
if
self
.
finetuning_args
.
reward_model_type
==
"lora"
:
if
self
.
finetuning_args
.
reward_model_type
==
"lora"
:
replace_model
(
unwrapped_model
,
target
=
"default"
)
replace_model
(
unwrapped_model
,
target
=
"default"
)
...
@@ -419,12 +412,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -419,12 +412,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
model
:
"AutoModelForCausalLMWithValueHead"
,
model
:
"AutoModelForCausalLMWithValueHead"
,
queries
:
"torch.Tensor"
,
queries
:
"torch.Tensor"
,
responses
:
"torch.Tensor"
,
responses
:
"torch.Tensor"
,
model_inputs
:
D
ict
[
str
,
Any
],
model_inputs
:
d
ict
[
str
,
Any
],
return_logits
:
bool
=
False
,
return_logits
:
bool
=
False
,
response_masks
:
Optional
[
"torch.Tensor"
]
=
None
,
response_masks
:
Optional
[
"torch.Tensor"
]
=
None
,
)
->
Tuple
[
"torch.Tensor"
,
Optional
[
"torch.Tensor"
],
"torch.Tensor"
,
"torch.Tensor"
]:
)
->
tuple
[
"torch.Tensor"
,
Optional
[
"torch.Tensor"
],
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""
r
"""Calculate model outputs in multiple batches.
Calculates model outputs in multiple batches.
Subclass and override to inject custom behavior.
Subclass and override to inject custom behavior.
"""
"""
...
@@ -483,8 +475,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -483,8 +475,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
@
override
@
override
def
save_model
(
self
,
output_dir
:
Optional
[
str
]
=
None
)
->
None
:
def
save_model
(
self
,
output_dir
:
Optional
[
str
]
=
None
)
->
None
:
r
"""
r
"""Save model checkpoint.
Saves model checkpoint.
Subclass and override to inject custom behavior.
Subclass and override to inject custom behavior.
"""
"""
...
@@ -508,5 +499,5 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -508,5 +499,5 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self
.
model
.
save_checkpoint
(
output_dir
)
self
.
model
.
save_checkpoint
(
output_dir
)
elif
self
.
args
.
should_save
:
elif
self
.
args
.
should_save
:
unwrapped_model
:
"
AutoModelForCausalLMWithValueHead
"
=
self
.
accelerator
.
unwrap_model
(
self
.
model
)
unwrapped_model
:
AutoModelForCausalLMWithValueHead
=
self
.
accelerator
.
unwrap_model
(
self
.
model
)
self
.
_save
(
output_dir
,
state_dict
=
unwrapped_model
.
state_dict
())
self
.
_save
(
output_dir
,
state_dict
=
unwrapped_model
.
state_dict
())
src/llamafactory/train/ppo/workflow.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's TRL library.
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
MultiModalDataCollatorForSeq2Seq
,
get_dataset
,
get_template_and_fix_tokenizer
from
...data
import
MultiModalDataCollatorForSeq2Seq
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.ploting
import
plot_loss
from
...extras.ploting
import
plot_loss
...
@@ -37,7 +37,7 @@ def run_ppo(
...
@@ -37,7 +37,7 @@ def run_ppo(
training_args
:
"Seq2SeqTrainingArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
generating_args
:
"GeneratingArguments"
,
callbacks
:
Optional
[
L
ist
[
"TrainerCallback"
]]
=
None
,
callbacks
:
Optional
[
l
ist
[
"TrainerCallback"
]]
=
None
,
):
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
tokenizer
=
tokenizer_module
[
"tokenizer"
]
...
@@ -53,7 +53,7 @@ def run_ppo(
...
@@ -53,7 +53,7 @@ def run_ppo(
reward_model
=
create_reward_model
(
model
,
model_args
,
finetuning_args
)
reward_model
=
create_reward_model
(
model
,
model_args
,
finetuning_args
)
# Initialize our Trainer
# Initialize our Trainer
ppo_trainer
:
"
CustomPPOTrainer
"
=
CustomPPOTrainer
(
ppo_trainer
:
CustomPPOTrainer
=
CustomPPOTrainer
(
model_args
=
model_args
,
model_args
=
model_args
,
training_args
=
training_args
,
training_args
=
training_args
,
finetuning_args
=
finetuning_args
,
finetuning_args
=
finetuning_args
,
...
...
src/llamafactory/train/pt/trainer.py
View file @
7ea81099
...
@@ -31,9 +31,7 @@ if TYPE_CHECKING:
...
@@ -31,9 +31,7 @@ if TYPE_CHECKING:
class
CustomTrainer
(
Trainer
):
class
CustomTrainer
(
Trainer
):
r
"""
r
"""Inherit Trainer for custom optimizer."""
Inherits Trainer for custom optimizer.
"""
def
__init__
(
def
__init__
(
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
...
@@ -72,3 +70,7 @@ class CustomTrainer(Trainer):
...
@@ -72,3 +70,7 @@ class CustomTrainer(Trainer):
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
()
return
super
().
_get_train_sampler
()
@
override
def
compute_loss
(
self
,
model
,
inputs
,
*
args
,
**
kwargs
):
return
super
().
compute_loss
(
model
,
inputs
,
*
args
,
**
kwargs
)
src/llamafactory/train/pt/workflow.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's transformers library.
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
# limitations under the License.
# limitations under the License.
import
math
import
math
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
from
transformers
import
DataCollatorForLanguageModeling
from
transformers
import
DataCollatorForLanguageModeling
...
@@ -38,7 +38,7 @@ def run_pt(
...
@@ -38,7 +38,7 @@ def run_pt(
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
finetuning_args
:
"FinetuningArguments"
,
callbacks
:
Optional
[
L
ist
[
"TrainerCallback"
]]
=
None
,
callbacks
:
Optional
[
l
ist
[
"TrainerCallback"
]]
=
None
,
):
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
tokenizer
=
tokenizer_module
[
"tokenizer"
]
...
...
src/llamafactory/train/rm/metric.py
View file @
7ea81099
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
import
numpy
as
np
import
numpy
as
np
...
@@ -26,11 +26,9 @@ if TYPE_CHECKING:
...
@@ -26,11 +26,9 @@ if TYPE_CHECKING:
@
dataclass
@
dataclass
class
ComputeAccuracy
:
class
ComputeAccuracy
:
r
"""
r
"""Compute reward accuracy and support `batch_eval_metrics`."""
Computes reward accuracy and supports `batch_eval_metrics`.
"""
def
_dump
(
self
)
->
Optional
[
D
ict
[
str
,
float
]]:
def
_dump
(
self
)
->
Optional
[
d
ict
[
str
,
float
]]:
result
=
None
result
=
None
if
hasattr
(
self
,
"score_dict"
):
if
hasattr
(
self
,
"score_dict"
):
result
=
{
k
:
float
(
np
.
mean
(
v
))
for
k
,
v
in
self
.
score_dict
.
items
()}
result
=
{
k
:
float
(
np
.
mean
(
v
))
for
k
,
v
in
self
.
score_dict
.
items
()}
...
@@ -41,7 +39,7 @@ class ComputeAccuracy:
...
@@ -41,7 +39,7 @@ class ComputeAccuracy:
def
__post_init__
(
self
):
def
__post_init__
(
self
):
self
.
_dump
()
self
.
_dump
()
def
__call__
(
self
,
eval_preds
:
"EvalPrediction"
,
compute_result
:
bool
=
True
)
->
Optional
[
D
ict
[
str
,
float
]]:
def
__call__
(
self
,
eval_preds
:
"EvalPrediction"
,
compute_result
:
bool
=
True
)
->
Optional
[
d
ict
[
str
,
float
]]:
chosen_scores
,
rejected_scores
=
numpify
(
eval_preds
.
predictions
[
0
]),
numpify
(
eval_preds
.
predictions
[
1
])
chosen_scores
,
rejected_scores
=
numpify
(
eval_preds
.
predictions
[
0
]),
numpify
(
eval_preds
.
predictions
[
1
])
if
not
chosen_scores
.
shape
:
if
not
chosen_scores
.
shape
:
self
.
score_dict
[
"accuracy"
].
append
(
chosen_scores
>
rejected_scores
)
self
.
score_dict
[
"accuracy"
].
append
(
chosen_scores
>
rejected_scores
)
...
...
src/llamafactory/train/rm/trainer.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's transformers library.
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
import
json
import
json
import
os
import
os
from
types
import
MethodType
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
torch
import
torch
from
transformers
import
Trainer
from
transformers
import
Trainer
...
@@ -41,9 +41,7 @@ logger = logging.get_logger(__name__)
...
@@ -41,9 +41,7 @@ logger = logging.get_logger(__name__)
class
PairwiseTrainer
(
Trainer
):
class
PairwiseTrainer
(
Trainer
):
r
"""
r
"""Inherits Trainer to compute pairwise loss."""
Inherits Trainer to compute pairwise loss.
"""
def
__init__
(
def
__init__
(
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
...
@@ -88,10 +86,9 @@ class PairwiseTrainer(Trainer):
...
@@ -88,10 +86,9 @@ class PairwiseTrainer(Trainer):
@
override
@
override
def
compute_loss
(
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
self
,
model
:
"PreTrainedModel"
,
inputs
:
dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
)
->
Union
[
"torch.Tensor"
,
tuple
[
"torch.Tensor"
,
list
[
"torch.Tensor"
]]]:
r
"""
r
"""Compute pairwise loss. The first n examples are chosen and the last n examples are rejected.
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
Subclass and override to inject custom behavior.
Subclass and override to inject custom behavior.
...
@@ -113,8 +110,7 @@ class PairwiseTrainer(Trainer):
...
@@ -113,8 +110,7 @@ class PairwiseTrainer(Trainer):
return
loss
return
loss
def
save_predictions
(
self
,
predict_results
:
"PredictionOutput"
)
->
None
:
def
save_predictions
(
self
,
predict_results
:
"PredictionOutput"
)
->
None
:
r
"""
r
"""Save model predictions to `output_dir`.
Saves model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
A custom behavior that not contained in Seq2SeqTrainer.
"""
"""
...
@@ -126,7 +122,7 @@ class PairwiseTrainer(Trainer):
...
@@ -126,7 +122,7 @@ class PairwiseTrainer(Trainer):
chosen_scores
,
rejected_scores
=
predict_results
.
predictions
chosen_scores
,
rejected_scores
=
predict_results
.
predictions
with
open
(
output_prediction_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
with
open
(
output_prediction_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
res
:
L
ist
[
str
]
=
[]
res
:
l
ist
[
str
]
=
[]
for
c_score
,
r_score
in
zip
(
chosen_scores
,
rejected_scores
):
for
c_score
,
r_score
in
zip
(
chosen_scores
,
rejected_scores
):
res
.
append
(
json
.
dumps
({
"chosen"
:
round
(
float
(
c_score
),
2
),
"rejected"
:
round
(
float
(
r_score
),
2
)}))
res
.
append
(
json
.
dumps
({
"chosen"
:
round
(
float
(
c_score
),
2
),
"rejected"
:
round
(
float
(
r_score
),
2
)}))
...
...
src/llamafactory/train/rm/workflow.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's transformers library.
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.ploting
import
plot_loss
from
...extras.ploting
import
plot_loss
...
@@ -37,7 +37,7 @@ def run_rm(
...
@@ -37,7 +37,7 @@ def run_rm(
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
finetuning_args
:
"FinetuningArguments"
,
callbacks
:
Optional
[
L
ist
[
"TrainerCallback"
]]
=
None
,
callbacks
:
Optional
[
l
ist
[
"TrainerCallback"
]]
=
None
,
):
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
tokenizer
=
tokenizer_module
[
"tokenizer"
]
...
...
src/llamafactory/train/sft/metric.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc., THUDM, and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc., THUDM, and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.
# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
# limitations under the License.
# limitations under the License.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -37,17 +37,15 @@ if is_jieba_available():
...
@@ -37,17 +37,15 @@ if is_jieba_available():
if
is_nltk_available
():
if
is_nltk_available
():
from
nltk.translate.bleu_score
import
SmoothingFunction
,
sentence_bleu
from
nltk.translate.bleu_score
import
SmoothingFunction
,
sentence_bleu
# type: ignore
if
is_rouge_available
():
if
is_rouge_available
():
from
rouge_chinese
import
Rouge
from
rouge_chinese
import
Rouge
# type: ignore
def
eval_logit_processor
(
logits
:
"torch.Tensor"
,
labels
:
"torch.Tensor"
)
->
"torch.Tensor"
:
def
eval_logit_processor
(
logits
:
"torch.Tensor"
,
labels
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""
r
"""Compute the token with the largest likelihood to reduce memory footprint."""
Computes the token with the largest likelihood to reduce memory footprint.
"""
if
isinstance
(
logits
,
(
list
,
tuple
)):
if
isinstance
(
logits
,
(
list
,
tuple
)):
if
logits
[
0
].
dim
()
==
3
:
# (batch_size, seq_len, vocab_size)
if
logits
[
0
].
dim
()
==
3
:
# (batch_size, seq_len, vocab_size)
logits
=
logits
[
0
]
logits
=
logits
[
0
]
...
@@ -62,11 +60,9 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
...
@@ -62,11 +60,9 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
@
dataclass
@
dataclass
class
ComputeAccuracy
:
class
ComputeAccuracy
:
r
"""
r
"""Compute accuracy and support `batch_eval_metrics`."""
Computes accuracy and supports `batch_eval_metrics`.
"""
def
_dump
(
self
)
->
Optional
[
D
ict
[
str
,
float
]]:
def
_dump
(
self
)
->
Optional
[
d
ict
[
str
,
float
]]:
result
=
None
result
=
None
if
hasattr
(
self
,
"score_dict"
):
if
hasattr
(
self
,
"score_dict"
):
result
=
{
k
:
float
(
np
.
mean
(
v
))
for
k
,
v
in
self
.
score_dict
.
items
()}
result
=
{
k
:
float
(
np
.
mean
(
v
))
for
k
,
v
in
self
.
score_dict
.
items
()}
...
@@ -77,7 +73,7 @@ class ComputeAccuracy:
...
@@ -77,7 +73,7 @@ class ComputeAccuracy:
def
__post_init__
(
self
):
def
__post_init__
(
self
):
self
.
_dump
()
self
.
_dump
()
def
__call__
(
self
,
eval_preds
:
"EvalPrediction"
,
compute_result
:
bool
=
True
)
->
Optional
[
D
ict
[
str
,
float
]]:
def
__call__
(
self
,
eval_preds
:
"EvalPrediction"
,
compute_result
:
bool
=
True
)
->
Optional
[
d
ict
[
str
,
float
]]:
preds
,
labels
=
numpify
(
eval_preds
.
predictions
),
numpify
(
eval_preds
.
label_ids
)
preds
,
labels
=
numpify
(
eval_preds
.
predictions
),
numpify
(
eval_preds
.
label_ids
)
for
i
in
range
(
len
(
preds
)):
for
i
in
range
(
len
(
preds
)):
pred
,
label
=
preds
[
i
,
:
-
1
],
labels
[
i
,
1
:]
pred
,
label
=
preds
[
i
,
:
-
1
],
labels
[
i
,
1
:]
...
@@ -90,15 +86,14 @@ class ComputeAccuracy:
...
@@ -90,15 +86,14 @@ class ComputeAccuracy:
@
dataclass
@
dataclass
class
ComputeSimilarity
:
class
ComputeSimilarity
:
r
"""
r
"""Compute text similarity scores and support `batch_eval_metrics`.
Computes text similarity scores and supports `batch_eval_metrics`.
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
"""
"""
tokenizer
:
"PreTrainedTokenizer"
tokenizer
:
"PreTrainedTokenizer"
def
_dump
(
self
)
->
Optional
[
D
ict
[
str
,
float
]]:
def
_dump
(
self
)
->
Optional
[
d
ict
[
str
,
float
]]:
result
=
None
result
=
None
if
hasattr
(
self
,
"score_dict"
):
if
hasattr
(
self
,
"score_dict"
):
result
=
{
k
:
float
(
np
.
mean
(
v
))
for
k
,
v
in
self
.
score_dict
.
items
()}
result
=
{
k
:
float
(
np
.
mean
(
v
))
for
k
,
v
in
self
.
score_dict
.
items
()}
...
@@ -109,7 +104,7 @@ class ComputeSimilarity:
...
@@ -109,7 +104,7 @@ class ComputeSimilarity:
def
__post_init__
(
self
):
def
__post_init__
(
self
):
self
.
_dump
()
self
.
_dump
()
def
__call__
(
self
,
eval_preds
:
"EvalPrediction"
,
compute_result
:
bool
=
True
)
->
Optional
[
D
ict
[
str
,
float
]]:
def
__call__
(
self
,
eval_preds
:
"EvalPrediction"
,
compute_result
:
bool
=
True
)
->
Optional
[
d
ict
[
str
,
float
]]:
preds
,
labels
=
numpify
(
eval_preds
.
predictions
),
numpify
(
eval_preds
.
label_ids
)
preds
,
labels
=
numpify
(
eval_preds
.
predictions
),
numpify
(
eval_preds
.
label_ids
)
preds
=
np
.
where
(
preds
!=
IGNORE_INDEX
,
preds
,
self
.
tokenizer
.
pad_token_id
)
preds
=
np
.
where
(
preds
!=
IGNORE_INDEX
,
preds
,
self
.
tokenizer
.
pad_token_id
)
...
...
src/llamafactory/train/sft/trainer.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's transformers library.
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
import
json
import
json
import
os
import
os
from
types
import
MethodType
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -44,23 +44,24 @@ logger = logging.get_logger(__name__)
...
@@ -44,23 +44,24 @@ logger = logging.get_logger(__name__)
class
CustomSeq2SeqTrainer
(
Seq2SeqTrainer
):
class
CustomSeq2SeqTrainer
(
Seq2SeqTrainer
):
r
"""
r
"""Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE."""
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
"""
def
__init__
(
def
__init__
(
self
,
self
,
finetuning_args
:
"FinetuningArguments"
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
gen_kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
,
gen_kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
if
is_transformers_version_greater_than
(
"4.46"
):
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
else
:
else
:
self
.
processing_class
:
"
PreTrainedTokenizer
"
=
kwargs
.
get
(
"tokenizer"
)
self
.
processing_class
:
PreTrainedTokenizer
=
kwargs
.
get
(
"tokenizer"
)
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
if
processor
is
not
None
:
self
.
model_accepts_loss_kwargs
=
False
self
.
finetuning_args
=
finetuning_args
self
.
finetuning_args
=
finetuning_args
if
gen_kwargs
is
not
None
:
if
gen_kwargs
is
not
None
:
# https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/trainer_seq2seq.py#L287
# https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/trainer_seq2seq.py#L287
...
@@ -95,17 +96,20 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
...
@@ -95,17 +96,20 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return
super
().
_get_train_sampler
()
return
super
().
_get_train_sampler
()
@
override
def
compute_loss
(
self
,
model
,
inputs
,
*
args
,
**
kwargs
):
return
super
().
compute_loss
(
model
,
inputs
,
*
args
,
**
kwargs
)
@
override
@
override
def
prediction_step
(
def
prediction_step
(
self
,
self
,
model
:
"torch.nn.Module"
,
model
:
"torch.nn.Module"
,
inputs
:
D
ict
[
str
,
Union
[
"torch.Tensor"
,
Any
]],
inputs
:
d
ict
[
str
,
Union
[
"torch.Tensor"
,
Any
]],
prediction_loss_only
:
bool
,
prediction_loss_only
:
bool
,
ignore_keys
:
Optional
[
L
ist
[
str
]]
=
None
,
ignore_keys
:
Optional
[
l
ist
[
str
]]
=
None
,
**
gen_kwargs
,
**
gen_kwargs
,
)
->
Tuple
[
Optional
[
float
],
Optional
[
"torch.Tensor"
],
Optional
[
"torch.Tensor"
]]:
)
->
tuple
[
Optional
[
float
],
Optional
[
"torch.Tensor"
],
Optional
[
"torch.Tensor"
]]:
r
"""
r
"""Remove the prompt part in the generated tokens.
Removes the prompt part in the generated tokens.
Subclass and override to inject custom behavior.
Subclass and override to inject custom behavior.
"""
"""
...
@@ -126,8 +130,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
...
@@ -126,8 +130,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def
save_predictions
(
def
save_predictions
(
self
,
dataset
:
"Dataset"
,
predict_results
:
"PredictionOutput"
,
skip_special_tokens
:
bool
=
True
self
,
dataset
:
"Dataset"
,
predict_results
:
"PredictionOutput"
,
skip_special_tokens
:
bool
=
True
)
->
None
:
)
->
None
:
r
"""
r
"""Save model predictions to `output_dir`.
Saves model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
A custom behavior that not contained in Seq2SeqTrainer.
"""
"""
...
...
src/llamafactory/train/sft/workflow.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's transformers library.
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
SFTDataCollatorWith4DAttentionMask
,
get_dataset
,
get_template_and_fix_tokenizer
from
...data
import
SFTDataCollatorWith4DAttentionMask
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
...
@@ -43,7 +43,7 @@ def run_sft(
...
@@ -43,7 +43,7 @@ def run_sft(
training_args
:
"Seq2SeqTrainingArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
generating_args
:
"GeneratingArguments"
,
callbacks
:
Optional
[
L
ist
[
"TrainerCallback"
]]
=
None
,
callbacks
:
Optional
[
l
ist
[
"TrainerCallback"
]]
=
None
,
):
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
tokenizer
=
tokenizer_module
[
"tokenizer"
]
...
...
src/llamafactory/train/test_utils.py
View file @
7ea81099
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Dict
,
Optional
,
Sequence
,
Set
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
torch
import
torch
from
peft
import
PeftModel
from
peft
import
PeftModel
...
@@ -32,7 +32,7 @@ if TYPE_CHECKING:
...
@@ -32,7 +32,7 @@ if TYPE_CHECKING:
from
..data.data_utils
import
DatasetModule
from
..data.data_utils
import
DatasetModule
def
compare_model
(
model_a
:
"torch.nn.Module"
,
model_b
:
"torch.nn.Module"
,
diff_keys
:
Sequence
[
str
]
=
[])
->
None
:
def
compare_model
(
model_a
:
"torch.nn.Module"
,
model_b
:
"torch.nn.Module"
,
diff_keys
:
list
[
str
]
=
[])
->
None
:
state_dict_a
=
model_a
.
state_dict
()
state_dict_a
=
model_a
.
state_dict
()
state_dict_b
=
model_b
.
state_dict
()
state_dict_b
=
model_b
.
state_dict
()
assert
set
(
state_dict_a
.
keys
())
==
set
(
state_dict_b
.
keys
())
assert
set
(
state_dict_a
.
keys
())
==
set
(
state_dict_b
.
keys
())
...
@@ -43,7 +43,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
...
@@ -43,7 +43,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
assert
torch
.
allclose
(
state_dict_a
[
name
],
state_dict_b
[
name
],
rtol
=
1e-4
,
atol
=
1e-5
)
is
True
assert
torch
.
allclose
(
state_dict_a
[
name
],
state_dict_b
[
name
],
rtol
=
1e-4
,
atol
=
1e-5
)
is
True
def
check_lora_model
(
model
:
"LoraModel"
)
->
T
uple
[
S
et
[
str
],
S
et
[
str
]]:
def
check_lora_model
(
model
:
"LoraModel"
)
->
t
uple
[
s
et
[
str
],
s
et
[
str
]]:
linear_modules
,
extra_modules
=
set
(),
set
()
linear_modules
,
extra_modules
=
set
(),
set
()
for
name
,
param
in
model
.
named_parameters
():
for
name
,
param
in
model
.
named_parameters
():
if
any
(
module
in
name
for
module
in
[
"lora_A"
,
"lora_B"
]):
if
any
(
module
in
name
for
module
in
[
"lora_A"
,
"lora_B"
]):
...
@@ -83,7 +83,7 @@ def load_reference_model(
...
@@ -83,7 +83,7 @@ def load_reference_model(
)
->
Union
[
"PreTrainedModel"
,
"LoraModel"
]:
)
->
Union
[
"PreTrainedModel"
,
"LoraModel"
]:
current_device
=
get_current_device
()
current_device
=
get_current_device
()
if
add_valuehead
:
if
add_valuehead
:
model
:
"
AutoModelForCausalLMWithValueHead
"
=
AutoModelForCausalLMWithValueHead
.
from_pretrained
(
model
:
AutoModelForCausalLMWithValueHead
=
AutoModelForCausalLMWithValueHead
.
from_pretrained
(
model_path
,
torch_dtype
=
torch
.
float16
,
device_map
=
current_device
model_path
,
torch_dtype
=
torch
.
float16
,
device_map
=
current_device
)
)
if
not
is_trainable
:
if
not
is_trainable
:
...
@@ -111,7 +111,7 @@ def load_dataset_module(**kwargs) -> "DatasetModule":
...
@@ -111,7 +111,7 @@ def load_dataset_module(**kwargs) -> "DatasetModule":
def
patch_valuehead_model
()
->
None
:
def
patch_valuehead_model
()
->
None
:
def
post_init
(
self
:
"AutoModelForCausalLMWithValueHead"
,
state_dict
:
D
ict
[
str
,
"torch.Tensor"
])
->
None
:
def
post_init
(
self
:
"AutoModelForCausalLMWithValueHead"
,
state_dict
:
d
ict
[
str
,
"torch.Tensor"
])
->
None
:
state_dict
=
{
k
[
7
:]:
state_dict
[
k
]
for
k
in
state_dict
.
keys
()
if
k
.
startswith
(
"v_head."
)}
state_dict
=
{
k
[
7
:]:
state_dict
[
k
]
for
k
in
state_dict
.
keys
()
if
k
.
startswith
(
"v_head."
)}
self
.
v_head
.
load_state_dict
(
state_dict
,
strict
=
False
)
self
.
v_head
.
load_state_dict
(
state_dict
,
strict
=
False
)
del
state_dict
del
state_dict
...
...
src/llamafactory/train/trainer_utils.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore
# This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore
# and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus
# and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus
...
@@ -21,7 +21,7 @@ import json
...
@@ -21,7 +21,7 @@ import json
import
os
import
os
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
,
Union
import
torch
import
torch
from
transformers
import
Trainer
from
transformers
import
Trainer
...
@@ -63,12 +63,10 @@ logger = logging.get_logger(__name__)
...
@@ -63,12 +63,10 @@ logger = logging.get_logger(__name__)
class
DummyOptimizer
(
torch
.
optim
.
Optimizer
):
class
DummyOptimizer
(
torch
.
optim
.
Optimizer
):
r
"""
r
"""A dummy optimizer used for the GaLore or APOLLO algorithm."""
A dummy optimizer used for the GaLore or APOLLO algorithm.
"""
def
__init__
(
def
__init__
(
self
,
lr
:
float
=
1e-3
,
optimizer_dict
:
Optional
[
D
ict
[
"torch.nn.Parameter"
,
"torch.optim.Optimizer"
]]
=
None
self
,
lr
:
float
=
1e-3
,
optimizer_dict
:
Optional
[
d
ict
[
"torch.nn.Parameter"
,
"torch.optim.Optimizer"
]]
=
None
)
->
None
:
)
->
None
:
dummy_tensor
=
torch
.
randn
(
1
,
1
)
dummy_tensor
=
torch
.
randn
(
1
,
1
)
self
.
optimizer_dict
=
optimizer_dict
self
.
optimizer_dict
=
optimizer_dict
...
@@ -112,8 +110,7 @@ def create_modelcard_and_push(
...
@@ -112,8 +110,7 @@ def create_modelcard_and_push(
def
create_ref_model
(
def
create_ref_model
(
model_args
:
"ModelArguments"
,
finetuning_args
:
"FinetuningArguments"
,
add_valuehead
:
bool
=
False
model_args
:
"ModelArguments"
,
finetuning_args
:
"FinetuningArguments"
,
add_valuehead
:
bool
=
False
)
->
Optional
[
Union
[
"PreTrainedModel"
,
"AutoModelForCausalLMWithValueHead"
]]:
)
->
Optional
[
Union
[
"PreTrainedModel"
,
"AutoModelForCausalLMWithValueHead"
]]:
r
"""
r
"""Create reference model for PPO/DPO training. Evaluation mode is not supported.
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
The valuehead parameter is randomly initialized since it is useless for PPO training.
The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
"""
...
@@ -148,9 +145,7 @@ def create_ref_model(
...
@@ -148,9 +145,7 @@ def create_ref_model(
def
create_reward_model
(
def
create_reward_model
(
model
:
"AutoModelForCausalLMWithValueHead"
,
model_args
:
"ModelArguments"
,
finetuning_args
:
"FinetuningArguments"
model
:
"AutoModelForCausalLMWithValueHead"
,
model_args
:
"ModelArguments"
,
finetuning_args
:
"FinetuningArguments"
)
->
Optional
[
"AutoModelForCausalLMWithValueHead"
]:
)
->
Optional
[
"AutoModelForCausalLMWithValueHead"
]:
r
"""
r
"""Create reward model for PPO training."""
Creates reward model for PPO training.
"""
if
finetuning_args
.
reward_model_type
==
"api"
:
if
finetuning_args
.
reward_model_type
==
"api"
:
assert
finetuning_args
.
reward_model
.
startswith
(
"http"
),
"Please provide full url."
assert
finetuning_args
.
reward_model
.
startswith
(
"http"
),
"Please provide full url."
logger
.
info_rank0
(
f
"Use reward server
{
finetuning_args
.
reward_model
}
"
)
logger
.
info_rank0
(
f
"Use reward server
{
finetuning_args
.
reward_model
}
"
)
...
@@ -189,10 +184,8 @@ def create_reward_model(
...
@@ -189,10 +184,8 @@ def create_reward_model(
return
reward_model
return
reward_model
def
_get_decay_parameter_names
(
model
:
"PreTrainedModel"
)
->
List
[
str
]:
def
_get_decay_parameter_names
(
model
:
"PreTrainedModel"
)
->
list
[
str
]:
r
"""
r
"""Return a list of names of parameters with weight decay. (weights in non-layernorm layers)."""
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
"""
decay_parameters
=
get_parameter_names
(
model
,
ALL_LAYERNORM_LAYERS
)
decay_parameters
=
get_parameter_names
(
model
,
ALL_LAYERNORM_LAYERS
)
decay_parameters
=
[
name
for
name
in
decay_parameters
if
"bias"
not
in
name
]
decay_parameters
=
[
name
for
name
in
decay_parameters
if
"bias"
not
in
name
]
return
decay_parameters
return
decay_parameters
...
@@ -208,7 +201,7 @@ def _create_galore_optimizer(
...
@@ -208,7 +201,7 @@ def _create_galore_optimizer(
else
:
else
:
galore_targets
=
finetuning_args
.
galore_target
galore_targets
=
finetuning_args
.
galore_target
galore_params
:
L
ist
[
"
torch.nn.Parameter
"
]
=
[]
galore_params
:
l
ist
[
torch
.
nn
.
Parameter
]
=
[]
for
name
,
module
in
model
.
named_modules
():
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
torch
.
nn
.
Linear
)
and
any
(
target
in
name
for
target
in
galore_targets
):
if
isinstance
(
module
,
torch
.
nn
.
Linear
)
and
any
(
target
in
name
for
target
in
galore_targets
):
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
...
@@ -224,7 +217,7 @@ def _create_galore_optimizer(
...
@@ -224,7 +217,7 @@ def _create_galore_optimizer(
id_galore_params
=
{
id
(
param
)
for
param
in
galore_params
}
id_galore_params
=
{
id
(
param
)
for
param
in
galore_params
}
decay_params
,
nodecay_params
=
[],
[]
# they are non-galore parameters
decay_params
,
nodecay_params
=
[],
[]
# they are non-galore parameters
trainable_params
:
L
ist
[
"
torch.nn.Parameter
"
]
=
[]
# galore_params + decay_params + nodecay_params
trainable_params
:
l
ist
[
torch
.
nn
.
Parameter
]
=
[]
# galore_params + decay_params + nodecay_params
decay_param_names
=
_get_decay_parameter_names
(
model
)
decay_param_names
=
_get_decay_parameter_names
(
model
)
for
name
,
param
in
model
.
named_parameters
():
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
if
param
.
requires_grad
:
...
@@ -251,7 +244,7 @@ def _create_galore_optimizer(
...
@@ -251,7 +244,7 @@ def _create_galore_optimizer(
if
training_args
.
gradient_accumulation_steps
!=
1
:
if
training_args
.
gradient_accumulation_steps
!=
1
:
raise
ValueError
(
"Per-layer GaLore does not support gradient accumulation."
)
raise
ValueError
(
"Per-layer GaLore does not support gradient accumulation."
)
optimizer_dict
:
D
ict
[
"
torch.Tensor
"
,
"
torch.optim.Optimizer
"
]
=
{}
optimizer_dict
:
d
ict
[
torch
.
Tensor
,
torch
.
optim
.
Optimizer
]
=
{}
for
param
in
nodecay_params
:
for
param
in
nodecay_params
:
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
0.0
)]
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
0.0
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
...
@@ -296,7 +289,7 @@ def _create_apollo_optimizer(
...
@@ -296,7 +289,7 @@ def _create_apollo_optimizer(
else
:
else
:
apollo_targets
=
finetuning_args
.
apollo_target
apollo_targets
=
finetuning_args
.
apollo_target
apollo_params
:
L
ist
[
"
torch.nn.Parameter
"
]
=
[]
apollo_params
:
l
ist
[
torch
.
nn
.
Parameter
]
=
[]
for
name
,
module
in
model
.
named_modules
():
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
torch
.
nn
.
Linear
)
and
any
(
target
in
name
for
target
in
apollo_targets
):
if
isinstance
(
module
,
torch
.
nn
.
Linear
)
and
any
(
target
in
name
for
target
in
apollo_targets
):
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
...
@@ -315,7 +308,7 @@ def _create_apollo_optimizer(
...
@@ -315,7 +308,7 @@ def _create_apollo_optimizer(
id_apollo_params
=
{
id
(
param
)
for
param
in
apollo_params
}
id_apollo_params
=
{
id
(
param
)
for
param
in
apollo_params
}
decay_params
,
nodecay_params
=
[],
[]
# they are non-apollo parameters
decay_params
,
nodecay_params
=
[],
[]
# they are non-apollo parameters
trainable_params
:
L
ist
[
"
torch.nn.Parameter
"
]
=
[]
# apollo_params + decay_params + nodecay_params
trainable_params
:
l
ist
[
torch
.
nn
.
Parameter
]
=
[]
# apollo_params + decay_params + nodecay_params
decay_param_names
=
_get_decay_parameter_names
(
model
)
decay_param_names
=
_get_decay_parameter_names
(
model
)
for
name
,
param
in
model
.
named_parameters
():
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
if
param
.
requires_grad
:
...
@@ -338,7 +331,7 @@ def _create_apollo_optimizer(
...
@@ -338,7 +331,7 @@ def _create_apollo_optimizer(
if
training_args
.
gradient_accumulation_steps
!=
1
:
if
training_args
.
gradient_accumulation_steps
!=
1
:
raise
ValueError
(
"Per-layer APOLLO does not support gradient accumulation."
)
raise
ValueError
(
"Per-layer APOLLO does not support gradient accumulation."
)
optimizer_dict
:
D
ict
[
"
torch.Tensor
"
,
"
torch.optim.Optimizer
"
]
=
{}
optimizer_dict
:
d
ict
[
torch
.
Tensor
,
torch
.
optim
.
Optimizer
]
=
{}
for
param
in
nodecay_params
:
for
param
in
nodecay_params
:
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
0.0
)]
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
0.0
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
...
@@ -380,7 +373,7 @@ def _create_loraplus_optimizer(
...
@@ -380,7 +373,7 @@ def _create_loraplus_optimizer(
embedding_lr
=
finetuning_args
.
loraplus_lr_embedding
embedding_lr
=
finetuning_args
.
loraplus_lr_embedding
decay_param_names
=
_get_decay_parameter_names
(
model
)
decay_param_names
=
_get_decay_parameter_names
(
model
)
param_dict
:
D
ict
[
str
,
L
ist
[
"
torch.nn.Parameter
"
]]
=
{
param_dict
:
d
ict
[
str
,
l
ist
[
torch
.
nn
.
Parameter
]]
=
{
"lora_a"
:
[],
"lora_a"
:
[],
"lora_b"
:
[],
"lora_b"
:
[],
"lora_b_nodecay"
:
[],
"lora_b_nodecay"
:
[],
...
@@ -522,9 +515,25 @@ def create_custom_scheduler(
...
@@ -522,9 +515,25 @@ def create_custom_scheduler(
num_training_steps
:
int
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
,
)
->
None
:
)
->
None
:
if
training_args
.
lr_scheduler_type
==
"warmup_stable_decay"
:
num_warmup_steps
=
training_args
.
get_warmup_steps
(
num_training_steps
)
remaining_steps
=
num_training_steps
-
num_warmup_steps
num_stable_steps
=
remaining_steps
//
3
# use 1/3 for stable by default
num_decay_steps
=
remaining_steps
-
num_stable_steps
scheduler_kwargs
=
training_args
.
lr_scheduler_kwargs
or
{}
default_kwargs
=
{
"num_stable_steps"
:
num_stable_steps
,
"num_decay_steps"
:
num_decay_steps
,
}
for
key
,
value
in
default_kwargs
.
items
():
if
key
not
in
scheduler_kwargs
:
scheduler_kwargs
[
key
]
=
value
training_args
.
lr_scheduler_kwargs
=
scheduler_kwargs
if
optimizer
is
not
None
and
isinstance
(
optimizer
,
DummyOptimizer
):
if
optimizer
is
not
None
and
isinstance
(
optimizer
,
DummyOptimizer
):
optimizer_dict
=
optimizer
.
optimizer_dict
optimizer_dict
=
optimizer
.
optimizer_dict
scheduler_dict
:
D
ict
[
"
torch.nn.Parameter
"
,
"
torch.optim.lr_scheduler.LRScheduler
"
]
=
{}
scheduler_dict
:
d
ict
[
torch
.
nn
.
Parameter
,
torch
.
optim
.
lr_scheduler
.
LRScheduler
]
=
{}
for
param
in
optimizer_dict
.
keys
():
for
param
in
optimizer_dict
.
keys
():
scheduler_dict
[
param
]
=
get_scheduler
(
scheduler_dict
[
param
]
=
get_scheduler
(
...
@@ -544,13 +553,13 @@ def create_custom_scheduler(
...
@@ -544,13 +553,13 @@ def create_custom_scheduler(
def
get_batch_logps
(
def
get_batch_logps
(
logits
:
"torch.Tensor"
,
labels
:
"torch.Tensor"
,
label_pad_token_id
:
int
=
IGNORE_INDEX
logits
:
"torch.Tensor"
,
labels
:
"torch.Tensor"
,
label_pad_token_id
:
int
=
IGNORE_INDEX
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
]:
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""
r
"""Compute the log probabilities of the given labels under the given logits.
Computes the log probabilities of the given labels under the given logits.
Returns:
Returns:
logps: A tensor of shape (batch_size,) containing the sum of log probabilities.
logps: A tensor of shape (batch_size,) containing the sum of log probabilities.
valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens.
valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens.
"""
"""
if
logits
.
shape
[:
-
1
]
!=
labels
.
shape
:
if
logits
.
shape
[:
-
1
]
!=
labels
.
shape
:
raise
ValueError
(
"Logits (batchsize x seqlen) and labels must have the same shape."
)
raise
ValueError
(
"Logits (batchsize x seqlen) and labels must have the same shape."
)
...
@@ -564,12 +573,10 @@ def get_batch_logps(
...
@@ -564,12 +573,10 @@ def get_batch_logps(
def
nested_detach
(
def
nested_detach
(
tensors
:
Union
[
"torch.Tensor"
,
L
ist
[
"torch.Tensor"
],
T
uple
[
"torch.Tensor"
],
D
ict
[
str
,
"torch.Tensor"
]],
tensors
:
Union
[
"torch.Tensor"
,
l
ist
[
"torch.Tensor"
],
t
uple
[
"torch.Tensor"
],
d
ict
[
str
,
"torch.Tensor"
]],
clone
:
bool
=
False
,
clone
:
bool
=
False
,
):
):
r
"""
r
"""Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."""
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
"""
if
isinstance
(
tensors
,
(
list
,
tuple
)):
if
isinstance
(
tensors
,
(
list
,
tuple
)):
return
type
(
tensors
)(
nested_detach
(
t
,
clone
=
clone
)
for
t
in
tensors
)
return
type
(
tensors
)(
nested_detach
(
t
,
clone
=
clone
)
for
t
in
tensors
)
elif
isinstance
(
tensors
,
Mapping
):
elif
isinstance
(
tensors
,
Mapping
):
...
@@ -585,15 +592,22 @@ def nested_detach(
...
@@ -585,15 +592,22 @@ def nested_detach(
def
get_swanlab_callback
(
finetuning_args
:
"FinetuningArguments"
)
->
"TrainerCallback"
:
def
get_swanlab_callback
(
finetuning_args
:
"FinetuningArguments"
)
->
"TrainerCallback"
:
r
"""
r
"""Get the callback for logging to SwanLab."""
Gets the callback for logging to SwanLab.
"""
import
swanlab
# type: ignore
import
swanlab
# type: ignore
from
swanlab.integration.transformers
import
SwanLabCallback
# type: ignore
from
swanlab.integration.transformers
import
SwanLabCallback
# type: ignore
if
finetuning_args
.
swanlab_api_key
is
not
None
:
if
finetuning_args
.
swanlab_api_key
is
not
None
:
swanlab
.
login
(
api_key
=
finetuning_args
.
swanlab_api_key
)
swanlab
.
login
(
api_key
=
finetuning_args
.
swanlab_api_key
)
if
finetuning_args
.
swanlab_lark_webhook_url
is
not
None
:
from
swanlab.plugin.notification
import
LarkCallback
# type: ignore
lark_callback
=
LarkCallback
(
webhook_url
=
finetuning_args
.
swanlab_lark_webhook_url
,
secret
=
finetuning_args
.
swanlab_lark_secret
,
)
swanlab
.
register_callbacks
([
lark_callback
])
class
SwanLabCallbackExtension
(
SwanLabCallback
):
class
SwanLabCallbackExtension
(
SwanLabCallback
):
def
setup
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
model
:
"PreTrainedModel"
,
**
kwargs
):
def
setup
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
model
:
"PreTrainedModel"
,
**
kwargs
):
if
not
state
.
is_world_process_zero
:
if
not
state
.
is_world_process_zero
:
...
@@ -624,7 +638,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
...
@@ -624,7 +638,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
def
get_ray_trainer
(
def
get_ray_trainer
(
training_function
:
Callable
,
training_function
:
Callable
,
train_loop_config
:
D
ict
[
str
,
Any
],
train_loop_config
:
d
ict
[
str
,
Any
],
ray_args
:
"RayArguments"
,
ray_args
:
"RayArguments"
,
)
->
"TorchTrainer"
:
)
->
"TorchTrainer"
:
if
not
ray_args
.
use_ray
:
if
not
ray_args
.
use_ray
:
...
...
src/llamafactory/train/tuner.py
View file @
7ea81099
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
os
import
os
import
shutil
import
shutil
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -38,6 +38,7 @@ from .trainer_utils import get_ray_trainer, get_swanlab_callback
...
@@ -38,6 +38,7 @@ from .trainer_utils import get_ray_trainer, get_swanlab_callback
if
is_ray_available
():
if
is_ray_available
():
import
ray
from
ray.train.huggingface.transformers
import
RayTrainReportCallback
from
ray.train.huggingface.transformers
import
RayTrainReportCallback
...
@@ -48,9 +49,9 @@ if TYPE_CHECKING:
...
@@ -48,9 +49,9 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
_training_function
(
config
:
D
ict
[
str
,
Any
])
->
None
:
def
_training_function
(
config
:
d
ict
[
str
,
Any
])
->
None
:
args
=
config
.
get
(
"args"
)
args
=
config
.
get
(
"args"
)
callbacks
:
L
ist
[
Any
]
=
config
.
get
(
"callbacks"
)
callbacks
:
l
ist
[
Any
]
=
config
.
get
(
"callbacks"
)
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
=
get_train_args
(
args
)
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
=
get_train_args
(
args
)
callbacks
.
append
(
LogCallback
())
callbacks
.
append
(
LogCallback
())
...
@@ -77,6 +78,9 @@ def _training_function(config: Dict[str, Any]) -> None:
...
@@ -77,6 +78,9 @@ def _training_function(config: Dict[str, Any]) -> None:
else
:
else
:
raise
ValueError
(
f
"Unknown task:
{
finetuning_args
.
stage
}
."
)
raise
ValueError
(
f
"Unknown task:
{
finetuning_args
.
stage
}
."
)
if
is_ray_available
()
and
ray
.
is_initialized
():
return
# if ray is intialized it will destroy the process group on return
try
:
try
:
if
dist
.
is_initialized
():
if
dist
.
is_initialized
():
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
...
@@ -84,7 +88,7 @@ def _training_function(config: Dict[str, Any]) -> None:
...
@@ -84,7 +88,7 @@ def _training_function(config: Dict[str, Any]) -> None:
logger
.
warning
(
f
"Failed to destroy process group:
{
e
}
."
)
logger
.
warning
(
f
"Failed to destroy process group:
{
e
}
."
)
def
run_exp
(
args
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
,
callbacks
:
Optional
[
L
ist
[
"TrainerCallback"
]]
=
None
)
->
None
:
def
run_exp
(
args
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
,
callbacks
:
Optional
[
l
ist
[
"TrainerCallback"
]]
=
None
)
->
None
:
args
=
read_args
(
args
)
args
=
read_args
(
args
)
if
"-h"
in
args
or
"--help"
in
args
:
if
"-h"
in
args
or
"--help"
in
args
:
get_train_args
(
args
)
get_train_args
(
args
)
...
@@ -103,7 +107,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
...
@@ -103,7 +107,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
_training_function
(
config
=
{
"args"
:
args
,
"callbacks"
:
callbacks
})
_training_function
(
config
=
{
"args"
:
args
,
"callbacks"
:
callbacks
})
def
export_model
(
args
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
)
->
None
:
def
export_model
(
args
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
)
->
None
:
model_args
,
data_args
,
finetuning_args
,
_
=
get_infer_args
(
args
)
model_args
,
data_args
,
finetuning_args
,
_
=
get_infer_args
(
args
)
if
model_args
.
export_dir
is
None
:
if
model_args
.
export_dir
is
None
:
...
...
src/llamafactory/webui/chatter.py
View file @
7ea81099
...
@@ -14,7 +14,8 @@
...
@@ -14,7 +14,8 @@
import
json
import
json
import
os
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
from
collections.abc
import
Generator
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
transformers.utils
import
is_torch_npu_available
from
transformers.utils
import
is_torch_npu_available
...
@@ -37,15 +38,12 @@ if is_gradio_available():
...
@@ -37,15 +38,12 @@ if is_gradio_available():
def
_escape_html
(
text
:
str
)
->
str
:
def
_escape_html
(
text
:
str
)
->
str
:
r
"""
r
"""Escape HTML characters."""
Escapes HTML characters.
"""
return
text
.
replace
(
"<"
,
"<"
).
replace
(
">"
,
">"
)
return
text
.
replace
(
"<"
,
"<"
).
replace
(
">"
,
">"
)
def
_format_response
(
text
:
str
,
lang
:
str
,
escape_html
:
bool
,
thought_words
:
Tuple
[
str
,
str
])
->
str
:
def
_format_response
(
text
:
str
,
lang
:
str
,
escape_html
:
bool
,
thought_words
:
tuple
[
str
,
str
])
->
str
:
r
"""
r
"""Post-process the response text.
Post-processes the response text.
Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
"""
"""
...
@@ -74,7 +72,7 @@ class WebChatModel(ChatModel):
...
@@ -74,7 +72,7 @@ class WebChatModel(ChatModel):
def
__init__
(
self
,
manager
:
"Manager"
,
demo_mode
:
bool
=
False
,
lazy_init
:
bool
=
True
)
->
None
:
def
__init__
(
self
,
manager
:
"Manager"
,
demo_mode
:
bool
=
False
,
lazy_init
:
bool
=
True
)
->
None
:
self
.
manager
=
manager
self
.
manager
=
manager
self
.
demo_mode
=
demo_mode
self
.
demo_mode
=
demo_mode
self
.
engine
:
Optional
[
"
BaseEngine
"
]
=
None
self
.
engine
:
Optional
[
BaseEngine
]
=
None
if
not
lazy_init
:
# read arguments from command line
if
not
lazy_init
:
# read arguments from command line
super
().
__init__
()
super
().
__init__
()
...
@@ -124,6 +122,7 @@ class WebChatModel(ChatModel):
...
@@ -124,6 +122,7 @@ class WebChatModel(ChatModel):
enable_liger_kernel
=
(
get
(
"top.booster"
)
==
"liger_kernel"
),
enable_liger_kernel
=
(
get
(
"top.booster"
)
==
"liger_kernel"
),
infer_backend
=
get
(
"infer.infer_backend"
),
infer_backend
=
get
(
"infer.infer_backend"
),
infer_dtype
=
get
(
"infer.infer_dtype"
),
infer_dtype
=
get
(
"infer.infer_dtype"
),
vllm_enforce_eager
=
True
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
)
)
...
@@ -160,14 +159,13 @@ class WebChatModel(ChatModel):
...
@@ -160,14 +159,13 @@ class WebChatModel(ChatModel):
@
staticmethod
@
staticmethod
def
append
(
def
append
(
chatbot
:
L
ist
[
D
ict
[
str
,
str
]],
chatbot
:
l
ist
[
d
ict
[
str
,
str
]],
messages
:
L
ist
[
D
ict
[
str
,
str
]],
messages
:
l
ist
[
d
ict
[
str
,
str
]],
role
:
str
,
role
:
str
,
query
:
str
,
query
:
str
,
escape_html
:
bool
,
escape_html
:
bool
,
)
->
Tuple
[
List
[
Dict
[
str
,
str
]],
List
[
Dict
[
str
,
str
]],
str
]:
)
->
tuple
[
list
[
dict
[
str
,
str
]],
list
[
dict
[
str
,
str
]],
str
]:
r
"""
r
"""Add the user input to chatbot.
Adds the user input to chatbot.
Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html
Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html
Output: infer.chatbot, infer.messages, infer.query
Output: infer.chatbot, infer.messages, infer.query
...
@@ -180,8 +178,8 @@ class WebChatModel(ChatModel):
...
@@ -180,8 +178,8 @@ class WebChatModel(ChatModel):
def
stream
(
def
stream
(
self
,
self
,
chatbot
:
L
ist
[
D
ict
[
str
,
str
]],
chatbot
:
l
ist
[
d
ict
[
str
,
str
]],
messages
:
L
ist
[
D
ict
[
str
,
str
]],
messages
:
l
ist
[
d
ict
[
str
,
str
]],
lang
:
str
,
lang
:
str
,
system
:
str
,
system
:
str
,
tools
:
str
,
tools
:
str
,
...
@@ -193,9 +191,8 @@ class WebChatModel(ChatModel):
...
@@ -193,9 +191,8 @@ class WebChatModel(ChatModel):
temperature
:
float
,
temperature
:
float
,
skip_special_tokens
:
bool
,
skip_special_tokens
:
bool
,
escape_html
:
bool
,
escape_html
:
bool
,
)
->
Generator
[
Tuple
[
List
[
Dict
[
str
,
str
]],
List
[
Dict
[
str
,
str
]]],
None
,
None
]:
)
->
Generator
[
tuple
[
list
[
dict
[
str
,
str
]],
list
[
dict
[
str
,
str
]]],
None
,
None
]:
r
"""
r
"""Generate output text in stream.
Generates output text in stream.
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages
Output: infer.chatbot, infer.messages
...
...
src/llamafactory/webui/common.py
View file @
7ea81099
...
@@ -17,7 +17,7 @@ import os
...
@@ -17,7 +17,7 @@ import os
import
signal
import
signal
from
collections
import
defaultdict
from
collections
import
defaultdict
from
datetime
import
datetime
from
datetime
import
datetime
from
typing
import
Any
,
Dict
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
from
psutil
import
Process
from
psutil
import
Process
from
yaml
import
safe_dump
,
safe_load
from
yaml
import
safe_dump
,
safe_load
...
@@ -44,9 +44,7 @@ USER_CONFIG = "user_config.yaml"
...
@@ -44,9 +44,7 @@ USER_CONFIG = "user_config.yaml"
def
abort_process
(
pid
:
int
)
->
None
:
def
abort_process
(
pid
:
int
)
->
None
:
r
"""
r
"""Abort the processes recursively in a bottom-up way."""
Aborts the processes recursively in a bottom-up way.
"""
try
:
try
:
children
=
Process
(
pid
).
children
()
children
=
Process
(
pid
).
children
()
if
children
:
if
children
:
...
@@ -59,9 +57,7 @@ def abort_process(pid: int) -> None:
...
@@ -59,9 +57,7 @@ def abort_process(pid: int) -> None:
def
get_save_dir
(
*
paths
:
str
)
->
os
.
PathLike
:
def
get_save_dir
(
*
paths
:
str
)
->
os
.
PathLike
:
r
"""
r
"""Get the path to saved model checkpoints."""
Gets the path to saved model checkpoints.
"""
if
os
.
path
.
sep
in
paths
[
-
1
]:
if
os
.
path
.
sep
in
paths
[
-
1
]:
logger
.
warning_rank0
(
"Found complex path, some features may be not available."
)
logger
.
warning_rank0
(
"Found complex path, some features may be not available."
)
return
paths
[
-
1
]
return
paths
[
-
1
]
...
@@ -71,16 +67,12 @@ def get_save_dir(*paths: str) -> os.PathLike:
...
@@ -71,16 +67,12 @@ def get_save_dir(*paths: str) -> os.PathLike:
def
_get_config_path
()
->
os
.
PathLike
:
def
_get_config_path
()
->
os
.
PathLike
:
r
"""
r
"""Get the path to user config."""
Gets the path to user config.
"""
return
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
USER_CONFIG
)
return
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
USER_CONFIG
)
def
load_config
()
->
Dict
[
str
,
Union
[
str
,
Dict
[
str
,
Any
]]]:
def
load_config
()
->
dict
[
str
,
Union
[
str
,
dict
[
str
,
Any
]]]:
r
"""
r
"""Load user config if exists."""
Loads user config if exists.
"""
try
:
try
:
with
open
(
_get_config_path
(),
encoding
=
"utf-8"
)
as
f
:
with
open
(
_get_config_path
(),
encoding
=
"utf-8"
)
as
f
:
return
safe_load
(
f
)
return
safe_load
(
f
)
...
@@ -89,9 +81,7 @@ def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
...
@@ -89,9 +81,7 @@ def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
def
save_config
(
lang
:
str
,
model_name
:
Optional
[
str
]
=
None
,
model_path
:
Optional
[
str
]
=
None
)
->
None
:
def
save_config
(
lang
:
str
,
model_name
:
Optional
[
str
]
=
None
,
model_path
:
Optional
[
str
]
=
None
)
->
None
:
r
"""
r
"""Save user config."""
Saves user config.
"""
os
.
makedirs
(
DEFAULT_CACHE_DIR
,
exist_ok
=
True
)
os
.
makedirs
(
DEFAULT_CACHE_DIR
,
exist_ok
=
True
)
user_config
=
load_config
()
user_config
=
load_config
()
user_config
[
"lang"
]
=
lang
or
user_config
[
"lang"
]
user_config
[
"lang"
]
=
lang
or
user_config
[
"lang"
]
...
@@ -106,11 +96,9 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
...
@@ -106,11 +96,9 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
def
get_model_path
(
model_name
:
str
)
->
str
:
def
get_model_path
(
model_name
:
str
)
->
str
:
r
"""
r
"""Get the model path according to the model name."""
Gets the model path according to the model name.
"""
user_config
=
load_config
()
user_config
=
load_config
()
path_dict
:
D
ict
[
"
DownloadSource
"
,
str
]
=
SUPPORTED_MODELS
.
get
(
model_name
,
defaultdict
(
str
))
path_dict
:
d
ict
[
DownloadSource
,
str
]
=
SUPPORTED_MODELS
.
get
(
model_name
,
defaultdict
(
str
))
model_path
=
user_config
[
"path_dict"
].
get
(
model_name
,
""
)
or
path_dict
.
get
(
DownloadSource
.
DEFAULT
,
""
)
model_path
=
user_config
[
"path_dict"
].
get
(
model_name
,
""
)
or
path_dict
.
get
(
DownloadSource
.
DEFAULT
,
""
)
if
(
if
(
use_modelscope
()
use_modelscope
()
...
@@ -130,30 +118,22 @@ def get_model_path(model_name: str) -> str:
...
@@ -130,30 +118,22 @@ def get_model_path(model_name: str) -> str:
def
get_template
(
model_name
:
str
)
->
str
:
def
get_template
(
model_name
:
str
)
->
str
:
r
"""
r
"""Get the template name if the model is a chat/distill/instruct model."""
Gets the template name if the model is a chat/distill/instruct model.
"""
return
DEFAULT_TEMPLATE
.
get
(
model_name
,
"default"
)
return
DEFAULT_TEMPLATE
.
get
(
model_name
,
"default"
)
def
get_time
()
->
str
:
def
get_time
()
->
str
:
r
"""
r
"""Get current date and time."""
Gets current date and time.
"""
return
datetime
.
now
().
strftime
(
r
"%Y-%m-%d-%H-%M-%S"
)
return
datetime
.
now
().
strftime
(
r
"%Y-%m-%d-%H-%M-%S"
)
def
is_multimodal
(
model_name
:
str
)
->
bool
:
def
is_multimodal
(
model_name
:
str
)
->
bool
:
r
"""
r
"""Judge if the model is a vision language model."""
Judges if the model is a vision language model.
"""
return
model_name
in
MULTIMODAL_SUPPORTED_MODELS
return
model_name
in
MULTIMODAL_SUPPORTED_MODELS
def
load_dataset_info
(
dataset_dir
:
str
)
->
Dict
[
str
,
Dict
[
str
,
Any
]]:
def
load_dataset_info
(
dataset_dir
:
str
)
->
dict
[
str
,
dict
[
str
,
Any
]]:
r
"""
r
"""Load dataset_info.json."""
Loads dataset_info.json.
"""
if
dataset_dir
==
"ONLINE"
or
dataset_dir
.
startswith
(
"REMOTE:"
):
if
dataset_dir
==
"ONLINE"
or
dataset_dir
.
startswith
(
"REMOTE:"
):
logger
.
info_rank0
(
f
"dataset_dir is
{
dataset_dir
}
, using online dataset."
)
logger
.
info_rank0
(
f
"dataset_dir is
{
dataset_dir
}
, using online dataset."
)
return
{}
return
{}
...
@@ -166,10 +146,8 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
...
@@ -166,10 +146,8 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
return
{}
return
{}
def
load_args
(
config_path
:
str
)
->
Optional
[
Dict
[
str
,
Any
]]:
def
load_args
(
config_path
:
str
)
->
Optional
[
dict
[
str
,
Any
]]:
r
"""
r
"""Load the training configuration from config path."""
Loads the training configuration from config path.
"""
try
:
try
:
with
open
(
config_path
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
config_path
,
encoding
=
"utf-8"
)
as
f
:
return
safe_load
(
f
)
return
safe_load
(
f
)
...
@@ -177,26 +155,20 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
...
@@ -177,26 +155,20 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
return
None
return
None
def
save_args
(
config_path
:
str
,
config_dict
:
Dict
[
str
,
Any
])
->
None
:
def
save_args
(
config_path
:
str
,
config_dict
:
dict
[
str
,
Any
])
->
None
:
r
"""
r
"""Save the training configuration to config path."""
Saves the training configuration to config path.
"""
with
open
(
config_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
config_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
safe_dump
(
config_dict
,
f
)
safe_dump
(
config_dict
,
f
)
def
_clean_cmd
(
args
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
def
_clean_cmd
(
args
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
r
"""
r
"""Remove args with NoneType or False or empty string value."""
Removes args with NoneType or False or empty string value.
"""
no_skip_keys
=
[
"packing"
]
no_skip_keys
=
[
"packing"
]
return
{
k
:
v
for
k
,
v
in
args
.
items
()
if
(
k
in
no_skip_keys
)
or
(
v
is
not
None
and
v
is
not
False
and
v
!=
""
)}
return
{
k
:
v
for
k
,
v
in
args
.
items
()
if
(
k
in
no_skip_keys
)
or
(
v
is
not
None
and
v
is
not
False
and
v
!=
""
)}
def
gen_cmd
(
args
:
Dict
[
str
,
Any
])
->
str
:
def
gen_cmd
(
args
:
dict
[
str
,
Any
])
->
str
:
r
"""
r
"""Generate CLI commands for previewing."""
Generates CLI commands for previewing.
"""
cmd_lines
=
[
"llamafactory-cli train "
]
cmd_lines
=
[
"llamafactory-cli train "
]
for
k
,
v
in
_clean_cmd
(
args
).
items
():
for
k
,
v
in
_clean_cmd
(
args
).
items
():
if
isinstance
(
v
,
dict
):
if
isinstance
(
v
,
dict
):
...
@@ -215,10 +187,8 @@ def gen_cmd(args: Dict[str, Any]) -> str:
...
@@ -215,10 +187,8 @@ def gen_cmd(args: Dict[str, Any]) -> str:
return
cmd_text
return
cmd_text
def
save_cmd
(
args
:
Dict
[
str
,
Any
])
->
str
:
def
save_cmd
(
args
:
dict
[
str
,
Any
])
->
str
:
r
"""
r
"""Save CLI commands to launch training."""
Saves CLI commands to launch training.
"""
output_dir
=
args
[
"output_dir"
]
output_dir
=
args
[
"output_dir"
]
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
with
open
(
os
.
path
.
join
(
output_dir
,
TRAINING_ARGS
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
os
.
path
.
join
(
output_dir
,
TRAINING_ARGS
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
...
@@ -228,9 +198,7 @@ def save_cmd(args: Dict[str, Any]) -> str:
...
@@ -228,9 +198,7 @@ def save_cmd(args: Dict[str, Any]) -> str:
def
load_eval_results
(
path
:
os
.
PathLike
)
->
str
:
def
load_eval_results
(
path
:
os
.
PathLike
)
->
str
:
r
"""
r
"""Get scores after evaluation."""
Gets scores after evaluation.
"""
with
open
(
path
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
path
,
encoding
=
"utf-8"
)
as
f
:
result
=
json
.
dumps
(
json
.
load
(
f
),
indent
=
4
)
result
=
json
.
dumps
(
json
.
load
(
f
),
indent
=
4
)
...
@@ -238,9 +206,7 @@ def load_eval_results(path: os.PathLike) -> str:
...
@@ -238,9 +206,7 @@ def load_eval_results(path: os.PathLike) -> str:
def
create_ds_config
()
->
None
:
def
create_ds_config
()
->
None
:
r
"""
r
"""Create deepspeed config in the current directory."""
Creates deepspeed config in the current directory.
"""
os
.
makedirs
(
DEFAULT_CACHE_DIR
,
exist_ok
=
True
)
os
.
makedirs
(
DEFAULT_CACHE_DIR
,
exist_ok
=
True
)
ds_config
=
{
ds_config
=
{
"train_batch_size"
:
"auto"
,
"train_batch_size"
:
"auto"
,
...
...
src/llamafactory/webui/components/chatbot.py
View file @
7ea81099
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
import
json
import
json
from
typing
import
TYPE_CHECKING
,
Dict
,
Tuple
from
typing
import
TYPE_CHECKING
from
...data
import
Role
from
...data
import
Role
from
...extras.packages
import
is_gradio_available
from
...extras.packages
import
is_gradio_available
...
@@ -31,9 +31,7 @@ if TYPE_CHECKING:
...
@@ -31,9 +31,7 @@ if TYPE_CHECKING:
def
check_json_schema
(
text
:
str
,
lang
:
str
)
->
None
:
def
check_json_schema
(
text
:
str
,
lang
:
str
)
->
None
:
r
"""
r
"""Check if the json schema is valid."""
Checks if the json schema is valid.
"""
try
:
try
:
tools
=
json
.
loads
(
text
)
tools
=
json
.
loads
(
text
)
if
tools
:
if
tools
:
...
@@ -49,7 +47,7 @@ def check_json_schema(text: str, lang: str) -> None:
...
@@ -49,7 +47,7 @@ def check_json_schema(text: str, lang: str) -> None:
def
create_chat_box
(
def
create_chat_box
(
engine
:
"Engine"
,
visible
:
bool
=
False
engine
:
"Engine"
,
visible
:
bool
=
False
)
->
T
uple
[
"Component"
,
"Component"
,
D
ict
[
str
,
"Component"
]]:
)
->
t
uple
[
"Component"
,
"Component"
,
d
ict
[
str
,
"Component"
]]:
lang
=
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)
lang
=
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)
with
gr
.
Column
(
visible
=
visible
)
as
chat_box
:
with
gr
.
Column
(
visible
=
visible
)
as
chat_box
:
chatbot
=
gr
.
Chatbot
(
type
=
"messages"
,
show_copy_button
=
True
)
chatbot
=
gr
.
Chatbot
(
type
=
"messages"
,
show_copy_button
=
True
)
...
...
src/llamafactory/webui/components/data.py
View file @
7ea81099
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
json
import
json
import
os
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
from
...extras.constants
import
DATA_CONFIG
from
...extras.constants
import
DATA_CONFIG
from
...extras.packages
import
is_gradio_available
from
...extras.packages
import
is_gradio_available
...
@@ -40,9 +40,7 @@ def next_page(page_index: int, total_num: int) -> int:
...
@@ -40,9 +40,7 @@ def next_page(page_index: int, total_num: int) -> int:
def
can_preview
(
dataset_dir
:
str
,
dataset
:
list
)
->
"gr.Button"
:
def
can_preview
(
dataset_dir
:
str
,
dataset
:
list
)
->
"gr.Button"
:
r
"""
r
"""Check if the dataset is a local dataset."""
Checks if the dataset is a local dataset.
"""
try
:
try
:
with
open
(
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
),
encoding
=
"utf-8"
)
as
f
:
with
open
(
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
),
encoding
=
"utf-8"
)
as
f
:
dataset_info
=
json
.
load
(
f
)
dataset_info
=
json
.
load
(
f
)
...
@@ -59,7 +57,7 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
...
@@ -59,7 +57,7 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
return
gr
.
Button
(
interactive
=
False
)
return
gr
.
Button
(
interactive
=
False
)
def
_load_data_file
(
file_path
:
str
)
->
L
ist
[
Any
]:
def
_load_data_file
(
file_path
:
str
)
->
l
ist
[
Any
]:
with
open
(
file_path
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
file_path
,
encoding
=
"utf-8"
)
as
f
:
if
file_path
.
endswith
(
".json"
):
if
file_path
.
endswith
(
".json"
):
return
json
.
load
(
f
)
return
json
.
load
(
f
)
...
@@ -69,10 +67,8 @@ def _load_data_file(file_path: str) -> List[Any]:
...
@@ -69,10 +67,8 @@ def _load_data_file(file_path: str) -> List[Any]:
return
list
(
f
)
return
list
(
f
)
def
get_preview
(
dataset_dir
:
str
,
dataset
:
list
,
page_index
:
int
)
->
Tuple
[
int
,
list
,
"gr.Column"
]:
def
get_preview
(
dataset_dir
:
str
,
dataset
:
list
,
page_index
:
int
)
->
tuple
[
int
,
list
,
"gr.Column"
]:
r
"""
r
"""Get the preview samples from the dataset."""
Gets the preview samples from the dataset.
"""
with
open
(
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
),
encoding
=
"utf-8"
)
as
f
:
with
open
(
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
),
encoding
=
"utf-8"
)
as
f
:
dataset_info
=
json
.
load
(
f
)
dataset_info
=
json
.
load
(
f
)
...
@@ -87,7 +83,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
...
@@ -87,7 +83,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
return
len
(
data
),
data
[
PAGE_SIZE
*
page_index
:
PAGE_SIZE
*
(
page_index
+
1
)],
gr
.
Column
(
visible
=
True
)
return
len
(
data
),
data
[
PAGE_SIZE
*
page_index
:
PAGE_SIZE
*
(
page_index
+
1
)],
gr
.
Column
(
visible
=
True
)
def
create_preview_box
(
dataset_dir
:
"gr.Textbox"
,
dataset
:
"gr.Dropdown"
)
->
D
ict
[
str
,
"Component"
]:
def
create_preview_box
(
dataset_dir
:
"gr.Textbox"
,
dataset
:
"gr.Dropdown"
)
->
d
ict
[
str
,
"Component"
]:
data_preview_btn
=
gr
.
Button
(
interactive
=
False
,
scale
=
1
)
data_preview_btn
=
gr
.
Button
(
interactive
=
False
,
scale
=
1
)
with
gr
.
Column
(
visible
=
False
,
elem_classes
=
"modal-box"
)
as
preview_box
:
with
gr
.
Column
(
visible
=
False
,
elem_classes
=
"modal-box"
)
as
preview_box
:
with
gr
.
Row
():
with
gr
.
Row
():
...
...
src/llamafactory/webui/components/eval.py
View file @
7ea81099
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Dict
from
typing
import
TYPE_CHECKING
from
...extras.packages
import
is_gradio_available
from
...extras.packages
import
is_gradio_available
from
..common
import
DEFAULT_DATA_DIR
from
..common
import
DEFAULT_DATA_DIR
...
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
...
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
from
..engine
import
Engine
from
..engine
import
Engine
def
create_eval_tab
(
engine
:
"Engine"
)
->
D
ict
[
str
,
"Component"
]:
def
create_eval_tab
(
engine
:
"Engine"
)
->
d
ict
[
str
,
"Component"
]:
input_elems
=
engine
.
manager
.
get_base_elems
()
input_elems
=
engine
.
manager
.
get_base_elems
()
elem_dict
=
dict
()
elem_dict
=
dict
()
...
...
src/llamafactory/webui/components/export.py
View file @
7ea81099
...
@@ -12,7 +12,8 @@
...
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Dict
,
Generator
,
List
,
Union
from
collections.abc
import
Generator
from
typing
import
TYPE_CHECKING
,
Union
from
...extras.constants
import
PEFT_METHODS
from
...extras.constants
import
PEFT_METHODS
from
...extras.misc
import
torch_gc
from
...extras.misc
import
torch_gc
...
@@ -35,7 +36,7 @@ if TYPE_CHECKING:
...
@@ -35,7 +36,7 @@ if TYPE_CHECKING:
GPTQ_BITS
=
[
"8"
,
"4"
,
"3"
,
"2"
]
GPTQ_BITS
=
[
"8"
,
"4"
,
"3"
,
"2"
]
def
can_quantize
(
checkpoint_path
:
Union
[
str
,
L
ist
[
str
]])
->
"gr.Dropdown"
:
def
can_quantize
(
checkpoint_path
:
Union
[
str
,
l
ist
[
str
]])
->
"gr.Dropdown"
:
if
isinstance
(
checkpoint_path
,
list
)
and
len
(
checkpoint_path
)
!=
0
:
if
isinstance
(
checkpoint_path
,
list
)
and
len
(
checkpoint_path
)
!=
0
:
return
gr
.
Dropdown
(
value
=
"none"
,
interactive
=
False
)
return
gr
.
Dropdown
(
value
=
"none"
,
interactive
=
False
)
else
:
else
:
...
@@ -47,7 +48,7 @@ def save_model(
...
@@ -47,7 +48,7 @@ def save_model(
model_name
:
str
,
model_name
:
str
,
model_path
:
str
,
model_path
:
str
,
finetuning_type
:
str
,
finetuning_type
:
str
,
checkpoint_path
:
Union
[
str
,
L
ist
[
str
]],
checkpoint_path
:
Union
[
str
,
l
ist
[
str
]],
template
:
str
,
template
:
str
,
export_size
:
int
,
export_size
:
int
,
export_quantization_bit
:
str
,
export_quantization_bit
:
str
,
...
@@ -106,7 +107,7 @@ def save_model(
...
@@ -106,7 +107,7 @@ def save_model(
yield
ALERTS
[
"info_exported"
][
lang
]
yield
ALERTS
[
"info_exported"
][
lang
]
def
create_export_tab
(
engine
:
"Engine"
)
->
D
ict
[
str
,
"Component"
]:
def
create_export_tab
(
engine
:
"Engine"
)
->
d
ict
[
str
,
"Component"
]:
with
gr
.
Row
():
with
gr
.
Row
():
export_size
=
gr
.
Slider
(
minimum
=
1
,
maximum
=
100
,
value
=
5
,
step
=
1
)
export_size
=
gr
.
Slider
(
minimum
=
1
,
maximum
=
100
,
value
=
5
,
step
=
1
)
export_quantization_bit
=
gr
.
Dropdown
(
choices
=
[
"none"
]
+
GPTQ_BITS
,
value
=
"none"
)
export_quantization_bit
=
gr
.
Dropdown
(
choices
=
[
"none"
]
+
GPTQ_BITS
,
value
=
"none"
)
...
...
src/llamafactory/webui/components/infer.py
View file @
7ea81099
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Dict
from
typing
import
TYPE_CHECKING
from
...extras.packages
import
is_gradio_available
from
...extras.packages
import
is_gradio_available
from
..common
import
is_multimodal
from
..common
import
is_multimodal
...
@@ -29,12 +29,12 @@ if TYPE_CHECKING:
...
@@ -29,12 +29,12 @@ if TYPE_CHECKING:
from
..engine
import
Engine
from
..engine
import
Engine
def
create_infer_tab
(
engine
:
"Engine"
)
->
D
ict
[
str
,
"Component"
]:
def
create_infer_tab
(
engine
:
"Engine"
)
->
d
ict
[
str
,
"Component"
]:
input_elems
=
engine
.
manager
.
get_base_elems
()
input_elems
=
engine
.
manager
.
get_base_elems
()
elem_dict
=
dict
()
elem_dict
=
dict
()
with
gr
.
Row
():
with
gr
.
Row
():
infer_backend
=
gr
.
Dropdown
(
choices
=
[
"huggingface"
,
"vllm"
],
value
=
"huggingface"
)
infer_backend
=
gr
.
Dropdown
(
choices
=
[
"huggingface"
,
"vllm"
,
"sglang"
],
value
=
"huggingface"
)
infer_dtype
=
gr
.
Dropdown
(
choices
=
[
"auto"
,
"float16"
,
"bfloat16"
,
"float32"
],
value
=
"auto"
)
infer_dtype
=
gr
.
Dropdown
(
choices
=
[
"auto"
,
"float16"
,
"bfloat16"
,
"float32"
],
value
=
"auto"
)
with
gr
.
Row
():
with
gr
.
Row
():
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment