Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
7ea81099
"vscode:/vscode.git/clone" did not exist on "4018b3c5339307844c9cdcfc61e91009930f7c8f"
Commit
7ea81099
authored
Apr 07, 2025
by
chenych
Browse files
update llama4
parent
84987715
Changes
139
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
207 additions
and
246 deletions
+207
-246
src/llamafactory/train/ppo/trainer.py
src/llamafactory/train/ppo/trainer.py
+27
-36
src/llamafactory/train/ppo/workflow.py
src/llamafactory/train/ppo/workflow.py
+4
-4
src/llamafactory/train/pt/trainer.py
src/llamafactory/train/pt/trainer.py
+5
-3
src/llamafactory/train/pt/workflow.py
src/llamafactory/train/pt/workflow.py
+3
-3
src/llamafactory/train/rm/metric.py
src/llamafactory/train/rm/metric.py
+4
-6
src/llamafactory/train/rm/trainer.py
src/llamafactory/train/rm/trainer.py
+8
-12
src/llamafactory/train/rm/workflow.py
src/llamafactory/train/rm/workflow.py
+3
-3
src/llamafactory/train/sft/metric.py
src/llamafactory/train/sft/metric.py
+11
-16
src/llamafactory/train/sft/trainer.py
src/llamafactory/train/sft/trainer.py
+17
-14
src/llamafactory/train/sft/workflow.py
src/llamafactory/train/sft/workflow.py
+3
-3
src/llamafactory/train/test_utils.py
src/llamafactory/train/test_utils.py
+5
-5
src/llamafactory/train/trainer_utils.py
src/llamafactory/train/trainer_utils.py
+48
-34
src/llamafactory/train/tuner.py
src/llamafactory/train/tuner.py
+9
-5
src/llamafactory/webui/chatter.py
src/llamafactory/webui/chatter.py
+15
-18
src/llamafactory/webui/common.py
src/llamafactory/webui/common.py
+26
-60
src/llamafactory/webui/components/chatbot.py
src/llamafactory/webui/components/chatbot.py
+3
-5
src/llamafactory/webui/components/data.py
src/llamafactory/webui/components/data.py
+6
-10
src/llamafactory/webui/components/eval.py
src/llamafactory/webui/components/eval.py
+2
-2
src/llamafactory/webui/components/export.py
src/llamafactory/webui/components/export.py
+5
-4
src/llamafactory/webui/components/infer.py
src/llamafactory/webui/components/infer.py
+3
-3
No files found.
src/llamafactory/train/ppo/trainer.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py
...
...
@@ -20,7 +20,7 @@ import os
import
sys
import
warnings
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
from
accelerate.utils
import
DistributedDataParallelKwargs
...
...
@@ -62,9 +62,7 @@ logger = logging.get_logger(__name__)
class
CustomPPOTrainer
(
PPOTrainer
,
Trainer
):
r
"""
Inherits PPOTrainer.
"""
r
"""Inherit PPOTrainer."""
def
__init__
(
self
,
...
...
@@ -72,7 +70,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
callbacks
:
Optional
[
L
ist
[
"TrainerCallback"
]],
callbacks
:
Optional
[
l
ist
[
"TrainerCallback"
]],
model
:
"AutoModelForCausalLMWithValueHead"
,
reward_model
:
Optional
[
"AutoModelForCausalLMWithValueHead"
],
ref_model
:
Optional
[
"AutoModelForCausalLMWithValueHead"
],
...
...
@@ -187,9 +185,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self
.
add_callback
(
BAdamCallback
)
def
ppo_train
(
self
,
resume_from_checkpoint
:
Optional
[
str
]
=
None
)
->
None
:
r
"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
"""
r
"""Implement training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer."""
if
resume_from_checkpoint
is
not
None
:
raise
ValueError
(
"`resume_from_checkpoint` will be supported in the future version."
)
...
...
@@ -221,9 +217,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
logger
.
info_rank0
(
f
" Num Epochs =
{
num_train_epochs
:,
}
"
)
logger
.
info_rank0
(
f
" Instantaneous batch size per device =
{
self
.
args
.
per_device_train_batch_size
:,
}
"
)
logger
.
info_rank0
(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}"
.
format
(
total_train_batch_size
)
f
" Total train batch size (w. parallel, buffer, distributed & accumulation) =
{
total_train_batch_size
:,
}
"
)
logger
.
info_rank0
(
f
" Gradient Accumulation steps =
{
self
.
args
.
gradient_accumulation_steps
:,
}
"
)
logger
.
info_rank0
(
f
" Num optimization epochs per batch =
{
self
.
finetuning_args
.
ppo_epochs
:,
}
"
)
...
...
@@ -247,9 +241,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self
.
tokenizer
.
padding_side
=
"right"
# change padding side
queries
,
responses
,
rewards
=
[],
[],
[]
for
idx
in
range
(
0
,
self
.
config
.
batch_size
,
self
.
config
.
mini_batch_size
):
mini_batch_queries
,
mini_batch_responses
=
self
.
get_inputs
(
batch
[
idx
:
idx
+
self
.
config
.
mini_batch_size
]
)
mini_batch
=
{
"input_ids"
:
batch
[
"input_ids"
][
idx
:
idx
+
self
.
config
.
mini_batch_size
],
"attention_mask"
:
batch
[
"attention_mask"
][
idx
:
idx
+
self
.
config
.
mini_batch_size
],
}
mini_batch_queries
,
mini_batch_responses
=
self
.
get_inputs
(
mini_batch
)
mini_batch_rewards
=
self
.
get_rewards
(
mini_batch_queries
,
mini_batch_responses
)
queries
.
extend
(
mini_batch_queries
)
responses
.
extend
(
mini_batch_responses
)
...
...
@@ -339,21 +335,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
return
lr_scheduler
@
torch
.
no_grad
()
def
get_inputs
(
self
,
batch
:
Dict
[
str
,
"torch.Tensor"
])
->
Tuple
[
List
[
"torch.Tensor"
],
List
[
"torch.Tensor"
]]:
r
"""
Generates model's responses given queries.
"""
def
get_inputs
(
self
,
batch
:
dict
[
str
,
"torch.Tensor"
])
->
tuple
[
list
[
"torch.Tensor"
],
list
[
"torch.Tensor"
]]:
r
"""Generate model's responses given queries."""
if
batch
[
"input_ids"
].
size
(
0
)
==
1
:
# handle llama2 ppo with gradient accumulation > 1
start_index
=
(
batch
[
"input_ids"
][
0
]
!=
self
.
tokenizer
.
pad_token_id
).
nonzero
()[
0
].
item
()
for
k
,
v
in
batch
.
items
():
batch
[
k
]
=
v
[:,
start_index
:]
with
unwrap_model_for_generation
(
self
.
model
,
self
.
accelerator
)
as
unwrapped_model
:
unwrapped_model
:
"
AutoModelForCausalLMWithValueHead
"
=
self
.
accelerator
.
unwrap_model
(
self
.
model
)
unwrapped_model
:
AutoModelForCausalLMWithValueHead
=
self
.
accelerator
.
unwrap_model
(
self
.
model
)
if
self
.
model_args
.
upcast_layernorm
:
layernorm_params
=
dump_layernorm
(
unwrapped_model
)
generate_output
:
"
torch.Tensor
"
=
unwrapped_model
.
generate
(
generate_output
:
torch
.
Tensor
=
unwrapped_model
.
generate
(
generation_config
=
self
.
generation_config
,
logits_processor
=
get_logits_processor
(),
**
batch
)
if
self
.
model_args
.
upcast_layernorm
:
...
...
@@ -381,11 +375,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
@
torch
.
no_grad
()
def
get_rewards
(
self
,
queries
:
List
[
"torch.Tensor"
],
responses
:
List
[
"torch.Tensor"
],
)
->
List
[
"torch.Tensor"
]:
r
"""
Computes scores using given reward model.
queries
:
list
[
"torch.Tensor"
],
responses
:
list
[
"torch.Tensor"
],
)
->
list
[
"torch.Tensor"
]:
r
"""Compute scores using given reward model.
Both inputs and outputs are put on CPU.
"""
...
...
@@ -394,8 +387,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
messages
=
self
.
tokenizer
.
batch_decode
(
token_ids
,
skip_special_tokens
=
False
)
return
get_rewards_from_server
(
self
.
reward_model
,
messages
)
batch
:
D
ict
[
str
,
"
torch.Tensor
"
]
=
self
.
prepare_model_inputs
(
queries
,
responses
)
unwrapped_model
:
"
AutoModelForCausalLMWithValueHead
"
=
self
.
accelerator
.
unwrap_model
(
self
.
model
)
batch
:
d
ict
[
str
,
torch
.
Tensor
]
=
self
.
prepare_model_inputs
(
queries
,
responses
)
unwrapped_model
:
AutoModelForCausalLMWithValueHead
=
self
.
accelerator
.
unwrap_model
(
self
.
model
)
if
self
.
finetuning_args
.
reward_model_type
==
"lora"
:
replace_model
(
unwrapped_model
,
target
=
"reward"
)
...
...
@@ -404,7 +397,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
reward_model
=
self
.
reward_model
with
unwrap_model_for_generation
(
reward_model
,
self
.
accelerator
),
self
.
amp_context
:
# support bf16
values
:
"
torch.Tensor
"
=
reward_model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
)[
-
1
]
values
:
torch
.
Tensor
=
reward_model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
)[
-
1
]
if
self
.
finetuning_args
.
reward_model_type
==
"lora"
:
replace_model
(
unwrapped_model
,
target
=
"default"
)
...
...
@@ -419,12 +412,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
model
:
"AutoModelForCausalLMWithValueHead"
,
queries
:
"torch.Tensor"
,
responses
:
"torch.Tensor"
,
model_inputs
:
D
ict
[
str
,
Any
],
model_inputs
:
d
ict
[
str
,
Any
],
return_logits
:
bool
=
False
,
response_masks
:
Optional
[
"torch.Tensor"
]
=
None
,
)
->
Tuple
[
"torch.Tensor"
,
Optional
[
"torch.Tensor"
],
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""
Calculates model outputs in multiple batches.
)
->
tuple
[
"torch.Tensor"
,
Optional
[
"torch.Tensor"
],
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""Calculate model outputs in multiple batches.
Subclass and override to inject custom behavior.
"""
...
...
@@ -483,8 +475,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
@
override
def
save_model
(
self
,
output_dir
:
Optional
[
str
]
=
None
)
->
None
:
r
"""
Saves model checkpoint.
r
"""Save model checkpoint.
Subclass and override to inject custom behavior.
"""
...
...
@@ -508,5 +499,5 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self
.
model
.
save_checkpoint
(
output_dir
)
elif
self
.
args
.
should_save
:
unwrapped_model
:
"
AutoModelForCausalLMWithValueHead
"
=
self
.
accelerator
.
unwrap_model
(
self
.
model
)
unwrapped_model
:
AutoModelForCausalLMWithValueHead
=
self
.
accelerator
.
unwrap_model
(
self
.
model
)
self
.
_save
(
output_dir
,
state_dict
=
unwrapped_model
.
state_dict
())
src/llamafactory/train/ppo/workflow.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py
...
...
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
MultiModalDataCollatorForSeq2Seq
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.ploting
import
plot_loss
...
...
@@ -37,7 +37,7 @@ def run_ppo(
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
callbacks
:
Optional
[
L
ist
[
"TrainerCallback"
]]
=
None
,
callbacks
:
Optional
[
l
ist
[
"TrainerCallback"
]]
=
None
,
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
...
...
@@ -53,7 +53,7 @@ def run_ppo(
reward_model
=
create_reward_model
(
model
,
model_args
,
finetuning_args
)
# Initialize our Trainer
ppo_trainer
:
"
CustomPPOTrainer
"
=
CustomPPOTrainer
(
ppo_trainer
:
CustomPPOTrainer
=
CustomPPOTrainer
(
model_args
=
model_args
,
training_args
=
training_args
,
finetuning_args
=
finetuning_args
,
...
...
src/llamafactory/train/pt/trainer.py
View file @
7ea81099
...
...
@@ -31,9 +31,7 @@ if TYPE_CHECKING:
class
CustomTrainer
(
Trainer
):
r
"""
Inherits Trainer for custom optimizer.
"""
r
"""Inherit Trainer for custom optimizer."""
def
__init__
(
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
...
...
@@ -72,3 +70,7 @@ class CustomTrainer(Trainer):
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
()
@
override
def
compute_loss
(
self
,
model
,
inputs
,
*
args
,
**
kwargs
):
return
super
().
compute_loss
(
model
,
inputs
,
*
args
,
**
kwargs
)
src/llamafactory/train/pt/workflow.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
...
...
@@ -16,7 +16,7 @@
# limitations under the License.
import
math
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
from
transformers
import
DataCollatorForLanguageModeling
...
...
@@ -38,7 +38,7 @@ def run_pt(
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
callbacks
:
Optional
[
L
ist
[
"TrainerCallback"
]]
=
None
,
callbacks
:
Optional
[
l
ist
[
"TrainerCallback"
]]
=
None
,
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
...
...
src/llamafactory/train/rm/metric.py
View file @
7ea81099
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
import
numpy
as
np
...
...
@@ -26,11 +26,9 @@ if TYPE_CHECKING:
@
dataclass
class
ComputeAccuracy
:
r
"""
Computes reward accuracy and supports `batch_eval_metrics`.
"""
r
"""Compute reward accuracy and support `batch_eval_metrics`."""
def
_dump
(
self
)
->
Optional
[
D
ict
[
str
,
float
]]:
def
_dump
(
self
)
->
Optional
[
d
ict
[
str
,
float
]]:
result
=
None
if
hasattr
(
self
,
"score_dict"
):
result
=
{
k
:
float
(
np
.
mean
(
v
))
for
k
,
v
in
self
.
score_dict
.
items
()}
...
...
@@ -41,7 +39,7 @@ class ComputeAccuracy:
def
__post_init__
(
self
):
self
.
_dump
()
def
__call__
(
self
,
eval_preds
:
"EvalPrediction"
,
compute_result
:
bool
=
True
)
->
Optional
[
D
ict
[
str
,
float
]]:
def
__call__
(
self
,
eval_preds
:
"EvalPrediction"
,
compute_result
:
bool
=
True
)
->
Optional
[
d
ict
[
str
,
float
]]:
chosen_scores
,
rejected_scores
=
numpify
(
eval_preds
.
predictions
[
0
]),
numpify
(
eval_preds
.
predictions
[
1
])
if
not
chosen_scores
.
shape
:
self
.
score_dict
[
"accuracy"
].
append
(
chosen_scores
>
rejected_scores
)
...
...
src/llamafactory/train/rm/trainer.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py
...
...
@@ -18,7 +18,7 @@
import
json
import
os
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
torch
from
transformers
import
Trainer
...
...
@@ -41,9 +41,7 @@ logger = logging.get_logger(__name__)
class
PairwiseTrainer
(
Trainer
):
r
"""
Inherits Trainer to compute pairwise loss.
"""
r
"""Inherits Trainer to compute pairwise loss."""
def
__init__
(
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
...
...
@@ -88,10 +86,9 @@ class PairwiseTrainer(Trainer):
@
override
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
r
"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
self
,
model
:
"PreTrainedModel"
,
inputs
:
dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
tuple
[
"torch.Tensor"
,
list
[
"torch.Tensor"
]]]:
r
"""Compute pairwise loss. The first n examples are chosen and the last n examples are rejected.
Subclass and override to inject custom behavior.
...
...
@@ -113,8 +110,7 @@ class PairwiseTrainer(Trainer):
return
loss
def
save_predictions
(
self
,
predict_results
:
"PredictionOutput"
)
->
None
:
r
"""
Saves model predictions to `output_dir`.
r
"""Save model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
...
...
@@ -126,7 +122,7 @@ class PairwiseTrainer(Trainer):
chosen_scores
,
rejected_scores
=
predict_results
.
predictions
with
open
(
output_prediction_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
res
:
L
ist
[
str
]
=
[]
res
:
l
ist
[
str
]
=
[]
for
c_score
,
r_score
in
zip
(
chosen_scores
,
rejected_scores
):
res
.
append
(
json
.
dumps
({
"chosen"
:
round
(
float
(
c_score
),
2
),
"rejected"
:
round
(
float
(
r_score
),
2
)}))
...
...
src/llamafactory/train/rm/workflow.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
...
...
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.ploting
import
plot_loss
...
...
@@ -37,7 +37,7 @@ def run_rm(
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
callbacks
:
Optional
[
L
ist
[
"TrainerCallback"
]]
=
None
,
callbacks
:
Optional
[
l
ist
[
"TrainerCallback"
]]
=
None
,
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
...
...
src/llamafactory/train/sft/metric.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc., THUDM, and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc., THUDM, and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
...
...
@@ -17,7 +17,7 @@
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
import
numpy
as
np
import
torch
...
...
@@ -37,17 +37,15 @@ if is_jieba_available():
if
is_nltk_available
():
from
nltk.translate.bleu_score
import
SmoothingFunction
,
sentence_bleu
from
nltk.translate.bleu_score
import
SmoothingFunction
,
sentence_bleu
# type: ignore
if
is_rouge_available
():
from
rouge_chinese
import
Rouge
from
rouge_chinese
import
Rouge
# type: ignore
def
eval_logit_processor
(
logits
:
"torch.Tensor"
,
labels
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""
Computes the token with the largest likelihood to reduce memory footprint.
"""
r
"""Compute the token with the largest likelihood to reduce memory footprint."""
if
isinstance
(
logits
,
(
list
,
tuple
)):
if
logits
[
0
].
dim
()
==
3
:
# (batch_size, seq_len, vocab_size)
logits
=
logits
[
0
]
...
...
@@ -62,11 +60,9 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
@
dataclass
class
ComputeAccuracy
:
r
"""
Computes accuracy and supports `batch_eval_metrics`.
"""
r
"""Compute accuracy and support `batch_eval_metrics`."""
def
_dump
(
self
)
->
Optional
[
D
ict
[
str
,
float
]]:
def
_dump
(
self
)
->
Optional
[
d
ict
[
str
,
float
]]:
result
=
None
if
hasattr
(
self
,
"score_dict"
):
result
=
{
k
:
float
(
np
.
mean
(
v
))
for
k
,
v
in
self
.
score_dict
.
items
()}
...
...
@@ -77,7 +73,7 @@ class ComputeAccuracy:
def
__post_init__
(
self
):
self
.
_dump
()
def
__call__
(
self
,
eval_preds
:
"EvalPrediction"
,
compute_result
:
bool
=
True
)
->
Optional
[
D
ict
[
str
,
float
]]:
def
__call__
(
self
,
eval_preds
:
"EvalPrediction"
,
compute_result
:
bool
=
True
)
->
Optional
[
d
ict
[
str
,
float
]]:
preds
,
labels
=
numpify
(
eval_preds
.
predictions
),
numpify
(
eval_preds
.
label_ids
)
for
i
in
range
(
len
(
preds
)):
pred
,
label
=
preds
[
i
,
:
-
1
],
labels
[
i
,
1
:]
...
...
@@ -90,15 +86,14 @@ class ComputeAccuracy:
@
dataclass
class
ComputeSimilarity
:
r
"""
Computes text similarity scores and supports `batch_eval_metrics`.
r
"""Compute text similarity scores and support `batch_eval_metrics`.
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
"""
tokenizer
:
"PreTrainedTokenizer"
def
_dump
(
self
)
->
Optional
[
D
ict
[
str
,
float
]]:
def
_dump
(
self
)
->
Optional
[
d
ict
[
str
,
float
]]:
result
=
None
if
hasattr
(
self
,
"score_dict"
):
result
=
{
k
:
float
(
np
.
mean
(
v
))
for
k
,
v
in
self
.
score_dict
.
items
()}
...
...
@@ -109,7 +104,7 @@ class ComputeSimilarity:
def
__post_init__
(
self
):
self
.
_dump
()
def
__call__
(
self
,
eval_preds
:
"EvalPrediction"
,
compute_result
:
bool
=
True
)
->
Optional
[
D
ict
[
str
,
float
]]:
def
__call__
(
self
,
eval_preds
:
"EvalPrediction"
,
compute_result
:
bool
=
True
)
->
Optional
[
d
ict
[
str
,
float
]]:
preds
,
labels
=
numpify
(
eval_preds
.
predictions
),
numpify
(
eval_preds
.
label_ids
)
preds
=
np
.
where
(
preds
!=
IGNORE_INDEX
,
preds
,
self
.
tokenizer
.
pad_token_id
)
...
...
src/llamafactory/train/sft/trainer.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
...
...
@@ -18,7 +18,7 @@
import
json
import
os
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
import
numpy
as
np
import
torch
...
...
@@ -44,23 +44,24 @@ logger = logging.get_logger(__name__)
class
CustomSeq2SeqTrainer
(
Seq2SeqTrainer
):
r
"""
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
"""
r
"""Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE."""
def
__init__
(
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
gen_kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
,
gen_kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
,
**
kwargs
,
)
->
None
:
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
else
:
self
.
processing_class
:
"
PreTrainedTokenizer
"
=
kwargs
.
get
(
"tokenizer"
)
self
.
processing_class
:
PreTrainedTokenizer
=
kwargs
.
get
(
"tokenizer"
)
super
().
__init__
(
**
kwargs
)
if
processor
is
not
None
:
self
.
model_accepts_loss_kwargs
=
False
self
.
finetuning_args
=
finetuning_args
if
gen_kwargs
is
not
None
:
# https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/trainer_seq2seq.py#L287
...
...
@@ -95,17 +96,20 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return
super
().
_get_train_sampler
()
@
override
def
compute_loss
(
self
,
model
,
inputs
,
*
args
,
**
kwargs
):
return
super
().
compute_loss
(
model
,
inputs
,
*
args
,
**
kwargs
)
@
override
def
prediction_step
(
self
,
model
:
"torch.nn.Module"
,
inputs
:
D
ict
[
str
,
Union
[
"torch.Tensor"
,
Any
]],
inputs
:
d
ict
[
str
,
Union
[
"torch.Tensor"
,
Any
]],
prediction_loss_only
:
bool
,
ignore_keys
:
Optional
[
L
ist
[
str
]]
=
None
,
ignore_keys
:
Optional
[
l
ist
[
str
]]
=
None
,
**
gen_kwargs
,
)
->
Tuple
[
Optional
[
float
],
Optional
[
"torch.Tensor"
],
Optional
[
"torch.Tensor"
]]:
r
"""
Removes the prompt part in the generated tokens.
)
->
tuple
[
Optional
[
float
],
Optional
[
"torch.Tensor"
],
Optional
[
"torch.Tensor"
]]:
r
"""Remove the prompt part in the generated tokens.
Subclass and override to inject custom behavior.
"""
...
...
@@ -126,8 +130,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def
save_predictions
(
self
,
dataset
:
"Dataset"
,
predict_results
:
"PredictionOutput"
,
skip_special_tokens
:
bool
=
True
)
->
None
:
r
"""
Saves model predictions to `output_dir`.
r
"""Save model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
...
...
src/llamafactory/train/sft/workflow.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
...
...
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
SFTDataCollatorWith4DAttentionMask
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
...
...
@@ -43,7 +43,7 @@ def run_sft(
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
callbacks
:
Optional
[
L
ist
[
"TrainerCallback"
]]
=
None
,
callbacks
:
Optional
[
l
ist
[
"TrainerCallback"
]]
=
None
,
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
...
...
src/llamafactory/train/test_utils.py
View file @
7ea81099
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Dict
,
Optional
,
Sequence
,
Set
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
torch
from
peft
import
PeftModel
...
...
@@ -32,7 +32,7 @@ if TYPE_CHECKING:
from
..data.data_utils
import
DatasetModule
def
compare_model
(
model_a
:
"torch.nn.Module"
,
model_b
:
"torch.nn.Module"
,
diff_keys
:
Sequence
[
str
]
=
[])
->
None
:
def
compare_model
(
model_a
:
"torch.nn.Module"
,
model_b
:
"torch.nn.Module"
,
diff_keys
:
list
[
str
]
=
[])
->
None
:
state_dict_a
=
model_a
.
state_dict
()
state_dict_b
=
model_b
.
state_dict
()
assert
set
(
state_dict_a
.
keys
())
==
set
(
state_dict_b
.
keys
())
...
...
@@ -43,7 +43,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
assert
torch
.
allclose
(
state_dict_a
[
name
],
state_dict_b
[
name
],
rtol
=
1e-4
,
atol
=
1e-5
)
is
True
def
check_lora_model
(
model
:
"LoraModel"
)
->
T
uple
[
S
et
[
str
],
S
et
[
str
]]:
def
check_lora_model
(
model
:
"LoraModel"
)
->
t
uple
[
s
et
[
str
],
s
et
[
str
]]:
linear_modules
,
extra_modules
=
set
(),
set
()
for
name
,
param
in
model
.
named_parameters
():
if
any
(
module
in
name
for
module
in
[
"lora_A"
,
"lora_B"
]):
...
...
@@ -83,7 +83,7 @@ def load_reference_model(
)
->
Union
[
"PreTrainedModel"
,
"LoraModel"
]:
current_device
=
get_current_device
()
if
add_valuehead
:
model
:
"
AutoModelForCausalLMWithValueHead
"
=
AutoModelForCausalLMWithValueHead
.
from_pretrained
(
model
:
AutoModelForCausalLMWithValueHead
=
AutoModelForCausalLMWithValueHead
.
from_pretrained
(
model_path
,
torch_dtype
=
torch
.
float16
,
device_map
=
current_device
)
if
not
is_trainable
:
...
...
@@ -111,7 +111,7 @@ def load_dataset_module(**kwargs) -> "DatasetModule":
def
patch_valuehead_model
()
->
None
:
def
post_init
(
self
:
"AutoModelForCausalLMWithValueHead"
,
state_dict
:
D
ict
[
str
,
"torch.Tensor"
])
->
None
:
def
post_init
(
self
:
"AutoModelForCausalLMWithValueHead"
,
state_dict
:
d
ict
[
str
,
"torch.Tensor"
])
->
None
:
state_dict
=
{
k
[
7
:]:
state_dict
[
k
]
for
k
in
state_dict
.
keys
()
if
k
.
startswith
(
"v_head."
)}
self
.
v_head
.
load_state_dict
(
state_dict
,
strict
=
False
)
del
state_dict
...
...
src/llamafactory/train/trainer_utils.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore
# and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus
...
...
@@ -21,7 +21,7 @@ import json
import
os
from
collections.abc
import
Mapping
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
,
Union
import
torch
from
transformers
import
Trainer
...
...
@@ -63,12 +63,10 @@ logger = logging.get_logger(__name__)
class
DummyOptimizer
(
torch
.
optim
.
Optimizer
):
r
"""
A dummy optimizer used for the GaLore or APOLLO algorithm.
"""
r
"""A dummy optimizer used for the GaLore or APOLLO algorithm."""
def
__init__
(
self
,
lr
:
float
=
1e-3
,
optimizer_dict
:
Optional
[
D
ict
[
"torch.nn.Parameter"
,
"torch.optim.Optimizer"
]]
=
None
self
,
lr
:
float
=
1e-3
,
optimizer_dict
:
Optional
[
d
ict
[
"torch.nn.Parameter"
,
"torch.optim.Optimizer"
]]
=
None
)
->
None
:
dummy_tensor
=
torch
.
randn
(
1
,
1
)
self
.
optimizer_dict
=
optimizer_dict
...
...
@@ -112,8 +110,7 @@ def create_modelcard_and_push(
def
create_ref_model
(
model_args
:
"ModelArguments"
,
finetuning_args
:
"FinetuningArguments"
,
add_valuehead
:
bool
=
False
)
->
Optional
[
Union
[
"PreTrainedModel"
,
"AutoModelForCausalLMWithValueHead"
]]:
r
"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
r
"""Create reference model for PPO/DPO training. Evaluation mode is not supported.
The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
...
...
@@ -148,9 +145,7 @@ def create_ref_model(
def
create_reward_model
(
model
:
"AutoModelForCausalLMWithValueHead"
,
model_args
:
"ModelArguments"
,
finetuning_args
:
"FinetuningArguments"
)
->
Optional
[
"AutoModelForCausalLMWithValueHead"
]:
r
"""
Creates reward model for PPO training.
"""
r
"""Create reward model for PPO training."""
if
finetuning_args
.
reward_model_type
==
"api"
:
assert
finetuning_args
.
reward_model
.
startswith
(
"http"
),
"Please provide full url."
logger
.
info_rank0
(
f
"Use reward server
{
finetuning_args
.
reward_model
}
"
)
...
...
@@ -189,10 +184,8 @@ def create_reward_model(
return
reward_model
def
_get_decay_parameter_names
(
model
:
"PreTrainedModel"
)
->
List
[
str
]:
r
"""
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
"""
def
_get_decay_parameter_names
(
model
:
"PreTrainedModel"
)
->
list
[
str
]:
r
"""Return a list of names of parameters with weight decay. (weights in non-layernorm layers)."""
decay_parameters
=
get_parameter_names
(
model
,
ALL_LAYERNORM_LAYERS
)
decay_parameters
=
[
name
for
name
in
decay_parameters
if
"bias"
not
in
name
]
return
decay_parameters
...
...
@@ -208,7 +201,7 @@ def _create_galore_optimizer(
else
:
galore_targets
=
finetuning_args
.
galore_target
galore_params
:
L
ist
[
"
torch.nn.Parameter
"
]
=
[]
galore_params
:
l
ist
[
torch
.
nn
.
Parameter
]
=
[]
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
torch
.
nn
.
Linear
)
and
any
(
target
in
name
for
target
in
galore_targets
):
for
param
in
module
.
parameters
():
...
...
@@ -224,7 +217,7 @@ def _create_galore_optimizer(
id_galore_params
=
{
id
(
param
)
for
param
in
galore_params
}
decay_params
,
nodecay_params
=
[],
[]
# they are non-galore parameters
trainable_params
:
L
ist
[
"
torch.nn.Parameter
"
]
=
[]
# galore_params + decay_params + nodecay_params
trainable_params
:
l
ist
[
torch
.
nn
.
Parameter
]
=
[]
# galore_params + decay_params + nodecay_params
decay_param_names
=
_get_decay_parameter_names
(
model
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
...
...
@@ -251,7 +244,7 @@ def _create_galore_optimizer(
if
training_args
.
gradient_accumulation_steps
!=
1
:
raise
ValueError
(
"Per-layer GaLore does not support gradient accumulation."
)
optimizer_dict
:
D
ict
[
"
torch.Tensor
"
,
"
torch.optim.Optimizer
"
]
=
{}
optimizer_dict
:
d
ict
[
torch
.
Tensor
,
torch
.
optim
.
Optimizer
]
=
{}
for
param
in
nodecay_params
:
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
0.0
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
...
...
@@ -296,7 +289,7 @@ def _create_apollo_optimizer(
else
:
apollo_targets
=
finetuning_args
.
apollo_target
apollo_params
:
L
ist
[
"
torch.nn.Parameter
"
]
=
[]
apollo_params
:
l
ist
[
torch
.
nn
.
Parameter
]
=
[]
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
torch
.
nn
.
Linear
)
and
any
(
target
in
name
for
target
in
apollo_targets
):
for
param
in
module
.
parameters
():
...
...
@@ -315,7 +308,7 @@ def _create_apollo_optimizer(
id_apollo_params
=
{
id
(
param
)
for
param
in
apollo_params
}
decay_params
,
nodecay_params
=
[],
[]
# they are non-apollo parameters
trainable_params
:
L
ist
[
"
torch.nn.Parameter
"
]
=
[]
# apollo_params + decay_params + nodecay_params
trainable_params
:
l
ist
[
torch
.
nn
.
Parameter
]
=
[]
# apollo_params + decay_params + nodecay_params
decay_param_names
=
_get_decay_parameter_names
(
model
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
...
...
@@ -338,7 +331,7 @@ def _create_apollo_optimizer(
if
training_args
.
gradient_accumulation_steps
!=
1
:
raise
ValueError
(
"Per-layer APOLLO does not support gradient accumulation."
)
optimizer_dict
:
D
ict
[
"
torch.Tensor
"
,
"
torch.optim.Optimizer
"
]
=
{}
optimizer_dict
:
d
ict
[
torch
.
Tensor
,
torch
.
optim
.
Optimizer
]
=
{}
for
param
in
nodecay_params
:
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
0.0
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
...
...
@@ -380,7 +373,7 @@ def _create_loraplus_optimizer(
embedding_lr
=
finetuning_args
.
loraplus_lr_embedding
decay_param_names
=
_get_decay_parameter_names
(
model
)
param_dict
:
D
ict
[
str
,
L
ist
[
"
torch.nn.Parameter
"
]]
=
{
param_dict
:
d
ict
[
str
,
l
ist
[
torch
.
nn
.
Parameter
]]
=
{
"lora_a"
:
[],
"lora_b"
:
[],
"lora_b_nodecay"
:
[],
...
...
@@ -522,9 +515,25 @@ def create_custom_scheduler(
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
,
)
->
None
:
if
training_args
.
lr_scheduler_type
==
"warmup_stable_decay"
:
num_warmup_steps
=
training_args
.
get_warmup_steps
(
num_training_steps
)
remaining_steps
=
num_training_steps
-
num_warmup_steps
num_stable_steps
=
remaining_steps
//
3
# use 1/3 for stable by default
num_decay_steps
=
remaining_steps
-
num_stable_steps
scheduler_kwargs
=
training_args
.
lr_scheduler_kwargs
or
{}
default_kwargs
=
{
"num_stable_steps"
:
num_stable_steps
,
"num_decay_steps"
:
num_decay_steps
,
}
for
key
,
value
in
default_kwargs
.
items
():
if
key
not
in
scheduler_kwargs
:
scheduler_kwargs
[
key
]
=
value
training_args
.
lr_scheduler_kwargs
=
scheduler_kwargs
if
optimizer
is
not
None
and
isinstance
(
optimizer
,
DummyOptimizer
):
optimizer_dict
=
optimizer
.
optimizer_dict
scheduler_dict
:
D
ict
[
"
torch.nn.Parameter
"
,
"
torch.optim.lr_scheduler.LRScheduler
"
]
=
{}
scheduler_dict
:
d
ict
[
torch
.
nn
.
Parameter
,
torch
.
optim
.
lr_scheduler
.
LRScheduler
]
=
{}
for
param
in
optimizer_dict
.
keys
():
scheduler_dict
[
param
]
=
get_scheduler
(
...
...
@@ -544,13 +553,13 @@ def create_custom_scheduler(
def
get_batch_logps
(
logits
:
"torch.Tensor"
,
labels
:
"torch.Tensor"
,
label_pad_token_id
:
int
=
IGNORE_INDEX
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""
Computes the log probabilities of the given labels under the given logits.
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""Compute the log probabilities of the given labels under the given logits.
Returns:
logps: A tensor of shape (batch_size,) containing the sum of log probabilities.
valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens.
"""
if
logits
.
shape
[:
-
1
]
!=
labels
.
shape
:
raise
ValueError
(
"Logits (batchsize x seqlen) and labels must have the same shape."
)
...
...
@@ -564,12 +573,10 @@ def get_batch_logps(
def
nested_detach
(
tensors
:
Union
[
"torch.Tensor"
,
L
ist
[
"torch.Tensor"
],
T
uple
[
"torch.Tensor"
],
D
ict
[
str
,
"torch.Tensor"
]],
tensors
:
Union
[
"torch.Tensor"
,
l
ist
[
"torch.Tensor"
],
t
uple
[
"torch.Tensor"
],
d
ict
[
str
,
"torch.Tensor"
]],
clone
:
bool
=
False
,
):
r
"""
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
"""
r
"""Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."""
if
isinstance
(
tensors
,
(
list
,
tuple
)):
return
type
(
tensors
)(
nested_detach
(
t
,
clone
=
clone
)
for
t
in
tensors
)
elif
isinstance
(
tensors
,
Mapping
):
...
...
@@ -585,15 +592,22 @@ def nested_detach(
def
get_swanlab_callback
(
finetuning_args
:
"FinetuningArguments"
)
->
"TrainerCallback"
:
r
"""
Gets the callback for logging to SwanLab.
"""
r
"""Get the callback for logging to SwanLab."""
import
swanlab
# type: ignore
from
swanlab.integration.transformers
import
SwanLabCallback
# type: ignore
if
finetuning_args
.
swanlab_api_key
is
not
None
:
swanlab
.
login
(
api_key
=
finetuning_args
.
swanlab_api_key
)
if
finetuning_args
.
swanlab_lark_webhook_url
is
not
None
:
from
swanlab.plugin.notification
import
LarkCallback
# type: ignore
lark_callback
=
LarkCallback
(
webhook_url
=
finetuning_args
.
swanlab_lark_webhook_url
,
secret
=
finetuning_args
.
swanlab_lark_secret
,
)
swanlab
.
register_callbacks
([
lark_callback
])
class
SwanLabCallbackExtension
(
SwanLabCallback
):
def
setup
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
model
:
"PreTrainedModel"
,
**
kwargs
):
if
not
state
.
is_world_process_zero
:
...
...
@@ -624,7 +638,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
def
get_ray_trainer
(
training_function
:
Callable
,
train_loop_config
:
D
ict
[
str
,
Any
],
train_loop_config
:
d
ict
[
str
,
Any
],
ray_args
:
"RayArguments"
,
)
->
"TorchTrainer"
:
if
not
ray_args
.
use_ray
:
...
...
src/llamafactory/train/tuner.py
View file @
7ea81099
...
...
@@ -14,7 +14,7 @@
import
os
import
shutil
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
import
torch.distributed
as
dist
...
...
@@ -38,6 +38,7 @@ from .trainer_utils import get_ray_trainer, get_swanlab_callback
if
is_ray_available
():
import
ray
from
ray.train.huggingface.transformers
import
RayTrainReportCallback
...
...
@@ -48,9 +49,9 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
def
_training_function
(
config
:
D
ict
[
str
,
Any
])
->
None
:
def
_training_function
(
config
:
d
ict
[
str
,
Any
])
->
None
:
args
=
config
.
get
(
"args"
)
callbacks
:
L
ist
[
Any
]
=
config
.
get
(
"callbacks"
)
callbacks
:
l
ist
[
Any
]
=
config
.
get
(
"callbacks"
)
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
=
get_train_args
(
args
)
callbacks
.
append
(
LogCallback
())
...
...
@@ -77,6 +78,9 @@ def _training_function(config: Dict[str, Any]) -> None:
else
:
raise
ValueError
(
f
"Unknown task:
{
finetuning_args
.
stage
}
."
)
if
is_ray_available
()
and
ray
.
is_initialized
():
return
# if ray is intialized it will destroy the process group on return
try
:
if
dist
.
is_initialized
():
dist
.
destroy_process_group
()
...
...
@@ -84,7 +88,7 @@ def _training_function(config: Dict[str, Any]) -> None:
logger
.
warning
(
f
"Failed to destroy process group:
{
e
}
."
)
def
run_exp
(
args
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
,
callbacks
:
Optional
[
L
ist
[
"TrainerCallback"
]]
=
None
)
->
None
:
def
run_exp
(
args
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
,
callbacks
:
Optional
[
l
ist
[
"TrainerCallback"
]]
=
None
)
->
None
:
args
=
read_args
(
args
)
if
"-h"
in
args
or
"--help"
in
args
:
get_train_args
(
args
)
...
...
@@ -103,7 +107,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
_training_function
(
config
=
{
"args"
:
args
,
"callbacks"
:
callbacks
})
def
export_model
(
args
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
)
->
None
:
def
export_model
(
args
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
)
->
None
:
model_args
,
data_args
,
finetuning_args
,
_
=
get_infer_args
(
args
)
if
model_args
.
export_dir
is
None
:
...
...
src/llamafactory/webui/chatter.py
View file @
7ea81099
...
...
@@ -14,7 +14,8 @@
import
json
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
from
collections.abc
import
Generator
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
transformers.utils
import
is_torch_npu_available
...
...
@@ -37,15 +38,12 @@ if is_gradio_available():
def
_escape_html
(
text
:
str
)
->
str
:
r
"""
Escapes HTML characters.
"""
r
"""Escape HTML characters."""
return
text
.
replace
(
"<"
,
"<"
).
replace
(
">"
,
">"
)
def
_format_response
(
text
:
str
,
lang
:
str
,
escape_html
:
bool
,
thought_words
:
Tuple
[
str
,
str
])
->
str
:
r
"""
Post-processes the response text.
def
_format_response
(
text
:
str
,
lang
:
str
,
escape_html
:
bool
,
thought_words
:
tuple
[
str
,
str
])
->
str
:
r
"""Post-process the response text.
Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
"""
...
...
@@ -74,7 +72,7 @@ class WebChatModel(ChatModel):
def
__init__
(
self
,
manager
:
"Manager"
,
demo_mode
:
bool
=
False
,
lazy_init
:
bool
=
True
)
->
None
:
self
.
manager
=
manager
self
.
demo_mode
=
demo_mode
self
.
engine
:
Optional
[
"
BaseEngine
"
]
=
None
self
.
engine
:
Optional
[
BaseEngine
]
=
None
if
not
lazy_init
:
# read arguments from command line
super
().
__init__
()
...
...
@@ -124,6 +122,7 @@ class WebChatModel(ChatModel):
enable_liger_kernel
=
(
get
(
"top.booster"
)
==
"liger_kernel"
),
infer_backend
=
get
(
"infer.infer_backend"
),
infer_dtype
=
get
(
"infer.infer_dtype"
),
vllm_enforce_eager
=
True
,
trust_remote_code
=
True
,
)
...
...
@@ -160,14 +159,13 @@ class WebChatModel(ChatModel):
@
staticmethod
def
append
(
chatbot
:
L
ist
[
D
ict
[
str
,
str
]],
messages
:
L
ist
[
D
ict
[
str
,
str
]],
chatbot
:
l
ist
[
d
ict
[
str
,
str
]],
messages
:
l
ist
[
d
ict
[
str
,
str
]],
role
:
str
,
query
:
str
,
escape_html
:
bool
,
)
->
Tuple
[
List
[
Dict
[
str
,
str
]],
List
[
Dict
[
str
,
str
]],
str
]:
r
"""
Adds the user input to chatbot.
)
->
tuple
[
list
[
dict
[
str
,
str
]],
list
[
dict
[
str
,
str
]],
str
]:
r
"""Add the user input to chatbot.
Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html
Output: infer.chatbot, infer.messages, infer.query
...
...
@@ -180,8 +178,8 @@ class WebChatModel(ChatModel):
def
stream
(
self
,
chatbot
:
L
ist
[
D
ict
[
str
,
str
]],
messages
:
L
ist
[
D
ict
[
str
,
str
]],
chatbot
:
l
ist
[
d
ict
[
str
,
str
]],
messages
:
l
ist
[
d
ict
[
str
,
str
]],
lang
:
str
,
system
:
str
,
tools
:
str
,
...
...
@@ -193,9 +191,8 @@ class WebChatModel(ChatModel):
temperature
:
float
,
skip_special_tokens
:
bool
,
escape_html
:
bool
,
)
->
Generator
[
Tuple
[
List
[
Dict
[
str
,
str
]],
List
[
Dict
[
str
,
str
]]],
None
,
None
]:
r
"""
Generates output text in stream.
)
->
Generator
[
tuple
[
list
[
dict
[
str
,
str
]],
list
[
dict
[
str
,
str
]]],
None
,
None
]:
r
"""Generate output text in stream.
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages
...
...
src/llamafactory/webui/common.py
View file @
7ea81099
...
...
@@ -17,7 +17,7 @@ import os
import
signal
from
collections
import
defaultdict
from
datetime
import
datetime
from
typing
import
Any
,
Dict
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
from
psutil
import
Process
from
yaml
import
safe_dump
,
safe_load
...
...
@@ -44,9 +44,7 @@ USER_CONFIG = "user_config.yaml"
def
abort_process
(
pid
:
int
)
->
None
:
r
"""
Aborts the processes recursively in a bottom-up way.
"""
r
"""Abort the processes recursively in a bottom-up way."""
try
:
children
=
Process
(
pid
).
children
()
if
children
:
...
...
@@ -59,9 +57,7 @@ def abort_process(pid: int) -> None:
def
get_save_dir
(
*
paths
:
str
)
->
os
.
PathLike
:
r
"""
Gets the path to saved model checkpoints.
"""
r
"""Get the path to saved model checkpoints."""
if
os
.
path
.
sep
in
paths
[
-
1
]:
logger
.
warning_rank0
(
"Found complex path, some features may be not available."
)
return
paths
[
-
1
]
...
...
@@ -71,16 +67,12 @@ def get_save_dir(*paths: str) -> os.PathLike:
def
_get_config_path
()
->
os
.
PathLike
:
r
"""
Gets the path to user config.
"""
r
"""Get the path to user config."""
return
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
USER_CONFIG
)
def
load_config
()
->
Dict
[
str
,
Union
[
str
,
Dict
[
str
,
Any
]]]:
r
"""
Loads user config if exists.
"""
def
load_config
()
->
dict
[
str
,
Union
[
str
,
dict
[
str
,
Any
]]]:
r
"""Load user config if exists."""
try
:
with
open
(
_get_config_path
(),
encoding
=
"utf-8"
)
as
f
:
return
safe_load
(
f
)
...
...
@@ -89,9 +81,7 @@ def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
def
save_config
(
lang
:
str
,
model_name
:
Optional
[
str
]
=
None
,
model_path
:
Optional
[
str
]
=
None
)
->
None
:
r
"""
Saves user config.
"""
r
"""Save user config."""
os
.
makedirs
(
DEFAULT_CACHE_DIR
,
exist_ok
=
True
)
user_config
=
load_config
()
user_config
[
"lang"
]
=
lang
or
user_config
[
"lang"
]
...
...
@@ -106,11 +96,9 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
def
get_model_path
(
model_name
:
str
)
->
str
:
r
"""
Gets the model path according to the model name.
"""
r
"""Get the model path according to the model name."""
user_config
=
load_config
()
path_dict
:
D
ict
[
"
DownloadSource
"
,
str
]
=
SUPPORTED_MODELS
.
get
(
model_name
,
defaultdict
(
str
))
path_dict
:
d
ict
[
DownloadSource
,
str
]
=
SUPPORTED_MODELS
.
get
(
model_name
,
defaultdict
(
str
))
model_path
=
user_config
[
"path_dict"
].
get
(
model_name
,
""
)
or
path_dict
.
get
(
DownloadSource
.
DEFAULT
,
""
)
if
(
use_modelscope
()
...
...
@@ -130,30 +118,22 @@ def get_model_path(model_name: str) -> str:
def
get_template
(
model_name
:
str
)
->
str
:
r
"""
Gets the template name if the model is a chat/distill/instruct model.
"""
r
"""Get the template name if the model is a chat/distill/instruct model."""
return
DEFAULT_TEMPLATE
.
get
(
model_name
,
"default"
)
def
get_time
()
->
str
:
r
"""
Gets current date and time.
"""
r
"""Get current date and time."""
return
datetime
.
now
().
strftime
(
r
"%Y-%m-%d-%H-%M-%S"
)
def
is_multimodal
(
model_name
:
str
)
->
bool
:
r
"""
Judges if the model is a vision language model.
"""
r
"""Judge if the model is a vision language model."""
return
model_name
in
MULTIMODAL_SUPPORTED_MODELS
def
load_dataset_info
(
dataset_dir
:
str
)
->
Dict
[
str
,
Dict
[
str
,
Any
]]:
r
"""
Loads dataset_info.json.
"""
def
load_dataset_info
(
dataset_dir
:
str
)
->
dict
[
str
,
dict
[
str
,
Any
]]:
r
"""Load dataset_info.json."""
if
dataset_dir
==
"ONLINE"
or
dataset_dir
.
startswith
(
"REMOTE:"
):
logger
.
info_rank0
(
f
"dataset_dir is
{
dataset_dir
}
, using online dataset."
)
return
{}
...
...
@@ -166,10 +146,8 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
return
{}
def
load_args
(
config_path
:
str
)
->
Optional
[
Dict
[
str
,
Any
]]:
r
"""
Loads the training configuration from config path.
"""
def
load_args
(
config_path
:
str
)
->
Optional
[
dict
[
str
,
Any
]]:
r
"""Load the training configuration from config path."""
try
:
with
open
(
config_path
,
encoding
=
"utf-8"
)
as
f
:
return
safe_load
(
f
)
...
...
@@ -177,26 +155,20 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
return
None
def
save_args
(
config_path
:
str
,
config_dict
:
Dict
[
str
,
Any
])
->
None
:
r
"""
Saves the training configuration to config path.
"""
def
save_args
(
config_path
:
str
,
config_dict
:
dict
[
str
,
Any
])
->
None
:
r
"""Save the training configuration to config path."""
with
open
(
config_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
safe_dump
(
config_dict
,
f
)
def
_clean_cmd
(
args
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
r
"""
Removes args with NoneType or False or empty string value.
"""
def
_clean_cmd
(
args
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
r
"""Remove args with NoneType or False or empty string value."""
no_skip_keys
=
[
"packing"
]
return
{
k
:
v
for
k
,
v
in
args
.
items
()
if
(
k
in
no_skip_keys
)
or
(
v
is
not
None
and
v
is
not
False
and
v
!=
""
)}
def
gen_cmd
(
args
:
Dict
[
str
,
Any
])
->
str
:
r
"""
Generates CLI commands for previewing.
"""
def
gen_cmd
(
args
:
dict
[
str
,
Any
])
->
str
:
r
"""Generate CLI commands for previewing."""
cmd_lines
=
[
"llamafactory-cli train "
]
for
k
,
v
in
_clean_cmd
(
args
).
items
():
if
isinstance
(
v
,
dict
):
...
...
@@ -215,10 +187,8 @@ def gen_cmd(args: Dict[str, Any]) -> str:
return
cmd_text
def
save_cmd
(
args
:
Dict
[
str
,
Any
])
->
str
:
r
"""
Saves CLI commands to launch training.
"""
def
save_cmd
(
args
:
dict
[
str
,
Any
])
->
str
:
r
"""Save CLI commands to launch training."""
output_dir
=
args
[
"output_dir"
]
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
with
open
(
os
.
path
.
join
(
output_dir
,
TRAINING_ARGS
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
...
...
@@ -228,9 +198,7 @@ def save_cmd(args: Dict[str, Any]) -> str:
def
load_eval_results
(
path
:
os
.
PathLike
)
->
str
:
r
"""
Gets scores after evaluation.
"""
r
"""Get scores after evaluation."""
with
open
(
path
,
encoding
=
"utf-8"
)
as
f
:
result
=
json
.
dumps
(
json
.
load
(
f
),
indent
=
4
)
...
...
@@ -238,9 +206,7 @@ def load_eval_results(path: os.PathLike) -> str:
def
create_ds_config
()
->
None
:
r
"""
Creates deepspeed config in the current directory.
"""
r
"""Create deepspeed config in the current directory."""
os
.
makedirs
(
DEFAULT_CACHE_DIR
,
exist_ok
=
True
)
ds_config
=
{
"train_batch_size"
:
"auto"
,
...
...
src/llamafactory/webui/components/chatbot.py
View file @
7ea81099
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
import
json
from
typing
import
TYPE_CHECKING
,
Dict
,
Tuple
from
typing
import
TYPE_CHECKING
from
...data
import
Role
from
...extras.packages
import
is_gradio_available
...
...
@@ -31,9 +31,7 @@ if TYPE_CHECKING:
def
check_json_schema
(
text
:
str
,
lang
:
str
)
->
None
:
r
"""
Checks if the json schema is valid.
"""
r
"""Check if the json schema is valid."""
try
:
tools
=
json
.
loads
(
text
)
if
tools
:
...
...
@@ -49,7 +47,7 @@ def check_json_schema(text: str, lang: str) -> None:
def
create_chat_box
(
engine
:
"Engine"
,
visible
:
bool
=
False
)
->
T
uple
[
"Component"
,
"Component"
,
D
ict
[
str
,
"Component"
]]:
)
->
t
uple
[
"Component"
,
"Component"
,
d
ict
[
str
,
"Component"
]]:
lang
=
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)
with
gr
.
Column
(
visible
=
visible
)
as
chat_box
:
chatbot
=
gr
.
Chatbot
(
type
=
"messages"
,
show_copy_button
=
True
)
...
...
src/llamafactory/webui/components/data.py
View file @
7ea81099
...
...
@@ -14,7 +14,7 @@
import
json
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
from
...extras.constants
import
DATA_CONFIG
from
...extras.packages
import
is_gradio_available
...
...
@@ -40,9 +40,7 @@ def next_page(page_index: int, total_num: int) -> int:
def
can_preview
(
dataset_dir
:
str
,
dataset
:
list
)
->
"gr.Button"
:
r
"""
Checks if the dataset is a local dataset.
"""
r
"""Check if the dataset is a local dataset."""
try
:
with
open
(
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
),
encoding
=
"utf-8"
)
as
f
:
dataset_info
=
json
.
load
(
f
)
...
...
@@ -59,7 +57,7 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
return
gr
.
Button
(
interactive
=
False
)
def
_load_data_file
(
file_path
:
str
)
->
L
ist
[
Any
]:
def
_load_data_file
(
file_path
:
str
)
->
l
ist
[
Any
]:
with
open
(
file_path
,
encoding
=
"utf-8"
)
as
f
:
if
file_path
.
endswith
(
".json"
):
return
json
.
load
(
f
)
...
...
@@ -69,10 +67,8 @@ def _load_data_file(file_path: str) -> List[Any]:
return
list
(
f
)
def
get_preview
(
dataset_dir
:
str
,
dataset
:
list
,
page_index
:
int
)
->
Tuple
[
int
,
list
,
"gr.Column"
]:
r
"""
Gets the preview samples from the dataset.
"""
def
get_preview
(
dataset_dir
:
str
,
dataset
:
list
,
page_index
:
int
)
->
tuple
[
int
,
list
,
"gr.Column"
]:
r
"""Get the preview samples from the dataset."""
with
open
(
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
),
encoding
=
"utf-8"
)
as
f
:
dataset_info
=
json
.
load
(
f
)
...
...
@@ -87,7 +83,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
return
len
(
data
),
data
[
PAGE_SIZE
*
page_index
:
PAGE_SIZE
*
(
page_index
+
1
)],
gr
.
Column
(
visible
=
True
)
def
create_preview_box
(
dataset_dir
:
"gr.Textbox"
,
dataset
:
"gr.Dropdown"
)
->
D
ict
[
str
,
"Component"
]:
def
create_preview_box
(
dataset_dir
:
"gr.Textbox"
,
dataset
:
"gr.Dropdown"
)
->
d
ict
[
str
,
"Component"
]:
data_preview_btn
=
gr
.
Button
(
interactive
=
False
,
scale
=
1
)
with
gr
.
Column
(
visible
=
False
,
elem_classes
=
"modal-box"
)
as
preview_box
:
with
gr
.
Row
():
...
...
src/llamafactory/webui/components/eval.py
View file @
7ea81099
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Dict
from
typing
import
TYPE_CHECKING
from
...extras.packages
import
is_gradio_available
from
..common
import
DEFAULT_DATA_DIR
...
...
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
from
..engine
import
Engine
def
create_eval_tab
(
engine
:
"Engine"
)
->
D
ict
[
str
,
"Component"
]:
def
create_eval_tab
(
engine
:
"Engine"
)
->
d
ict
[
str
,
"Component"
]:
input_elems
=
engine
.
manager
.
get_base_elems
()
elem_dict
=
dict
()
...
...
src/llamafactory/webui/components/export.py
View file @
7ea81099
...
...
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Dict
,
Generator
,
List
,
Union
from
collections.abc
import
Generator
from
typing
import
TYPE_CHECKING
,
Union
from
...extras.constants
import
PEFT_METHODS
from
...extras.misc
import
torch_gc
...
...
@@ -35,7 +36,7 @@ if TYPE_CHECKING:
GPTQ_BITS
=
[
"8"
,
"4"
,
"3"
,
"2"
]
def
can_quantize
(
checkpoint_path
:
Union
[
str
,
L
ist
[
str
]])
->
"gr.Dropdown"
:
def
can_quantize
(
checkpoint_path
:
Union
[
str
,
l
ist
[
str
]])
->
"gr.Dropdown"
:
if
isinstance
(
checkpoint_path
,
list
)
and
len
(
checkpoint_path
)
!=
0
:
return
gr
.
Dropdown
(
value
=
"none"
,
interactive
=
False
)
else
:
...
...
@@ -47,7 +48,7 @@ def save_model(
model_name
:
str
,
model_path
:
str
,
finetuning_type
:
str
,
checkpoint_path
:
Union
[
str
,
L
ist
[
str
]],
checkpoint_path
:
Union
[
str
,
l
ist
[
str
]],
template
:
str
,
export_size
:
int
,
export_quantization_bit
:
str
,
...
...
@@ -106,7 +107,7 @@ def save_model(
yield
ALERTS
[
"info_exported"
][
lang
]
def
create_export_tab
(
engine
:
"Engine"
)
->
D
ict
[
str
,
"Component"
]:
def
create_export_tab
(
engine
:
"Engine"
)
->
d
ict
[
str
,
"Component"
]:
with
gr
.
Row
():
export_size
=
gr
.
Slider
(
minimum
=
1
,
maximum
=
100
,
value
=
5
,
step
=
1
)
export_quantization_bit
=
gr
.
Dropdown
(
choices
=
[
"none"
]
+
GPTQ_BITS
,
value
=
"none"
)
...
...
src/llamafactory/webui/components/infer.py
View file @
7ea81099
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Dict
from
typing
import
TYPE_CHECKING
from
...extras.packages
import
is_gradio_available
from
..common
import
is_multimodal
...
...
@@ -29,12 +29,12 @@ if TYPE_CHECKING:
from
..engine
import
Engine
def
create_infer_tab
(
engine
:
"Engine"
)
->
D
ict
[
str
,
"Component"
]:
def
create_infer_tab
(
engine
:
"Engine"
)
->
d
ict
[
str
,
"Component"
]:
input_elems
=
engine
.
manager
.
get_base_elems
()
elem_dict
=
dict
()
with
gr
.
Row
():
infer_backend
=
gr
.
Dropdown
(
choices
=
[
"huggingface"
,
"vllm"
],
value
=
"huggingface"
)
infer_backend
=
gr
.
Dropdown
(
choices
=
[
"huggingface"
,
"vllm"
,
"sglang"
],
value
=
"huggingface"
)
infer_dtype
=
gr
.
Dropdown
(
choices
=
[
"auto"
,
"float16"
,
"bfloat16"
,
"float32"
],
value
=
"auto"
)
with
gr
.
Row
():
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment