Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
317a82e2
Commit
317a82e2
authored
Mar 07, 2025
by
chenych
Browse files
Add QWQ-32B
parent
37b0ad9f
Changes
255
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
626 additions
and
562 deletions
+626
-562
src/llamafactory/train/rm/metric.py
src/llamafactory/train/rm/metric.py
+1
-1
src/llamafactory/train/rm/trainer.py
src/llamafactory/train/rm/trainer.py
+1
-5
src/llamafactory/train/sft/__init__.py
src/llamafactory/train/sft/__init__.py
+1
-1
src/llamafactory/train/sft/trainer.py
src/llamafactory/train/sft/trainer.py
+9
-20
src/llamafactory/train/sft/workflow.py
src/llamafactory/train/sft/workflow.py
+7
-6
src/llamafactory/train/test_utils.py
src/llamafactory/train/test_utils.py
+1
-1
src/llamafactory/train/trainer_utils.py
src/llamafactory/train/trainer_utils.py
+23
-4
src/llamafactory/train/tuner.py
src/llamafactory/train/tuner.py
+17
-3
src/llamafactory/train/utils.py
src/llamafactory/train/utils.py
+0
-401
src/llamafactory/webui/chatter.py
src/llamafactory/webui/chatter.py
+86
-23
src/llamafactory/webui/common.py
src/llamafactory/webui/common.py
+163
-56
src/llamafactory/webui/components/__init__.py
src/llamafactory/webui/components/__init__.py
+1
-1
src/llamafactory/webui/components/chatbot.py
src/llamafactory/webui/components/chatbot.py
+49
-8
src/llamafactory/webui/components/data.py
src/llamafactory/webui/components/data.py
+7
-1
src/llamafactory/webui/components/eval.py
src/llamafactory/webui/components/eval.py
+3
-2
src/llamafactory/webui/components/export.py
src/llamafactory/webui/components/export.py
+7
-2
src/llamafactory/webui/components/infer.py
src/llamafactory/webui/components/infer.py
+3
-3
src/llamafactory/webui/components/top.py
src/llamafactory/webui/components/top.py
+11
-12
src/llamafactory/webui/components/train.py
src/llamafactory/webui/components/train.py
+23
-12
src/llamafactory/webui/control.py
src/llamafactory/webui/control.py
+213
-0
No files found.
src/llamafactory/train/rm/metric.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/train/rm/trainer.py
View file @
317a82e2
...
@@ -25,7 +25,7 @@ from transformers import Trainer
...
@@ -25,7 +25,7 @@ from transformers import Trainer
from
typing_extensions
import
override
from
typing_extensions
import
override
from
...extras
import
logging
from
...extras
import
logging
from
...extras.packages
import
is_transformers_version_equal_to_4_46
,
is_transformers_version_greater_than
from
...extras.packages
import
is_transformers_version_greater_than
from
..callbacks
import
FixValueHeadModelCallback
,
SaveProcessorCallback
from
..callbacks
import
FixValueHeadModelCallback
,
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
...
@@ -107,10 +107,6 @@ class PairwiseTrainer(Trainer):
...
@@ -107,10 +107,6 @@ class PairwiseTrainer(Trainer):
chosen_scores
,
rejected_scores
=
chosen_scores
.
squeeze
(),
rejected_scores
.
squeeze
()
chosen_scores
,
rejected_scores
=
chosen_scores
.
squeeze
(),
rejected_scores
.
squeeze
()
loss
=
-
torch
.
nn
.
functional
.
logsigmoid
(
chosen_scores
.
float
()
-
rejected_scores
.
float
()).
mean
()
loss
=
-
torch
.
nn
.
functional
.
logsigmoid
(
chosen_scores
.
float
()
-
rejected_scores
.
float
()).
mean
()
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
get
(
"num_items_in_batch"
):
loss
/=
self
.
args
.
gradient_accumulation_steps
# fixes the loss value for transformers 4.46.0-4.46.1
if
return_outputs
:
if
return_outputs
:
return
loss
,
(
loss
,
chosen_scores
,
rejected_scores
)
return
loss
,
(
loss
,
chosen_scores
,
rejected_scores
)
else
:
else
:
...
...
src/llamafactory/train/sft/__init__.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/train/sft/trainer.py
View file @
317a82e2
...
@@ -34,7 +34,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
...
@@ -34,7 +34,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
,
ProcessorMixin
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
transformers.trainer
import
PredictionOutput
from
transformers.trainer
import
PredictionOutput
from
...hparams
import
FinetuningArguments
from
...hparams
import
FinetuningArguments
...
@@ -49,7 +49,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
...
@@ -49,7 +49,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
"""
"""
def
__init__
(
def
__init__
(
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
gen_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
**
kwargs
,
)
->
None
:
)
->
None
:
if
is_transformers_version_greater_than
(
"4.46"
):
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
...
@@ -58,6 +62,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
...
@@ -58,6 +62,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
finetuning_args
=
finetuning_args
self
.
finetuning_args
=
finetuning_args
if
gen_kwargs
is
not
None
:
# https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/trainer_seq2seq.py#L287
self
.
_gen_kwargs
=
gen_kwargs
if
processor
is
not
None
:
if
processor
is
not
None
:
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
...
@@ -88,24 +95,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
...
@@ -88,24 +95,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return
super
().
_get_train_sampler
()
return
super
().
_get_train_sampler
()
@
override
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
r
"""
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
,
**
kwargs
)
if
kwargs
.
get
(
"num_items_in_batch"
)
and
not
getattr
(
self
,
"model_accepts_loss_kwargs"
,
False
):
if
return_outputs
:
loss
=
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
loss
=
loss
/
self
.
args
.
gradient_accumulation_steps
return
loss
@
override
@
override
def
prediction_step
(
def
prediction_step
(
self
,
self
,
...
...
src/llamafactory/train/sft/workflow.py
View file @
317a82e2
...
@@ -78,6 +78,12 @@ def run_sft(
...
@@ -78,6 +78,12 @@ def run_sft(
metric_module
[
"compute_metrics"
]
=
ComputeAccuracy
()
metric_module
[
"compute_metrics"
]
=
ComputeAccuracy
()
metric_module
[
"preprocess_logits_for_metrics"
]
=
eval_logit_processor
metric_module
[
"preprocess_logits_for_metrics"
]
=
eval_logit_processor
# Keyword arguments for `model.generate`
gen_kwargs
=
generating_args
.
to_dict
(
obey_generation_config
=
True
)
gen_kwargs
[
"eos_token_id"
]
=
[
tokenizer
.
eos_token_id
]
+
tokenizer
.
additional_special_tokens_ids
gen_kwargs
[
"pad_token_id"
]
=
tokenizer
.
pad_token_id
gen_kwargs
[
"logits_processor"
]
=
get_logits_processor
()
# Initialize our Trainer
# Initialize our Trainer
trainer
=
CustomSeq2SeqTrainer
(
trainer
=
CustomSeq2SeqTrainer
(
model
=
model
,
model
=
model
,
...
@@ -85,17 +91,12 @@ def run_sft(
...
@@ -85,17 +91,12 @@ def run_sft(
finetuning_args
=
finetuning_args
,
finetuning_args
=
finetuning_args
,
data_collator
=
data_collator
,
data_collator
=
data_collator
,
callbacks
=
callbacks
,
callbacks
=
callbacks
,
gen_kwargs
=
gen_kwargs
,
**
dataset_module
,
**
dataset_module
,
**
tokenizer_module
,
**
tokenizer_module
,
**
metric_module
,
**
metric_module
,
)
)
# Keyword arguments for `model.generate`
gen_kwargs
=
generating_args
.
to_dict
(
obey_generation_config
=
True
)
gen_kwargs
[
"eos_token_id"
]
=
[
tokenizer
.
eos_token_id
]
+
tokenizer
.
additional_special_tokens_ids
gen_kwargs
[
"pad_token_id"
]
=
tokenizer
.
pad_token_id
gen_kwargs
[
"logits_processor"
]
=
get_logits_processor
()
# Training
# Training
if
training_args
.
do_train
:
if
training_args
.
do_train
:
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
...
...
src/llamafactory/train/test_utils.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/train/trainer_utils.py
View file @
317a82e2
...
@@ -17,6 +17,8 @@
...
@@ -17,6 +17,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
json
import
os
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
...
@@ -31,7 +33,7 @@ from transformers.trainer_pt_utils import get_parameter_names
...
@@ -31,7 +33,7 @@ from transformers.trainer_pt_utils import get_parameter_names
from
typing_extensions
import
override
from
typing_extensions
import
override
from
..extras
import
logging
from
..extras
import
logging
from
..extras.constants
import
IGNORE_INDEX
from
..extras.constants
import
IGNORE_INDEX
,
SWANLAB_CONFIG
from
..extras.packages
import
is_apollo_available
,
is_galore_available
,
is_ray_available
from
..extras.packages
import
is_apollo_available
,
is_galore_available
,
is_ray_available
from
..hparams
import
FinetuningArguments
,
ModelArguments
from
..hparams
import
FinetuningArguments
,
ModelArguments
from
..model
import
find_all_linear_modules
,
load_model
,
load_tokenizer
,
load_valuehead_params
from
..model
import
find_all_linear_modules
,
load_model
,
load_tokenizer
,
load_valuehead_params
...
@@ -51,7 +53,7 @@ if is_ray_available():
...
@@ -51,7 +53,7 @@ if is_ray_available():
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
,
TrainerCallback
from
transformers
import
PreTrainedModel
,
TrainerCallback
,
TrainerState
from
trl
import
AutoModelForCausalLMWithValueHead
from
trl
import
AutoModelForCausalLMWithValueHead
from
..hparams
import
DataArguments
,
RayArguments
,
TrainingArguments
from
..hparams
import
DataArguments
,
RayArguments
,
TrainingArguments
...
@@ -592,7 +594,24 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
...
@@ -592,7 +594,24 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
if
finetuning_args
.
swanlab_api_key
is
not
None
:
if
finetuning_args
.
swanlab_api_key
is
not
None
:
swanlab
.
login
(
api_key
=
finetuning_args
.
swanlab_api_key
)
swanlab
.
login
(
api_key
=
finetuning_args
.
swanlab_api_key
)
swanlab_callback
=
SwanLabCallback
(
class
SwanLabCallbackExtension
(
SwanLabCallback
):
def
setup
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
model
:
"PreTrainedModel"
,
**
kwargs
):
if
not
state
.
is_world_process_zero
:
return
super
().
setup
(
args
,
state
,
model
,
**
kwargs
)
try
:
if
hasattr
(
self
,
"_swanlab"
):
swanlab_public_config
=
self
.
_swanlab
.
get_run
().
public
.
json
()
else
:
# swanlab <= 0.4.9
swanlab_public_config
=
self
.
_experiment
.
get_run
().
public
.
json
()
except
Exception
:
swanlab_public_config
=
{}
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
SWANLAB_CONFIG
),
"w"
)
as
f
:
f
.
write
(
json
.
dumps
(
swanlab_public_config
,
indent
=
2
))
swanlab_callback
=
SwanLabCallbackExtension
(
project
=
finetuning_args
.
swanlab_project
,
project
=
finetuning_args
.
swanlab_project
,
workspace
=
finetuning_args
.
swanlab_workspace
,
workspace
=
finetuning_args
.
swanlab_workspace
,
experiment_name
=
finetuning_args
.
swanlab_run_name
,
experiment_name
=
finetuning_args
.
swanlab_run_name
,
...
@@ -621,7 +640,7 @@ def get_ray_trainer(
...
@@ -621,7 +640,7 @@ def get_ray_trainer(
),
),
run_config
=
RunConfig
(
run_config
=
RunConfig
(
name
=
ray_args
.
ray_run_name
,
name
=
ray_args
.
ray_run_name
,
storage_path
=
Path
(
"./saves"
).
absolute
().
as_posix
(),
storage_path
=
Path
(
ray_args
.
ray_storage_path
).
absolute
().
as_posix
(),
),
),
)
)
return
trainer
return
trainer
src/llamafactory/train/tuner.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -17,11 +17,13 @@ import shutil
...
@@ -17,11 +17,13 @@ import shutil
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
import
torch
import
torch
import
torch.distributed
as
dist
from
transformers
import
PreTrainedModel
from
transformers
import
PreTrainedModel
from
..data
import
get_template_and_fix_tokenizer
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras
import
logging
from
..extras.constants
import
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
..extras.constants
import
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
..extras.misc
import
infer_optim_dtype
from
..extras.packages
import
is_ray_available
from
..extras.packages
import
is_ray_available
from
..hparams
import
get_infer_args
,
get_ray_args
,
get_train_args
,
read_args
from
..hparams
import
get_infer_args
,
get_ray_args
,
get_train_args
,
read_args
from
..model
import
load_model
,
load_tokenizer
from
..model
import
load_model
,
load_tokenizer
...
@@ -75,6 +77,12 @@ def _training_function(config: Dict[str, Any]) -> None:
...
@@ -75,6 +77,12 @@ def _training_function(config: Dict[str, Any]) -> None:
else
:
else
:
raise
ValueError
(
f
"Unknown task:
{
finetuning_args
.
stage
}
."
)
raise
ValueError
(
f
"Unknown task:
{
finetuning_args
.
stage
}
."
)
try
:
if
dist
.
is_initialized
():
dist
.
destroy_process_group
()
except
Exception
as
e
:
logger
.
warning
(
f
"Failed to destroy process group:
{
e
}
."
)
def
run_exp
(
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
callbacks
:
Optional
[
List
[
"TrainerCallback"
]]
=
None
)
->
None
:
def
run_exp
(
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
callbacks
:
Optional
[
List
[
"TrainerCallback"
]]
=
None
)
->
None
:
args
=
read_args
(
args
)
args
=
read_args
(
args
)
...
@@ -104,7 +112,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
...
@@ -104,7 +112,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
tokenizer
=
tokenizer_module
[
"tokenizer"
]
processor
=
tokenizer_module
[
"processor"
]
processor
=
tokenizer_module
[
"processor"
]
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
)
# must after fixing tokenizer to resize vocab
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
)
# must after fixing tokenizer to resize vocab
if
getattr
(
model
,
"quantization_method"
,
None
)
is
not
None
and
model_args
.
adapter_name_or_path
is
not
None
:
if
getattr
(
model
,
"quantization_method"
,
None
)
is
not
None
and
model_args
.
adapter_name_or_path
is
not
None
:
...
@@ -117,7 +125,9 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
...
@@ -117,7 +125,9 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
setattr
(
model
.
config
,
"torch_dtype"
,
torch
.
float16
)
setattr
(
model
.
config
,
"torch_dtype"
,
torch
.
float16
)
else
:
else
:
if
model_args
.
infer_dtype
==
"auto"
:
if
model_args
.
infer_dtype
==
"auto"
:
output_dtype
=
getattr
(
model
.
config
,
"torch_dtype"
,
torch
.
float16
)
output_dtype
=
getattr
(
model
.
config
,
"torch_dtype"
,
torch
.
float32
)
if
output_dtype
==
torch
.
float32
:
# if infer_dtype is auto, try using half precision first
output_dtype
=
infer_optim_dtype
(
torch
.
bfloat16
)
else
:
else
:
output_dtype
=
getattr
(
torch
,
model_args
.
infer_dtype
)
output_dtype
=
getattr
(
torch
,
model_args
.
infer_dtype
)
...
@@ -171,3 +181,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
...
@@ -171,3 +181,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
warning_rank0
(
f
"Cannot save tokenizer, please copy the files manually:
{
e
}
."
)
logger
.
warning_rank0
(
f
"Cannot save tokenizer, please copy the files manually:
{
e
}
."
)
with
open
(
os
.
path
.
join
(
model_args
.
export_dir
,
"Modelfile"
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
template
.
get_ollama_modelfile
(
tokenizer
))
logger
.
info_rank0
(
f
"Saved ollama modelfile to
{
model_args
.
export_dir
}
."
)
src/llamafactory/train/utils.py
deleted
100644 → 0
View file @
37b0ad9f
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
torch
from
transformers
import
Trainer
from
transformers.optimization
import
get_scheduler
from
transformers.pytorch_utils
import
ALL_LAYERNORM_LAYERS
from
transformers.trainer_pt_utils
import
get_parameter_names
from
..extras.logging
import
get_logger
from
..extras.packages
import
is_galore_available
from
..hparams
import
FinetuningArguments
,
ModelArguments
from
..model
import
find_all_linear_modules
,
load_model
,
load_tokenizer
,
load_valuehead_params
if
is_galore_available
():
from
galore_torch
import
GaLoreAdafactor
,
GaLoreAdamW
,
GaLoreAdamW8bit
if
TYPE_CHECKING
:
from
accelerate
import
Accelerator
from
transformers
import
PreTrainedModel
,
Seq2SeqTrainingArguments
from
trl
import
AutoModelForCausalLMWithValueHead
from
..hparams
import
DataArguments
logger
=
get_logger
(
__name__
)
class
DummyOptimizer
(
torch
.
optim
.
Optimizer
):
r
"""
A dummy optimizer used for the GaLore algorithm.
"""
def
__init__
(
self
,
lr
:
float
=
1e-3
,
optimizer_dict
:
Optional
[
Dict
[
"torch.nn.Parameter"
,
"torch.optim.Optimizer"
]]
=
None
)
->
None
:
dummy_tensor
=
torch
.
randn
(
1
,
1
)
self
.
optimizer_dict
=
optimizer_dict
super
().
__init__
([
dummy_tensor
],
{
"lr"
:
lr
})
def
zero_grad
(
self
,
set_to_none
:
bool
=
True
)
->
None
:
pass
def
step
(
self
,
closure
:
Optional
[
Callable
[[],
float
]]
=
None
)
->
Optional
[
float
]:
pass
def
create_modelcard_and_push
(
trainer
:
"Trainer"
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
None
:
kwargs
=
{
"tasks"
:
"text-generation"
,
"finetuned_from"
:
model_args
.
model_name_or_path
,
"tags"
:
[
"llama-factory"
,
finetuning_args
.
finetuning_type
],
}
if
data_args
.
dataset
is
not
None
:
kwargs
[
"dataset"
]
=
[
dataset
.
strip
()
for
dataset
in
data_args
.
dataset
.
split
(
","
)]
if
model_args
.
use_unsloth
:
kwargs
[
"tags"
]
=
kwargs
[
"tags"
]
+
[
"unsloth"
]
if
not
training_args
.
do_train
:
pass
elif
training_args
.
push_to_hub
:
trainer
.
push_to_hub
(
**
kwargs
)
else
:
trainer
.
create_model_card
(
license
=
"other"
,
**
kwargs
)
# prevent from connecting to hub
def
create_ref_model
(
model_args
:
"ModelArguments"
,
finetuning_args
:
"FinetuningArguments"
,
add_valuehead
:
bool
=
False
)
->
Optional
[
Union
[
"PreTrainedModel"
,
"AutoModelForCausalLMWithValueHead"
]]:
r
"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
if
finetuning_args
.
ref_model
is
not
None
:
ref_model_args_dict
=
model_args
.
to_dict
()
ref_model_args_dict
.
update
(
dict
(
model_name_or_path
=
finetuning_args
.
ref_model
,
adapter_name_or_path
=
finetuning_args
.
ref_model_adapters
,
quantization_bit
=
finetuning_args
.
ref_model_quantization_bit
,
)
)
ref_model_args
=
ModelArguments
(
**
ref_model_args_dict
)
ref_finetuning_args
=
FinetuningArguments
()
tokenizer
=
load_tokenizer
(
ref_model_args
)[
"tokenizer"
]
ref_model
=
load_model
(
tokenizer
,
ref_model_args
,
ref_finetuning_args
,
is_trainable
=
False
,
add_valuehead
=
add_valuehead
)
logger
.
info
(
"Created reference model from {}"
.
format
(
finetuning_args
.
ref_model
))
else
:
if
finetuning_args
.
finetuning_type
==
"lora"
:
ref_model
=
None
else
:
tokenizer
=
load_tokenizer
(
model_args
)[
"tokenizer"
]
ref_model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
is_trainable
=
False
,
add_valuehead
=
add_valuehead
)
logger
.
info
(
"Created reference model from the model itself."
)
return
ref_model
def
create_reward_model
(
model
:
"AutoModelForCausalLMWithValueHead"
,
model_args
:
"ModelArguments"
,
finetuning_args
:
"FinetuningArguments"
)
->
Optional
[
"AutoModelForCausalLMWithValueHead"
]:
r
"""
Creates reward model for PPO training.
"""
if
finetuning_args
.
reward_model_type
==
"api"
:
assert
finetuning_args
.
reward_model
.
startswith
(
"http"
),
"Please provide full url."
logger
.
info
(
"Use reward server {}"
.
format
(
finetuning_args
.
reward_model
))
return
finetuning_args
.
reward_model
elif
finetuning_args
.
reward_model_type
==
"lora"
:
model
.
pretrained_model
.
load_adapter
(
finetuning_args
.
reward_model
,
"reward"
)
for
name
,
param
in
model
.
named_parameters
():
# https://github.com/huggingface/peft/issues/1090
if
"default"
in
name
:
param
.
data
=
param
.
data
.
to
(
torch
.
float32
)
# trainable params should in fp32
vhead_params
=
load_valuehead_params
(
finetuning_args
.
reward_model
,
model_args
)
assert
vhead_params
is
not
None
,
"Reward model is not correctly loaded."
model
.
register_buffer
(
"reward_head_weight"
,
vhead_params
[
"v_head.summary.weight"
],
persistent
=
False
)
model
.
register_buffer
(
"reward_head_bias"
,
vhead_params
[
"v_head.summary.bias"
],
persistent
=
False
)
model
.
register_buffer
(
"default_head_weight"
,
torch
.
zeros_like
(
vhead_params
[
"v_head.summary.weight"
]),
persistent
=
False
)
model
.
register_buffer
(
"default_head_bias"
,
torch
.
zeros_like
(
vhead_params
[
"v_head.summary.bias"
]),
persistent
=
False
)
logger
.
info
(
"Loaded adapter weights of reward model from {}"
.
format
(
finetuning_args
.
reward_model
))
return
None
else
:
reward_model_args_dict
=
model_args
.
to_dict
()
reward_model_args_dict
.
update
(
dict
(
model_name_or_path
=
finetuning_args
.
reward_model
,
adapter_name_or_path
=
finetuning_args
.
reward_model_adapters
,
quantization_bit
=
finetuning_args
.
reward_model_quantization_bit
,
)
)
reward_model_args
=
ModelArguments
(
**
reward_model_args_dict
)
reward_finetuning_args
=
FinetuningArguments
()
tokenizer
=
load_tokenizer
(
reward_model_args
)[
"tokenizer"
]
reward_model
=
load_model
(
tokenizer
,
reward_model_args
,
reward_finetuning_args
,
is_trainable
=
False
,
add_valuehead
=
True
)
logger
.
info
(
"Loaded full weights of reward model from {}"
.
format
(
finetuning_args
.
reward_model
))
logger
.
warning
(
"Please ensure the ppo model and reward model share SAME tokenizer and vocabulary."
)
return
reward_model
@
contextmanager
def
get_ref_context
(
accelerator
:
"Accelerator"
,
model
:
"PreTrainedModel"
):
r
"""
Gets adapter context for the reference model.
"""
with
accelerator
.
unwrap_model
(
model
).
disable_adapter
():
model
.
eval
()
yield
model
.
train
()
def
_get_decay_parameter_names
(
model
:
"PreTrainedModel"
)
->
List
[
str
]:
r
"""
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
"""
decay_parameters
=
get_parameter_names
(
model
,
ALL_LAYERNORM_LAYERS
)
decay_parameters
=
[
name
for
name
in
decay_parameters
if
"bias"
not
in
name
]
return
decay_parameters
def
_create_galore_optimizer
(
model
:
"PreTrainedModel"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
"torch.optim.Optimizer"
:
if
len
(
finetuning_args
.
galore_target
)
==
1
and
finetuning_args
.
galore_target
[
0
]
==
"all"
:
galore_targets
=
find_all_linear_modules
(
model
)
else
:
galore_targets
=
finetuning_args
.
galore_target
galore_params
:
List
[
"torch.nn.Parameter"
]
=
[]
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
torch
.
nn
.
Linear
)
and
any
(
target
in
name
for
target
in
galore_targets
):
for
param
in
module
.
parameters
():
if
param
.
requires_grad
and
len
(
param
.
shape
)
>
1
:
galore_params
.
append
(
param
)
galore_kwargs
=
{
"rank"
:
finetuning_args
.
galore_rank
,
"update_proj_gap"
:
finetuning_args
.
galore_update_interval
,
"scale"
:
finetuning_args
.
galore_scale
,
"proj_type"
:
finetuning_args
.
galore_proj_type
,
}
id_galore_params
=
{
id
(
param
)
for
param
in
galore_params
}
decay_params
,
nodecay_params
=
[],
[]
# they are non-galore parameters
trainable_params
:
List
[
"torch.nn.Parameter"
]
=
[]
# galore_params + decay_params + nodecay_params
decay_param_names
=
_get_decay_parameter_names
(
model
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
trainable_params
.
append
(
param
)
if
id
(
param
)
not
in
id_galore_params
:
if
name
in
decay_param_names
:
decay_params
.
append
(
param
)
else
:
nodecay_params
.
append
(
param
)
_
,
optim_kwargs
=
Trainer
.
get_optimizer_cls_and_kwargs
(
training_args
)
if
training_args
.
optim
==
"adamw_torch"
:
optim_class
=
GaLoreAdamW
elif
training_args
.
optim
in
[
"adamw_bnb_8bit"
,
"adamw_8bit"
,
"paged_adamw_8bit"
]:
optim_class
=
GaLoreAdamW8bit
elif
training_args
.
optim
==
"adafactor"
:
optim_class
=
GaLoreAdafactor
else
:
raise
NotImplementedError
(
"Unknow optim: {}"
.
format
(
training_args
.
optim
))
if
finetuning_args
.
galore_layerwise
:
if
training_args
.
gradient_accumulation_steps
!=
1
:
raise
ValueError
(
"Per-layer GaLore does not support gradient accumulation."
)
optimizer_dict
:
Dict
[
"torch.Tensor"
,
"torch.optim.Optimizer"
]
=
{}
for
param
in
nodecay_params
:
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
0.0
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
for
param
in
decay_params
:
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
training_args
.
weight_decay
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
for
param
in
galore_params
:
# galore params have weight decay
param_groups
=
[
dict
(
params
=
[
param
],
weight_decay
=
training_args
.
weight_decay
,
**
galore_kwargs
)]
optimizer_dict
[
param
]
=
optim_class
(
param_groups
,
**
optim_kwargs
)
def
optimizer_hook
(
param
:
"torch.nn.Parameter"
):
if
param
.
grad
is
not
None
:
optimizer_dict
[
param
].
step
()
optimizer_dict
[
param
].
zero_grad
()
for
param
in
trainable_params
:
param
.
register_post_accumulate_grad_hook
(
optimizer_hook
)
optimizer
=
DummyOptimizer
(
lr
=
training_args
.
learning_rate
,
optimizer_dict
=
optimizer_dict
)
else
:
param_groups
=
[
dict
(
params
=
nodecay_params
,
weight_decay
=
0.0
),
dict
(
params
=
decay_params
,
weight_decay
=
training_args
.
weight_decay
),
dict
(
params
=
galore_params
,
weight_decay
=
training_args
.
weight_decay
,
**
galore_kwargs
),
]
optimizer
=
optim_class
(
param_groups
,
**
optim_kwargs
)
logger
.
info
(
"Using GaLore optimizer, may cause hanging at the start of training, wait patiently."
)
return
optimizer
def
_create_loraplus_optimizer
(
model
:
"PreTrainedModel"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
"torch.optim.Optimizer"
:
default_lr
=
training_args
.
learning_rate
loraplus_lr
=
training_args
.
learning_rate
*
finetuning_args
.
loraplus_lr_ratio
embedding_lr
=
finetuning_args
.
loraplus_lr_embedding
decay_param_names
=
_get_decay_parameter_names
(
model
)
param_dict
:
Dict
[
str
,
List
[
"torch.nn.Parameter"
]]
=
{
"lora_a"
:
[],
"lora_b"
:
[],
"lora_b_nodecay"
:
[],
"embedding"
:
[],
}
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
if
"lora_embedding_B"
in
name
:
param_dict
[
"embedding"
].
append
(
param
)
elif
"lora_B"
in
name
or
param
.
ndim
==
1
:
if
name
in
decay_param_names
:
param_dict
[
"lora_b"
].
append
(
param
)
else
:
param_dict
[
"lora_b_nodecay"
].
append
(
param
)
else
:
param_dict
[
"lora_a"
].
append
(
param
)
optim_class
,
optim_kwargs
=
Trainer
.
get_optimizer_cls_and_kwargs
(
training_args
)
param_groups
=
[
dict
(
params
=
param_dict
[
"lora_a"
],
lr
=
default_lr
,
weight_decay
=
training_args
.
weight_decay
),
dict
(
params
=
param_dict
[
"lora_b"
],
lr
=
loraplus_lr
,
weight_decay
=
training_args
.
weight_decay
),
dict
(
params
=
param_dict
[
"lora_b_nodecay"
],
lr
=
loraplus_lr
,
weight_decay
=
0.0
),
dict
(
params
=
param_dict
[
"embedding"
],
lr
=
embedding_lr
,
weight_decay
=
training_args
.
weight_decay
),
]
optimizer
=
optim_class
(
param_groups
,
**
optim_kwargs
)
logger
.
info
(
"Using LoRA+ optimizer with loraplus lr ratio {:.2f}."
.
format
(
finetuning_args
.
loraplus_lr_ratio
))
return
optimizer
def
_create_badam_optimizer
(
model
:
"PreTrainedModel"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
"torch.optim.Optimizer"
:
decay_params
,
nodecay_params
=
[],
[]
decay_param_names
=
_get_decay_parameter_names
(
model
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
if
name
in
decay_param_names
:
decay_params
.
append
(
param
)
else
:
nodecay_params
.
append
(
param
)
optim_class
,
optim_kwargs
=
Trainer
.
get_optimizer_cls_and_kwargs
(
training_args
)
param_groups
=
[
dict
(
params
=
nodecay_params
,
weight_decay
=
0.0
),
dict
(
params
=
decay_params
,
weight_decay
=
training_args
.
weight_decay
),
]
if
finetuning_args
.
badam_mode
==
"layer"
:
from
badam
import
BlockOptimizer
base_optimizer
=
optim_class
(
param_groups
,
**
optim_kwargs
)
optimizer
=
BlockOptimizer
(
base_optimizer
=
base_optimizer
,
named_parameters_list
=
list
(
model
.
named_parameters
()),
block_prefix_list
=
None
,
switch_block_every
=
finetuning_args
.
badam_switch_interval
,
start_block
=
finetuning_args
.
badam_start_block
,
switch_mode
=
finetuning_args
.
badam_switch_mode
,
verbose
=
finetuning_args
.
badam_verbose
,
)
logger
.
info
(
f
"Using BAdam optimizer with layer-wise update, switch mode is
{
finetuning_args
.
badam_switch_mode
}
, "
f
"switch block every
{
finetuning_args
.
badam_switch_interval
}
steps, "
f
"default start block is
{
finetuning_args
.
badam_start_block
}
"
)
elif
finetuning_args
.
badam_mode
==
"ratio"
:
from
badam
import
BlockOptimizerRatio
assert
finetuning_args
.
badam_update_ratio
>
1e-6
optimizer
=
BlockOptimizerRatio
(
param_groups
=
param_groups
,
named_parameters_list
=
list
(
model
.
named_parameters
()),
update_ratio
=
finetuning_args
.
badam_update_ratio
,
mask_mode
=
finetuning_args
.
badam_mask_mode
,
verbose
=
finetuning_args
.
badam_verbose
,
include_embedding
=
False
,
**
optim_kwargs
,
)
logger
.
info
(
f
"Using BAdam optimizer with ratio-wise update, update ratio is
{
finetuning_args
.
badam_update_ratio
}
, "
f
"mask mode is
{
finetuning_args
.
badam_mask_mode
}
"
)
return
optimizer
def
create_custom_optimzer
(
model
:
"PreTrainedModel"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
Optional
[
"torch.optim.Optimizer"
]:
if
finetuning_args
.
use_galore
:
return
_create_galore_optimizer
(
model
,
training_args
,
finetuning_args
)
if
finetuning_args
.
loraplus_lr_ratio
is
not
None
:
return
_create_loraplus_optimizer
(
model
,
training_args
,
finetuning_args
)
if
finetuning_args
.
use_badam
:
return
_create_badam_optimizer
(
model
,
training_args
,
finetuning_args
)
def
create_custom_scheduler
(
training_args
:
"Seq2SeqTrainingArguments"
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
,
)
->
None
:
if
optimizer
is
not
None
and
isinstance
(
optimizer
,
DummyOptimizer
):
optimizer_dict
=
optimizer
.
optimizer_dict
scheduler_dict
:
Dict
[
"torch.nn.Parameter"
,
"torch.optim.lr_scheduler.LRScheduler"
]
=
{}
for
param
in
optimizer_dict
.
keys
():
scheduler_dict
[
param
]
=
get_scheduler
(
training_args
.
lr_scheduler_type
,
optimizer
=
optimizer_dict
[
param
],
num_warmup_steps
=
training_args
.
get_warmup_steps
(
num_training_steps
),
num_training_steps
=
num_training_steps
,
scheduler_specific_kwargs
=
training_args
.
lr_scheduler_kwargs
,
)
def
scheduler_hook
(
param
:
"torch.nn.Parameter"
):
scheduler_dict
[
param
].
step
()
for
param
in
optimizer_dict
.
keys
():
param
.
register_post_accumulate_grad_hook
(
scheduler_hook
)
src/llamafactory/webui/chatter.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -14,14 +14,16 @@
...
@@ -14,14 +14,16 @@
import
json
import
json
import
os
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Generator
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
from
transformers.utils
import
is_torch_npu_available
from
..chat
import
ChatModel
from
..chat
import
ChatModel
from
..data
import
Role
from
..data
import
Role
from
..extras.constants
import
PEFT_METHODS
from
..extras.constants
import
PEFT_METHODS
from
..extras.misc
import
torch_gc
from
..extras.misc
import
torch_gc
from
..extras.packages
import
is_gradio_available
from
..extras.packages
import
is_gradio_available
from
.common
import
QUANTIZATION_BITS
,
get_save_dir
from
.common
import
get_save_dir
,
load_config
from
.locales
import
ALERTS
from
.locales
import
ALERTS
...
@@ -34,6 +36,40 @@ if is_gradio_available():
...
@@ -34,6 +36,40 @@ if is_gradio_available():
import
gradio
as
gr
import
gradio
as
gr
def
_escape_html
(
text
:
str
)
->
str
:
r
"""
Escapes HTML characters.
"""
return
text
.
replace
(
"<"
,
"<"
).
replace
(
">"
,
">"
)
def
_format_response
(
text
:
str
,
lang
:
str
,
escape_html
:
bool
,
thought_words
:
Tuple
[
str
,
str
])
->
str
:
r
"""
Post-processes the response text.
Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
"""
if
thought_words
[
0
]
not
in
text
:
return
_escape_html
(
text
)
if
escape_html
else
text
text
=
text
.
replace
(
thought_words
[
0
],
""
)
result
=
text
.
split
(
thought_words
[
1
],
maxsplit
=
1
)
if
len
(
result
)
==
1
:
summary
=
ALERTS
[
"info_thinking"
][
lang
]
thought
,
answer
=
text
,
""
else
:
summary
=
ALERTS
[
"info_thought"
][
lang
]
thought
,
answer
=
result
if
escape_html
:
thought
,
answer
=
_escape_html
(
thought
),
_escape_html
(
answer
)
return
(
f
"<details open><summary class='thinking-summary'><span>
{
summary
}
</span></summary>
\n\n
"
f
"<div class='thinking-container'>
\n
{
thought
}
\n
</div>
\n
</details>
{
answer
}
"
)
class
WebChatModel
(
ChatModel
):
class
WebChatModel
(
ChatModel
):
def
__init__
(
self
,
manager
:
"Manager"
,
demo_mode
:
bool
=
False
,
lazy_init
:
bool
=
True
)
->
None
:
def
__init__
(
self
,
manager
:
"Manager"
,
demo_mode
:
bool
=
False
,
lazy_init
:
bool
=
True
)
->
None
:
self
.
manager
=
manager
self
.
manager
=
manager
...
@@ -59,6 +95,8 @@ class WebChatModel(ChatModel):
...
@@ -59,6 +95,8 @@ class WebChatModel(ChatModel):
get
=
lambda
elem_id
:
data
[
self
.
manager
.
get_elem_by_id
(
elem_id
)]
get
=
lambda
elem_id
:
data
[
self
.
manager
.
get_elem_by_id
(
elem_id
)]
lang
,
model_name
,
model_path
=
get
(
"top.lang"
),
get
(
"top.model_name"
),
get
(
"top.model_path"
)
lang
,
model_name
,
model_path
=
get
(
"top.lang"
),
get
(
"top.model_name"
),
get
(
"top.model_path"
)
finetuning_type
,
checkpoint_path
=
get
(
"top.finetuning_type"
),
get
(
"top.checkpoint_path"
)
finetuning_type
,
checkpoint_path
=
get
(
"top.finetuning_type"
),
get
(
"top.checkpoint_path"
)
user_config
=
load_config
()
error
=
""
error
=
""
if
self
.
loaded
:
if
self
.
loaded
:
error
=
ALERTS
[
"err_exists"
][
lang
]
error
=
ALERTS
[
"err_exists"
][
lang
]
...
@@ -74,26 +112,22 @@ class WebChatModel(ChatModel):
...
@@ -74,26 +112,22 @@ class WebChatModel(ChatModel):
yield
error
yield
error
return
return
if
get
(
"top.quantization_bit"
)
in
QUANTIZATION_BITS
:
quantization_bit
=
int
(
get
(
"top.quantization_bit"
))
else
:
quantization_bit
=
None
yield
ALERTS
[
"info_loading"
][
lang
]
yield
ALERTS
[
"info_loading"
][
lang
]
args
=
dict
(
args
=
dict
(
model_name_or_path
=
model_path
,
model_name_or_path
=
model_path
,
cache_dir
=
user_config
.
get
(
"cache_dir"
,
None
),
finetuning_type
=
finetuning_type
,
finetuning_type
=
finetuning_type
,
quantization_bit
=
quantization_bit
,
quantization_method
=
get
(
"top.quantization_method"
),
template
=
get
(
"top.template"
),
template
=
get
(
"top.template"
),
rope_scaling
=
get
(
"top.rope_scaling"
)
if
get
(
"top.rope_scaling"
)
!=
"none"
else
None
,
flash_attn
=
"fa2"
if
get
(
"top.booster"
)
==
"flashattn2"
else
"auto"
,
flash_attn
=
"fa2"
if
get
(
"top.booster"
)
==
"flashattn2"
else
"auto"
,
use_unsloth
=
(
get
(
"top.booster"
)
==
"unsloth"
),
use_unsloth
=
(
get
(
"top.booster"
)
==
"unsloth"
),
rope_scaling
=
get
(
"top.rope_scaling"
)
if
get
(
"top.rope_scaling"
)
in
[
"linear"
,
"dynamic"
]
else
None
,
enable_liger_kernel
=
(
get
(
"top.booster"
)
==
"liger_kernel"
)
,
infer_backend
=
get
(
"infer.infer_backend"
),
infer_backend
=
get
(
"infer.infer_backend"
),
infer_dtype
=
get
(
"infer.infer_dtype"
),
infer_dtype
=
get
(
"infer.infer_dtype"
),
trust_remote_code
=
True
,
trust_remote_code
=
True
,
)
)
# checkpoints
if
checkpoint_path
:
if
checkpoint_path
:
if
finetuning_type
in
PEFT_METHODS
:
# list
if
finetuning_type
in
PEFT_METHODS
:
# list
args
[
"adapter_name_or_path"
]
=
","
.
join
(
args
[
"adapter_name_or_path"
]
=
","
.
join
(
...
@@ -102,6 +136,12 @@ class WebChatModel(ChatModel):
...
@@ -102,6 +136,12 @@ class WebChatModel(ChatModel):
else
:
# str
else
:
# str
args
[
"model_name_or_path"
]
=
get_save_dir
(
model_name
,
finetuning_type
,
checkpoint_path
)
args
[
"model_name_or_path"
]
=
get_save_dir
(
model_name
,
finetuning_type
,
checkpoint_path
)
# quantization
if
get
(
"top.quantization_bit"
)
!=
"none"
:
args
[
"quantization_bit"
]
=
int
(
get
(
"top.quantization_bit"
))
args
[
"quantization_method"
]
=
get
(
"top.quantization_method"
)
args
[
"double_quantization"
]
=
not
is_torch_npu_available
()
super
().
__init__
(
args
)
super
().
__init__
(
args
)
yield
ALERTS
[
"info_loaded"
][
lang
]
yield
ALERTS
[
"info_loaded"
][
lang
]
...
@@ -118,28 +158,49 @@ class WebChatModel(ChatModel):
...
@@ -118,28 +158,49 @@ class WebChatModel(ChatModel):
torch_gc
()
torch_gc
()
yield
ALERTS
[
"info_unloaded"
][
lang
]
yield
ALERTS
[
"info_unloaded"
][
lang
]
@
staticmethod
def
append
(
def
append
(
self
,
chatbot
:
List
[
Dict
[
str
,
str
]],
chatbot
:
List
[
List
[
Optional
[
str
]]],
messages
:
List
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
role
:
str
,
role
:
str
,
query
:
str
,
query
:
str
,
)
->
Tuple
[
List
[
List
[
Optional
[
str
]]],
List
[
Dict
[
str
,
str
]],
str
]:
escape_html
:
bool
,
return
chatbot
+
[[
query
,
None
]],
messages
+
[{
"role"
:
role
,
"content"
:
query
}],
""
)
->
Tuple
[
List
[
Dict
[
str
,
str
]],
List
[
Dict
[
str
,
str
]],
str
]:
r
"""
Adds the user input to chatbot.
Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html
Output: infer.chatbot, infer.messages, infer.query
"""
return
(
chatbot
+
[{
"role"
:
"user"
,
"content"
:
_escape_html
(
query
)
if
escape_html
else
query
}],
messages
+
[{
"role"
:
role
,
"content"
:
query
}],
""
,
)
def
stream
(
def
stream
(
self
,
self
,
chatbot
:
List
[
List
[
Optional
[
str
]]],
chatbot
:
List
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
List
[
Dict
[
str
,
str
]],
lang
:
str
,
system
:
str
,
system
:
str
,
tools
:
str
,
tools
:
str
,
image
:
Optional
[
Any
],
image
:
Optional
[
Any
],
video
:
Optional
[
Any
],
video
:
Optional
[
Any
],
audio
:
Optional
[
Any
],
max_new_tokens
:
int
,
max_new_tokens
:
int
,
top_p
:
float
,
top_p
:
float
,
temperature
:
float
,
temperature
:
float
,
)
->
Generator
[
Tuple
[
List
[
List
[
Optional
[
str
]]],
List
[
Dict
[
str
,
str
]]],
None
,
None
]:
skip_special_tokens
:
bool
,
chatbot
[
-
1
][
1
]
=
""
escape_html
:
bool
,
)
->
Generator
[
Tuple
[
List
[
Dict
[
str
,
str
]],
List
[
Dict
[
str
,
str
]]],
None
,
None
]:
r
"""
Generates output text in stream.
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages
"""
chatbot
.
append
({
"role"
:
"assistant"
,
"content"
:
""
})
response
=
""
response
=
""
for
new_text
in
self
.
stream_chat
(
for
new_text
in
self
.
stream_chat
(
messages
,
messages
,
...
@@ -147,9 +208,11 @@ class WebChatModel(ChatModel):
...
@@ -147,9 +208,11 @@ class WebChatModel(ChatModel):
tools
,
tools
,
images
=
[
image
]
if
image
else
None
,
images
=
[
image
]
if
image
else
None
,
videos
=
[
video
]
if
video
else
None
,
videos
=
[
video
]
if
video
else
None
,
audios
=
[
audio
]
if
audio
else
None
,
max_new_tokens
=
max_new_tokens
,
max_new_tokens
=
max_new_tokens
,
top_p
=
top_p
,
top_p
=
top_p
,
temperature
=
temperature
,
temperature
=
temperature
,
skip_special_tokens
=
skip_special_tokens
,
):
):
response
+=
new_text
response
+=
new_text
if
tools
:
if
tools
:
...
@@ -159,12 +222,12 @@ class WebChatModel(ChatModel):
...
@@ -159,12 +222,12 @@ class WebChatModel(ChatModel):
if
isinstance
(
result
,
list
):
if
isinstance
(
result
,
list
):
tool_calls
=
[{
"name"
:
tool
.
name
,
"arguments"
:
json
.
loads
(
tool
.
arguments
)}
for
tool
in
result
]
tool_calls
=
[{
"name"
:
tool
.
name
,
"arguments"
:
json
.
loads
(
tool
.
arguments
)}
for
tool
in
result
]
tool_calls
=
json
.
dumps
(
tool_calls
,
indent
=
4
,
ensure_ascii
=
False
)
tool_calls
=
json
.
dumps
(
tool_calls
,
ensure_ascii
=
False
)
output_messages
=
messages
+
[{
"role"
:
Role
.
FUNCTION
.
value
,
"content"
:
tool_calls
}]
output_messages
=
messages
+
[{
"role"
:
Role
.
FUNCTION
.
value
,
"content"
:
tool_calls
}]
bot_text
=
"```json
\n
"
+
tool_calls
+
"
\n
```"
bot_text
=
"```json
\n
"
+
tool_calls
+
"
\n
```"
else
:
else
:
output_messages
=
messages
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
result
}]
output_messages
=
messages
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
result
}]
bot_text
=
result
bot_text
=
_format_response
(
result
,
lang
,
escape_html
,
self
.
engine
.
template
.
thought_words
)
chatbot
[
-
1
]
[
1
]
=
bot_text
chatbot
[
-
1
]
=
{
"role"
:
"assistant"
,
"content"
:
bot_text
}
yield
chatbot
,
output_messages
yield
chatbot
,
output_messages
src/llamafactory/webui/common.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -14,41 +14,48 @@
...
@@ -14,41 +14,48 @@
import
json
import
json
import
os
import
os
import
signal
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
from
datetime
import
datetime
from
typing
import
Any
,
Dict
,
Optional
,
Union
from
psutil
import
Process
from
yaml
import
safe_dump
,
safe_load
from
yaml
import
safe_dump
,
safe_load
from
..extras
import
logging
from
..extras
import
logging
from
..extras.constants
import
(
from
..extras.constants
import
(
CHECKPOINT_NAMES
,
DATA_CONFIG
,
DATA_CONFIG
,
DEFAULT_TEMPLATE
,
DEFAULT_TEMPLATE
,
PEFT_METHODS
,
MULTIMODAL_SUPPORTED_MODELS
,
STAGES_USE_PAIR_DATA
,
SUPPORTED_MODELS
,
SUPPORTED_MODELS
,
TRAINING_STAGES
,
TRAINING_ARGS
,
VISION_MODELS
,
DownloadSource
,
DownloadSource
,
)
)
from
..extras.misc
import
use_modelscope
,
use_openmind
from
..extras.misc
import
use_modelscope
,
use_openmind
from
..extras.packages
import
is_gradio_available
if
is_gradio_available
():
import
gradio
as
gr
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
DEFAULT_CACHE_DIR
=
"cache"
DEFAULT_CACHE_DIR
=
"cache"
DEFAULT_CONFIG_DIR
=
"config"
DEFAULT_CONFIG_DIR
=
"config"
DEFAULT_DATA_DIR
=
"data"
DEFAULT_DATA_DIR
=
"data"
DEFAULT_SAVE_DIR
=
"saves"
DEFAULT_SAVE_DIR
=
"saves"
USER_CONFIG
=
"user_config.yaml"
USER_CONFIG
=
"user_config.yaml"
QUANTIZATION_BITS
=
[
"8"
,
"6"
,
"5"
,
"4"
,
"3"
,
"2"
,
"1"
]
GPTQ_BITS
=
[
"8"
,
"4"
,
"3"
,
"2"
]
def
abort_process
(
pid
:
int
)
->
None
:
r
"""
Aborts the processes recursively in a bottom-up way.
"""
try
:
children
=
Process
(
pid
).
children
()
if
children
:
for
child
in
children
:
abort_process
(
child
.
pid
)
os
.
kill
(
pid
,
signal
.
SIGABRT
)
except
Exception
:
pass
def
get_save_dir
(
*
paths
:
str
)
->
os
.
PathLike
:
def
get_save_dir
(
*
paths
:
str
)
->
os
.
PathLike
:
...
@@ -63,19 +70,19 @@ def get_save_dir(*paths: str) -> os.PathLike:
...
@@ -63,19 +70,19 @@ def get_save_dir(*paths: str) -> os.PathLike:
return
os
.
path
.
join
(
DEFAULT_SAVE_DIR
,
*
paths
)
return
os
.
path
.
join
(
DEFAULT_SAVE_DIR
,
*
paths
)
def
get_config_path
()
->
os
.
PathLike
:
def
_
get_config_path
()
->
os
.
PathLike
:
r
"""
r
"""
Gets the path to user config.
Gets the path to user config.
"""
"""
return
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
USER_CONFIG
)
return
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
USER_CONFIG
)
def
load_config
()
->
Dict
[
str
,
Any
]:
def
load_config
()
->
Dict
[
str
,
Union
[
str
,
Dict
[
str
,
Any
]
]]
:
r
"""
r
"""
Loads user config if exists.
Loads user config if exists.
"""
"""
try
:
try
:
with
open
(
get_config_path
(),
encoding
=
"utf-8"
)
as
f
:
with
open
(
_
get_config_path
(),
encoding
=
"utf-8"
)
as
f
:
return
safe_load
(
f
)
return
safe_load
(
f
)
except
Exception
:
except
Exception
:
return
{
"lang"
:
None
,
"last_model"
:
None
,
"path_dict"
:
{},
"cache_dir"
:
None
}
return
{
"lang"
:
None
,
"last_model"
:
None
,
"path_dict"
:
{},
"cache_dir"
:
None
}
...
@@ -94,7 +101,7 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
...
@@ -94,7 +101,7 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
if
model_name
and
model_path
:
if
model_name
and
model_path
:
user_config
[
"path_dict"
][
model_name
]
=
model_path
user_config
[
"path_dict"
][
model_name
]
=
model_path
with
open
(
get_config_path
(),
"w"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
_
get_config_path
(),
"w"
,
encoding
=
"utf-8"
)
as
f
:
safe_dump
(
user_config
,
f
)
safe_dump
(
user_config
,
f
)
...
@@ -122,49 +129,25 @@ def get_model_path(model_name: str) -> str:
...
@@ -122,49 +129,25 @@ def get_model_path(model_name: str) -> str:
return
model_path
return
model_path
def
get_model_info
(
model_name
:
str
)
->
Tuple
[
str
,
str
]:
r
"""
Gets the necessary information of this model.
Returns:
model_path (str)
template (str)
"""
return
get_model_path
(
model_name
),
get_template
(
model_name
)
def
get_template
(
model_name
:
str
)
->
str
:
def
get_template
(
model_name
:
str
)
->
str
:
r
"""
r
"""
Gets the template name if the model is a chat model.
Gets the template name if the model is a chat
/distill/instruct
model.
"""
"""
return
DEFAULT_TEMPLATE
.
get
(
model_name
,
"default"
)
return
DEFAULT_TEMPLATE
.
get
(
model_name
,
"default"
)
def
get_
visual
(
model_name
:
str
)
->
bool
:
def
get_
time
(
)
->
str
:
r
"""
r
"""
Judges if the model is a vision language model
.
Gets current date and time
.
"""
"""
return
model_name
in
VISION_MODELS
return
datetime
.
now
().
strftime
(
r
"%Y-%m-%d-%H-%M-%S"
)
def
l
is
t_checkpoints
(
model_name
:
str
,
finetuning_type
:
str
)
->
"gr.Dropdown"
:
def
is
_multimodal
(
model_name
:
str
)
->
bool
:
r
"""
r
"""
Lists all available checkpoints
.
Judges if the model is a vision language model
.
"""
"""
checkpoints
=
[]
return
model_name
in
MULTIMODAL_SUPPORTED_MODELS
if
model_name
:
save_dir
=
get_save_dir
(
model_name
,
finetuning_type
)
if
save_dir
and
os
.
path
.
isdir
(
save_dir
):
for
checkpoint
in
os
.
listdir
(
save_dir
):
if
os
.
path
.
isdir
(
os
.
path
.
join
(
save_dir
,
checkpoint
))
and
any
(
os
.
path
.
isfile
(
os
.
path
.
join
(
save_dir
,
checkpoint
,
name
))
for
name
in
CHECKPOINT_NAMES
):
checkpoints
.
append
(
checkpoint
)
if
finetuning_type
in
PEFT_METHODS
:
return
gr
.
Dropdown
(
value
=
[],
choices
=
checkpoints
,
multiselect
=
True
)
else
:
return
gr
.
Dropdown
(
value
=
None
,
choices
=
checkpoints
,
multiselect
=
False
)
def
load_dataset_info
(
dataset_dir
:
str
)
->
Dict
[
str
,
Dict
[
str
,
Any
]]:
def
load_dataset_info
(
dataset_dir
:
str
)
->
Dict
[
str
,
Dict
[
str
,
Any
]]:
...
@@ -183,11 +166,135 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
...
@@ -183,11 +166,135 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
return
{}
return
{}
def
list_datasets
(
dataset_dir
:
str
=
None
,
training_stage
:
str
=
list
(
TRAINING_STAGES
.
keys
())[
0
])
->
"gr.Dropdown"
:
def
load_args
(
config_path
:
str
)
->
Optional
[
Dict
[
str
,
Any
]]:
r
"""
Loads the training configuration from config path.
"""
try
:
with
open
(
config_path
,
encoding
=
"utf-8"
)
as
f
:
return
safe_load
(
f
)
except
Exception
:
return
None
def
save_args
(
config_path
:
str
,
config_dict
:
Dict
[
str
,
Any
])
->
None
:
r
"""
r
"""
Lists all available datasets in the dataset dir for the training stage
.
Saves the training configuration to config path
.
"""
"""
dataset_info
=
load_dataset_info
(
dataset_dir
if
dataset_dir
is
not
None
else
DEFAULT_DATA_DIR
)
with
open
(
config_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
ranking
=
TRAINING_STAGES
[
training_stage
]
in
STAGES_USE_PAIR_DATA
safe_dump
(
config_dict
,
f
)
datasets
=
[
k
for
k
,
v
in
dataset_info
.
items
()
if
v
.
get
(
"ranking"
,
False
)
==
ranking
]
return
gr
.
Dropdown
(
choices
=
datasets
)
def
_clean_cmd
(
args
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
r
"""
Removes args with NoneType or False or empty string value.
"""
no_skip_keys
=
[
"packing"
]
return
{
k
:
v
for
k
,
v
in
args
.
items
()
if
(
k
in
no_skip_keys
)
or
(
v
is
not
None
and
v
is
not
False
and
v
!=
""
)}
def
gen_cmd
(
args
:
Dict
[
str
,
Any
])
->
str
:
r
"""
Generates CLI commands for previewing.
"""
cmd_lines
=
[
"llamafactory-cli train "
]
for
k
,
v
in
_clean_cmd
(
args
).
items
():
if
isinstance
(
v
,
dict
):
cmd_lines
.
append
(
f
" --
{
k
}
{
json
.
dumps
(
v
,
ensure_ascii
=
False
)
}
"
)
elif
isinstance
(
v
,
list
):
cmd_lines
.
append
(
f
" --
{
k
}
{
' '
.
join
(
map
(
str
,
v
))
}
"
)
else
:
cmd_lines
.
append
(
f
" --
{
k
}
{
str
(
v
)
}
"
)
if
os
.
name
==
"nt"
:
cmd_text
=
"`
\n
"
.
join
(
cmd_lines
)
else
:
cmd_text
=
"
\\\n
"
.
join
(
cmd_lines
)
cmd_text
=
f
"```bash
\n
{
cmd_text
}
\n
```"
return
cmd_text
def
save_cmd
(
args
:
Dict
[
str
,
Any
])
->
str
:
r
"""
Saves CLI commands to launch training.
"""
output_dir
=
args
[
"output_dir"
]
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
with
open
(
os
.
path
.
join
(
output_dir
,
TRAINING_ARGS
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
safe_dump
(
_clean_cmd
(
args
),
f
)
return
os
.
path
.
join
(
output_dir
,
TRAINING_ARGS
)
def
load_eval_results
(
path
:
os
.
PathLike
)
->
str
:
r
"""
Gets scores after evaluation.
"""
with
open
(
path
,
encoding
=
"utf-8"
)
as
f
:
result
=
json
.
dumps
(
json
.
load
(
f
),
indent
=
4
)
return
f
"```json
\n
{
result
}
\n
```
\n
"
def
create_ds_config
()
->
None
:
r
"""
Creates deepspeed config in the current directory.
"""
os
.
makedirs
(
DEFAULT_CACHE_DIR
,
exist_ok
=
True
)
ds_config
=
{
"train_batch_size"
:
"auto"
,
"train_micro_batch_size_per_gpu"
:
"auto"
,
"gradient_accumulation_steps"
:
"auto"
,
"gradient_clipping"
:
"auto"
,
"zero_allow_untested_optimizer"
:
True
,
"fp16"
:
{
"enabled"
:
"auto"
,
"loss_scale"
:
0
,
"loss_scale_window"
:
1000
,
"initial_scale_power"
:
16
,
"hysteresis"
:
2
,
"min_loss_scale"
:
1
,
},
"bf16"
:
{
"enabled"
:
"auto"
},
}
offload_config
=
{
"device"
:
"cpu"
,
"pin_memory"
:
True
,
}
ds_config
[
"zero_optimization"
]
=
{
"stage"
:
2
,
"allgather_partitions"
:
True
,
"allgather_bucket_size"
:
5e8
,
"overlap_comm"
:
True
,
"reduce_scatter"
:
True
,
"reduce_bucket_size"
:
5e8
,
"contiguous_gradients"
:
True
,
"round_robin_gradients"
:
True
,
}
with
open
(
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
"ds_z2_config.json"
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
ds_config
,
f
,
indent
=
2
)
ds_config
[
"zero_optimization"
][
"offload_optimizer"
]
=
offload_config
with
open
(
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
"ds_z2_offload_config.json"
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
ds_config
,
f
,
indent
=
2
)
ds_config
[
"zero_optimization"
]
=
{
"stage"
:
3
,
"overlap_comm"
:
True
,
"contiguous_gradients"
:
True
,
"sub_group_size"
:
1e9
,
"reduce_bucket_size"
:
"auto"
,
"stage3_prefetch_bucket_size"
:
"auto"
,
"stage3_param_persistence_threshold"
:
"auto"
,
"stage3_max_live_parameters"
:
1e9
,
"stage3_max_reuse_distance"
:
1e9
,
"stage3_gather_16bit_weights_on_model_save"
:
True
,
}
with
open
(
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
"ds_z3_config.json"
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
ds_config
,
f
,
indent
=
2
)
ds_config
[
"zero_optimization"
][
"offload_optimizer"
]
=
offload_config
ds_config
[
"zero_optimization"
][
"offload_param"
]
=
offload_config
with
open
(
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
"ds_z3_offload_config.json"
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
ds_config
,
f
,
indent
=
2
)
src/llamafactory/webui/components/__init__.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/webui/components/chatbot.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,11 +12,12 @@
...
@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
json
from
typing
import
TYPE_CHECKING
,
Dict
,
Tuple
from
typing
import
TYPE_CHECKING
,
Dict
,
Tuple
from
...data
import
Role
from
...data
import
Role
from
...extras.packages
import
is_gradio_available
from
...extras.packages
import
is_gradio_available
from
..
util
s
import
check_json_schema
from
..
locale
s
import
ALERTS
if
is_gradio_available
():
if
is_gradio_available
():
...
@@ -29,11 +30,29 @@ if TYPE_CHECKING:
...
@@ -29,11 +30,29 @@ if TYPE_CHECKING:
from
..engine
import
Engine
from
..engine
import
Engine
def
check_json_schema
(
text
:
str
,
lang
:
str
)
->
None
:
r
"""
Checks if the json schema is valid.
"""
try
:
tools
=
json
.
loads
(
text
)
if
tools
:
assert
isinstance
(
tools
,
list
)
for
tool
in
tools
:
if
"name"
not
in
tool
:
raise
NotImplementedError
(
"Name not found."
)
except
NotImplementedError
:
gr
.
Warning
(
ALERTS
[
"err_tool_name"
][
lang
])
except
Exception
:
gr
.
Warning
(
ALERTS
[
"err_json_schema"
][
lang
])
def
create_chat_box
(
def
create_chat_box
(
engine
:
"Engine"
,
visible
:
bool
=
False
engine
:
"Engine"
,
visible
:
bool
=
False
)
->
Tuple
[
"Component"
,
"Component"
,
Dict
[
str
,
"Component"
]]:
)
->
Tuple
[
"Component"
,
"Component"
,
Dict
[
str
,
"Component"
]]:
lang
=
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)
with
gr
.
Column
(
visible
=
visible
)
as
chat_box
:
with
gr
.
Column
(
visible
=
visible
)
as
chat_box
:
chatbot
=
gr
.
Chatbot
(
show_copy_button
=
True
)
chatbot
=
gr
.
Chatbot
(
type
=
"messages"
,
show_copy_button
=
True
)
messages
=
gr
.
State
([])
messages
=
gr
.
State
([])
with
gr
.
Row
():
with
gr
.
Row
():
with
gr
.
Column
(
scale
=
4
):
with
gr
.
Column
(
scale
=
4
):
...
@@ -45,29 +64,48 @@ def create_chat_box(
...
@@ -45,29 +64,48 @@ def create_chat_box(
with
gr
.
Column
()
as
mm_box
:
with
gr
.
Column
()
as
mm_box
:
with
gr
.
Tab
(
"Image"
):
with
gr
.
Tab
(
"Image"
):
image
=
gr
.
Image
(
sources
=
[
"upload"
],
type
=
"pil"
)
image
=
gr
.
Image
(
type
=
"pil"
)
with
gr
.
Tab
(
"Video"
):
with
gr
.
Tab
(
"Video"
):
video
=
gr
.
Video
(
sources
=
[
"upload"
])
video
=
gr
.
Video
()
with
gr
.
Tab
(
"Audio"
):
audio
=
gr
.
Audio
(
type
=
"filepath"
)
query
=
gr
.
Textbox
(
show_label
=
False
,
lines
=
8
)
query
=
gr
.
Textbox
(
show_label
=
False
,
lines
=
8
)
submit_btn
=
gr
.
Button
(
variant
=
"primary"
)
submit_btn
=
gr
.
Button
(
variant
=
"primary"
)
with
gr
.
Column
(
scale
=
1
):
with
gr
.
Column
(
scale
=
1
):
max_new_tokens
=
gr
.
Slider
(
minimum
=
8
,
maximum
=
4096
,
value
=
512
,
step
=
1
)
max_new_tokens
=
gr
.
Slider
(
minimum
=
8
,
maximum
=
8192
,
value
=
1024
,
step
=
1
)
top_p
=
gr
.
Slider
(
minimum
=
0.01
,
maximum
=
1.0
,
value
=
0.7
,
step
=
0.01
)
top_p
=
gr
.
Slider
(
minimum
=
0.01
,
maximum
=
1.0
,
value
=
0.7
,
step
=
0.01
)
temperature
=
gr
.
Slider
(
minimum
=
0.01
,
maximum
=
1.5
,
value
=
0.95
,
step
=
0.01
)
temperature
=
gr
.
Slider
(
minimum
=
0.01
,
maximum
=
1.5
,
value
=
0.95
,
step
=
0.01
)
skip_special_tokens
=
gr
.
Checkbox
(
value
=
True
)
escape_html
=
gr
.
Checkbox
(
value
=
True
)
clear_btn
=
gr
.
Button
()
clear_btn
=
gr
.
Button
()
tools
.
input
(
check_json_schema
,
inputs
=
[
tools
,
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)])
tools
.
input
(
check_json_schema
,
inputs
=
[
tools
,
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)])
submit_btn
.
click
(
submit_btn
.
click
(
engine
.
chatter
.
append
,
engine
.
chatter
.
append
,
[
chatbot
,
messages
,
role
,
query
],
[
chatbot
,
messages
,
role
,
query
,
escape_html
],
[
chatbot
,
messages
,
query
],
[
chatbot
,
messages
,
query
],
).
then
(
).
then
(
engine
.
chatter
.
stream
,
engine
.
chatter
.
stream
,
[
chatbot
,
messages
,
system
,
tools
,
image
,
video
,
max_new_tokens
,
top_p
,
temperature
],
[
chatbot
,
messages
,
lang
,
system
,
tools
,
image
,
video
,
audio
,
max_new_tokens
,
top_p
,
temperature
,
skip_special_tokens
,
escape_html
,
],
[
chatbot
,
messages
],
[
chatbot
,
messages
],
)
)
clear_btn
.
click
(
lambda
:
([],
[]),
outputs
=
[
chatbot
,
messages
])
clear_btn
.
click
(
lambda
:
([],
[]),
outputs
=
[
chatbot
,
messages
])
...
@@ -83,11 +121,14 @@ def create_chat_box(
...
@@ -83,11 +121,14 @@ def create_chat_box(
mm_box
=
mm_box
,
mm_box
=
mm_box
,
image
=
image
,
image
=
image
,
video
=
video
,
video
=
video
,
audio
=
audio
,
query
=
query
,
query
=
query
,
submit_btn
=
submit_btn
,
submit_btn
=
submit_btn
,
max_new_tokens
=
max_new_tokens
,
max_new_tokens
=
max_new_tokens
,
top_p
=
top_p
,
top_p
=
top_p
,
temperature
=
temperature
,
temperature
=
temperature
,
skip_special_tokens
=
skip_special_tokens
,
escape_html
=
escape_html
,
clear_btn
=
clear_btn
,
clear_btn
=
clear_btn
,
),
),
)
)
src/llamafactory/webui/components/data.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -40,6 +40,9 @@ def next_page(page_index: int, total_num: int) -> int:
...
@@ -40,6 +40,9 @@ def next_page(page_index: int, total_num: int) -> int:
def
can_preview
(
dataset_dir
:
str
,
dataset
:
list
)
->
"gr.Button"
:
def
can_preview
(
dataset_dir
:
str
,
dataset
:
list
)
->
"gr.Button"
:
r
"""
Checks if the dataset is a local dataset.
"""
try
:
try
:
with
open
(
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
),
encoding
=
"utf-8"
)
as
f
:
with
open
(
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
),
encoding
=
"utf-8"
)
as
f
:
dataset_info
=
json
.
load
(
f
)
dataset_info
=
json
.
load
(
f
)
...
@@ -67,6 +70,9 @@ def _load_data_file(file_path: str) -> List[Any]:
...
@@ -67,6 +70,9 @@ def _load_data_file(file_path: str) -> List[Any]:
def
get_preview
(
dataset_dir
:
str
,
dataset
:
list
,
page_index
:
int
)
->
Tuple
[
int
,
list
,
"gr.Column"
]:
def
get_preview
(
dataset_dir
:
str
,
dataset
:
list
,
page_index
:
int
)
->
Tuple
[
int
,
list
,
"gr.Column"
]:
r
"""
Gets the preview samples from the dataset.
"""
with
open
(
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
),
encoding
=
"utf-8"
)
as
f
:
with
open
(
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
),
encoding
=
"utf-8"
)
as
f
:
dataset_info
=
json
.
load
(
f
)
dataset_info
=
json
.
load
(
f
)
...
...
src/llamafactory/webui/components/eval.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -15,7 +15,8 @@
...
@@ -15,7 +15,8 @@
from
typing
import
TYPE_CHECKING
,
Dict
from
typing
import
TYPE_CHECKING
,
Dict
from
...extras.packages
import
is_gradio_available
from
...extras.packages
import
is_gradio_available
from
..common
import
DEFAULT_DATA_DIR
,
list_datasets
from
..common
import
DEFAULT_DATA_DIR
from
..control
import
list_datasets
from
.data
import
create_preview_box
from
.data
import
create_preview_box
...
...
src/llamafactory/webui/components/export.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -18,7 +18,7 @@ from ...extras.constants import PEFT_METHODS
...
@@ -18,7 +18,7 @@ from ...extras.constants import PEFT_METHODS
from
...extras.misc
import
torch_gc
from
...extras.misc
import
torch_gc
from
...extras.packages
import
is_gradio_available
from
...extras.packages
import
is_gradio_available
from
...train.tuner
import
export_model
from
...train.tuner
import
export_model
from
..common
import
GPTQ_BITS
,
get_save_dir
from
..common
import
get_save_dir
,
load_config
from
..locales
import
ALERTS
from
..locales
import
ALERTS
...
@@ -32,6 +32,9 @@ if TYPE_CHECKING:
...
@@ -32,6 +32,9 @@ if TYPE_CHECKING:
from
..engine
import
Engine
from
..engine
import
Engine
GPTQ_BITS
=
[
"8"
,
"4"
,
"3"
,
"2"
]
def
can_quantize
(
checkpoint_path
:
Union
[
str
,
List
[
str
]])
->
"gr.Dropdown"
:
def
can_quantize
(
checkpoint_path
:
Union
[
str
,
List
[
str
]])
->
"gr.Dropdown"
:
if
isinstance
(
checkpoint_path
,
list
)
and
len
(
checkpoint_path
)
!=
0
:
if
isinstance
(
checkpoint_path
,
list
)
and
len
(
checkpoint_path
)
!=
0
:
return
gr
.
Dropdown
(
value
=
"none"
,
interactive
=
False
)
return
gr
.
Dropdown
(
value
=
"none"
,
interactive
=
False
)
...
@@ -54,6 +57,7 @@ def save_model(
...
@@ -54,6 +57,7 @@ def save_model(
export_dir
:
str
,
export_dir
:
str
,
export_hub_model_id
:
str
,
export_hub_model_id
:
str
,
)
->
Generator
[
str
,
None
,
None
]:
)
->
Generator
[
str
,
None
,
None
]:
user_config
=
load_config
()
error
=
""
error
=
""
if
not
model_name
:
if
not
model_name
:
error
=
ALERTS
[
"err_no_model"
][
lang
]
error
=
ALERTS
[
"err_no_model"
][
lang
]
...
@@ -75,6 +79,7 @@ def save_model(
...
@@ -75,6 +79,7 @@ def save_model(
args
=
dict
(
args
=
dict
(
model_name_or_path
=
model_path
,
model_name_or_path
=
model_path
,
cache_dir
=
user_config
.
get
(
"cache_dir"
,
None
),
finetuning_type
=
finetuning_type
,
finetuning_type
=
finetuning_type
,
template
=
template
,
template
=
template
,
export_dir
=
export_dir
,
export_dir
=
export_dir
,
...
...
src/llamafactory/webui/components/infer.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
from
typing
import
TYPE_CHECKING
,
Dict
from
typing
import
TYPE_CHECKING
,
Dict
from
...extras.packages
import
is_gradio_available
from
...extras.packages
import
is_gradio_available
from
..common
import
get_visu
al
from
..common
import
is_multimod
al
from
.chatbot
import
create_chat_box
from
.chatbot
import
create_chat_box
...
@@ -66,7 +66,7 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
...
@@ -66,7 +66,7 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
).
then
(
lambda
:
gr
.
Column
(
visible
=
engine
.
chatter
.
loaded
),
outputs
=
[
chat_elems
[
"chat_box"
]])
).
then
(
lambda
:
gr
.
Column
(
visible
=
engine
.
chatter
.
loaded
),
outputs
=
[
chat_elems
[
"chat_box"
]])
engine
.
manager
.
get_elem_by_id
(
"top.model_name"
).
change
(
engine
.
manager
.
get_elem_by_id
(
"top.model_name"
).
change
(
lambda
model_name
:
gr
.
Column
(
visible
=
get_visu
al
(
model_name
)),
lambda
model_name
:
gr
.
Column
(
visible
=
is_multimod
al
(
model_name
)),
[
engine
.
manager
.
get_elem_by_id
(
"top.model_name"
)],
[
engine
.
manager
.
get_elem_by_id
(
"top.model_name"
)],
[
chat_elems
[
"mm_box"
]],
[
chat_elems
[
"mm_box"
]],
)
)
...
...
src/llamafactory/webui/components/top.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict
...
@@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict
from
...data
import
TEMPLATES
from
...data
import
TEMPLATES
from
...extras.constants
import
METHODS
,
SUPPORTED_MODELS
from
...extras.constants
import
METHODS
,
SUPPORTED_MODELS
from
...extras.packages
import
is_gradio_available
from
...extras.packages
import
is_gradio_available
from
..common
import
get_model_info
,
list_checkpoints
,
save_config
from
..common
import
save_config
from
..
utils
import
can_quantize
,
can_quantize_to
from
..
control
import
can_quantize
,
can_quantize_to
,
get_model_info
,
list_checkpoints
if
is_gradio_available
():
if
is_gradio_available
():
...
@@ -30,11 +30,10 @@ if TYPE_CHECKING:
...
@@ -30,11 +30,10 @@ if TYPE_CHECKING:
def
create_top
()
->
Dict
[
str
,
"Component"
]:
def
create_top
()
->
Dict
[
str
,
"Component"
]:
available_models
=
list
(
SUPPORTED_MODELS
.
keys
())
+
[
"Custom"
]
with
gr
.
Row
():
with
gr
.
Row
():
lang
=
gr
.
Dropdown
(
choices
=
[
"en"
,
"ru"
,
"zh"
,
"ko"
],
scale
=
1
)
lang
=
gr
.
Dropdown
(
choices
=
[
"en"
,
"ru"
,
"zh"
,
"ko"
,
"ja"
],
value
=
None
,
scale
=
1
)
model_name
=
gr
.
Dropdown
(
choices
=
available_models
,
scale
=
3
)
available_models
=
list
(
SUPPORTED_MODELS
.
keys
())
+
[
"Custom"
]
model_name
=
gr
.
Dropdown
(
choices
=
available_models
,
value
=
None
,
scale
=
3
)
model_path
=
gr
.
Textbox
(
scale
=
3
)
model_path
=
gr
.
Textbox
(
scale
=
3
)
with
gr
.
Row
():
with
gr
.
Row
():
...
@@ -42,11 +41,11 @@ def create_top() -> Dict[str, "Component"]:
...
@@ -42,11 +41,11 @@ def create_top() -> Dict[str, "Component"]:
checkpoint_path
=
gr
.
Dropdown
(
multiselect
=
True
,
allow_custom_value
=
True
,
scale
=
6
)
checkpoint_path
=
gr
.
Dropdown
(
multiselect
=
True
,
allow_custom_value
=
True
,
scale
=
6
)
with
gr
.
Row
():
with
gr
.
Row
():
quantization_bit
=
gr
.
Dropdown
(
choices
=
[
"none"
,
"8"
,
"4"
],
value
=
"none"
,
allow_custom_value
=
True
,
scale
=
2
)
quantization_bit
=
gr
.
Dropdown
(
choices
=
[
"none"
,
"8"
,
"4"
],
value
=
"none"
,
allow_custom_value
=
True
)
quantization_method
=
gr
.
Dropdown
(
choices
=
[
"bitsandbytes"
,
"hqq"
,
"eetq"
],
value
=
"bitsandbytes"
,
scale
=
2
)
quantization_method
=
gr
.
Dropdown
(
choices
=
[
"bitsandbytes"
,
"hqq"
,
"eetq"
],
value
=
"bitsandbytes"
)
template
=
gr
.
Dropdown
(
choices
=
list
(
TEMPLATES
.
keys
()),
value
=
"default"
,
scale
=
2
)
template
=
gr
.
Dropdown
(
choices
=
list
(
TEMPLATES
.
keys
()),
value
=
"default"
)
rope_scaling
=
gr
.
Radio
(
choices
=
[
"none"
,
"linear"
,
"dynamic"
],
value
=
"none"
,
scale
=
3
)
rope_scaling
=
gr
.
Dropdown
(
choices
=
[
"none"
,
"linear"
,
"dynamic"
,
"yarn"
,
"llama3"
],
value
=
"none"
)
booster
=
gr
.
Radio
(
choices
=
[
"auto"
,
"flashattn2"
,
"unsloth"
,
"liger_kernel"
],
value
=
"auto"
,
scale
=
5
)
booster
=
gr
.
Dropdown
(
choices
=
[
"auto"
,
"flashattn2"
,
"unsloth"
,
"liger_kernel"
],
value
=
"auto"
)
model_name
.
change
(
get_model_info
,
[
model_name
],
[
model_path
,
template
],
queue
=
False
).
then
(
model_name
.
change
(
get_model_info
,
[
model_name
],
[
model_path
,
template
],
queue
=
False
).
then
(
list_checkpoints
,
[
model_name
,
finetuning_type
],
[
checkpoint_path
],
queue
=
False
list_checkpoints
,
[
model_name
,
finetuning_type
],
[
checkpoint_path
],
queue
=
False
...
...
src/llamafactory/webui/components/train.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -19,8 +19,8 @@ from transformers.trainer_utils import SchedulerType
...
@@ -19,8 +19,8 @@ from transformers.trainer_utils import SchedulerType
from
...extras.constants
import
TRAINING_STAGES
from
...extras.constants
import
TRAINING_STAGES
from
...extras.misc
import
get_device_count
from
...extras.misc
import
get_device_count
from
...extras.packages
import
is_gradio_available
from
...extras.packages
import
is_gradio_available
from
..common
import
DEFAULT_DATA_DIR
,
list_checkpoints
,
list_datasets
from
..common
import
DEFAULT_DATA_DIR
from
..
utils
import
change_stage
,
list_c
onfig_path
s
,
list_output_dirs
from
..
control
import
change_stage
,
list_c
heckpoints
,
list_config_paths
,
list_dataset
s
,
list_output_dirs
from
.data
import
create_preview_box
from
.data
import
create_preview_box
...
@@ -39,9 +39,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
...
@@ -39,9 +39,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
elem_dict
=
dict
()
elem_dict
=
dict
()
with
gr
.
Row
():
with
gr
.
Row
():
training_stage
=
gr
.
Dropdown
(
stages
=
list
(
TRAINING_STAGES
.
keys
())
choices
=
list
(
TRAINING_STAGES
.
keys
()),
value
=
list
(
TRAINING_STAGES
.
keys
())[
0
],
scale
=
1
training_stage
=
gr
.
Dropdown
(
choices
=
stages
,
value
=
stages
[
0
],
scale
=
1
)
)
dataset_dir
=
gr
.
Textbox
(
value
=
DEFAULT_DATA_DIR
,
scale
=
1
)
dataset_dir
=
gr
.
Textbox
(
value
=
DEFAULT_DATA_DIR
,
scale
=
1
)
dataset
=
gr
.
Dropdown
(
multiselect
=
True
,
allow_custom_value
=
True
,
scale
=
4
)
dataset
=
gr
.
Dropdown
(
multiselect
=
True
,
allow_custom_value
=
True
,
scale
=
4
)
preview_elems
=
create_preview_box
(
dataset_dir
,
dataset
)
preview_elems
=
create_preview_box
(
dataset_dir
,
dataset
)
...
@@ -107,8 +106,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
...
@@ -107,8 +106,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
use_llama_pro
=
gr
.
Checkbox
()
use_llama_pro
=
gr
.
Checkbox
()
with
gr
.
Column
():
with
gr
.
Column
():
shift_attn
=
gr
.
Checkbox
()
report_to
=
gr
.
Dropdown
(
report_to
=
gr
.
Checkbox
()
choices
=
[
"none"
,
"all"
,
"wandb"
,
"mlflow"
,
"neptune"
,
"tensorboard"
],
value
=
[
"none"
],
allow_custom_value
=
True
,
multiselect
=
True
,
)
input_elems
.
update
(
input_elems
.
update
(
{
{
...
@@ -123,7 +126,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
...
@@ -123,7 +126,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
mask_history
,
mask_history
,
resize_vocab
,
resize_vocab
,
use_llama_pro
,
use_llama_pro
,
shift_attn
,
report_to
,
report_to
,
}
}
)
)
...
@@ -141,7 +143,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
...
@@ -141,7 +143,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
mask_history
=
mask_history
,
mask_history
=
mask_history
,
resize_vocab
=
resize_vocab
,
resize_vocab
=
resize_vocab
,
use_llama_pro
=
use_llama_pro
,
use_llama_pro
=
use_llama_pro
,
shift_attn
=
shift_attn
,
report_to
=
report_to
,
report_to
=
report_to
,
)
)
)
)
...
@@ -298,9 +299,18 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
...
@@ -298,9 +299,18 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
swanlab_workspace
=
gr
.
Textbox
()
swanlab_workspace
=
gr
.
Textbox
()
swanlab_api_key
=
gr
.
Textbox
()
swanlab_api_key
=
gr
.
Textbox
()
swanlab_mode
=
gr
.
Dropdown
(
choices
=
[
"cloud"
,
"local"
],
value
=
"cloud"
)
swanlab_mode
=
gr
.
Dropdown
(
choices
=
[
"cloud"
,
"local"
],
value
=
"cloud"
)
swanlab_link
=
gr
.
Markdown
(
visible
=
False
)
input_elems
.
update
(
input_elems
.
update
(
{
use_swanlab
,
swanlab_project
,
swanlab_run_name
,
swanlab_workspace
,
swanlab_api_key
,
swanlab_mode
}
{
use_swanlab
,
swanlab_project
,
swanlab_run_name
,
swanlab_workspace
,
swanlab_api_key
,
swanlab_mode
,
swanlab_link
,
}
)
)
elem_dict
.
update
(
elem_dict
.
update
(
dict
(
dict
(
...
@@ -311,6 +321,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
...
@@ -311,6 +321,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
swanlab_workspace
=
swanlab_workspace
,
swanlab_workspace
=
swanlab_workspace
,
swanlab_api_key
=
swanlab_api_key
,
swanlab_api_key
=
swanlab_api_key
,
swanlab_mode
=
swanlab_mode
,
swanlab_mode
=
swanlab_mode
,
swanlab_link
=
swanlab_link
,
)
)
)
)
...
@@ -363,7 +374,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
...
@@ -363,7 +374,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
loss_viewer
=
loss_viewer
,
loss_viewer
=
loss_viewer
,
)
)
)
)
output_elems
=
[
output_box
,
progress_bar
,
loss_viewer
]
output_elems
=
[
output_box
,
progress_bar
,
loss_viewer
,
swanlab_link
]
cmd_preview_btn
.
click
(
engine
.
runner
.
preview_train
,
input_elems
,
output_elems
,
concurrency_limit
=
None
)
cmd_preview_btn
.
click
(
engine
.
runner
.
preview_train
,
input_elems
,
output_elems
,
concurrency_limit
=
None
)
start_btn
.
click
(
engine
.
runner
.
run_train
,
input_elems
,
output_elems
)
start_btn
.
click
(
engine
.
runner
.
run_train
,
input_elems
,
output_elems
)
...
...
src/llamafactory/webui/
utils
.py
→
src/llamafactory/webui/
control
.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -14,19 +14,23 @@
...
@@ -14,19 +14,23 @@
import
json
import
json
import
os
import
os
import
signal
from
datetime
import
datetime
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
psutil
from
transformers.trainer_utils
import
get_last_checkpoint
from
transformers.trainer_utils
import
get_last_checkpoint
from
yaml
import
safe_dump
,
safe_load
from
..extras.constants
import
PEFT_METHODS
,
RUNNING_LOG
,
TRAINER_LOG
,
TRAINING_ARGS
,
TRAINING_STAGES
from
..extras.constants
import
(
CHECKPOINT_NAMES
,
PEFT_METHODS
,
RUNNING_LOG
,
STAGES_USE_PAIR_DATA
,
SWANLAB_CONFIG
,
TRAINER_LOG
,
TRAINING_STAGES
,
)
from
..extras.packages
import
is_gradio_available
,
is_matplotlib_available
from
..extras.packages
import
is_gradio_available
,
is_matplotlib_available
from
..extras.ploting
import
gen_loss_plot
from
..extras.ploting
import
gen_loss_plot
from
..model
import
QuantizationMethod
from
..model
import
QuantizationMethod
from
.common
import
DEFAULT_C
ACHE
_DIR
,
DEFAULT_
CONFIG
_DIR
,
get_
save_dir
from
.common
import
DEFAULT_C
ONFIG
_DIR
,
DEFAULT_
DATA
_DIR
,
get_
model_path
,
get_save_dir
,
get_template
,
load_dataset_info
from
.locales
import
ALERTS
from
.locales
import
ALERTS
...
@@ -34,24 +38,12 @@ if is_gradio_available():
...
@@ -34,24 +38,12 @@ if is_gradio_available():
import
gradio
as
gr
import
gradio
as
gr
def
abort_process
(
pid
:
int
)
->
None
:
r
"""
Aborts the processes recursively in a bottom-up way.
"""
try
:
children
=
psutil
.
Process
(
pid
).
children
()
if
children
:
for
child
in
children
:
abort_process
(
child
.
pid
)
os
.
kill
(
pid
,
signal
.
SIGABRT
)
except
Exception
:
pass
def
can_quantize
(
finetuning_type
:
str
)
->
"gr.Dropdown"
:
def
can_quantize
(
finetuning_type
:
str
)
->
"gr.Dropdown"
:
r
"""
r
"""
Judges if the quantization is available in this finetuning type.
Judges if the quantization is available in this finetuning type.
Inputs: top.finetuning_type
Outputs: top.quantization_bit
"""
"""
if
finetuning_type
not
in
PEFT_METHODS
:
if
finetuning_type
not
in
PEFT_METHODS
:
return
gr
.
Dropdown
(
value
=
"none"
,
interactive
=
False
)
return
gr
.
Dropdown
(
value
=
"none"
,
interactive
=
False
)
...
@@ -61,7 +53,10 @@ def can_quantize(finetuning_type: str) -> "gr.Dropdown":
...
@@ -61,7 +53,10 @@ def can_quantize(finetuning_type: str) -> "gr.Dropdown":
def
can_quantize_to
(
quantization_method
:
str
)
->
"gr.Dropdown"
:
def
can_quantize_to
(
quantization_method
:
str
)
->
"gr.Dropdown"
:
r
"""
r
"""
Returns the available quantization bits.
Gets the available quantization bits.
Inputs: top.quantization_method
Outputs: top.quantization_bit
"""
"""
if
quantization_method
==
QuantizationMethod
.
BITS_AND_BYTES
.
value
:
if
quantization_method
==
QuantizationMethod
.
BITS_AND_BYTES
.
value
:
available_bits
=
[
"none"
,
"8"
,
"4"
]
available_bits
=
[
"none"
,
"8"
,
"4"
]
...
@@ -76,93 +71,42 @@ def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
...
@@ -76,93 +71,42 @@ def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
def
change_stage
(
training_stage
:
str
=
list
(
TRAINING_STAGES
.
keys
())[
0
])
->
Tuple
[
List
[
str
],
bool
]:
def
change_stage
(
training_stage
:
str
=
list
(
TRAINING_STAGES
.
keys
())[
0
])
->
Tuple
[
List
[
str
],
bool
]:
r
"""
r
"""
Modifys states after changing the training stage.
Modifys states after changing the training stage.
"""
return
[],
TRAINING_STAGES
[
training_stage
]
==
"pt"
def
check_json_schema
(
text
:
str
,
lang
:
str
)
->
None
:
Inputs: train.training_stage
r
"""
Outputs: train.dataset, train.packing
Checks if the json schema is valid.
"""
try
:
tools
=
json
.
loads
(
text
)
if
tools
:
assert
isinstance
(
tools
,
list
)
for
tool
in
tools
:
if
"name"
not
in
tool
:
raise
NotImplementedError
(
"Name not found."
)
except
NotImplementedError
:
gr
.
Warning
(
ALERTS
[
"err_tool_name"
][
lang
])
except
Exception
:
gr
.
Warning
(
ALERTS
[
"err_json_schema"
][
lang
])
def
clean_cmd
(
args
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
r
"""
Removes args with NoneType or False or empty string value.
"""
"""
no_skip_keys
=
[
"packing"
]
return
[],
TRAINING_STAGES
[
training_stage
]
==
"pt"
return
{
k
:
v
for
k
,
v
in
args
.
items
()
if
(
k
in
no_skip_keys
)
or
(
v
is
not
None
and
v
is
not
False
and
v
!=
""
)}
def
gen_cmd
(
args
:
Dict
[
str
,
Any
])
->
str
:
r
"""
Generates arguments for previewing.
"""
cmd_lines
=
[
"llamafactory-cli train "
]
for
k
,
v
in
clean_cmd
(
args
).
items
():
cmd_lines
.
append
(
f
" --
{
k
}
{
str
(
v
)
}
"
)
if
os
.
name
==
"nt"
:
cmd_text
=
"`
\n
"
.
join
(
cmd_lines
)
else
:
cmd_text
=
"
\\\n
"
.
join
(
cmd_lines
)
cmd_text
=
f
"```bash
\n
{
cmd_text
}
\n
```"
return
cmd_text
def
save_cmd
(
args
:
Dict
[
str
,
Any
])
->
str
:
r
"""
Saves arguments to launch training.
"""
output_dir
=
args
[
"output_dir"
]
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
with
open
(
os
.
path
.
join
(
output_dir
,
TRAINING_ARGS
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
safe_dump
(
clean_cmd
(
args
),
f
)
return
os
.
path
.
join
(
output_dir
,
TRAINING_ARGS
)
def
get_
eval_results
(
path
:
os
.
PathLike
)
->
str
:
def
get_
model_info
(
model_name
:
str
)
->
Tuple
[
str
,
str
]
:
r
"""
r
"""
Gets scores after evaluation.
Gets the necessary information of this model.
"""
with
open
(
path
,
encoding
=
"utf-8"
)
as
f
:
result
=
json
.
dumps
(
json
.
load
(
f
),
indent
=
4
)
return
f
"```json
\n
{
result
}
\n
```
\n
"
def
get_time
()
->
str
:
Inputs: top.model_name
r
"""
Outputs: top.model_path, top.template
Gets current date and time.
"""
"""
return
datetime
.
now
().
strftime
(
r
"%Y-%m-%d-%H-%M-%S"
)
return
get_model_path
(
model_name
),
get_template
(
model_name
)
def
get_trainer_info
(
output_path
:
os
.
PathLike
,
do_train
:
bool
)
->
Tuple
[
str
,
"gr.Slider"
,
Optional
[
"gr.Plot"
]]:
def
get_trainer_info
(
lang
:
str
,
output_path
:
os
.
PathLike
,
do_train
:
bool
)
->
Tuple
[
str
,
"gr.Slider"
,
Dict
[
str
,
Any
]]:
r
"""
r
"""
Gets training infomation for monitor.
Gets training infomation for monitor.
If do_train is True:
Inputs: top.lang, train.output_path
Outputs: train.output_box, train.progress_bar, train.loss_viewer, train.swanlab_link
If do_train is False:
Inputs: top.lang, eval.output_path
Outputs: eval.output_box, eval.progress_bar, None, None
"""
"""
running_log
=
""
running_log
=
""
running_progress
=
gr
.
Slider
(
visible
=
False
)
running_progress
=
gr
.
Slider
(
visible
=
False
)
running_
loss
=
None
running_
info
=
{}
running_log_path
=
os
.
path
.
join
(
output_path
,
RUNNING_LOG
)
running_log_path
=
os
.
path
.
join
(
output_path
,
RUNNING_LOG
)
if
os
.
path
.
isfile
(
running_log_path
):
if
os
.
path
.
isfile
(
running_log_path
):
with
open
(
running_log_path
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
running_log_path
,
encoding
=
"utf-8"
)
as
f
:
running_log
=
f
.
read
()
running_log
=
f
.
read
()
[
-
20000
:]
# avoid lengthy log
trainer_log_path
=
os
.
path
.
join
(
output_path
,
TRAINER_LOG
)
trainer_log_path
=
os
.
path
.
join
(
output_path
,
TRAINER_LOG
)
if
os
.
path
.
isfile
(
trainer_log_path
):
if
os
.
path
.
isfile
(
trainer_log_path
):
...
@@ -183,33 +127,50 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
...
@@ -183,33 +127,50 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
running_progress
=
gr
.
Slider
(
label
=
label
,
value
=
percentage
,
visible
=
True
)
running_progress
=
gr
.
Slider
(
label
=
label
,
value
=
percentage
,
visible
=
True
)
if
do_train
and
is_matplotlib_available
():
if
do_train
and
is_matplotlib_available
():
running_
loss
=
gr
.
Plot
(
gen_loss_plot
(
trainer_log
))
running_
info
[
"loss_viewer"
]
=
gr
.
Plot
(
gen_loss_plot
(
trainer_log
))
return
running_log
,
running_progress
,
running_loss
swanlab_config_path
=
os
.
path
.
join
(
output_path
,
SWANLAB_CONFIG
)
if
os
.
path
.
isfile
(
swanlab_config_path
):
with
open
(
swanlab_config_path
,
encoding
=
"utf-8"
)
as
f
:
swanlab_public_config
=
json
.
load
(
f
)
swanlab_link
=
swanlab_public_config
[
"cloud"
][
"experiment_url"
]
if
swanlab_link
is
not
None
:
running_info
[
"swanlab_link"
]
=
gr
.
Markdown
(
ALERTS
[
"info_swanlab_link"
][
lang
]
+
swanlab_link
,
visible
=
True
)
return
running_log
,
running_progress
,
running_info
def
load_args
(
config_path
:
str
)
->
Optional
[
Dict
[
str
,
Any
]]:
r
"""
Loads saved arguments.
"""
try
:
with
open
(
config_path
,
encoding
=
"utf-8"
)
as
f
:
return
safe_load
(
f
)
except
Exception
:
return
None
def
save_args
(
config_path
:
str
,
config_dict
:
Dict
[
str
,
Any
])
:
def
list_checkpoints
(
model_name
:
str
,
finetuning_type
:
str
)
->
"gr.Dropdown"
:
r
"""
r
"""
Saves arguments.
Lists all available checkpoints.
Inputs: top.model_name, top.finetuning_type
Outputs: top.checkpoint_path
"""
"""
with
open
(
config_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
checkpoints
=
[]
safe_dump
(
config_dict
,
f
)
if
model_name
:
save_dir
=
get_save_dir
(
model_name
,
finetuning_type
)
if
save_dir
and
os
.
path
.
isdir
(
save_dir
):
for
checkpoint
in
os
.
listdir
(
save_dir
):
if
os
.
path
.
isdir
(
os
.
path
.
join
(
save_dir
,
checkpoint
))
and
any
(
os
.
path
.
isfile
(
os
.
path
.
join
(
save_dir
,
checkpoint
,
name
))
for
name
in
CHECKPOINT_NAMES
):
checkpoints
.
append
(
checkpoint
)
if
finetuning_type
in
PEFT_METHODS
:
return
gr
.
Dropdown
(
value
=
[],
choices
=
checkpoints
,
multiselect
=
True
)
else
:
return
gr
.
Dropdown
(
value
=
None
,
choices
=
checkpoints
,
multiselect
=
False
)
def
list_config_paths
(
current_time
:
str
)
->
"gr.Dropdown"
:
def
list_config_paths
(
current_time
:
str
)
->
"gr.Dropdown"
:
r
"""
r
"""
Lists all the saved configuration files.
Lists all the saved configuration files.
Inputs: train.current_time
Outputs: train.config_path
"""
"""
config_files
=
[
f
"
{
current_time
}
.yaml"
]
config_files
=
[
f
"
{
current_time
}
.yaml"
]
if
os
.
path
.
isdir
(
DEFAULT_CONFIG_DIR
):
if
os
.
path
.
isdir
(
DEFAULT_CONFIG_DIR
):
...
@@ -220,9 +181,25 @@ def list_config_paths(current_time: str) -> "gr.Dropdown":
...
@@ -220,9 +181,25 @@ def list_config_paths(current_time: str) -> "gr.Dropdown":
return
gr
.
Dropdown
(
choices
=
config_files
)
return
gr
.
Dropdown
(
choices
=
config_files
)
def
list_datasets
(
dataset_dir
:
str
=
None
,
training_stage
:
str
=
list
(
TRAINING_STAGES
.
keys
())[
0
])
->
"gr.Dropdown"
:
r
"""
Lists all available datasets in the dataset dir for the training stage.
Inputs: *.dataset_dir, *.training_stage
Outputs: *.dataset
"""
dataset_info
=
load_dataset_info
(
dataset_dir
if
dataset_dir
is
not
None
else
DEFAULT_DATA_DIR
)
ranking
=
TRAINING_STAGES
[
training_stage
]
in
STAGES_USE_PAIR_DATA
datasets
=
[
k
for
k
,
v
in
dataset_info
.
items
()
if
v
.
get
(
"ranking"
,
False
)
==
ranking
]
return
gr
.
Dropdown
(
choices
=
datasets
)
def
list_output_dirs
(
model_name
:
Optional
[
str
],
finetuning_type
:
str
,
current_time
:
str
)
->
"gr.Dropdown"
:
def
list_output_dirs
(
model_name
:
Optional
[
str
],
finetuning_type
:
str
,
current_time
:
str
)
->
"gr.Dropdown"
:
r
"""
r
"""
Lists all the directories that can resume from.
Lists all the directories that can resume from.
Inputs: top.model_name, top.finetuning_type, train.current_time
Outputs: train.output_dir
"""
"""
output_dirs
=
[
f
"train_
{
current_time
}
"
]
output_dirs
=
[
f
"train_
{
current_time
}
"
]
if
model_name
:
if
model_name
:
...
@@ -234,66 +211,3 @@ def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_ti
...
@@ -234,66 +211,3 @@ def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_ti
output_dirs
.
append
(
folder
)
output_dirs
.
append
(
folder
)
return
gr
.
Dropdown
(
choices
=
output_dirs
)
return
gr
.
Dropdown
(
choices
=
output_dirs
)
def
create_ds_config
()
->
None
:
r
"""
Creates deepspeed config.
"""
os
.
makedirs
(
DEFAULT_CACHE_DIR
,
exist_ok
=
True
)
ds_config
=
{
"train_batch_size"
:
"auto"
,
"train_micro_batch_size_per_gpu"
:
"auto"
,
"gradient_accumulation_steps"
:
"auto"
,
"gradient_clipping"
:
"auto"
,
"zero_allow_untested_optimizer"
:
True
,
"fp16"
:
{
"enabled"
:
"auto"
,
"loss_scale"
:
0
,
"loss_scale_window"
:
1000
,
"initial_scale_power"
:
16
,
"hysteresis"
:
2
,
"min_loss_scale"
:
1
,
},
"bf16"
:
{
"enabled"
:
"auto"
},
}
offload_config
=
{
"device"
:
"cpu"
,
"pin_memory"
:
True
,
}
ds_config
[
"zero_optimization"
]
=
{
"stage"
:
2
,
"allgather_partitions"
:
True
,
"allgather_bucket_size"
:
5e8
,
"overlap_comm"
:
True
,
"reduce_scatter"
:
True
,
"reduce_bucket_size"
:
5e8
,
"contiguous_gradients"
:
True
,
"round_robin_gradients"
:
True
,
}
with
open
(
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
"ds_z2_config.json"
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
ds_config
,
f
,
indent
=
2
)
ds_config
[
"zero_optimization"
][
"offload_optimizer"
]
=
offload_config
with
open
(
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
"ds_z2_offload_config.json"
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
ds_config
,
f
,
indent
=
2
)
ds_config
[
"zero_optimization"
]
=
{
"stage"
:
3
,
"overlap_comm"
:
True
,
"contiguous_gradients"
:
True
,
"sub_group_size"
:
1e9
,
"reduce_bucket_size"
:
"auto"
,
"stage3_prefetch_bucket_size"
:
"auto"
,
"stage3_param_persistence_threshold"
:
"auto"
,
"stage3_max_live_parameters"
:
1e9
,
"stage3_max_reuse_distance"
:
1e9
,
"stage3_gather_16bit_weights_on_model_save"
:
True
,
}
with
open
(
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
"ds_z3_config.json"
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
ds_config
,
f
,
indent
=
2
)
ds_config
[
"zero_optimization"
][
"offload_optimizer"
]
=
offload_config
ds_config
[
"zero_optimization"
][
"offload_param"
]
=
offload_config
with
open
(
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
"ds_z3_offload_config.json"
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
ds_config
,
f
,
indent
=
2
)
Prev
1
…
7
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment