Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenych
llama-grpo
Commits
c7c477c7
"template/solar-instruct.gotmpl" did not exist on "9b6c2e6eb62c234f8a44556984bbb680d7065e01"
Commit
c7c477c7
authored
Sep 24, 2025
by
chenych
Browse files
add grpo
parents
Pipeline
#2942
failed with stages
in 0 seconds
Changes
282
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2989 additions
and
0 deletions
+2989
-0
src/llamafactory/train/pt/workflow.py
src/llamafactory/train/pt/workflow.py
+101
-0
src/llamafactory/train/rm/__init__.py
src/llamafactory/train/rm/__init__.py
+18
-0
src/llamafactory/train/rm/metric.py
src/llamafactory/train/rm/metric.py
+51
-0
src/llamafactory/train/rm/trainer.py
src/llamafactory/train/rm/trainer.py
+129
-0
src/llamafactory/train/rm/workflow.py
src/llamafactory/train/rm/workflow.py
+98
-0
src/llamafactory/train/sft/__init__.py
src/llamafactory/train/sft/__init__.py
+18
-0
src/llamafactory/train/sft/metric.py
src/llamafactory/train/sft/metric.py
+134
-0
src/llamafactory/train/sft/trainer.py
src/llamafactory/train/sft/trainer.py
+165
-0
src/llamafactory/train/sft/workflow.py
src/llamafactory/train/sft/workflow.py
+135
-0
src/llamafactory/train/test_utils.py
src/llamafactory/train/test_utils.py
+119
-0
src/llamafactory/train/trainer_utils.py
src/llamafactory/train/trainer_utils.py
+731
-0
src/llamafactory/train/tuner.py
src/llamafactory/train/tuner.py
+198
-0
src/llamafactory/webui/__init__.py
src/llamafactory/webui/__init__.py
+0
-0
src/llamafactory/webui/chatter.py
src/llamafactory/webui/chatter.py
+246
-0
src/llamafactory/webui/common.py
src/llamafactory/webui/common.py
+286
-0
src/llamafactory/webui/components/__init__.py
src/llamafactory/webui/components/__init__.py
+32
-0
src/llamafactory/webui/components/chatbot.py
src/llamafactory/webui/components/chatbot.py
+143
-0
src/llamafactory/webui/components/data.py
src/llamafactory/webui/components/data.py
+122
-0
src/llamafactory/webui/components/eval.py
src/llamafactory/webui/components/eval.py
+94
-0
src/llamafactory/webui/components/export.py
src/llamafactory/webui/components/export.py
+169
-0
No files found.
src/llamafactory/train/pt/workflow.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
from
typing
import
TYPE_CHECKING
,
Optional
from
transformers
import
DataCollatorForLanguageModeling
from
...data
import
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.ploting
import
plot_loss
from
...model
import
load_model
,
load_tokenizer
from
..trainer_utils
import
create_modelcard_and_push
from
.trainer
import
CustomTrainer
if
TYPE_CHECKING
:
from
transformers
import
Seq2SeqTrainingArguments
,
TrainerCallback
from
...hparams
import
DataArguments
,
FinetuningArguments
,
ModelArguments
def
run_pt
(
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
callbacks
:
Optional
[
list
[
"TrainerCallback"
]]
=
None
,
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"pt"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
data_collator
=
DataCollatorForLanguageModeling
(
tokenizer
=
tokenizer
,
mlm
=
False
)
# Initialize our Trainer
trainer
=
CustomTrainer
(
model
=
model
,
args
=
training_args
,
finetuning_args
=
finetuning_args
,
data_collator
=
data_collator
,
callbacks
=
callbacks
,
**
dataset_module
,
**
tokenizer_module
,
)
# Training
if
training_args
.
do_train
:
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
trainer
.
save_model
()
trainer
.
log_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_state
()
if
trainer
.
is_world_process_zero
()
and
finetuning_args
.
plot_loss
:
keys
=
[
"loss"
]
if
isinstance
(
dataset_module
.
get
(
"eval_dataset"
),
dict
):
keys
+=
[
f
"eval_
{
key
}
_loss"
for
key
in
dataset_module
[
"eval_dataset"
].
keys
()]
else
:
keys
+=
[
"eval_loss"
]
plot_loss
(
training_args
.
output_dir
,
keys
=
keys
)
# Evaluation
if
training_args
.
do_eval
:
metrics
=
trainer
.
evaluate
(
metric_key_prefix
=
"eval"
)
if
isinstance
(
dataset_module
.
get
(
"eval_dataset"
),
dict
):
for
key
in
dataset_module
[
"eval_dataset"
].
keys
():
try
:
perplexity
=
math
.
exp
(
metrics
[
f
"eval_
{
key
}
_loss"
])
except
OverflowError
:
perplexity
=
float
(
"inf"
)
metrics
[
f
"eval_
{
key
}
_perplexity"
]
=
perplexity
else
:
try
:
perplexity
=
math
.
exp
(
metrics
[
"eval_loss"
])
except
OverflowError
:
perplexity
=
float
(
"inf"
)
metrics
[
"eval_perplexity"
]
=
perplexity
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
# Create model card
create_modelcard_and_push
(
trainer
,
model_args
,
data_args
,
training_args
,
finetuning_args
)
src/llamafactory/train/rm/__init__.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.workflow
import
run_rm
__all__
=
[
"run_rm"
]
src/llamafactory/train/rm/metric.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
import
numpy
as
np
from
...extras.misc
import
numpify
if
TYPE_CHECKING
:
from
transformers
import
EvalPrediction
@
dataclass
class
ComputeAccuracy
:
r
"""Compute reward accuracy and support `batch_eval_metrics`."""
def
_dump
(
self
)
->
Optional
[
dict
[
str
,
float
]]:
result
=
None
if
hasattr
(
self
,
"score_dict"
):
result
=
{
k
:
float
(
np
.
mean
(
v
))
for
k
,
v
in
self
.
score_dict
.
items
()}
self
.
score_dict
=
{
"accuracy"
:
[]}
return
result
def
__post_init__
(
self
):
self
.
_dump
()
def
__call__
(
self
,
eval_preds
:
"EvalPrediction"
,
compute_result
:
bool
=
True
)
->
Optional
[
dict
[
str
,
float
]]:
chosen_scores
,
rejected_scores
=
numpify
(
eval_preds
.
predictions
[
0
]),
numpify
(
eval_preds
.
predictions
[
1
])
if
not
chosen_scores
.
shape
:
self
.
score_dict
[
"accuracy"
].
append
(
chosen_scores
>
rejected_scores
)
else
:
for
i
in
range
(
len
(
chosen_scores
)):
self
.
score_dict
[
"accuracy"
].
append
(
chosen_scores
[
i
]
>
rejected_scores
[
i
])
if
compute_result
:
return
self
.
_dump
()
src/llamafactory/train/rm/trainer.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
torch
from
transformers
import
Trainer
from
typing_extensions
import
override
from
...extras
import
logging
from
...extras.packages
import
is_transformers_version_greater_than
from
..callbacks
import
FixValueHeadModelCallback
,
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
,
ProcessorMixin
from
transformers.trainer
import
PredictionOutput
from
...hparams
import
FinetuningArguments
logger
=
logging
.
get_logger
(
__name__
)
class
PairwiseTrainer
(
Trainer
):
r
"""Inherits Trainer to compute pairwise loss."""
def
__init__
(
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
)
->
None
:
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
super
().
__init__
(
**
kwargs
)
self
.
model_accepts_loss_kwargs
=
False
# overwrite trainer's default behavior
self
.
finetuning_args
=
finetuning_args
self
.
can_return_loss
=
True
# override property to return eval_loss
self
.
add_callback
(
FixValueHeadModelCallback
)
if
processor
is
not
None
:
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
@
override
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
if
self
.
optimizer
is
None
:
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
return
super
().
create_optimizer
()
@
override
def
create_scheduler
(
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
(
*
args
,
**
kwargs
)
@
override
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
tuple
[
"torch.Tensor"
,
list
[
"torch.Tensor"
]]]:
r
"""Compute pairwise loss. The first n examples are chosen and the last n examples are rejected.
Subclass and override to inject custom behavior.
Note that the first element will be removed from the output tuple.
See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842
"""
_
,
_
,
values
=
model
(
**
inputs
,
output_hidden_states
=
True
,
return_dict
=
True
,
use_cache
=
False
)
batch_size
=
inputs
[
"input_ids"
].
size
(
0
)
//
2
chosen_masks
,
rejected_masks
=
torch
.
split
(
inputs
[
"attention_mask"
],
batch_size
,
dim
=
0
)
chosen_rewards
,
rejected_rewards
=
torch
.
split
(
values
,
batch_size
,
dim
=
0
)
chosen_scores
=
chosen_rewards
.
gather
(
dim
=-
1
,
index
=
(
chosen_masks
.
sum
(
dim
=-
1
,
keepdim
=
True
)
-
1
))
rejected_scores
=
rejected_rewards
.
gather
(
dim
=-
1
,
index
=
(
rejected_masks
.
sum
(
dim
=-
1
,
keepdim
=
True
)
-
1
))
chosen_scores
,
rejected_scores
=
chosen_scores
.
squeeze
(),
rejected_scores
.
squeeze
()
loss
=
-
torch
.
nn
.
functional
.
logsigmoid
(
chosen_scores
.
float
()
-
rejected_scores
.
float
()).
mean
()
if
return_outputs
:
return
loss
,
(
loss
,
chosen_scores
,
rejected_scores
)
else
:
return
loss
def
save_predictions
(
self
,
predict_results
:
"PredictionOutput"
)
->
None
:
r
"""Save model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
if
not
self
.
is_world_process_zero
():
return
output_prediction_file
=
os
.
path
.
join
(
self
.
args
.
output_dir
,
"generated_predictions.jsonl"
)
logger
.
info_rank0
(
f
"Saving prediction results to
{
output_prediction_file
}
"
)
chosen_scores
,
rejected_scores
=
predict_results
.
predictions
with
open
(
output_prediction_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
res
:
list
[
str
]
=
[]
for
c_score
,
r_score
in
zip
(
chosen_scores
,
rejected_scores
):
res
.
append
(
json
.
dumps
({
"chosen"
:
round
(
float
(
c_score
),
2
),
"rejected"
:
round
(
float
(
r_score
),
2
)}))
writer
.
write
(
"
\n
"
.
join
(
res
))
src/llamafactory/train/rm/workflow.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.ploting
import
plot_loss
from
...model
import
load_model
,
load_tokenizer
from
..callbacks
import
fix_valuehead_checkpoint
from
..trainer_utils
import
create_modelcard_and_push
from
.metric
import
ComputeAccuracy
from
.trainer
import
PairwiseTrainer
if
TYPE_CHECKING
:
from
transformers
import
Seq2SeqTrainingArguments
,
TrainerCallback
from
...hparams
import
DataArguments
,
FinetuningArguments
,
ModelArguments
def
run_rm
(
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
callbacks
:
Optional
[
list
[
"TrainerCallback"
]]
=
None
,
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"rm"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
,
add_valuehead
=
True
)
data_collator
=
PairwiseDataCollatorWithPadding
(
template
=
template
,
model
=
model
,
pad_to_multiple_of
=
8
,
**
tokenizer_module
)
# Initialize our Trainer
trainer
=
PairwiseTrainer
(
model
=
model
,
args
=
training_args
,
finetuning_args
=
finetuning_args
,
data_collator
=
data_collator
,
callbacks
=
callbacks
,
compute_metrics
=
ComputeAccuracy
(),
**
dataset_module
,
**
tokenizer_module
,
)
# Training
if
training_args
.
do_train
:
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
trainer
.
save_model
()
if
training_args
.
should_save
:
fix_valuehead_checkpoint
(
model
,
training_args
.
output_dir
,
training_args
.
save_safetensors
)
trainer
.
log_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_state
()
if
trainer
.
is_world_process_zero
()
and
finetuning_args
.
plot_loss
:
keys
=
[
"loss"
]
if
isinstance
(
dataset_module
.
get
(
"eval_dataset"
),
dict
):
keys
+=
sum
(
[[
f
"eval_
{
key
}
_loss"
,
f
"eval_
{
key
}
_accuracy"
]
for
key
in
dataset_module
[
"eval_dataset"
].
keys
()],
[]
)
else
:
keys
+=
[
"eval_loss"
,
"eval_accuracy"
]
plot_loss
(
training_args
.
output_dir
,
keys
=
keys
)
# Evaluation
if
training_args
.
do_eval
:
metrics
=
trainer
.
evaluate
(
metric_key_prefix
=
"eval"
)
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
# Predict
if
training_args
.
do_predict
:
predict_results
=
trainer
.
predict
(
dataset_module
[
"eval_dataset"
],
metric_key_prefix
=
"predict"
)
trainer
.
log_metrics
(
"predict"
,
predict_results
.
metrics
)
trainer
.
save_metrics
(
"predict"
,
predict_results
.
metrics
)
trainer
.
save_predictions
(
predict_results
)
# Create model card
create_modelcard_and_push
(
trainer
,
model_args
,
data_args
,
training_args
,
finetuning_args
)
src/llamafactory/train/sft/__init__.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.workflow
import
run_sft
__all__
=
[
"run_sft"
]
src/llamafactory/train/sft/metric.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 HuggingFace Inc., THUDM, and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
# https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
import
numpy
as
np
import
torch
from
transformers.utils
import
is_jieba_available
,
is_nltk_available
from
...extras.constants
import
IGNORE_INDEX
from
...extras.misc
import
numpify
from
...extras.packages
import
is_rouge_available
if
TYPE_CHECKING
:
from
transformers
import
EvalPrediction
,
PreTrainedTokenizer
if
is_jieba_available
():
import
jieba
# type: ignore
if
is_nltk_available
():
from
nltk.translate.bleu_score
import
SmoothingFunction
,
sentence_bleu
# type: ignore
if
is_rouge_available
():
from
rouge_chinese
import
Rouge
# type: ignore
def
eval_logit_processor
(
logits
:
"torch.Tensor"
,
labels
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""Compute the token with the largest likelihood to reduce memory footprint."""
if
isinstance
(
logits
,
(
list
,
tuple
)):
if
logits
[
0
].
dim
()
==
3
:
# (batch_size, seq_len, vocab_size)
logits
=
logits
[
0
]
else
:
# moe models have aux loss
logits
=
logits
[
1
]
if
logits
.
dim
()
!=
3
:
raise
ValueError
(
"Cannot process the logits."
)
return
torch
.
argmax
(
logits
,
dim
=-
1
)
@
dataclass
class
ComputeAccuracy
:
r
"""Compute accuracy and support `batch_eval_metrics`."""
def
_dump
(
self
)
->
Optional
[
dict
[
str
,
float
]]:
result
=
None
if
hasattr
(
self
,
"score_dict"
):
result
=
{
k
:
float
(
np
.
mean
(
v
))
for
k
,
v
in
self
.
score_dict
.
items
()}
self
.
score_dict
=
{
"accuracy"
:
[]}
return
result
def
__post_init__
(
self
):
self
.
_dump
()
def
__call__
(
self
,
eval_preds
:
"EvalPrediction"
,
compute_result
:
bool
=
True
)
->
Optional
[
dict
[
str
,
float
]]:
preds
,
labels
=
numpify
(
eval_preds
.
predictions
),
numpify
(
eval_preds
.
label_ids
)
for
i
in
range
(
len
(
preds
)):
pred
,
label
=
preds
[
i
,
:
-
1
],
labels
[
i
,
1
:]
label_mask
=
label
!=
IGNORE_INDEX
self
.
score_dict
[
"accuracy"
].
append
(
np
.
mean
(
pred
[
label_mask
]
==
label
[
label_mask
]))
if
compute_result
:
return
self
.
_dump
()
@
dataclass
class
ComputeSimilarity
:
r
"""Compute text similarity scores and support `batch_eval_metrics`.
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
"""
tokenizer
:
"PreTrainedTokenizer"
def
_dump
(
self
)
->
Optional
[
dict
[
str
,
float
]]:
result
=
None
if
hasattr
(
self
,
"score_dict"
):
result
=
{
k
:
float
(
np
.
mean
(
v
))
for
k
,
v
in
self
.
score_dict
.
items
()}
self
.
score_dict
=
{
"rouge-1"
:
[],
"rouge-2"
:
[],
"rouge-l"
:
[],
"bleu-4"
:
[]}
return
result
def
__post_init__
(
self
):
self
.
_dump
()
def
__call__
(
self
,
eval_preds
:
"EvalPrediction"
,
compute_result
:
bool
=
True
)
->
Optional
[
dict
[
str
,
float
]]:
preds
,
labels
=
numpify
(
eval_preds
.
predictions
),
numpify
(
eval_preds
.
label_ids
)
preds
=
np
.
where
(
preds
!=
IGNORE_INDEX
,
preds
,
self
.
tokenizer
.
pad_token_id
)
labels
=
np
.
where
(
labels
!=
IGNORE_INDEX
,
labels
,
self
.
tokenizer
.
pad_token_id
)
decoded_preds
=
self
.
tokenizer
.
batch_decode
(
preds
,
skip_special_tokens
=
True
)
decoded_labels
=
self
.
tokenizer
.
batch_decode
(
labels
,
skip_special_tokens
=
True
)
for
pred
,
label
in
zip
(
decoded_preds
,
decoded_labels
):
hypothesis
=
list
(
jieba
.
cut
(
pred
))
reference
=
list
(
jieba
.
cut
(
label
))
if
len
(
" "
.
join
(
hypothesis
).
split
())
==
0
or
len
(
" "
.
join
(
reference
).
split
())
==
0
:
result
=
{
"rouge-1"
:
{
"f"
:
0.0
},
"rouge-2"
:
{
"f"
:
0.0
},
"rouge-l"
:
{
"f"
:
0.0
}}
else
:
rouge
=
Rouge
()
scores
=
rouge
.
get_scores
(
" "
.
join
(
hypothesis
),
" "
.
join
(
reference
))
result
=
scores
[
0
]
for
k
,
v
in
result
.
items
():
self
.
score_dict
[
k
].
append
(
round
(
v
[
"f"
]
*
100
,
4
))
bleu_score
=
sentence_bleu
([
list
(
label
)],
list
(
pred
),
smoothing_function
=
SmoothingFunction
().
method3
)
self
.
score_dict
[
"bleu-4"
].
append
(
round
(
bleu_score
*
100
,
4
))
if
compute_result
:
return
self
.
_dump
()
src/llamafactory/train/sft/trainer.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
import
numpy
as
np
import
torch
from
transformers
import
Seq2SeqTrainer
from
typing_extensions
import
override
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
...extras.packages
import
is_transformers_version_greater_than
from
..callbacks
import
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
if
TYPE_CHECKING
:
from
torch.utils.data
import
Dataset
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
transformers.trainer
import
PredictionOutput
from
...hparams
import
FinetuningArguments
logger
=
logging
.
get_logger
(
__name__
)
class
CustomSeq2SeqTrainer
(
Seq2SeqTrainer
):
r
"""Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE."""
def
__init__
(
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
gen_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
**
kwargs
,
)
->
None
:
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
else
:
self
.
processing_class
:
PreTrainedTokenizer
=
kwargs
.
get
(
"tokenizer"
)
super
().
__init__
(
**
kwargs
)
if
processor
is
not
None
:
# avoid wrong loss under gradient accumulation
# https://github.com/huggingface/transformers/pull/36044#issuecomment-2746657112
self
.
model_accepts_loss_kwargs
=
False
self
.
finetuning_args
=
finetuning_args
if
gen_kwargs
is
not
None
:
# https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/trainer_seq2seq.py#L287
self
.
_gen_kwargs
=
gen_kwargs
if
processor
is
not
None
:
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
@
override
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
if
self
.
optimizer
is
None
:
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
return
super
().
create_optimizer
()
@
override
def
create_scheduler
(
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
(
*
args
,
**
kwargs
)
@
override
def
compute_loss
(
self
,
model
,
inputs
,
*
args
,
**
kwargs
):
return
super
().
compute_loss
(
model
,
inputs
,
*
args
,
**
kwargs
)
@
override
def
prediction_step
(
self
,
model
:
"torch.nn.Module"
,
inputs
:
dict
[
str
,
Union
[
"torch.Tensor"
,
Any
]],
prediction_loss_only
:
bool
,
ignore_keys
:
Optional
[
list
[
str
]]
=
None
,
**
gen_kwargs
,
)
->
tuple
[
Optional
[
float
],
Optional
[
"torch.Tensor"
],
Optional
[
"torch.Tensor"
]]:
r
"""Remove the prompt part in the generated tokens.
Subclass and override to inject custom behavior.
"""
if
self
.
args
.
predict_with_generate
:
# do not pass labels to model when generate
labels
=
inputs
.
pop
(
"labels"
,
None
)
else
:
labels
=
inputs
.
get
(
"labels"
)
loss
,
generated_tokens
,
_
=
super
().
prediction_step
(
model
,
inputs
,
prediction_loss_only
=
prediction_loss_only
,
ignore_keys
=
ignore_keys
,
**
gen_kwargs
)
if
generated_tokens
is
not
None
and
self
.
args
.
predict_with_generate
:
generated_tokens
[:,
:
inputs
[
"input_ids"
].
size
(
-
1
)]
=
self
.
processing_class
.
pad_token_id
generated_tokens
=
generated_tokens
.
contiguous
()
return
loss
,
generated_tokens
,
labels
def
save_predictions
(
self
,
dataset
:
"Dataset"
,
predict_results
:
"PredictionOutput"
,
skip_special_tokens
:
bool
=
True
)
->
None
:
r
"""Save model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
if
not
self
.
is_world_process_zero
():
return
output_prediction_file
=
os
.
path
.
join
(
self
.
args
.
output_dir
,
"generated_predictions.jsonl"
)
logger
.
info_rank0
(
f
"Saving prediction results to
{
output_prediction_file
}
"
)
labels
=
np
.
where
(
predict_results
.
label_ids
!=
IGNORE_INDEX
,
predict_results
.
label_ids
,
self
.
processing_class
.
pad_token_id
)
preds
=
np
.
where
(
predict_results
.
predictions
!=
IGNORE_INDEX
,
predict_results
.
predictions
,
self
.
processing_class
.
pad_token_id
,
)
for
i
in
range
(
len
(
preds
)):
pad_len
=
np
.
nonzero
(
preds
[
i
]
!=
self
.
processing_class
.
pad_token_id
)[
0
]
if
len
(
pad_len
):
# move pad token to last
preds
[
i
]
=
np
.
concatenate
((
preds
[
i
][
pad_len
[
0
]
:],
preds
[
i
][:
pad_len
[
0
]]),
axis
=-
1
)
decoded_inputs
=
self
.
processing_class
.
batch_decode
(
dataset
[
"input_ids"
],
skip_special_tokens
=
False
)
decoded_preds
=
self
.
processing_class
.
batch_decode
(
preds
,
skip_special_tokens
=
skip_special_tokens
)
decoded_labels
=
self
.
processing_class
.
batch_decode
(
labels
,
skip_special_tokens
=
skip_special_tokens
)
with
open
(
output_prediction_file
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
for
text
,
pred
,
label
in
zip
(
decoded_inputs
,
decoded_preds
,
decoded_labels
):
f
.
write
(
json
.
dumps
({
"prompt"
:
text
,
"predict"
:
pred
,
"label"
:
label
},
ensure_ascii
=
False
)
+
"
\n
"
)
src/llamafactory/train/sft/workflow.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
SFTDataCollatorWith4DAttentionMask
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.logging
import
get_logger
from
...extras.misc
import
calculate_tps
from
...extras.ploting
import
plot_loss
from
...model
import
load_model
,
load_tokenizer
from
..trainer_utils
import
create_modelcard_and_push
from
.metric
import
ComputeAccuracy
,
ComputeSimilarity
,
eval_logit_processor
from
.trainer
import
CustomSeq2SeqTrainer
if
TYPE_CHECKING
:
from
transformers
import
Seq2SeqTrainingArguments
,
TrainerCallback
from
...hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
get_logger
(
__name__
)
def
run_sft
(
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
callbacks
:
Optional
[
list
[
"TrainerCallback"
]]
=
None
,
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"sft"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
if
getattr
(
model
,
"is_quantized"
,
False
)
and
not
training_args
.
do_train
:
setattr
(
model
,
"_hf_peft_config_loaded"
,
True
)
# hack here: make model compatible with prediction
data_collator
=
SFTDataCollatorWith4DAttentionMask
(
template
=
template
,
model
=
model
if
not
training_args
.
predict_with_generate
else
None
,
pad_to_multiple_of
=
8
if
training_args
.
do_train
else
None
,
# for shift short attention
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
block_diag_attn
=
model_args
.
block_diag_attn
,
attn_implementation
=
getattr
(
model
.
config
,
"_attn_implementation"
,
None
),
compute_dtype
=
model_args
.
compute_dtype
,
**
tokenizer_module
,
)
# Metric utils
metric_module
=
{}
if
training_args
.
predict_with_generate
:
metric_module
[
"compute_metrics"
]
=
ComputeSimilarity
(
tokenizer
=
tokenizer
)
elif
finetuning_args
.
compute_accuracy
:
metric_module
[
"compute_metrics"
]
=
ComputeAccuracy
()
metric_module
[
"preprocess_logits_for_metrics"
]
=
eval_logit_processor
# Keyword arguments for `model.generate`
gen_kwargs
=
generating_args
.
to_dict
(
obey_generation_config
=
True
)
gen_kwargs
[
"eos_token_id"
]
=
[
tokenizer
.
eos_token_id
]
+
tokenizer
.
additional_special_tokens_ids
gen_kwargs
[
"pad_token_id"
]
=
tokenizer
.
pad_token_id
# Initialize our Trainer
trainer
=
CustomSeq2SeqTrainer
(
model
=
model
,
args
=
training_args
,
finetuning_args
=
finetuning_args
,
data_collator
=
data_collator
,
callbacks
=
callbacks
,
gen_kwargs
=
gen_kwargs
,
**
dataset_module
,
**
tokenizer_module
,
**
metric_module
,
)
# Training
if
training_args
.
do_train
:
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
trainer
.
save_model
()
if
finetuning_args
.
include_effective_tokens_per_second
:
train_result
.
metrics
[
"effective_tokens_per_sec"
]
=
calculate_tps
(
dataset_module
[
"train_dataset"
],
train_result
.
metrics
,
stage
=
"sft"
)
trainer
.
log_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_state
()
if
trainer
.
is_world_process_zero
()
and
finetuning_args
.
plot_loss
:
keys
=
[
"loss"
]
if
isinstance
(
dataset_module
.
get
(
"eval_dataset"
),
dict
):
keys
+=
sum
(
[[
f
"eval_
{
key
}
_loss"
,
f
"eval_
{
key
}
_accuracy"
]
for
key
in
dataset_module
[
"eval_dataset"
].
keys
()],
[]
)
else
:
keys
+=
[
"eval_loss"
,
"eval_accuracy"
]
plot_loss
(
training_args
.
output_dir
,
keys
=
keys
)
if
training_args
.
predict_with_generate
:
tokenizer
.
padding_side
=
"left"
# use left-padding in generation
# Evaluation
if
training_args
.
do_eval
:
metrics
=
trainer
.
evaluate
(
metric_key_prefix
=
"eval"
,
**
gen_kwargs
)
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
# Predict
if
training_args
.
do_predict
:
logger
.
warning_rank0_once
(
"Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead."
)
predict_results
=
trainer
.
predict
(
dataset_module
[
"eval_dataset"
],
metric_key_prefix
=
"predict"
,
**
gen_kwargs
)
trainer
.
log_metrics
(
"predict"
,
predict_results
.
metrics
)
trainer
.
save_metrics
(
"predict"
,
predict_results
.
metrics
)
trainer
.
save_predictions
(
dataset_module
[
"eval_dataset"
],
predict_results
,
generating_args
.
skip_special_tokens
)
# Create model card
create_modelcard_and_push
(
trainer
,
model_args
,
data_args
,
training_args
,
finetuning_args
)
src/llamafactory/train/test_utils.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
torch
from
peft
import
PeftModel
from
transformers
import
AutoModelForCausalLM
from
trl
import
AutoModelForCausalLMWithValueHead
from
..data
import
get_dataset
,
get_template_and_fix_tokenizer
from
..extras.misc
import
get_current_device
from
..hparams
import
get_infer_args
,
get_train_args
from
..model
import
load_model
,
load_tokenizer
if
TYPE_CHECKING
:
from
peft
import
LoraModel
from
transformers
import
PreTrainedModel
from
..data.data_utils
import
DatasetModule
def
compare_model
(
model_a
:
"torch.nn.Module"
,
model_b
:
"torch.nn.Module"
,
diff_keys
:
list
[
str
]
=
[])
->
None
:
state_dict_a
=
model_a
.
state_dict
()
state_dict_b
=
model_b
.
state_dict
()
assert
set
(
state_dict_a
.
keys
())
==
set
(
state_dict_b
.
keys
())
for
name
in
state_dict_a
.
keys
():
if
any
(
key
in
name
for
key
in
diff_keys
):
assert
torch
.
allclose
(
state_dict_a
[
name
],
state_dict_b
[
name
],
rtol
=
1e-4
,
atol
=
1e-5
)
is
False
else
:
assert
torch
.
allclose
(
state_dict_a
[
name
],
state_dict_b
[
name
],
rtol
=
1e-4
,
atol
=
1e-5
)
is
True
def
check_lora_model
(
model
:
"LoraModel"
)
->
tuple
[
set
[
str
],
set
[
str
]]:
linear_modules
,
extra_modules
=
set
(),
set
()
for
name
,
param
in
model
.
named_parameters
():
if
any
(
module
in
name
for
module
in
[
"lora_A"
,
"lora_B"
]):
linear_modules
.
add
(
name
.
split
(
".lora_"
,
maxsplit
=
1
)[
0
].
split
(
"."
)[
-
1
])
assert
param
.
requires_grad
is
True
assert
param
.
dtype
==
torch
.
float32
elif
"modules_to_save"
in
name
:
extra_modules
.
add
(
name
.
split
(
".modules_to_save"
,
maxsplit
=
1
)[
0
].
split
(
"."
)[
-
1
])
assert
param
.
requires_grad
is
True
assert
param
.
dtype
==
torch
.
float32
else
:
assert
param
.
requires_grad
is
False
assert
param
.
dtype
==
torch
.
float16
return
linear_modules
,
extra_modules
def
load_train_model
(
add_valuehead
:
bool
=
False
,
**
kwargs
)
->
"PreTrainedModel"
:
model_args
,
_
,
_
,
finetuning_args
,
_
=
get_train_args
(
kwargs
)
tokenizer
=
load_tokenizer
(
model_args
)[
"tokenizer"
]
return
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
is_trainable
=
True
,
add_valuehead
=
add_valuehead
)
def
load_infer_model
(
add_valuehead
:
bool
=
False
,
**
kwargs
)
->
"PreTrainedModel"
:
model_args
,
_
,
finetuning_args
,
_
=
get_infer_args
(
kwargs
)
tokenizer
=
load_tokenizer
(
model_args
)[
"tokenizer"
]
return
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
is_trainable
=
False
,
add_valuehead
=
add_valuehead
)
def
load_reference_model
(
model_path
:
str
,
lora_path
:
Optional
[
str
]
=
None
,
use_lora
:
bool
=
False
,
use_pissa
:
bool
=
False
,
is_trainable
:
bool
=
False
,
add_valuehead
:
bool
=
False
,
)
->
Union
[
"PreTrainedModel"
,
"LoraModel"
]:
current_device
=
get_current_device
()
if
add_valuehead
:
model
:
AutoModelForCausalLMWithValueHead
=
AutoModelForCausalLMWithValueHead
.
from_pretrained
(
model_path
,
torch_dtype
=
torch
.
float16
,
device_map
=
current_device
)
if
not
is_trainable
:
model
.
v_head
=
model
.
v_head
.
to
(
torch
.
float16
)
return
model
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
torch_dtype
=
torch
.
float16
,
device_map
=
current_device
)
if
use_lora
or
use_pissa
:
model
=
PeftModel
.
from_pretrained
(
model
,
lora_path
,
subfolder
=
"pissa_init"
if
use_pissa
else
None
,
is_trainable
=
is_trainable
)
for
param
in
filter
(
lambda
p
:
p
.
requires_grad
,
model
.
parameters
()):
param
.
data
=
param
.
data
.
to
(
torch
.
float32
)
return
model
def
load_dataset_module
(
**
kwargs
)
->
"DatasetModule"
:
model_args
,
data_args
,
training_args
,
_
,
_
=
get_train_args
(
kwargs
)
tokenizer_module
=
load_tokenizer
(
model_args
)
template
=
get_template_and_fix_tokenizer
(
tokenizer_module
[
"tokenizer"
],
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
kwargs
[
"stage"
],
**
tokenizer_module
)
return
dataset_module
def
patch_valuehead_model
()
->
None
:
def
post_init
(
self
:
"AutoModelForCausalLMWithValueHead"
,
state_dict
:
dict
[
str
,
"torch.Tensor"
])
->
None
:
state_dict
=
{
k
[
7
:]:
state_dict
[
k
]
for
k
in
state_dict
.
keys
()
if
k
.
startswith
(
"v_head."
)}
self
.
v_head
.
load_state_dict
(
state_dict
,
strict
=
False
)
del
state_dict
AutoModelForCausalLMWithValueHead
.
post_init
=
post_init
src/llamafactory/train/trainer_utils.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore
# and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus
# and the original BAdam's implementation: https://github.com/Ledzy/BAdam
# and the HuggingFace's TRL library: https://github.com/huggingface/trl
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
from
collections.abc
import
Mapping
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
,
Union
import
torch
from
transformers
import
Trainer
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.modeling_utils
import
is_fsdp_enabled
from
transformers.optimization
import
get_scheduler
from
transformers.pytorch_utils
import
ALL_LAYERNORM_LAYERS
from
transformers.trainer_pt_utils
import
get_parameter_names
from
typing_extensions
import
override
from
..extras
import
logging
from
..extras.constants
import
IGNORE_INDEX
,
SWANLAB_CONFIG
from
..extras.packages
import
is_apollo_available
,
is_galore_available
,
is_ray_available
from
..hparams
import
FinetuningArguments
,
ModelArguments
from
..model
import
find_all_linear_modules
,
load_model
,
load_tokenizer
,
load_valuehead_params
if
is_galore_available
():
from
galore_torch
import
GaLoreAdafactor
,
GaLoreAdamW
,
GaLoreAdamW8bit
# type: ignore
if
is_apollo_available
():
from
apollo_torch
import
APOLLOAdamW
# type: ignore
if
is_ray_available
():
import
ray
from
ray.train
import
RunConfig
,
ScalingConfig
from
ray.train.torch
import
TorchTrainer
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
,
TrainerCallback
,
TrainerState
from
trl
import
AutoModelForCausalLMWithValueHead
from
..hparams
import
DataArguments
,
RayArguments
,
TrainingArguments
logger
=
logging
.
get_logger
(
__name__
)
class
DummyOptimizer
(
torch
.
optim
.
Optimizer
):
r
"""A dummy optimizer used for the GaLore or APOLLO algorithm."""
def
__init__
(
self
,
lr
:
float
=
1e-3
,
optimizer_dict
:
Optional
[
dict
[
"torch.nn.Parameter"
,
"torch.optim.Optimizer"
]]
=
None
)
->
None
:
dummy_tensor
=
torch
.
randn
(
1
,
1
)
self
.
optimizer_dict
=
optimizer_dict
super
().
__init__
([
dummy_tensor
],
{
"lr"
:
lr
})
@
override
def
zero_grad
(
self
,
set_to_none
:
bool
=
True
)
->
None
:
pass
@
override
def
step
(
self
,
closure
:
Optional
[
Callable
[[],
float
]]
=
None
)
->
Optional
[
float
]:
pass
def
create_modelcard_and_push
(
trainer
:
"Trainer"
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"TrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
None
:
kwargs
=
{
"tasks"
:
"text-generation"
,
"finetuned_from"
:
model_args
.
model_name_or_path
,
"tags"
:
[
"llama-factory"
,
finetuning_args
.
finetuning_type
],
}
if
data_args
.
dataset
is
not
None
:
kwargs
[
"dataset"
]
=
data_args
.
dataset
if
model_args
.
use_unsloth
:
kwargs
[
"tags"
]
=
kwargs
[
"tags"
]
+
[
"unsloth"
]
if
not
training_args
.
do_train
:
pass
elif
training_args
.
push_to_hub
:
trainer
.
push_to_hub
(
**
kwargs
)
else
:
trainer
.
create_model_card
(
license
=
"other"
,
**
kwargs
)
# prevent from connecting to hub
def
create_ref_model
(
model_args
:
"ModelArguments"
,
finetuning_args
:
"FinetuningArguments"
,
add_valuehead
:
bool
=
False
)
->
Optional
[
Union
[
"PreTrainedModel"
,
"AutoModelForCausalLMWithValueHead"
]]:
r
"""Create reference model for PPO/DPO training. Evaluation mode is not supported.
The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
if
finetuning_args
.
ref_model
is
not
None
:
ref_model_args
=
ModelArguments
.
copyfrom
(
model_args
,
model_name_or_path
=
finetuning_args
.
ref_model
,
adapter_name_or_path
=
finetuning_args
.
ref_model_adapters
,
quantization_bit
=
finetuning_args
.
ref_model_quantization_bit
,
)
ref_finetuning_args
=
FinetuningArguments
()
tokenizer
=
load_tokenizer
(
ref_model_args
)[
"tokenizer"
]
ref_model
=
load_model
(
tokenizer
,
ref_model_args
,
ref_finetuning_args
,
is_trainable
=
False
,
add_valuehead
=
add_valuehead
)
logger
.
info_rank0
(
f
"Created reference model from
{
finetuning_args
.
ref_model
}
"
)
else
:
if
finetuning_args
.
finetuning_type
==
"lora"
:
ref_model
=
None
else
:
ref_model_args
=
ModelArguments
.
copyfrom
(
model_args
)
ref_finetuning_args
=
FinetuningArguments
()
tokenizer
=
load_tokenizer
(
ref_model_args
)[
"tokenizer"
]
ref_model
=
load_model
(
tokenizer
,
ref_model_args
,
ref_finetuning_args
,
is_trainable
=
False
,
add_valuehead
=
add_valuehead
)
logger
.
info_rank0
(
"Created reference model from the model itself."
)
return
ref_model
def
create_reward_model
(
model
:
"AutoModelForCausalLMWithValueHead"
,
model_args
:
"ModelArguments"
,
finetuning_args
:
"FinetuningArguments"
)
->
Optional
[
"AutoModelForCausalLMWithValueHead"
]:
r
"""Create reward model for PPO training."""
if
finetuning_args
.
reward_model_type
==
"api"
:
assert
finetuning_args
.
reward_model
.
startswith
(
"http"
),
"Please provide full url."
logger
.
info_rank0
(
f
"Use reward server
{
finetuning_args
.
reward_model
}
"
)
return
finetuning_args
.
reward_model
elif
finetuning_args
.
reward_model_type
==
"lora"
:
model
.
pretrained_model
.
load_adapter
(
finetuning_args
.
reward_model
,
"reward"
)
for
name
,
param
in
model
.
named_parameters
():
# https://github.com/huggingface/peft/issues/1090
if
"default"
in
name
:
param
.
data
=
param
.
data
.
to
(
torch
.
float32
)
# trainable params should in fp32
vhead_params
=
load_valuehead_params
(
finetuning_args
.
reward_model
,
model_args
)
assert
vhead_params
is
not
None
,
"Reward model is not correctly loaded."
model
.
register_buffer
(
"reward_head_weight"
,
vhead_params
[
"v_head.summary.weight"
],
persistent
=
False
)
model
.
register_buffer
(
"reward_head_bias"
,
vhead_params
[
"v_head.summary.bias"
],
persistent
=
False
)
model
.
register_buffer
(
"default_head_weight"
,
torch
.
zeros_like
(
vhead_params
[
"v_head.summary.weight"
]),
persistent
=
False
)
model
.
register_buffer
(
"default_head_bias"
,
torch
.
zeros_like
(
vhead_params
[
"v_head.summary.bias"
]),
persistent
=
False
)
logger
.
info_rank0
(
f
"Loaded adapter weights of reward model from
{
finetuning_args
.
reward_model
}
"
)
return
None
else
:
reward_model_args
=
ModelArguments
.
copyfrom
(
model_args
,
model_name_or_path
=
finetuning_args
.
reward_model
,
adapter_name_or_path
=
finetuning_args
.
reward_model_adapters
,
quantization_bit
=
finetuning_args
.
reward_model_quantization_bit
,
)
reward_finetuning_args
=
FinetuningArguments
()
tokenizer
=
load_tokenizer
(
reward_model_args
)[
"tokenizer"
]
reward_model
=
load_model
(
tokenizer
,
reward_model_args
,
reward_finetuning_args
,
is_trainable
=
False
,
add_valuehead
=
True
)
logger
.
info_rank0
(
f
"Loaded full weights of reward model from
{
finetuning_args
.
reward_model
}
"
)
logger
.
warning_rank0
(
"Please ensure the ppo model and reward model share SAME tokenizer and vocabulary."
)
return
reward_model
def
_get_decay_parameter_names
(
model
:
"PreTrainedModel"
)
->
list
[
str
]:
r
"""Return a list of names of parameters with weight decay. (weights in non-layernorm layers)."""
decay_parameters
=
get_parameter_names
(
model
,
ALL_LAYERNORM_LAYERS
)
decay_parameters
=
[
name
for
name
in
decay_parameters
if
"bias"
not
in
name
]
return
decay_parameters
def
_create_galore_optimizer
(
model
:
"PreTrainedModel"
,
training_args
:
"TrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
"torch.optim.Optimizer"
:
if
len
(
finetuning_args
.
galore_target
)
==
1
and
finetuning_args
.
galore_target
[
0
]
==
"all"
:
galore_targets
=
find_all_linear_modules
(
model
,
finetuning_args
.
freeze_vision_tower
)
else
:
galore_targets
=
finetuning_args
.
galore_target
galore_params
:
list
[
torch
.
nn
.
Parameter
]
=
[]
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
torch
.
nn
.
Linear
)
and
any
(
target
in
name
for
target
in
galore_targets
):
for
param
in
module
.
parameters
():
if
param
.
requires_grad
and
len
(
param
.
shape
)
>
1
:
galore_params
.
append
(
param
)
galore_kwargs
=
{
"rank"
:
finetuning_args
.
galore_rank
,
"update_proj_gap"
:
finetuning_args
.
galore_update_interval
,
"scale"
:
finetuning_args
.
galore_scale
,
"proj_type"
:
finetuning_args
.
galore_proj_type
,
}
id_galore_params
=
{
id
(
param
)
for
param
in
galore_params
}
decay_params
,
nodecay_params
=
[],
[]
# they are non-galore parameters
trainable_params
:
list
[
torch
.
nn
.
Parameter
]
=
[]
# galore_params + decay_params + nodecay_params
decay_param_names
=
_get_decay_parameter_names
(
model
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
trainable_params
.
append
(
param
)
if
id
(
param
)
not
in
id_galore_params
:
if
name
in
decay_param_names
:
decay_params
.
append
(
param
)
else
:
nodecay_params
.
append
(
param
)
_
,
optim_kwargs
=
Trainer
.
get_optimizer_cls_and_kwargs
(
training_args
)
if
training_args
.
optim
==
"adamw_torch"
:
optim_class
=
GaLoreAdamW
elif
training_args
.
optim
in
[
"adamw_bnb_8bit"
,
"adamw_8bit"
,
"paged_adamw_8bit"
]:
optim_class
=
GaLoreAdamW8bit
elif
training_args
.
optim
==
"adafactor"
:
optim_class
=
GaLoreAdafactor
else
:
raise
NotImplementedError
(
f
"Unknown optim:
{
training_args
.
optim
}
."
)
if
finetuning_args
.
galore_layerwise
:
logger
.
warning_rank0
(
"The displayed gradient norm will be all zeros in layerwise GaLore."
)
if
training_args
.
gradient_accumulation_steps
!=
1
:
raise
ValueError
(
"Per-layer GaLore does not support gradient accumulation."
)
optimizer_dict
:
dict
[
torch
.
Tensor
,
torch
.
optim
.
Optimizer
]
=
{}
for
param
in
nodecay_params
:
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
0.0
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
for
param
in
decay_params
:
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
training_args
.
weight_decay
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
for
param
in
galore_params
:
# galore params have weight decay
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
training_args
.
weight_decay
,
**
galore_kwargs
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
def
optimizer_hook
(
param
:
"torch.nn.Parameter"
):
if
param
.
grad
is
not
None
:
optimizer_dict
[
param
].
step
()
optimizer_dict
[
param
].
zero_grad
()
for
param
in
trainable_params
:
param
.
register_post_accumulate_grad_hook
(
optimizer_hook
)
optimizer
=
DummyOptimizer
(
lr
=
training_args
.
learning_rate
,
optimizer_dict
=
optimizer_dict
)
else
:
param_groups
=
[
dict
(
params
=
nodecay_params
,
weight_decay
=
0.0
),
dict
(
params
=
decay_params
,
weight_decay
=
training_args
.
weight_decay
),
dict
(
params
=
galore_params
,
weight_decay
=
training_args
.
weight_decay
,
**
galore_kwargs
),
]
optimizer
=
optim_class
(
param_groups
,
**
optim_kwargs
)
logger
.
info_rank0
(
f
"Using GaLore optimizer with args:
{
galore_kwargs
}
. "
"It may cause hanging at the start of training, wait patiently."
)
return
optimizer
def
_create_apollo_optimizer
(
model
:
"PreTrainedModel"
,
training_args
:
"TrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
"torch.optim.Optimizer"
:
if
len
(
finetuning_args
.
apollo_target
)
==
1
and
finetuning_args
.
apollo_target
[
0
]
==
"all"
:
apollo_targets
=
find_all_linear_modules
(
model
,
finetuning_args
.
freeze_vision_tower
)
else
:
apollo_targets
=
finetuning_args
.
apollo_target
apollo_params
:
list
[
torch
.
nn
.
Parameter
]
=
[]
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
torch
.
nn
.
Linear
)
and
any
(
target
in
name
for
target
in
apollo_targets
):
for
param
in
module
.
parameters
():
if
param
.
requires_grad
and
len
(
param
.
shape
)
>
1
:
apollo_params
.
append
(
param
)
apollo_kwargs
=
{
"rank"
:
finetuning_args
.
apollo_rank
,
"proj"
:
finetuning_args
.
apollo_proj
,
"proj_type"
:
finetuning_args
.
apollo_proj_type
,
"update_proj_gap"
:
finetuning_args
.
apollo_update_interval
,
"scale"
:
finetuning_args
.
apollo_scale
,
"scale_type"
:
finetuning_args
.
apollo_scale_type
,
"scale_front"
:
finetuning_args
.
apollo_scale_front
,
}
id_apollo_params
=
{
id
(
param
)
for
param
in
apollo_params
}
decay_params
,
nodecay_params
=
[],
[]
# they are non-apollo parameters
trainable_params
:
list
[
torch
.
nn
.
Parameter
]
=
[]
# apollo_params + decay_params + nodecay_params
decay_param_names
=
_get_decay_parameter_names
(
model
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
trainable_params
.
append
(
param
)
if
id
(
param
)
not
in
id_apollo_params
:
if
name
in
decay_param_names
:
decay_params
.
append
(
param
)
else
:
nodecay_params
.
append
(
param
)
_
,
optim_kwargs
=
Trainer
.
get_optimizer_cls_and_kwargs
(
training_args
)
if
training_args
.
optim
==
"adamw_torch"
:
optim_class
=
APOLLOAdamW
else
:
raise
NotImplementedError
(
f
"Unknown optim:
{
training_args
.
optim
}
."
)
if
finetuning_args
.
apollo_layerwise
:
logger
.
warning_rank0
(
"The displayed gradient norm will be all zeros in layerwise APOLLO."
)
if
training_args
.
gradient_accumulation_steps
!=
1
:
raise
ValueError
(
"Per-layer APOLLO does not support gradient accumulation."
)
optimizer_dict
:
dict
[
torch
.
Tensor
,
torch
.
optim
.
Optimizer
]
=
{}
for
param
in
nodecay_params
:
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
0.0
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
for
param
in
decay_params
:
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
training_args
.
weight_decay
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
for
param
in
apollo_params
:
# apollo params have weight decay
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
training_args
.
weight_decay
,
**
apollo_kwargs
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
def
optimizer_hook
(
param
:
"torch.nn.Parameter"
):
if
param
.
grad
is
not
None
:
optimizer_dict
[
param
].
step
()
optimizer_dict
[
param
].
zero_grad
()
for
param
in
trainable_params
:
param
.
register_post_accumulate_grad_hook
(
optimizer_hook
)
optimizer
=
DummyOptimizer
(
lr
=
training_args
.
learning_rate
,
optimizer_dict
=
optimizer_dict
)
else
:
param_groups
=
[
dict
(
params
=
nodecay_params
,
weight_decay
=
0.0
),
dict
(
params
=
decay_params
,
weight_decay
=
training_args
.
weight_decay
),
dict
(
params
=
apollo_params
,
weight_decay
=
training_args
.
weight_decay
,
**
apollo_kwargs
),
]
optimizer
=
optim_class
(
param_groups
,
**
optim_kwargs
)
logger
.
info_rank0
(
f
"Using APOLLO optimizer with args:
{
apollo_kwargs
}
."
)
return
optimizer
def
_create_loraplus_optimizer
(
model
:
"PreTrainedModel"
,
training_args
:
"TrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
"torch.optim.Optimizer"
:
default_lr
=
training_args
.
learning_rate
loraplus_lr
=
training_args
.
learning_rate
*
finetuning_args
.
loraplus_lr_ratio
embedding_lr
=
finetuning_args
.
loraplus_lr_embedding
decay_param_names
=
_get_decay_parameter_names
(
model
)
param_dict
:
dict
[
str
,
list
[
torch
.
nn
.
Parameter
]]
=
{
"lora_a"
:
[],
"lora_b"
:
[],
"lora_b_nodecay"
:
[],
"embedding"
:
[],
}
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
if
"lora_embedding_B"
in
name
:
param_dict
[
"embedding"
].
append
(
param
)
elif
"lora_B"
in
name
or
param
.
ndim
==
1
:
if
name
in
decay_param_names
:
param_dict
[
"lora_b"
].
append
(
param
)
else
:
param_dict
[
"lora_b_nodecay"
].
append
(
param
)
else
:
param_dict
[
"lora_a"
].
append
(
param
)
optim_class
,
optim_kwargs
=
Trainer
.
get_optimizer_cls_and_kwargs
(
training_args
)
param_groups
=
[
dict
(
params
=
param_dict
[
"lora_a"
],
lr
=
default_lr
,
weight_decay
=
training_args
.
weight_decay
),
dict
(
params
=
param_dict
[
"lora_b"
],
lr
=
loraplus_lr
,
weight_decay
=
training_args
.
weight_decay
),
dict
(
params
=
param_dict
[
"lora_b_nodecay"
],
lr
=
loraplus_lr
,
weight_decay
=
0.0
),
dict
(
params
=
param_dict
[
"embedding"
],
lr
=
embedding_lr
,
weight_decay
=
training_args
.
weight_decay
),
]
optimizer
=
optim_class
(
param_groups
,
**
optim_kwargs
)
logger
.
info_rank0
(
f
"Using LoRA+ optimizer with loraplus lr ratio
{
finetuning_args
.
loraplus_lr_ratio
:.
2
f
}
."
)
return
optimizer
def
_create_badam_optimizer
(
model
:
"PreTrainedModel"
,
training_args
:
"TrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
"torch.optim.Optimizer"
:
decay_params
,
nodecay_params
=
[],
[]
decay_param_names
=
_get_decay_parameter_names
(
model
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
if
name
in
decay_param_names
:
decay_params
.
append
(
param
)
else
:
nodecay_params
.
append
(
param
)
optim_class
,
optim_kwargs
=
Trainer
.
get_optimizer_cls_and_kwargs
(
training_args
)
param_groups
=
[
dict
(
params
=
nodecay_params
,
weight_decay
=
0.0
),
dict
(
params
=
decay_params
,
weight_decay
=
training_args
.
weight_decay
),
]
if
finetuning_args
.
badam_mode
==
"layer"
:
from
badam
import
BlockOptimizer
# type: ignore
base_optimizer
=
optim_class
(
param_groups
,
**
optim_kwargs
)
optimizer
=
BlockOptimizer
(
base_optimizer
=
base_optimizer
,
named_parameters_list
=
list
(
model
.
named_parameters
()),
block_prefix_list
=
None
,
switch_block_every
=
finetuning_args
.
badam_switch_interval
,
start_block
=
finetuning_args
.
badam_start_block
,
switch_mode
=
finetuning_args
.
badam_switch_mode
,
verbose
=
finetuning_args
.
badam_verbose
,
ds_zero3_enabled
=
is_deepspeed_zero3_enabled
(),
)
logger
.
info_rank0
(
f
"Using BAdam optimizer with layer-wise update, switch mode is
{
finetuning_args
.
badam_switch_mode
}
, "
f
"switch block every
{
finetuning_args
.
badam_switch_interval
}
steps, "
f
"default start block is
{
finetuning_args
.
badam_start_block
}
"
)
elif
finetuning_args
.
badam_mode
==
"ratio"
:
from
badam
import
BlockOptimizerRatio
# type: ignore
assert
finetuning_args
.
badam_update_ratio
>
1e-6
optimizer
=
BlockOptimizerRatio
(
param_groups
=
param_groups
,
named_parameters_list
=
list
(
model
.
named_parameters
()),
update_ratio
=
finetuning_args
.
badam_update_ratio
,
mask_mode
=
finetuning_args
.
badam_mask_mode
,
verbose
=
finetuning_args
.
badam_verbose
,
include_embedding
=
False
,
**
optim_kwargs
,
)
logger
.
info_rank0
(
f
"Using BAdam optimizer with ratio-based update, update ratio is
{
finetuning_args
.
badam_update_ratio
}
, "
f
"mask mode is
{
finetuning_args
.
badam_mask_mode
}
"
)
return
optimizer
def
_create_adam_mini_optimizer
(
model
:
"PreTrainedModel"
,
training_args
:
"TrainingArguments"
,
)
->
"torch.optim.Optimizer"
:
from
adam_mini
import
Adam_mini
# type: ignore
hidden_size
=
getattr
(
model
.
config
,
"hidden_size"
,
None
)
num_q_head
=
getattr
(
model
.
config
,
"num_attention_heads"
,
None
)
num_kv_head
=
getattr
(
model
.
config
,
"num_key_value_heads"
,
None
)
optimizer
=
Adam_mini
(
named_parameters
=
model
.
named_parameters
(),
lr
=
training_args
.
learning_rate
,
betas
=
(
training_args
.
adam_beta1
,
training_args
.
adam_beta2
),
eps
=
training_args
.
adam_epsilon
,
weight_decay
=
training_args
.
weight_decay
,
model_sharding
=
is_fsdp_enabled
()
or
is_deepspeed_zero3_enabled
(),
dim
=
hidden_size
,
n_heads
=
num_q_head
,
n_kv_heads
=
num_kv_head
,
)
logger
.
info_rank0
(
"Using Adam-mini optimizer."
)
return
optimizer
def
_create_muon_optimizer
(
model
:
"PreTrainedModel"
,
training_args
:
"TrainingArguments"
,
)
->
"torch.optim.Optimizer"
:
from
..third_party.muon
import
Muon
muon_params
,
adamw_params
=
[],
[]
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
# Use Muon for 2D parameters that aren't embeddings or heads
if
param
.
ndim
==
2
and
"embed"
not
in
name
and
"lm_head"
not
in
name
:
muon_params
.
append
(
param
)
else
:
adamw_params
.
append
(
param
)
optimizer
=
Muon
(
lr
=
training_args
.
learning_rate
,
wd
=
training_args
.
weight_decay
,
muon_params
=
muon_params
,
adamw_params
=
adamw_params
,
adamw_betas
=
(
training_args
.
adam_beta1
,
training_args
.
adam_beta2
),
adamw_eps
=
training_args
.
adam_epsilon
,
)
logger
.
info_rank0
(
f
"Using Muon optimizer with
{
len
(
muon_params
)
}
Muon params and
{
len
(
adamw_params
)
}
AdamW params."
)
return
optimizer
def
create_custom_optimizer
(
model
:
"PreTrainedModel"
,
training_args
:
"TrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
Optional
[
"torch.optim.Optimizer"
]:
if
finetuning_args
.
use_galore
:
return
_create_galore_optimizer
(
model
,
training_args
,
finetuning_args
)
if
finetuning_args
.
use_apollo
:
return
_create_apollo_optimizer
(
model
,
training_args
,
finetuning_args
)
if
finetuning_args
.
loraplus_lr_ratio
is
not
None
:
return
_create_loraplus_optimizer
(
model
,
training_args
,
finetuning_args
)
if
finetuning_args
.
use_badam
:
return
_create_badam_optimizer
(
model
,
training_args
,
finetuning_args
)
if
finetuning_args
.
use_adam_mini
:
return
_create_adam_mini_optimizer
(
model
,
training_args
)
if
finetuning_args
.
use_muon
:
return
_create_muon_optimizer
(
model
,
training_args
)
def
create_custom_scheduler
(
training_args
:
"TrainingArguments"
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
,
)
->
None
:
if
training_args
.
lr_scheduler_type
==
"warmup_stable_decay"
:
num_warmup_steps
=
training_args
.
get_warmup_steps
(
num_training_steps
)
remaining_steps
=
num_training_steps
-
num_warmup_steps
num_stable_steps
=
remaining_steps
//
3
# use 1/3 for stable by default
num_decay_steps
=
remaining_steps
-
num_stable_steps
scheduler_kwargs
=
training_args
.
lr_scheduler_kwargs
or
{}
default_kwargs
=
{
"num_stable_steps"
:
num_stable_steps
,
"num_decay_steps"
:
num_decay_steps
,
}
for
key
,
value
in
default_kwargs
.
items
():
if
key
not
in
scheduler_kwargs
:
scheduler_kwargs
[
key
]
=
value
training_args
.
lr_scheduler_kwargs
=
scheduler_kwargs
if
optimizer
is
not
None
and
isinstance
(
optimizer
,
DummyOptimizer
):
optimizer_dict
=
optimizer
.
optimizer_dict
scheduler_dict
:
dict
[
torch
.
nn
.
Parameter
,
torch
.
optim
.
lr_scheduler
.
LRScheduler
]
=
{}
for
param
in
optimizer_dict
.
keys
():
scheduler_dict
[
param
]
=
get_scheduler
(
training_args
.
lr_scheduler_type
,
optimizer
=
optimizer_dict
[
param
],
num_warmup_steps
=
training_args
.
get_warmup_steps
(
num_training_steps
),
num_training_steps
=
num_training_steps
,
scheduler_specific_kwargs
=
training_args
.
lr_scheduler_kwargs
,
)
def
scheduler_hook
(
param
:
"torch.nn.Parameter"
):
scheduler_dict
[
param
].
step
()
for
param
in
optimizer_dict
.
keys
():
param
.
register_post_accumulate_grad_hook
(
scheduler_hook
)
def
get_batch_logps
(
logits
:
"torch.Tensor"
,
labels
:
"torch.Tensor"
,
label_pad_token_id
:
int
=
IGNORE_INDEX
,
ld_alpha
:
Optional
[
float
]
=
None
,
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""Compute the log probabilities of the given labels under the given logits.
Returns:
logps: A tensor of shape (batch_size,) containing the sum of log probabilities.
valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens.
"""
if
logits
.
shape
[:
-
1
]
!=
labels
.
shape
:
raise
ValueError
(
"Logits (batchsize x seqlen) and labels must have the same shape."
)
labels
=
labels
[:,
1
:].
clone
()
logits
=
logits
[:,
:
-
1
,
:]
loss_mask
=
labels
!=
label_pad_token_id
labels
[
labels
==
label_pad_token_id
]
=
0
# dummy token
per_token_logps
=
torch
.
gather
(
logits
.
log_softmax
(
-
1
),
dim
=
2
,
index
=
labels
.
unsqueeze
(
2
)).
squeeze
(
2
)
valid_length
=
loss_mask
.
sum
(
-
1
)
if
ld_alpha
is
not
None
:
num_examples
=
labels
.
shape
[
0
]
//
2
chosen_lengths
=
valid_length
[:
num_examples
]
rejected_lengths
=
valid_length
[
num_examples
:]
min_lengths
=
torch
.
min
(
chosen_lengths
,
rejected_lengths
)
start_positions
=
torch
.
argmax
(
loss_mask
.
int
(),
dim
=
1
)
public_lengths
=
start_positions
+
torch
.
cat
([
min_lengths
,
min_lengths
],
dim
=
0
)
seq_len
=
labels
.
shape
[
-
1
]
position_ids
=
torch
.
arange
(
seq_len
,
device
=
per_token_logps
.
device
).
expand_as
(
per_token_logps
)
ld_mask
=
position_ids
<
public_lengths
.
unsqueeze
(
1
)
front_mask
=
(
ld_mask
*
loss_mask
).
float
()
rear_mask
=
(
~
ld_mask
*
loss_mask
).
float
()
front_logps
=
(
per_token_logps
*
front_mask
).
sum
(
-
1
)
rear_logps
=
(
per_token_logps
*
rear_mask
).
sum
(
-
1
)
logps
=
front_logps
+
ld_alpha
*
rear_logps
else
:
logps
=
(
per_token_logps
*
loss_mask
).
sum
(
-
1
)
return
logps
,
valid_length
def
nested_detach
(
tensors
:
Union
[
"torch.Tensor"
,
list
[
"torch.Tensor"
],
tuple
[
"torch.Tensor"
],
dict
[
str
,
"torch.Tensor"
]],
clone
:
bool
=
False
,
):
r
"""Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."""
if
isinstance
(
tensors
,
(
list
,
tuple
)):
return
type
(
tensors
)(
nested_detach
(
t
,
clone
=
clone
)
for
t
in
tensors
)
elif
isinstance
(
tensors
,
Mapping
):
return
type
(
tensors
)({
k
:
nested_detach
(
t
,
clone
=
clone
)
for
k
,
t
in
tensors
.
items
()})
if
isinstance
(
tensors
,
torch
.
Tensor
):
if
clone
:
return
tensors
.
detach
().
clone
()
else
:
return
tensors
.
detach
()
else
:
return
tensors
def
get_swanlab_callback
(
finetuning_args
:
"FinetuningArguments"
)
->
"TrainerCallback"
:
r
"""Get the callback for logging to SwanLab."""
import
swanlab
# type: ignore
from
swanlab.integration.transformers
import
SwanLabCallback
# type: ignore
if
finetuning_args
.
swanlab_api_key
is
not
None
:
swanlab
.
login
(
api_key
=
finetuning_args
.
swanlab_api_key
)
if
finetuning_args
.
swanlab_lark_webhook_url
is
not
None
:
from
swanlab.plugin.notification
import
LarkCallback
# type: ignore
lark_callback
=
LarkCallback
(
webhook_url
=
finetuning_args
.
swanlab_lark_webhook_url
,
secret
=
finetuning_args
.
swanlab_lark_secret
,
)
swanlab
.
register_callbacks
([
lark_callback
])
class
SwanLabCallbackExtension
(
SwanLabCallback
):
def
setup
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
model
:
"PreTrainedModel"
,
**
kwargs
):
if
not
state
.
is_world_process_zero
:
return
super
().
setup
(
args
,
state
,
model
,
**
kwargs
)
try
:
if
hasattr
(
self
,
"_swanlab"
):
swanlab_public_config
=
self
.
_swanlab
.
get_run
().
public
.
json
()
else
:
# swanlab <= 0.4.9
swanlab_public_config
=
self
.
_experiment
.
get_run
().
public
.
json
()
except
Exception
:
swanlab_public_config
=
{}
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
SWANLAB_CONFIG
),
"w"
)
as
f
:
f
.
write
(
json
.
dumps
(
swanlab_public_config
,
indent
=
2
))
swanlab_callback
=
SwanLabCallbackExtension
(
project
=
finetuning_args
.
swanlab_project
,
workspace
=
finetuning_args
.
swanlab_workspace
,
experiment_name
=
finetuning_args
.
swanlab_run_name
,
mode
=
finetuning_args
.
swanlab_mode
,
config
=
{
"Framework"
:
"🦙LlamaFactory"
},
logdir
=
finetuning_args
.
swanlab_logdir
,
tags
=
[
"🦙LlamaFactory"
],
)
return
swanlab_callback
def
get_ray_trainer
(
training_function
:
Callable
,
train_loop_config
:
dict
[
str
,
Any
],
ray_args
:
"RayArguments"
,
)
->
"TorchTrainer"
:
if
not
ray_args
.
use_ray
:
raise
ValueError
(
"Ray was not enabled. Please set `USE_RAY=1` to enable ray."
)
if
ray_args
.
ray_init_kwargs
is
not
None
:
ray
.
init
(
**
ray_args
.
ray_init_kwargs
)
if
ray_args
.
ray_storage_filesystem
is
not
None
:
# this means we are using s3/gcs
storage_path
=
ray_args
.
ray_storage_path
else
:
storage_path
=
Path
(
ray_args
.
ray_storage_path
).
absolute
().
as_posix
()
trainer
=
TorchTrainer
(
training_function
,
train_loop_config
=
train_loop_config
,
scaling_config
=
ScalingConfig
(
num_workers
=
ray_args
.
ray_num_workers
,
resources_per_worker
=
ray_args
.
resources_per_worker
,
placement_strategy
=
ray_args
.
placement_strategy
,
use_gpu
=
True
,
),
run_config
=
RunConfig
(
name
=
ray_args
.
ray_run_name
,
storage_filesystem
=
ray_args
.
ray_storage_filesystem
,
storage_path
=
storage_path
,
),
)
return
trainer
src/llamafactory/train/tuner.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
shutil
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
import
torch.distributed
as
dist
from
transformers
import
EarlyStoppingCallback
,
PreTrainedModel
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras.constants
import
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
..extras.misc
import
infer_optim_dtype
from
..extras.packages
import
is_ray_available
from
..hparams
import
get_infer_args
,
get_ray_args
,
get_train_args
,
read_args
from
..model
import
load_model
,
load_tokenizer
from
.callbacks
import
LogCallback
,
PissaConvertCallback
,
ReporterCallback
from
.dpo
import
run_dpo
from
.kto
import
run_kto
from
.pt
import
run_pt
from
.rm
import
run_rm
from
.sft
import
run_sft
from
.grpo
import
run_grpo
from
.trainer_utils
import
get_ray_trainer
,
get_swanlab_callback
if
is_ray_available
():
import
ray
from
ray.train.huggingface.transformers
import
RayTrainReportCallback
if
TYPE_CHECKING
:
from
transformers
import
TrainerCallback
logger
=
logging
.
get_logger
(
__name__
)
def
_training_function
(
config
:
dict
[
str
,
Any
])
->
None
:
args
=
config
.
get
(
"args"
)
callbacks
:
list
[
Any
]
=
config
.
get
(
"callbacks"
)
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
=
get_train_args
(
args
)
callbacks
.
append
(
LogCallback
())
if
finetuning_args
.
pissa_convert
:
callbacks
.
append
(
PissaConvertCallback
())
if
finetuning_args
.
use_swanlab
:
callbacks
.
append
(
get_swanlab_callback
(
finetuning_args
))
if
finetuning_args
.
early_stopping_steps
is
not
None
:
callbacks
.
append
(
EarlyStoppingCallback
(
early_stopping_patience
=
finetuning_args
.
early_stopping_steps
))
callbacks
.
append
(
ReporterCallback
(
model_args
,
data_args
,
finetuning_args
,
generating_args
))
# add to last
if
finetuning_args
.
stage
==
"pt"
:
run_pt
(
model_args
,
data_args
,
training_args
,
finetuning_args
,
callbacks
)
elif
finetuning_args
.
stage
==
"sft"
:
run_sft
(
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
,
callbacks
)
elif
finetuning_args
.
stage
==
"rm"
:
run_rm
(
model_args
,
data_args
,
training_args
,
finetuning_args
,
callbacks
)
elif
finetuning_args
.
stage
==
"grpo"
:
run_grpo
(
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
,
callbacks
)
elif
finetuning_args
.
stage
==
"dpo"
:
run_dpo
(
model_args
,
data_args
,
training_args
,
finetuning_args
,
callbacks
)
elif
finetuning_args
.
stage
==
"kto"
:
run_kto
(
model_args
,
data_args
,
training_args
,
finetuning_args
,
callbacks
)
else
:
raise
ValueError
(
f
"Unknown task:
{
finetuning_args
.
stage
}
."
)
if
is_ray_available
()
and
ray
.
is_initialized
():
return
# if ray is intialized it will destroy the process group on return
try
:
if
dist
.
is_initialized
():
dist
.
destroy_process_group
()
except
Exception
as
e
:
logger
.
warning
(
f
"Failed to destroy process group:
{
e
}
."
)
def
run_exp
(
args
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
callbacks
:
Optional
[
list
[
"TrainerCallback"
]]
=
None
)
->
None
:
args
=
read_args
(
args
)
if
"-h"
in
args
or
"--help"
in
args
:
get_train_args
(
args
)
ray_args
=
get_ray_args
(
args
)
callbacks
=
callbacks
or
[]
if
ray_args
.
use_ray
:
callbacks
.
append
(
RayTrainReportCallback
())
trainer
=
get_ray_trainer
(
training_function
=
_training_function
,
train_loop_config
=
{
"args"
:
args
,
"callbacks"
:
callbacks
},
ray_args
=
ray_args
,
)
trainer
.
fit
()
else
:
_training_function
(
config
=
{
"args"
:
args
,
"callbacks"
:
callbacks
})
def
export_model
(
args
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
None
:
model_args
,
data_args
,
finetuning_args
,
_
=
get_infer_args
(
args
)
if
model_args
.
export_dir
is
None
:
raise
ValueError
(
"Please specify `export_dir` to save model."
)
if
model_args
.
adapter_name_or_path
is
not
None
and
model_args
.
export_quantization_bit
is
not
None
:
raise
ValueError
(
"Please merge adapters before quantizing the model."
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
processor
=
tokenizer_module
[
"processor"
]
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
)
# must after fixing tokenizer to resize vocab
if
getattr
(
model
,
"quantization_method"
,
None
)
is
not
None
and
model_args
.
adapter_name_or_path
is
not
None
:
raise
ValueError
(
"Cannot merge adapters to a quantized model."
)
if
not
isinstance
(
model
,
PreTrainedModel
):
raise
ValueError
(
"The model is not a `PreTrainedModel`, export aborted."
)
if
getattr
(
model
,
"quantization_method"
,
None
)
is
not
None
:
# quantized model adopts float16 type
setattr
(
model
.
config
,
"torch_dtype"
,
torch
.
float16
)
else
:
if
model_args
.
infer_dtype
==
"auto"
:
output_dtype
=
getattr
(
model
.
config
,
"torch_dtype"
,
torch
.
float32
)
if
output_dtype
==
torch
.
float32
:
# if infer_dtype is auto, try using half precision first
output_dtype
=
infer_optim_dtype
(
torch
.
bfloat16
)
else
:
output_dtype
=
getattr
(
torch
,
model_args
.
infer_dtype
)
setattr
(
model
.
config
,
"torch_dtype"
,
output_dtype
)
model
=
model
.
to
(
output_dtype
)
logger
.
info_rank0
(
f
"Convert model dtype to:
{
output_dtype
}
."
)
model
.
save_pretrained
(
save_directory
=
model_args
.
export_dir
,
max_shard_size
=
f
"
{
model_args
.
export_size
}
GB"
,
safe_serialization
=
(
not
model_args
.
export_legacy_format
),
)
if
model_args
.
export_hub_model_id
is
not
None
:
model
.
push_to_hub
(
model_args
.
export_hub_model_id
,
token
=
model_args
.
hf_hub_token
,
max_shard_size
=
f
"
{
model_args
.
export_size
}
GB"
,
safe_serialization
=
(
not
model_args
.
export_legacy_format
),
)
if
finetuning_args
.
stage
==
"rm"
:
if
model_args
.
adapter_name_or_path
is
not
None
:
vhead_path
=
model_args
.
adapter_name_or_path
[
-
1
]
else
:
vhead_path
=
model_args
.
model_name_or_path
if
os
.
path
.
exists
(
os
.
path
.
join
(
vhead_path
,
V_HEAD_SAFE_WEIGHTS_NAME
)):
shutil
.
copy
(
os
.
path
.
join
(
vhead_path
,
V_HEAD_SAFE_WEIGHTS_NAME
),
os
.
path
.
join
(
model_args
.
export_dir
,
V_HEAD_SAFE_WEIGHTS_NAME
),
)
logger
.
info_rank0
(
f
"Copied valuehead to
{
model_args
.
export_dir
}
."
)
elif
os
.
path
.
exists
(
os
.
path
.
join
(
vhead_path
,
V_HEAD_WEIGHTS_NAME
)):
shutil
.
copy
(
os
.
path
.
join
(
vhead_path
,
V_HEAD_WEIGHTS_NAME
),
os
.
path
.
join
(
model_args
.
export_dir
,
V_HEAD_WEIGHTS_NAME
),
)
logger
.
info_rank0
(
f
"Copied valuehead to
{
model_args
.
export_dir
}
."
)
try
:
tokenizer
.
padding_side
=
"left"
# restore padding side
tokenizer
.
init_kwargs
[
"padding_side"
]
=
"left"
tokenizer
.
save_pretrained
(
model_args
.
export_dir
)
if
model_args
.
export_hub_model_id
is
not
None
:
tokenizer
.
push_to_hub
(
model_args
.
export_hub_model_id
,
token
=
model_args
.
hf_hub_token
)
if
processor
is
not
None
:
processor
.
save_pretrained
(
model_args
.
export_dir
)
if
model_args
.
export_hub_model_id
is
not
None
:
processor
.
push_to_hub
(
model_args
.
export_hub_model_id
,
token
=
model_args
.
hf_hub_token
)
except
Exception
as
e
:
logger
.
warning_rank0
(
f
"Cannot save tokenizer, please copy the files manually:
{
e
}
."
)
ollama_modelfile
=
os
.
path
.
join
(
model_args
.
export_dir
,
"Modelfile"
)
with
open
(
ollama_modelfile
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
template
.
get_ollama_modelfile
(
tokenizer
))
logger
.
info_rank0
(
f
"Ollama modelfile saved in
{
ollama_modelfile
}
"
)
src/llamafactory/webui/__init__.py
0 → 100644
View file @
c7c477c7
src/llamafactory/webui/chatter.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
from
collections.abc
import
Generator
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
transformers.utils
import
is_torch_npu_available
from
..chat
import
ChatModel
from
..data
import
Role
from
..extras.constants
import
PEFT_METHODS
from
..extras.misc
import
torch_gc
from
..extras.packages
import
is_gradio_available
from
.common
import
get_save_dir
,
load_config
from
.locales
import
ALERTS
if
TYPE_CHECKING
:
from
..chat
import
BaseEngine
from
.manager
import
Manager
if
is_gradio_available
():
import
gradio
as
gr
def
_escape_html
(
text
:
str
)
->
str
:
r
"""Escape HTML characters."""
return
text
.
replace
(
"<"
,
"<"
).
replace
(
">"
,
">"
)
def
_format_response
(
text
:
str
,
lang
:
str
,
escape_html
:
bool
,
thought_words
:
tuple
[
str
,
str
])
->
str
:
r
"""Post-process the response text.
Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
"""
if
thought_words
[
0
]
not
in
text
:
return
_escape_html
(
text
)
if
escape_html
else
text
text
=
text
.
replace
(
thought_words
[
0
],
""
)
result
=
text
.
split
(
thought_words
[
1
],
maxsplit
=
1
)
if
len
(
result
)
==
1
:
summary
=
ALERTS
[
"info_thinking"
][
lang
]
thought
,
answer
=
text
,
""
else
:
summary
=
ALERTS
[
"info_thought"
][
lang
]
thought
,
answer
=
result
if
escape_html
:
thought
,
answer
=
_escape_html
(
thought
),
_escape_html
(
answer
)
return
(
f
"<details open><summary class='thinking-summary'><span>
{
summary
}
</span></summary>
\n\n
"
f
"<div class='thinking-container'>
\n
{
thought
}
\n
</div>
\n
</details>
{
answer
}
"
)
@
contextmanager
def
update_attr
(
obj
:
Any
,
name
:
str
,
value
:
Any
):
old_value
=
getattr
(
obj
,
name
,
None
)
setattr
(
obj
,
name
,
value
)
yield
setattr
(
obj
,
name
,
old_value
)
class
WebChatModel
(
ChatModel
):
def
__init__
(
self
,
manager
:
"Manager"
,
demo_mode
:
bool
=
False
,
lazy_init
:
bool
=
True
)
->
None
:
self
.
manager
=
manager
self
.
demo_mode
=
demo_mode
self
.
engine
:
Optional
[
BaseEngine
]
=
None
if
not
lazy_init
:
# read arguments from command line
super
().
__init__
()
if
demo_mode
and
os
.
getenv
(
"DEMO_MODEL"
)
and
os
.
getenv
(
"DEMO_TEMPLATE"
):
# load demo model
model_name_or_path
=
os
.
getenv
(
"DEMO_MODEL"
)
template
=
os
.
getenv
(
"DEMO_TEMPLATE"
)
infer_backend
=
os
.
getenv
(
"DEMO_BACKEND"
,
"huggingface"
)
super
().
__init__
(
dict
(
model_name_or_path
=
model_name_or_path
,
template
=
template
,
infer_backend
=
infer_backend
)
)
@
property
def
loaded
(
self
)
->
bool
:
return
self
.
engine
is
not
None
def
load_model
(
self
,
data
)
->
Generator
[
str
,
None
,
None
]:
get
=
lambda
elem_id
:
data
[
self
.
manager
.
get_elem_by_id
(
elem_id
)]
lang
,
model_name
,
model_path
=
get
(
"top.lang"
),
get
(
"top.model_name"
),
get
(
"top.model_path"
)
finetuning_type
,
checkpoint_path
=
get
(
"top.finetuning_type"
),
get
(
"top.checkpoint_path"
)
user_config
=
load_config
()
error
=
""
if
self
.
loaded
:
error
=
ALERTS
[
"err_exists"
][
lang
]
elif
not
model_name
:
error
=
ALERTS
[
"err_no_model"
][
lang
]
elif
not
model_path
:
error
=
ALERTS
[
"err_no_path"
][
lang
]
elif
self
.
demo_mode
:
error
=
ALERTS
[
"err_demo"
][
lang
]
try
:
json
.
loads
(
get
(
"infer.extra_args"
))
except
json
.
JSONDecodeError
:
error
=
ALERTS
[
"err_json_schema"
][
lang
]
if
error
:
gr
.
Warning
(
error
)
yield
error
return
yield
ALERTS
[
"info_loading"
][
lang
]
args
=
dict
(
model_name_or_path
=
model_path
,
cache_dir
=
user_config
.
get
(
"cache_dir"
,
None
),
finetuning_type
=
finetuning_type
,
template
=
get
(
"top.template"
),
rope_scaling
=
get
(
"top.rope_scaling"
)
if
get
(
"top.rope_scaling"
)
!=
"none"
else
None
,
flash_attn
=
"fa2"
if
get
(
"top.booster"
)
==
"flashattn2"
else
"auto"
,
use_unsloth
=
(
get
(
"top.booster"
)
==
"unsloth"
),
enable_liger_kernel
=
(
get
(
"top.booster"
)
==
"liger_kernel"
),
infer_backend
=
get
(
"infer.infer_backend"
),
infer_dtype
=
get
(
"infer.infer_dtype"
),
trust_remote_code
=
True
,
)
args
.
update
(
json
.
loads
(
get
(
"infer.extra_args"
)))
# checkpoints
if
checkpoint_path
:
if
finetuning_type
in
PEFT_METHODS
:
# list
args
[
"adapter_name_or_path"
]
=
","
.
join
(
[
get_save_dir
(
model_name
,
finetuning_type
,
adapter
)
for
adapter
in
checkpoint_path
]
)
else
:
# str
args
[
"model_name_or_path"
]
=
get_save_dir
(
model_name
,
finetuning_type
,
checkpoint_path
)
# quantization
if
get
(
"top.quantization_bit"
)
!=
"none"
:
args
[
"quantization_bit"
]
=
int
(
get
(
"top.quantization_bit"
))
args
[
"quantization_method"
]
=
get
(
"top.quantization_method"
)
args
[
"double_quantization"
]
=
not
is_torch_npu_available
()
super
().
__init__
(
args
)
yield
ALERTS
[
"info_loaded"
][
lang
]
def
unload_model
(
self
,
data
)
->
Generator
[
str
,
None
,
None
]:
lang
=
data
[
self
.
manager
.
get_elem_by_id
(
"top.lang"
)]
if
self
.
demo_mode
:
gr
.
Warning
(
ALERTS
[
"err_demo"
][
lang
])
yield
ALERTS
[
"err_demo"
][
lang
]
return
yield
ALERTS
[
"info_unloading"
][
lang
]
self
.
engine
=
None
torch_gc
()
yield
ALERTS
[
"info_unloaded"
][
lang
]
@
staticmethod
def
append
(
chatbot
:
list
[
dict
[
str
,
str
]],
messages
:
list
[
dict
[
str
,
str
]],
role
:
str
,
query
:
str
,
escape_html
:
bool
,
)
->
tuple
[
list
[
dict
[
str
,
str
]],
list
[
dict
[
str
,
str
]],
str
]:
r
"""Add the user input to chatbot.
Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html
Output: infer.chatbot, infer.messages, infer.query
"""
return
(
chatbot
+
[{
"role"
:
"user"
,
"content"
:
_escape_html
(
query
)
if
escape_html
else
query
}],
messages
+
[{
"role"
:
role
,
"content"
:
query
}],
""
,
)
def
stream
(
self
,
chatbot
:
list
[
dict
[
str
,
str
]],
messages
:
list
[
dict
[
str
,
str
]],
lang
:
str
,
system
:
str
,
tools
:
str
,
image
:
Optional
[
Any
],
video
:
Optional
[
Any
],
audio
:
Optional
[
Any
],
max_new_tokens
:
int
,
top_p
:
float
,
temperature
:
float
,
skip_special_tokens
:
bool
,
escape_html
:
bool
,
enable_thinking
:
bool
,
)
->
Generator
[
tuple
[
list
[
dict
[
str
,
str
]],
list
[
dict
[
str
,
str
]]],
None
,
None
]:
r
"""Generate output text in stream.
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages
"""
with
update_attr
(
self
.
engine
.
template
,
"enable_thinking"
,
enable_thinking
):
chatbot
.
append
({
"role"
:
"assistant"
,
"content"
:
""
})
response
=
""
for
new_text
in
self
.
stream_chat
(
messages
,
system
,
tools
,
images
=
[
image
]
if
image
else
None
,
videos
=
[
video
]
if
video
else
None
,
audios
=
[
audio
]
if
audio
else
None
,
max_new_tokens
=
max_new_tokens
,
top_p
=
top_p
,
temperature
=
temperature
,
skip_special_tokens
=
skip_special_tokens
,
):
response
+=
new_text
if
tools
:
result
=
self
.
engine
.
template
.
extract_tool
(
response
)
else
:
result
=
response
if
isinstance
(
result
,
list
):
tool_calls
=
[{
"name"
:
tool
.
name
,
"arguments"
:
json
.
loads
(
tool
.
arguments
)}
for
tool
in
result
]
tool_calls
=
json
.
dumps
(
tool_calls
,
ensure_ascii
=
False
)
output_messages
=
messages
+
[{
"role"
:
Role
.
FUNCTION
.
value
,
"content"
:
tool_calls
}]
bot_text
=
"```json
\n
"
+
tool_calls
+
"
\n
```"
else
:
output_messages
=
messages
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
result
}]
bot_text
=
_format_response
(
result
,
lang
,
escape_html
,
self
.
engine
.
template
.
thought_words
)
chatbot
[
-
1
]
=
{
"role"
:
"assistant"
,
"content"
:
bot_text
}
yield
chatbot
,
output_messages
src/llamafactory/webui/common.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
import
signal
from
collections
import
defaultdict
from
datetime
import
datetime
from
typing
import
Any
,
Optional
,
Union
from
psutil
import
Process
from
yaml
import
safe_dump
,
safe_load
from
..extras
import
logging
from
..extras.constants
import
(
DATA_CONFIG
,
DEFAULT_TEMPLATE
,
MULTIMODAL_SUPPORTED_MODELS
,
SUPPORTED_MODELS
,
TRAINING_ARGS
,
DownloadSource
,
)
from
..extras.misc
import
use_modelscope
,
use_openmind
logger
=
logging
.
get_logger
(
__name__
)
DEFAULT_CACHE_DIR
=
"cache"
DEFAULT_CONFIG_DIR
=
"config"
DEFAULT_DATA_DIR
=
"data"
DEFAULT_SAVE_DIR
=
"saves"
USER_CONFIG
=
"user_config.yaml"
def
abort_process
(
pid
:
int
)
->
None
:
r
"""Abort the processes recursively in a bottom-up way."""
try
:
children
=
Process
(
pid
).
children
()
if
children
:
for
child
in
children
:
abort_process
(
child
.
pid
)
os
.
kill
(
pid
,
signal
.
SIGABRT
)
except
Exception
:
pass
def
get_save_dir
(
*
paths
:
str
)
->
os
.
PathLike
:
r
"""Get the path to saved model checkpoints."""
if
os
.
path
.
sep
in
paths
[
-
1
]:
logger
.
warning_rank0
(
"Found complex path, some features may be not available."
)
return
paths
[
-
1
]
paths
=
(
path
.
replace
(
" "
,
""
).
strip
()
for
path
in
paths
)
return
os
.
path
.
join
(
DEFAULT_SAVE_DIR
,
*
paths
)
def
_get_config_path
()
->
os
.
PathLike
:
r
"""Get the path to user config."""
return
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
USER_CONFIG
)
def
load_config
()
->
dict
[
str
,
Union
[
str
,
dict
[
str
,
Any
]]]:
r
"""Load user config if exists."""
try
:
with
open
(
_get_config_path
(),
encoding
=
"utf-8"
)
as
f
:
return
safe_load
(
f
)
except
Exception
:
return
{
"lang"
:
None
,
"hub_name"
:
None
,
"last_model"
:
None
,
"path_dict"
:
{},
"cache_dir"
:
None
}
def
save_config
(
lang
:
str
,
hub_name
:
Optional
[
str
]
=
None
,
model_name
:
Optional
[
str
]
=
None
,
model_path
:
Optional
[
str
]
=
None
)
->
None
:
r
"""Save user config."""
os
.
makedirs
(
DEFAULT_CACHE_DIR
,
exist_ok
=
True
)
user_config
=
load_config
()
user_config
[
"lang"
]
=
lang
or
user_config
[
"lang"
]
if
hub_name
:
user_config
[
"hub_name"
]
=
hub_name
if
model_name
:
user_config
[
"last_model"
]
=
model_name
if
model_name
and
model_path
:
user_config
[
"path_dict"
][
model_name
]
=
model_path
with
open
(
_get_config_path
(),
"w"
,
encoding
=
"utf-8"
)
as
f
:
safe_dump
(
user_config
,
f
)
def
get_model_path
(
model_name
:
str
)
->
str
:
r
"""Get the model path according to the model name."""
user_config
=
load_config
()
path_dict
:
dict
[
DownloadSource
,
str
]
=
SUPPORTED_MODELS
.
get
(
model_name
,
defaultdict
(
str
))
model_path
=
user_config
[
"path_dict"
].
get
(
model_name
,
""
)
or
path_dict
.
get
(
DownloadSource
.
DEFAULT
,
""
)
if
(
use_modelscope
()
and
path_dict
.
get
(
DownloadSource
.
MODELSCOPE
)
and
model_path
==
path_dict
.
get
(
DownloadSource
.
DEFAULT
)
):
# replace hf path with ms path
model_path
=
path_dict
.
get
(
DownloadSource
.
MODELSCOPE
)
if
(
use_openmind
()
and
path_dict
.
get
(
DownloadSource
.
OPENMIND
)
and
model_path
==
path_dict
.
get
(
DownloadSource
.
DEFAULT
)
):
# replace hf path with om path
model_path
=
path_dict
.
get
(
DownloadSource
.
OPENMIND
)
return
model_path
def
get_template
(
model_name
:
str
)
->
str
:
r
"""Get the template name if the model is a chat/distill/instruct model."""
return
DEFAULT_TEMPLATE
.
get
(
model_name
,
"default"
)
def
get_time
()
->
str
:
r
"""Get current date and time."""
return
datetime
.
now
().
strftime
(
r
"%Y-%m-%d-%H-%M-%S"
)
def
is_multimodal
(
model_name
:
str
)
->
bool
:
r
"""Judge if the model is a vision language model."""
return
model_name
in
MULTIMODAL_SUPPORTED_MODELS
def
load_dataset_info
(
dataset_dir
:
str
)
->
dict
[
str
,
dict
[
str
,
Any
]]:
r
"""Load dataset_info.json."""
if
dataset_dir
==
"ONLINE"
or
dataset_dir
.
startswith
(
"REMOTE:"
):
logger
.
info_rank0
(
f
"dataset_dir is
{
dataset_dir
}
, using online dataset."
)
return
{}
try
:
with
open
(
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
),
encoding
=
"utf-8"
)
as
f
:
return
json
.
load
(
f
)
except
Exception
as
err
:
logger
.
warning_rank0
(
f
"Cannot open
{
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
)
}
due to
{
str
(
err
)
}
."
)
return
{}
def
load_args
(
config_path
:
str
)
->
Optional
[
dict
[
str
,
Any
]]:
r
"""Load the training configuration from config path."""
try
:
with
open
(
config_path
,
encoding
=
"utf-8"
)
as
f
:
return
safe_load
(
f
)
except
Exception
:
return
None
def
save_args
(
config_path
:
str
,
config_dict
:
dict
[
str
,
Any
])
->
None
:
r
"""Save the training configuration to config path."""
with
open
(
config_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
safe_dump
(
config_dict
,
f
)
def
_clean_cmd
(
args
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
r
"""Remove args with NoneType or False or empty string value."""
no_skip_keys
=
[
"packing"
,
"enable_thinking"
,
"use_reentrant_gc"
,
"double_quantization"
,
"freeze_vision_tower"
,
"freeze_multi_modal_projector"
,
]
return
{
k
:
v
for
k
,
v
in
args
.
items
()
if
(
k
in
no_skip_keys
)
or
(
v
is
not
None
and
v
is
not
False
and
v
!=
""
)}
def
gen_cmd
(
args
:
dict
[
str
,
Any
])
->
str
:
r
"""Generate CLI commands for previewing."""
cmd_lines
=
[
"llamafactory-cli train "
]
for
k
,
v
in
_clean_cmd
(
args
).
items
():
if
isinstance
(
v
,
dict
):
cmd_lines
.
append
(
f
" --
{
k
}
{
json
.
dumps
(
v
,
ensure_ascii
=
False
)
}
"
)
elif
isinstance
(
v
,
list
):
cmd_lines
.
append
(
f
" --
{
k
}
{
' '
.
join
(
map
(
str
,
v
))
}
"
)
else
:
cmd_lines
.
append
(
f
" --
{
k
}
{
str
(
v
)
}
"
)
if
os
.
name
==
"nt"
:
cmd_text
=
"`
\n
"
.
join
(
cmd_lines
)
else
:
cmd_text
=
"
\\\n
"
.
join
(
cmd_lines
)
cmd_text
=
f
"```bash
\n
{
cmd_text
}
\n
```"
return
cmd_text
def
save_cmd
(
args
:
dict
[
str
,
Any
])
->
str
:
r
"""Save CLI commands to launch training."""
output_dir
=
args
[
"output_dir"
]
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
with
open
(
os
.
path
.
join
(
output_dir
,
TRAINING_ARGS
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
safe_dump
(
_clean_cmd
(
args
),
f
)
return
os
.
path
.
join
(
output_dir
,
TRAINING_ARGS
)
def
load_eval_results
(
path
:
os
.
PathLike
)
->
str
:
r
"""Get scores after evaluation."""
with
open
(
path
,
encoding
=
"utf-8"
)
as
f
:
result
=
json
.
dumps
(
json
.
load
(
f
),
indent
=
4
)
return
f
"```json
\n
{
result
}
\n
```
\n
"
def
calculate_pixels
(
pixels
:
str
)
->
int
:
r
"""Calculate the number of pixels from the expression."""
if
"*"
in
pixels
:
return
int
(
pixels
.
split
(
"*"
)[
0
])
*
int
(
pixels
.
split
(
"*"
)[
1
])
else
:
return
int
(
pixels
)
def
create_ds_config
()
->
None
:
r
"""Create deepspeed config in the current directory."""
os
.
makedirs
(
DEFAULT_CACHE_DIR
,
exist_ok
=
True
)
ds_config
=
{
"train_batch_size"
:
"auto"
,
"train_micro_batch_size_per_gpu"
:
"auto"
,
"gradient_accumulation_steps"
:
"auto"
,
"gradient_clipping"
:
"auto"
,
"zero_allow_untested_optimizer"
:
True
,
"fp16"
:
{
"enabled"
:
"auto"
,
"loss_scale"
:
0
,
"loss_scale_window"
:
1000
,
"initial_scale_power"
:
16
,
"hysteresis"
:
2
,
"min_loss_scale"
:
1
,
},
"bf16"
:
{
"enabled"
:
"auto"
},
}
offload_config
=
{
"device"
:
"cpu"
,
"pin_memory"
:
True
,
}
ds_config
[
"zero_optimization"
]
=
{
"stage"
:
2
,
"allgather_partitions"
:
True
,
"allgather_bucket_size"
:
5e8
,
"overlap_comm"
:
False
,
"reduce_scatter"
:
True
,
"reduce_bucket_size"
:
5e8
,
"contiguous_gradients"
:
True
,
"round_robin_gradients"
:
True
,
}
with
open
(
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
"ds_z2_config.json"
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
ds_config
,
f
,
indent
=
2
)
ds_config
[
"zero_optimization"
][
"offload_optimizer"
]
=
offload_config
with
open
(
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
"ds_z2_offload_config.json"
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
ds_config
,
f
,
indent
=
2
)
ds_config
[
"zero_optimization"
]
=
{
"stage"
:
3
,
"overlap_comm"
:
False
,
"contiguous_gradients"
:
True
,
"sub_group_size"
:
1e9
,
"reduce_bucket_size"
:
"auto"
,
"stage3_prefetch_bucket_size"
:
"auto"
,
"stage3_param_persistence_threshold"
:
"auto"
,
"stage3_max_live_parameters"
:
1e9
,
"stage3_max_reuse_distance"
:
1e9
,
"stage3_gather_16bit_weights_on_model_save"
:
True
,
}
with
open
(
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
"ds_z3_config.json"
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
ds_config
,
f
,
indent
=
2
)
ds_config
[
"zero_optimization"
][
"offload_optimizer"
]
=
offload_config
ds_config
[
"zero_optimization"
][
"offload_param"
]
=
offload_config
with
open
(
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
"ds_z3_offload_config.json"
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
ds_config
,
f
,
indent
=
2
)
src/llamafactory/webui/components/__init__.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.chatbot
import
create_chat_box
from
.eval
import
create_eval_tab
from
.export
import
create_export_tab
from
.footer
import
create_footer
from
.infer
import
create_infer_tab
from
.top
import
create_top
from
.train
import
create_train_tab
__all__
=
[
"create_chat_box"
,
"create_eval_tab"
,
"create_export_tab"
,
"create_footer"
,
"create_infer_tab"
,
"create_top"
,
"create_train_tab"
,
]
src/llamafactory/webui/components/chatbot.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
inspect
import
json
from
typing
import
TYPE_CHECKING
from
...data
import
Role
from
...extras.packages
import
is_gradio_available
from
..locales
import
ALERTS
if
is_gradio_available
():
import
gradio
as
gr
if
TYPE_CHECKING
:
from
gradio.components
import
Component
from
..engine
import
Engine
def
check_json_schema
(
text
:
str
,
lang
:
str
)
->
None
:
r
"""Check if the json schema is valid."""
try
:
tools
=
json
.
loads
(
text
)
if
tools
:
assert
isinstance
(
tools
,
list
)
for
tool
in
tools
:
if
"name"
not
in
tool
:
raise
NotImplementedError
(
"Name not found."
)
except
NotImplementedError
:
gr
.
Warning
(
ALERTS
[
"err_tool_name"
][
lang
])
except
Exception
:
gr
.
Warning
(
ALERTS
[
"err_json_schema"
][
lang
])
def
create_chat_box
(
engine
:
"Engine"
,
visible
:
bool
=
False
)
->
tuple
[
"Component"
,
"Component"
,
dict
[
str
,
"Component"
]]:
lang
=
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)
with
gr
.
Column
(
visible
=
visible
)
as
chat_box
:
kwargs
=
{}
if
"show_copy_button"
in
inspect
.
signature
(
gr
.
Chatbot
.
__init__
).
parameters
:
kwargs
[
"show_copy_button"
]
=
True
if
"resizable"
in
inspect
.
signature
(
gr
.
Chatbot
.
__init__
).
parameters
:
kwargs
[
"resizable"
]
=
True
chatbot
=
gr
.
Chatbot
(
type
=
"messages"
,
**
kwargs
)
messages
=
gr
.
State
([])
with
gr
.
Row
():
with
gr
.
Column
(
scale
=
4
):
with
gr
.
Row
():
with
gr
.
Column
():
role
=
gr
.
Dropdown
(
choices
=
[
Role
.
USER
.
value
,
Role
.
OBSERVATION
.
value
],
value
=
Role
.
USER
.
value
)
system
=
gr
.
Textbox
(
show_label
=
False
)
tools
=
gr
.
Textbox
(
show_label
=
False
,
lines
=
3
)
with
gr
.
Column
()
as
mm_box
:
with
gr
.
Tab
(
"Image"
):
image
=
gr
.
Image
(
type
=
"pil"
)
with
gr
.
Tab
(
"Video"
):
video
=
gr
.
Video
()
with
gr
.
Tab
(
"Audio"
):
audio
=
gr
.
Audio
(
type
=
"filepath"
)
query
=
gr
.
Textbox
(
show_label
=
False
,
lines
=
8
)
submit_btn
=
gr
.
Button
(
variant
=
"primary"
)
with
gr
.
Column
(
scale
=
1
):
max_new_tokens
=
gr
.
Slider
(
minimum
=
8
,
maximum
=
8192
,
value
=
1024
,
step
=
1
)
top_p
=
gr
.
Slider
(
minimum
=
0.01
,
maximum
=
1.0
,
value
=
0.7
,
step
=
0.01
)
temperature
=
gr
.
Slider
(
minimum
=
0.01
,
maximum
=
1.5
,
value
=
0.95
,
step
=
0.01
)
skip_special_tokens
=
gr
.
Checkbox
(
value
=
True
)
escape_html
=
gr
.
Checkbox
(
value
=
True
)
enable_thinking
=
gr
.
Checkbox
(
value
=
True
)
clear_btn
=
gr
.
Button
()
tools
.
input
(
check_json_schema
,
inputs
=
[
tools
,
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)])
submit_btn
.
click
(
engine
.
chatter
.
append
,
[
chatbot
,
messages
,
role
,
query
,
escape_html
],
[
chatbot
,
messages
,
query
],
).
then
(
engine
.
chatter
.
stream
,
[
chatbot
,
messages
,
lang
,
system
,
tools
,
image
,
video
,
audio
,
max_new_tokens
,
top_p
,
temperature
,
skip_special_tokens
,
escape_html
,
enable_thinking
,
],
[
chatbot
,
messages
],
)
clear_btn
.
click
(
lambda
:
([],
[]),
outputs
=
[
chatbot
,
messages
])
return
(
chatbot
,
messages
,
dict
(
chat_box
=
chat_box
,
role
=
role
,
system
=
system
,
tools
=
tools
,
mm_box
=
mm_box
,
image
=
image
,
video
=
video
,
audio
=
audio
,
query
=
query
,
submit_btn
=
submit_btn
,
max_new_tokens
=
max_new_tokens
,
top_p
=
top_p
,
temperature
=
temperature
,
skip_special_tokens
=
skip_special_tokens
,
escape_html
=
escape_html
,
enable_thinking
=
enable_thinking
,
clear_btn
=
clear_btn
,
),
)
src/llamafactory/webui/components/data.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
from
typing
import
TYPE_CHECKING
,
Any
from
...extras.constants
import
DATA_CONFIG
from
...extras.packages
import
is_gradio_available
if
is_gradio_available
():
import
gradio
as
gr
if
TYPE_CHECKING
:
from
gradio.components
import
Component
PAGE_SIZE
=
2
def
prev_page
(
page_index
:
int
)
->
int
:
return
page_index
-
1
if
page_index
>
0
else
page_index
def
next_page
(
page_index
:
int
,
total_num
:
int
)
->
int
:
return
page_index
+
1
if
(
page_index
+
1
)
*
PAGE_SIZE
<
total_num
else
page_index
def
can_preview
(
dataset_dir
:
str
,
dataset
:
list
)
->
"gr.Button"
:
r
"""Check if the dataset is a local dataset."""
try
:
with
open
(
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
),
encoding
=
"utf-8"
)
as
f
:
dataset_info
=
json
.
load
(
f
)
except
Exception
:
return
gr
.
Button
(
interactive
=
False
)
if
len
(
dataset
)
==
0
or
"file_name"
not
in
dataset_info
[
dataset
[
0
]]:
return
gr
.
Button
(
interactive
=
False
)
data_path
=
os
.
path
.
join
(
dataset_dir
,
dataset_info
[
dataset
[
0
]][
"file_name"
])
if
os
.
path
.
isfile
(
data_path
)
or
(
os
.
path
.
isdir
(
data_path
)
and
os
.
listdir
(
data_path
)):
return
gr
.
Button
(
interactive
=
True
)
else
:
return
gr
.
Button
(
interactive
=
False
)
def
_load_data_file
(
file_path
:
str
)
->
list
[
Any
]:
with
open
(
file_path
,
encoding
=
"utf-8"
)
as
f
:
if
file_path
.
endswith
(
".json"
):
return
json
.
load
(
f
)
elif
file_path
.
endswith
(
".jsonl"
):
return
[
json
.
loads
(
line
)
for
line
in
f
]
else
:
return
list
(
f
)
def
get_preview
(
dataset_dir
:
str
,
dataset
:
list
,
page_index
:
int
)
->
tuple
[
int
,
list
,
"gr.Column"
]:
r
"""Get the preview samples from the dataset."""
with
open
(
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
),
encoding
=
"utf-8"
)
as
f
:
dataset_info
=
json
.
load
(
f
)
data_path
=
os
.
path
.
join
(
dataset_dir
,
dataset_info
[
dataset
[
0
]][
"file_name"
])
if
os
.
path
.
isfile
(
data_path
):
data
=
_load_data_file
(
data_path
)
else
:
data
=
[]
for
file_name
in
os
.
listdir
(
data_path
):
data
.
extend
(
_load_data_file
(
os
.
path
.
join
(
data_path
,
file_name
)))
return
len
(
data
),
data
[
PAGE_SIZE
*
page_index
:
PAGE_SIZE
*
(
page_index
+
1
)],
gr
.
Column
(
visible
=
True
)
def
create_preview_box
(
dataset_dir
:
"gr.Textbox"
,
dataset
:
"gr.Dropdown"
)
->
dict
[
str
,
"Component"
]:
data_preview_btn
=
gr
.
Button
(
interactive
=
False
,
scale
=
1
)
with
gr
.
Column
(
visible
=
False
,
elem_classes
=
"modal-box"
)
as
preview_box
:
with
gr
.
Row
():
preview_count
=
gr
.
Number
(
value
=
0
,
interactive
=
False
,
precision
=
0
)
page_index
=
gr
.
Number
(
value
=
0
,
interactive
=
False
,
precision
=
0
)
with
gr
.
Row
():
prev_btn
=
gr
.
Button
()
next_btn
=
gr
.
Button
()
close_btn
=
gr
.
Button
()
with
gr
.
Row
():
preview_samples
=
gr
.
JSON
()
dataset
.
change
(
can_preview
,
[
dataset_dir
,
dataset
],
[
data_preview_btn
],
queue
=
False
).
then
(
lambda
:
0
,
outputs
=
[
page_index
],
queue
=
False
)
data_preview_btn
.
click
(
get_preview
,
[
dataset_dir
,
dataset
,
page_index
],
[
preview_count
,
preview_samples
,
preview_box
],
queue
=
False
)
prev_btn
.
click
(
prev_page
,
[
page_index
],
[
page_index
],
queue
=
False
).
then
(
get_preview
,
[
dataset_dir
,
dataset
,
page_index
],
[
preview_count
,
preview_samples
,
preview_box
],
queue
=
False
)
next_btn
.
click
(
next_page
,
[
page_index
,
preview_count
],
[
page_index
],
queue
=
False
).
then
(
get_preview
,
[
dataset_dir
,
dataset
,
page_index
],
[
preview_count
,
preview_samples
,
preview_box
],
queue
=
False
)
close_btn
.
click
(
lambda
:
gr
.
Column
(
visible
=
False
),
outputs
=
[
preview_box
],
queue
=
False
)
return
dict
(
data_preview_btn
=
data_preview_btn
,
preview_count
=
preview_count
,
page_index
=
page_index
,
prev_btn
=
prev_btn
,
next_btn
=
next_btn
,
close_btn
=
close_btn
,
preview_samples
=
preview_samples
,
)
src/llamafactory/webui/components/eval.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
from
...extras.packages
import
is_gradio_available
from
..common
import
DEFAULT_DATA_DIR
from
..control
import
list_datasets
from
.data
import
create_preview_box
if
is_gradio_available
():
import
gradio
as
gr
if
TYPE_CHECKING
:
from
gradio.components
import
Component
from
..engine
import
Engine
def
create_eval_tab
(
engine
:
"Engine"
)
->
dict
[
str
,
"Component"
]:
input_elems
=
engine
.
manager
.
get_base_elems
()
elem_dict
=
dict
()
with
gr
.
Row
():
dataset_dir
=
gr
.
Textbox
(
value
=
DEFAULT_DATA_DIR
,
scale
=
2
)
dataset
=
gr
.
Dropdown
(
multiselect
=
True
,
allow_custom_value
=
True
,
scale
=
4
)
preview_elems
=
create_preview_box
(
dataset_dir
,
dataset
)
input_elems
.
update
({
dataset_dir
,
dataset
})
elem_dict
.
update
(
dict
(
dataset_dir
=
dataset_dir
,
dataset
=
dataset
,
**
preview_elems
))
with
gr
.
Row
():
cutoff_len
=
gr
.
Slider
(
minimum
=
4
,
maximum
=
131072
,
value
=
1024
,
step
=
1
)
max_samples
=
gr
.
Textbox
(
value
=
"100000"
)
batch_size
=
gr
.
Slider
(
minimum
=
1
,
maximum
=
1024
,
value
=
2
,
step
=
1
)
predict
=
gr
.
Checkbox
(
value
=
True
)
input_elems
.
update
({
cutoff_len
,
max_samples
,
batch_size
,
predict
})
elem_dict
.
update
(
dict
(
cutoff_len
=
cutoff_len
,
max_samples
=
max_samples
,
batch_size
=
batch_size
,
predict
=
predict
))
with
gr
.
Row
():
max_new_tokens
=
gr
.
Slider
(
minimum
=
8
,
maximum
=
4096
,
value
=
512
,
step
=
1
)
top_p
=
gr
.
Slider
(
minimum
=
0.01
,
maximum
=
1
,
value
=
0.7
,
step
=
0.01
)
temperature
=
gr
.
Slider
(
minimum
=
0.01
,
maximum
=
1.5
,
value
=
0.95
,
step
=
0.01
)
output_dir
=
gr
.
Textbox
()
input_elems
.
update
({
max_new_tokens
,
top_p
,
temperature
,
output_dir
})
elem_dict
.
update
(
dict
(
max_new_tokens
=
max_new_tokens
,
top_p
=
top_p
,
temperature
=
temperature
,
output_dir
=
output_dir
))
with
gr
.
Row
():
cmd_preview_btn
=
gr
.
Button
()
start_btn
=
gr
.
Button
(
variant
=
"primary"
)
stop_btn
=
gr
.
Button
(
variant
=
"stop"
)
with
gr
.
Row
():
resume_btn
=
gr
.
Checkbox
(
visible
=
False
,
interactive
=
False
)
progress_bar
=
gr
.
Slider
(
visible
=
False
,
interactive
=
False
)
with
gr
.
Row
():
output_box
=
gr
.
Markdown
()
elem_dict
.
update
(
dict
(
cmd_preview_btn
=
cmd_preview_btn
,
start_btn
=
start_btn
,
stop_btn
=
stop_btn
,
resume_btn
=
resume_btn
,
progress_bar
=
progress_bar
,
output_box
=
output_box
,
)
)
output_elems
=
[
output_box
,
progress_bar
]
cmd_preview_btn
.
click
(
engine
.
runner
.
preview_eval
,
input_elems
,
output_elems
,
concurrency_limit
=
None
)
start_btn
.
click
(
engine
.
runner
.
run_eval
,
input_elems
,
output_elems
)
stop_btn
.
click
(
engine
.
runner
.
set_abort
)
resume_btn
.
change
(
engine
.
runner
.
monitor
,
outputs
=
output_elems
,
concurrency_limit
=
None
)
dataset
.
focus
(
list_datasets
,
[
dataset_dir
],
[
dataset
],
queue
=
False
)
return
elem_dict
src/llamafactory/webui/components/export.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
from
collections.abc
import
Generator
from
typing
import
TYPE_CHECKING
,
Union
from
...extras.constants
import
PEFT_METHODS
from
...extras.misc
import
torch_gc
from
...extras.packages
import
is_gradio_available
from
...train.tuner
import
export_model
from
..common
import
get_save_dir
,
load_config
from
..locales
import
ALERTS
if
is_gradio_available
():
import
gradio
as
gr
if
TYPE_CHECKING
:
from
gradio.components
import
Component
from
..engine
import
Engine
GPTQ_BITS
=
[
"8"
,
"4"
,
"3"
,
"2"
]
def
can_quantize
(
checkpoint_path
:
Union
[
str
,
list
[
str
]])
->
"gr.Dropdown"
:
if
isinstance
(
checkpoint_path
,
list
)
and
len
(
checkpoint_path
)
!=
0
:
return
gr
.
Dropdown
(
value
=
"none"
,
interactive
=
False
)
else
:
return
gr
.
Dropdown
(
interactive
=
True
)
def
save_model
(
lang
:
str
,
model_name
:
str
,
model_path
:
str
,
finetuning_type
:
str
,
checkpoint_path
:
Union
[
str
,
list
[
str
]],
template
:
str
,
export_size
:
int
,
export_quantization_bit
:
str
,
export_quantization_dataset
:
str
,
export_device
:
str
,
export_legacy_format
:
bool
,
export_dir
:
str
,
export_hub_model_id
:
str
,
extra_args
:
str
,
)
->
Generator
[
str
,
None
,
None
]:
user_config
=
load_config
()
error
=
""
if
not
model_name
:
error
=
ALERTS
[
"err_no_model"
][
lang
]
elif
not
model_path
:
error
=
ALERTS
[
"err_no_path"
][
lang
]
elif
not
export_dir
:
error
=
ALERTS
[
"err_no_export_dir"
][
lang
]
elif
export_quantization_bit
in
GPTQ_BITS
and
not
export_quantization_dataset
:
error
=
ALERTS
[
"err_no_dataset"
][
lang
]
elif
export_quantization_bit
not
in
GPTQ_BITS
and
not
checkpoint_path
:
error
=
ALERTS
[
"err_no_adapter"
][
lang
]
elif
export_quantization_bit
in
GPTQ_BITS
and
checkpoint_path
and
isinstance
(
checkpoint_path
,
list
):
error
=
ALERTS
[
"err_gptq_lora"
][
lang
]
try
:
json
.
loads
(
extra_args
)
except
json
.
JSONDecodeError
:
error
=
ALERTS
[
"err_json_schema"
][
lang
]
if
error
:
gr
.
Warning
(
error
)
yield
error
return
args
=
dict
(
model_name_or_path
=
model_path
,
cache_dir
=
user_config
.
get
(
"cache_dir"
,
None
),
finetuning_type
=
finetuning_type
,
template
=
template
,
export_dir
=
export_dir
,
export_hub_model_id
=
export_hub_model_id
or
None
,
export_size
=
export_size
,
export_quantization_bit
=
int
(
export_quantization_bit
)
if
export_quantization_bit
in
GPTQ_BITS
else
None
,
export_quantization_dataset
=
export_quantization_dataset
,
export_device
=
export_device
,
export_legacy_format
=
export_legacy_format
,
trust_remote_code
=
True
,
)
args
.
update
(
json
.
loads
(
extra_args
))
if
checkpoint_path
:
if
finetuning_type
in
PEFT_METHODS
:
# list
args
[
"adapter_name_or_path"
]
=
","
.
join
(
[
get_save_dir
(
model_name
,
finetuning_type
,
adapter
)
for
adapter
in
checkpoint_path
]
)
else
:
# str
args
[
"model_name_or_path"
]
=
get_save_dir
(
model_name
,
finetuning_type
,
checkpoint_path
)
yield
ALERTS
[
"info_exporting"
][
lang
]
export_model
(
args
)
torch_gc
()
yield
ALERTS
[
"info_exported"
][
lang
]
def
create_export_tab
(
engine
:
"Engine"
)
->
dict
[
str
,
"Component"
]:
with
gr
.
Row
():
export_size
=
gr
.
Slider
(
minimum
=
1
,
maximum
=
100
,
value
=
5
,
step
=
1
)
export_quantization_bit
=
gr
.
Dropdown
(
choices
=
[
"none"
]
+
GPTQ_BITS
,
value
=
"none"
)
export_quantization_dataset
=
gr
.
Textbox
(
value
=
"data/c4_demo.jsonl"
)
export_device
=
gr
.
Radio
(
choices
=
[
"cpu"
,
"auto"
],
value
=
"cpu"
)
export_legacy_format
=
gr
.
Checkbox
()
with
gr
.
Row
():
export_dir
=
gr
.
Textbox
()
export_hub_model_id
=
gr
.
Textbox
()
extra_args
=
gr
.
Textbox
(
value
=
"{}"
)
checkpoint_path
:
gr
.
Dropdown
=
engine
.
manager
.
get_elem_by_id
(
"top.checkpoint_path"
)
checkpoint_path
.
change
(
can_quantize
,
[
checkpoint_path
],
[
export_quantization_bit
],
queue
=
False
)
export_btn
=
gr
.
Button
()
info_box
=
gr
.
Textbox
(
show_label
=
False
,
interactive
=
False
)
export_btn
.
click
(
save_model
,
[
engine
.
manager
.
get_elem_by_id
(
"top.lang"
),
engine
.
manager
.
get_elem_by_id
(
"top.model_name"
),
engine
.
manager
.
get_elem_by_id
(
"top.model_path"
),
engine
.
manager
.
get_elem_by_id
(
"top.finetuning_type"
),
engine
.
manager
.
get_elem_by_id
(
"top.checkpoint_path"
),
engine
.
manager
.
get_elem_by_id
(
"top.template"
),
export_size
,
export_quantization_bit
,
export_quantization_dataset
,
export_device
,
export_legacy_format
,
export_dir
,
export_hub_model_id
,
extra_args
,
],
[
info_box
],
)
return
dict
(
export_size
=
export_size
,
export_quantization_bit
=
export_quantization_bit
,
export_quantization_dataset
=
export_quantization_dataset
,
export_device
=
export_device
,
export_legacy_format
=
export_legacy_format
,
export_dir
=
export_dir
,
export_hub_model_id
=
export_hub_model_id
,
extra_args
=
extra_args
,
export_btn
=
export_btn
,
info_box
=
info_box
,
)
Prev
1
…
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment