Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
0722acf1
Commit
0722acf1
authored
Jun 04, 2025
by
chenych
Browse files
Update 0604
parent
c4ba4563
Changes
68
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
228 additions
and
107 deletions
+228
-107
src/llamafactory/model/model_utils/moe.py
src/llamafactory/model/model_utils/moe.py
+25
-23
src/llamafactory/model/model_utils/quantization.py
src/llamafactory/model/model_utils/quantization.py
+6
-5
src/llamafactory/model/model_utils/rope.py
src/llamafactory/model/model_utils/rope.py
+29
-19
src/llamafactory/model/model_utils/visual.py
src/llamafactory/model/model_utils/visual.py
+3
-2
src/llamafactory/model/patcher.py
src/llamafactory/model/patcher.py
+3
-3
src/llamafactory/train/dpo/trainer.py
src/llamafactory/train/dpo/trainer.py
+2
-2
src/llamafactory/train/kto/trainer.py
src/llamafactory/train/kto/trainer.py
+2
-3
src/llamafactory/train/pt/trainer.py
src/llamafactory/train/pt/trainer.py
+2
-2
src/llamafactory/train/pt/workflow.py
src/llamafactory/train/pt/workflow.py
+16
-5
src/llamafactory/train/rm/trainer.py
src/llamafactory/train/rm/trainer.py
+2
-2
src/llamafactory/train/sft/trainer.py
src/llamafactory/train/sft/trainer.py
+2
-2
src/llamafactory/train/trainer_utils.py
src/llamafactory/train/trainer_utils.py
+1
-0
src/llamafactory/webui/chatter.py
src/llamafactory/webui/chatter.py
+48
-32
src/llamafactory/webui/common.py
src/llamafactory/webui/common.py
+16
-1
src/llamafactory/webui/components/chatbot.py
src/llamafactory/webui/components/chatbot.py
+3
-0
src/llamafactory/webui/components/export.py
src/llamafactory/webui/components/export.py
+11
-0
src/llamafactory/webui/components/infer.py
src/llamafactory/webui/components/infer.py
+3
-1
src/llamafactory/webui/components/top.py
src/llamafactory/webui/components/top.py
+2
-2
src/llamafactory/webui/components/train.py
src/llamafactory/webui/components/train.py
+41
-3
src/llamafactory/webui/control.py
src/llamafactory/webui/control.py
+11
-0
No files found.
src/llamafactory/model/model_utils/moe.py
View file @
0722acf1
...
@@ -99,27 +99,29 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
...
@@ -99,27 +99,29 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
def
configure_moe
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
def
configure_moe
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
if
not
is_trainable
or
not
model_args
.
moe_aux_loss_coef
:
return
model_type
=
getattr
(
config
,
"model_type"
,
None
)
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
model_args
.
moe_aux_loss_coef
is
not
None
:
if
model_type
in
[
if
model_type
in
[
"dbrx"
,
"dbrx"
,
"granitemoe"
,
"granitemoe"
,
"jamba"
,
"jamba"
,
"jetmoe"
,
"jetmoe"
,
"llama4"
,
"llama4"
,
"mixtral"
,
"mixtral"
,
"olmoe"
,
"olmoe"
,
"phimoe"
,
"phimoe"
,
"qwen2_moe"
,
"qwen2_moe"
,
"qwen3_moe"
,
"qwen3_moe"
,
]:
]:
setattr
(
config
,
"output_router_logits"
,
True
)
setattr
(
config
,
"output_router_logits"
,
is_trainable
)
if
model_type
in
[
"granitemoe"
,
"jamba"
,
"llama4"
,
"mixtral"
,
"olmoe"
,
"phimoe"
,
"qwen2_moe"
,
"qwen3_moe"
]:
if
model_type
in
[
"granitemoe"
,
"jamba"
,
"llama4"
,
"mixtral"
,
"olmoe"
,
"phimoe"
,
"qwen2_moe"
,
"qwen3_moe"
]:
setattr
(
config
,
"router_aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
setattr
(
config
,
"router_aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
elif
model_type
==
"deepseek"
:
elif
model_type
==
"deepseek"
:
setattr
(
config
,
"aux_loss_alpha"
,
model_args
.
moe_aux_loss_coef
)
setattr
(
config
,
"aux_loss_alpha"
,
model_args
.
moe_aux_loss_coef
)
elif
model_type
==
"jetmoe"
:
elif
model_type
==
"jetmoe"
:
setattr
(
config
,
"aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
setattr
(
config
,
"aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
src/llamafactory/model/model_utils/quantization.py
View file @
0722acf1
...
@@ -97,7 +97,7 @@ def configure_quantization(
...
@@ -97,7 +97,7 @@ def configure_quantization(
quant_method
=
quantization_config
.
get
(
"quant_method"
,
""
)
quant_method
=
quantization_config
.
get
(
"quant_method"
,
""
)
if
quant_method
==
QuantizationMethod
.
GPTQ
:
if
quant_method
==
QuantizationMethod
.
GPTQ
:
check_version
(
"
auto_gptq>=0.5
.0"
,
mandatory
=
True
)
check_version
(
"
gptqmodel>=2.0
.0"
,
mandatory
=
True
)
quantization_config
.
pop
(
"disable_exllama"
,
None
)
# remove deprecated args
quantization_config
.
pop
(
"disable_exllama"
,
None
)
# remove deprecated args
quantization_config
[
"use_exllama"
]
=
False
# disable exllama
quantization_config
[
"use_exllama"
]
=
False
# disable exllama
...
@@ -111,12 +111,12 @@ def configure_quantization(
...
@@ -111,12 +111,12 @@ def configure_quantization(
quant_bits
=
quantization_config
.
get
(
"bits"
,
"?"
)
quant_bits
=
quantization_config
.
get
(
"bits"
,
"?"
)
logger
.
info_rank0
(
f
"Loading
{
quant_bits
}
-bit
{
quant_method
.
upper
()
}
-quantized model."
)
logger
.
info_rank0
(
f
"Loading
{
quant_bits
}
-bit
{
quant_method
.
upper
()
}
-quantized model."
)
elif
model_args
.
export_quantization_bit
is
not
None
:
#
auto-gptq
elif
model_args
.
export_quantization_bit
is
not
None
:
#
gptqmodel
if
model_args
.
export_quantization_bit
not
in
[
8
,
4
,
3
,
2
]:
if
model_args
.
export_quantization_bit
not
in
[
8
,
4
,
3
,
2
]:
raise
ValueError
(
"AutoGPTQ only accepts 2/3/4/8-bit quantization."
)
raise
ValueError
(
"AutoGPTQ only accepts 2/3/4/8-bit quantization."
)
check_version
(
"optimum>=1.
17
.0"
,
mandatory
=
True
)
check_version
(
"optimum>=1.
24
.0"
,
mandatory
=
True
)
check_version
(
"
auto_gptq>=0.5
.0"
,
mandatory
=
True
)
check_version
(
"
gptqmodel>=2.0
.0"
,
mandatory
=
True
)
from
accelerate.utils
import
get_max_memory
from
accelerate.utils
import
get_max_memory
if
getattr
(
config
,
"model_type"
,
None
)
==
"chatglm"
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"chatglm"
:
...
@@ -142,7 +142,8 @@ def configure_quantization(
...
@@ -142,7 +142,8 @@ def configure_quantization(
)
)
init_kwargs
[
"device_map"
]
=
"auto"
init_kwargs
[
"device_map"
]
=
"auto"
init_kwargs
[
"max_memory"
]
=
get_max_memory
()
init_kwargs
[
"max_memory"
]
=
get_max_memory
()
logger
.
info_rank0
(
f
"Quantizing model to
{
model_args
.
export_quantization_bit
}
bit with AutoGPTQ."
)
model_args
.
compute_dtype
=
torch
.
float16
# force fp16 for gptqmodel
logger
.
info_rank0
(
f
"Quantizing model to
{
model_args
.
export_quantization_bit
}
bit with GPTQModel."
)
elif
model_args
.
quantization_bit
is
not
None
:
# on-the-fly
elif
model_args
.
quantization_bit
is
not
None
:
# on-the-fly
if
model_args
.
quantization_method
==
QuantizationMethod
.
BNB
:
if
model_args
.
quantization_method
==
QuantizationMethod
.
BNB
:
...
...
src/llamafactory/model/model_utils/rope.py
View file @
0722acf1
...
@@ -32,7 +32,7 @@ if TYPE_CHECKING:
...
@@ -32,7 +32,7 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
configure_rope
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
def
configure_rope
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
)
->
None
:
if
model_args
.
rope_scaling
is
None
:
if
model_args
.
rope_scaling
is
None
:
return
return
...
@@ -40,30 +40,40 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
...
@@ -40,30 +40,40 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
logger
.
warning_rank0
(
"Current model does not support RoPE scaling."
)
logger
.
warning_rank0
(
"Current model does not support RoPE scaling."
)
return
return
rope_kwargs
=
{
"rope_type"
:
getattr
(
model_args
.
rope_scaling
,
"value"
,
model_args
.
rope_scaling
)}
# handle enum
if
hasattr
(
config
,
"max_position_embeddings"
):
if
model_args
.
model_max_length
is
not
None
:
old_max_length
=
getattr
(
config
,
"max_position_embeddings"
,
None
)
if
is_trainable
and
model_args
.
rope_scaling
==
RopeScaling
.
DYNAMIC
:
else
:
logger
.
warning_rank0
(
"Cannot find the max position embeddings in the config."
)
return
if
model_args
.
model_max_length
is
not
None
:
# training
if
model_args
.
model_max_length
<=
old_max_length
:
logger
.
warning_rank0
(
"Input length is smaller than max length. Disabling rope scaling."
)
return
if
model_args
.
rope_scaling
==
RopeScaling
.
DYNAMIC
:
logger
.
warning_rank0
(
logger
.
warning_rank0
(
"Dynamic NTK scaling may not work well with fine-tuning. "
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
"See: https://github.com/huggingface/transformers/pull/24653"
)
)
current_max_length
=
getattr
(
config
,
"max_position_embeddings"
,
None
)
rope_factor
=
float
(
math
.
ceil
(
model_args
.
model_max_length
/
old_max_length
))
if
(
not
current_max_length
)
or
model_args
.
model_max_length
<=
current_max_length
:
else
:
# inference
logger
.
warning_rank0
(
"Input length is smaller than max length. Disabling rope scaling."
)
rope_factor
=
2.0
return
logger
.
info_rank0
(
f
"Enlarge max model length from
{
current_max_length
}
to
{
model_args
.
model_max_length
}
."
)
rope_kwargs
=
{
setattr
(
config
,
"max_position_embeddings"
,
model_args
.
model_max_length
)
"rope_type"
:
getattr
(
model_args
.
rope_scaling
,
"value"
,
model_args
.
rope_scaling
),
# handle enum
rope_kwargs
[
"factor"
]
=
float
(
math
.
ceil
(
model_args
.
model_max_length
/
current_max_length
))
"factor"
:
rope_factor
,
if
model_args
.
rope_scaling
==
RopeScaling
.
DYNAMIC
:
}
rope_kwargs
[
"original_max_position_embeddings"
]
=
current_max_length
setattr
(
config
,
"max_position_embeddings"
,
old_max_length
*
rope_factor
)
elif
model_args
.
rope_scaling
==
RopeScaling
.
LLAMA3
:
logger
.
info_rank0
(
f
"Enlarge max model length from
{
old_max_length
}
to
{
old_max_length
*
rope_factor
}
."
)
rope_kwargs
[
"original_max_position_embeddings"
]
=
current_max_length
rope_kwargs
[
"low_freq_factor"
]
=
1.0
if
model_args
.
rope_scaling
in
[
RopeScaling
.
DYNAMIC
,
RopeScaling
.
YARN
]:
rope_kwargs
[
"high_freq_factor"
]
=
4.0
rope_kwargs
[
"original_max_position_embeddings"
]
=
old_max_length
else
:
elif
model_args
.
rope_scaling
==
RopeScaling
.
LLAMA3
:
rope_kwargs
[
"factor"
]
=
2.0
rope_kwargs
[
"original_max_position_embeddings"
]
=
old_max_length
rope_kwargs
[
"low_freq_factor"
]
=
1.0
rope_kwargs
[
"high_freq_factor"
]
=
4.0
setattr
(
config
,
"rope_scaling"
,
rope_kwargs
)
setattr
(
config
,
"rope_scaling"
,
rope_kwargs
)
logger
.
info_rank0
(
logger
.
info_rank0
(
...
...
src/llamafactory/model/model_utils/visual.py
View file @
0722acf1
...
@@ -24,6 +24,7 @@ import transformers.models
...
@@ -24,6 +24,7 @@ import transformers.models
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
...extras
import
logging
from
...extras
import
logging
from
...extras.packages
import
is_transformers_version_greater_than
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -281,7 +282,7 @@ _register_composite_model(
...
@@ -281,7 +282,7 @@ _register_composite_model(
model_type
=
"qwen2_vl"
,
model_type
=
"qwen2_vl"
,
projector_key
=
"visual.merger"
,
projector_key
=
"visual.merger"
,
vision_model_keys
=
[
"visual.patch_embed"
,
"visual.blocks"
],
vision_model_keys
=
[
"visual.patch_embed"
,
"visual.blocks"
],
language_model_keys
=
[
"model"
,
"lm_head"
],
language_model_keys
=
[
"language_model"
]
if
is_transformers_version_greater_than
(
"4.52.0"
)
else
[
"model"
,
"lm_head"
],
lora_conflict_keys
=
[
"patch_embed"
],
lora_conflict_keys
=
[
"patch_embed"
],
)
)
...
@@ -290,6 +291,6 @@ _register_composite_model(
...
@@ -290,6 +291,6 @@ _register_composite_model(
model_type
=
"qwen2_5_vl"
,
model_type
=
"qwen2_5_vl"
,
projector_key
=
"visual.merger"
,
projector_key
=
"visual.merger"
,
vision_model_keys
=
[
"visual.patch_embed"
,
"visual.blocks"
],
vision_model_keys
=
[
"visual.patch_embed"
,
"visual.blocks"
],
language_model_keys
=
[
"model"
,
"lm_head"
],
language_model_keys
=
[
"language_model"
]
if
is_transformers_version_greater_than
(
"4.52.0"
)
else
[
"model"
,
"lm_head"
],
lora_conflict_keys
=
[
"patch_embed"
],
lora_conflict_keys
=
[
"patch_embed"
],
)
)
src/llamafactory/model/patcher.py
View file @
0722acf1
...
@@ -85,8 +85,8 @@ def patch_processor(
...
@@ -85,8 +85,8 @@ def patch_processor(
setattr
(
processor
,
"video_min_pixels"
,
model_args
.
video_min_pixels
)
setattr
(
processor
,
"video_min_pixels"
,
model_args
.
video_min_pixels
)
setattr
(
processor
,
"video_fps"
,
model_args
.
video_fps
)
setattr
(
processor
,
"video_fps"
,
model_args
.
video_fps
)
setattr
(
processor
,
"video_maxlen"
,
model_args
.
video_maxlen
)
setattr
(
processor
,
"video_maxlen"
,
model_args
.
video_maxlen
)
setattr
(
processor
,
"audio_sampling_rate"
,
model_args
.
audio_sampling_rate
)
setattr
(
processor
,
"use_audio_in_video"
,
model_args
.
use_audio_in_video
)
setattr
(
processor
,
"use_audio_in_video"
,
model_args
.
use_audio_in_video
)
setattr
(
processor
,
"audio_sampling_rate"
,
model_args
.
audio_sampling_rate
)
def
patch_config
(
def
patch_config
(
...
@@ -102,8 +102,8 @@ def patch_config(
...
@@ -102,8 +102,8 @@ def patch_config(
else
:
else
:
model_args
.
compute_dtype
=
infer_optim_dtype
(
model_dtype
=
getattr
(
config
,
"torch_dtype"
,
None
))
model_args
.
compute_dtype
=
infer_optim_dtype
(
model_dtype
=
getattr
(
config
,
"torch_dtype"
,
None
))
configure_attn_implementation
(
config
,
model_args
,
is_trainable
)
configure_attn_implementation
(
config
,
model_args
)
configure_rope
(
config
,
model_args
,
is_trainable
)
configure_rope
(
config
,
model_args
)
configure_longlora
(
config
,
model_args
,
is_trainable
)
configure_longlora
(
config
,
model_args
,
is_trainable
)
configure_quantization
(
config
,
tokenizer
,
model_args
,
init_kwargs
)
configure_quantization
(
config
,
tokenizer
,
model_args
,
init_kwargs
)
configure_moe
(
config
,
model_args
,
is_trainable
)
configure_moe
(
config
,
model_args
,
is_trainable
)
...
...
src/llamafactory/train/dpo/trainer.py
View file @
0722acf1
...
@@ -121,11 +121,11 @@ class CustomDPOTrainer(DPOTrainer):
...
@@ -121,11 +121,11 @@ class CustomDPOTrainer(DPOTrainer):
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
@
override
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
def
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
()
return
super
().
_get_train_sampler
(
*
args
,
**
kwargs
)
@
override
@
override
def
get_batch_samples
(
self
,
*
args
,
**
kwargs
):
def
get_batch_samples
(
self
,
*
args
,
**
kwargs
):
...
...
src/llamafactory/train/kto/trainer.py
View file @
0722acf1
...
@@ -34,7 +34,6 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, ge
...
@@ -34,7 +34,6 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, ge
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
import
torch.utils.data
from
transformers
import
PreTrainedModel
,
ProcessorMixin
from
transformers
import
PreTrainedModel
,
ProcessorMixin
from
...hparams
import
FinetuningArguments
from
...hparams
import
FinetuningArguments
...
@@ -119,12 +118,12 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -119,12 +118,12 @@ class CustomKTOTrainer(KTOTrainer):
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
@
override
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
def
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
->
Optional
[
"torch.utils.data.Sampler"
]:
r
"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
r
"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
if
self
.
finetuning_args
.
disable_shuffling
:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
Trainer
.
_get_train_sampler
(
self
)
return
Trainer
.
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
@
override
@
override
def
get_batch_samples
(
self
,
*
args
,
**
kwargs
):
def
get_batch_samples
(
self
,
*
args
,
**
kwargs
):
...
...
src/llamafactory/train/pt/trainer.py
View file @
0722acf1
...
@@ -70,11 +70,11 @@ class CustomTrainer(Trainer):
...
@@ -70,11 +70,11 @@ class CustomTrainer(Trainer):
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
@
override
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
def
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
()
return
super
().
_get_train_sampler
(
*
args
,
**
kwargs
)
@
override
@
override
def
compute_loss
(
self
,
model
,
inputs
,
*
args
,
**
kwargs
):
def
compute_loss
(
self
,
model
,
inputs
,
*
args
,
**
kwargs
):
...
...
src/llamafactory/train/pt/workflow.py
View file @
0722acf1
...
@@ -77,12 +77,23 @@ def run_pt(
...
@@ -77,12 +77,23 @@ def run_pt(
# Evaluation
# Evaluation
if
training_args
.
do_eval
:
if
training_args
.
do_eval
:
metrics
=
trainer
.
evaluate
(
metric_key_prefix
=
"eval"
)
metrics
=
trainer
.
evaluate
(
metric_key_prefix
=
"eval"
)
try
:
perplexity
=
math
.
exp
(
metrics
[
"eval_loss"
])
except
OverflowError
:
perplexity
=
float
(
"inf"
)
metrics
[
"perplexity"
]
=
perplexity
if
isinstance
(
dataset_module
.
get
(
"eval_dataset"
),
dict
):
for
key
in
dataset_module
[
"eval_dataset"
].
keys
():
try
:
perplexity
=
math
.
exp
(
metrics
[
f
"eval_
{
key
}
_loss"
])
except
OverflowError
:
perplexity
=
float
(
"inf"
)
metrics
[
f
"eval_
{
key
}
_perplexity"
]
=
perplexity
else
:
try
:
perplexity
=
math
.
exp
(
metrics
[
"eval_loss"
])
except
OverflowError
:
perplexity
=
float
(
"inf"
)
metrics
[
"eval_perplexity"
]
=
perplexity
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
...
...
src/llamafactory/train/rm/trainer.py
View file @
0722acf1
...
@@ -78,11 +78,11 @@ class PairwiseTrainer(Trainer):
...
@@ -78,11 +78,11 @@ class PairwiseTrainer(Trainer):
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
@
override
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
def
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
()
return
super
().
_get_train_sampler
(
*
args
,
**
kwargs
)
@
override
@
override
def
compute_loss
(
def
compute_loss
(
...
...
src/llamafactory/train/sft/trainer.py
View file @
0722acf1
...
@@ -92,11 +92,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
...
@@ -92,11 +92,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
@
override
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
def
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
()
return
super
().
_get_train_sampler
(
*
args
,
**
kwargs
)
@
override
@
override
def
compute_loss
(
self
,
model
,
inputs
,
*
args
,
**
kwargs
):
def
compute_loss
(
self
,
model
,
inputs
,
*
args
,
**
kwargs
):
...
...
src/llamafactory/train/trainer_utils.py
View file @
0722acf1
...
@@ -665,6 +665,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
...
@@ -665,6 +665,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
mode
=
finetuning_args
.
swanlab_mode
,
mode
=
finetuning_args
.
swanlab_mode
,
config
=
{
"Framework"
:
"🦙LlamaFactory"
},
config
=
{
"Framework"
:
"🦙LlamaFactory"
},
logdir
=
finetuning_args
.
swanlab_logdir
,
logdir
=
finetuning_args
.
swanlab_logdir
,
tags
=
[
"🦙LlamaFactory"
],
)
)
return
swanlab_callback
return
swanlab_callback
...
...
src/llamafactory/webui/chatter.py
View file @
0722acf1
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
import
json
import
json
import
os
import
os
from
collections.abc
import
Generator
from
collections.abc
import
Generator
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
transformers.utils
import
is_torch_npu_available
from
transformers.utils
import
is_torch_npu_available
...
@@ -68,6 +69,14 @@ def _format_response(text: str, lang: str, escape_html: bool, thought_words: tup
...
@@ -68,6 +69,14 @@ def _format_response(text: str, lang: str, escape_html: bool, thought_words: tup
)
)
@
contextmanager
def
update_attr
(
obj
:
Any
,
name
:
str
,
value
:
Any
):
old_value
=
getattr
(
obj
,
name
,
None
)
setattr
(
obj
,
name
,
value
)
yield
setattr
(
obj
,
name
,
old_value
)
class
WebChatModel
(
ChatModel
):
class
WebChatModel
(
ChatModel
):
def
__init__
(
self
,
manager
:
"Manager"
,
demo_mode
:
bool
=
False
,
lazy_init
:
bool
=
True
)
->
None
:
def
__init__
(
self
,
manager
:
"Manager"
,
demo_mode
:
bool
=
False
,
lazy_init
:
bool
=
True
)
->
None
:
self
.
manager
=
manager
self
.
manager
=
manager
...
@@ -105,6 +114,11 @@ class WebChatModel(ChatModel):
...
@@ -105,6 +114,11 @@ class WebChatModel(ChatModel):
elif
self
.
demo_mode
:
elif
self
.
demo_mode
:
error
=
ALERTS
[
"err_demo"
][
lang
]
error
=
ALERTS
[
"err_demo"
][
lang
]
try
:
json
.
loads
(
get
(
"infer.extra_args"
))
except
json
.
JSONDecodeError
:
error
=
ALERTS
[
"err_json_schema"
][
lang
]
if
error
:
if
error
:
gr
.
Warning
(
error
)
gr
.
Warning
(
error
)
yield
error
yield
error
...
@@ -122,9 +136,9 @@ class WebChatModel(ChatModel):
...
@@ -122,9 +136,9 @@ class WebChatModel(ChatModel):
enable_liger_kernel
=
(
get
(
"top.booster"
)
==
"liger_kernel"
),
enable_liger_kernel
=
(
get
(
"top.booster"
)
==
"liger_kernel"
),
infer_backend
=
get
(
"infer.infer_backend"
),
infer_backend
=
get
(
"infer.infer_backend"
),
infer_dtype
=
get
(
"infer.infer_dtype"
),
infer_dtype
=
get
(
"infer.infer_dtype"
),
vllm_enforce_eager
=
True
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
)
)
args
.
update
(
json
.
loads
(
get
(
"infer.extra_args"
)))
# checkpoints
# checkpoints
if
checkpoint_path
:
if
checkpoint_path
:
...
@@ -191,40 +205,42 @@ class WebChatModel(ChatModel):
...
@@ -191,40 +205,42 @@ class WebChatModel(ChatModel):
temperature
:
float
,
temperature
:
float
,
skip_special_tokens
:
bool
,
skip_special_tokens
:
bool
,
escape_html
:
bool
,
escape_html
:
bool
,
enable_thinking
:
bool
,
)
->
Generator
[
tuple
[
list
[
dict
[
str
,
str
]],
list
[
dict
[
str
,
str
]]],
None
,
None
]:
)
->
Generator
[
tuple
[
list
[
dict
[
str
,
str
]],
list
[
dict
[
str
,
str
]]],
None
,
None
]:
r
"""Generate output text in stream.
r
"""Generate output text in stream.
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages
Output: infer.chatbot, infer.messages
"""
"""
chatbot
.
append
({
"role"
:
"assistant"
,
"content"
:
""
})
with
update_attr
(
self
.
engine
.
template
,
"enable_thinking"
,
enable_thinking
):
response
=
""
chatbot
.
append
({
"role"
:
"assistant"
,
"content"
:
""
})
for
new_text
in
self
.
stream_chat
(
response
=
""
messages
,
for
new_text
in
self
.
stream_chat
(
system
,
messages
,
tools
,
system
,
images
=
[
image
]
if
image
else
None
,
tools
,
videos
=
[
video
]
if
video
else
None
,
images
=
[
image
]
if
image
else
None
,
audios
=
[
audio
]
if
audio
else
None
,
videos
=
[
video
]
if
video
else
None
,
max_new_tokens
=
max_new_tokens
,
audios
=
[
audio
]
if
audio
else
None
,
top_p
=
top_p
,
max_new_tokens
=
max_new_tokens
,
temperature
=
temperature
,
top_p
=
top_p
,
skip_special_tokens
=
skip_special_tokens
,
temperature
=
temperature
,
):
skip_special_tokens
=
skip_special_tokens
,
response
+=
new_text
):
if
tools
:
response
+=
new_text
result
=
self
.
engine
.
template
.
extract_tool
(
response
)
if
tools
:
else
:
result
=
self
.
engine
.
template
.
extract_tool
(
response
)
result
=
response
else
:
result
=
response
if
isinstance
(
result
,
list
):
tool_calls
=
[{
"name"
:
tool
.
name
,
"arguments"
:
json
.
loads
(
tool
.
arguments
)}
for
tool
in
result
]
if
isinstance
(
result
,
list
):
tool_calls
=
json
.
dumps
(
tool_calls
,
ensure_ascii
=
False
)
tool_calls
=
[{
"name"
:
tool
.
name
,
"arguments"
:
json
.
loads
(
tool
.
arguments
)}
for
tool
in
result
]
output_messages
=
messages
+
[{
"role"
:
Role
.
FUNCTION
.
value
,
"content"
:
tool_calls
}]
tool_calls
=
json
.
dumps
(
tool_calls
,
ensure_ascii
=
False
)
bot_text
=
"```json
\n
"
+
tool_calls
+
"
\n
```"
output_messages
=
messages
+
[{
"role"
:
Role
.
FUNCTION
.
value
,
"content"
:
tool_calls
}]
else
:
bot_text
=
"```json
\n
"
+
tool_calls
+
"
\n
```"
output_messages
=
messages
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
result
}]
else
:
bot_text
=
_format_response
(
result
,
lang
,
escape_html
,
self
.
engine
.
template
.
thought_words
)
output_messages
=
messages
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
result
}]
bot_text
=
_format_response
(
result
,
lang
,
escape_html
,
self
.
engine
.
template
.
thought_words
)
chatbot
[
-
1
]
=
{
"role"
:
"assistant"
,
"content"
:
bot_text
}
yield
chatbot
,
output_messages
chatbot
[
-
1
]
=
{
"role"
:
"assistant"
,
"content"
:
bot_text
}
yield
chatbot
,
output_messages
src/llamafactory/webui/common.py
View file @
0722acf1
...
@@ -163,7 +163,14 @@ def save_args(config_path: str, config_dict: dict[str, Any]) -> None:
...
@@ -163,7 +163,14 @@ def save_args(config_path: str, config_dict: dict[str, Any]) -> None:
def
_clean_cmd
(
args
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
def
_clean_cmd
(
args
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
r
"""Remove args with NoneType or False or empty string value."""
r
"""Remove args with NoneType or False or empty string value."""
no_skip_keys
=
[
"packing"
]
no_skip_keys
=
[
"packing"
,
"enable_thinking"
,
"use_reentrant_gc"
,
"double_quantization"
,
"freeze_vision_tower"
,
"freeze_multi_modal_projector"
,
]
return
{
k
:
v
for
k
,
v
in
args
.
items
()
if
(
k
in
no_skip_keys
)
or
(
v
is
not
None
and
v
is
not
False
and
v
!=
""
)}
return
{
k
:
v
for
k
,
v
in
args
.
items
()
if
(
k
in
no_skip_keys
)
or
(
v
is
not
None
and
v
is
not
False
and
v
!=
""
)}
...
@@ -205,6 +212,14 @@ def load_eval_results(path: os.PathLike) -> str:
...
@@ -205,6 +212,14 @@ def load_eval_results(path: os.PathLike) -> str:
return
f
"```json
\n
{
result
}
\n
```
\n
"
return
f
"```json
\n
{
result
}
\n
```
\n
"
def
calculate_pixels
(
pixels
:
str
)
->
int
:
r
"""Calculate the number of pixels from the expression."""
if
"*"
in
pixels
:
return
int
(
pixels
.
split
(
"*"
)[
0
])
*
int
(
pixels
.
split
(
"*"
)[
1
])
else
:
return
int
(
pixels
)
def
create_ds_config
()
->
None
:
def
create_ds_config
()
->
None
:
r
"""Create deepspeed config in the current directory."""
r
"""Create deepspeed config in the current directory."""
os
.
makedirs
(
DEFAULT_CACHE_DIR
,
exist_ok
=
True
)
os
.
makedirs
(
DEFAULT_CACHE_DIR
,
exist_ok
=
True
)
...
...
src/llamafactory/webui/components/chatbot.py
View file @
0722acf1
...
@@ -79,6 +79,7 @@ def create_chat_box(
...
@@ -79,6 +79,7 @@ def create_chat_box(
temperature
=
gr
.
Slider
(
minimum
=
0.01
,
maximum
=
1.5
,
value
=
0.95
,
step
=
0.01
)
temperature
=
gr
.
Slider
(
minimum
=
0.01
,
maximum
=
1.5
,
value
=
0.95
,
step
=
0.01
)
skip_special_tokens
=
gr
.
Checkbox
(
value
=
True
)
skip_special_tokens
=
gr
.
Checkbox
(
value
=
True
)
escape_html
=
gr
.
Checkbox
(
value
=
True
)
escape_html
=
gr
.
Checkbox
(
value
=
True
)
enable_thinking
=
gr
.
Checkbox
(
value
=
True
)
clear_btn
=
gr
.
Button
()
clear_btn
=
gr
.
Button
()
tools
.
input
(
check_json_schema
,
inputs
=
[
tools
,
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)])
tools
.
input
(
check_json_schema
,
inputs
=
[
tools
,
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)])
...
@@ -103,6 +104,7 @@ def create_chat_box(
...
@@ -103,6 +104,7 @@ def create_chat_box(
temperature
,
temperature
,
skip_special_tokens
,
skip_special_tokens
,
escape_html
,
escape_html
,
enable_thinking
,
],
],
[
chatbot
,
messages
],
[
chatbot
,
messages
],
)
)
...
@@ -127,6 +129,7 @@ def create_chat_box(
...
@@ -127,6 +129,7 @@ def create_chat_box(
temperature
=
temperature
,
temperature
=
temperature
,
skip_special_tokens
=
skip_special_tokens
,
skip_special_tokens
=
skip_special_tokens
,
escape_html
=
escape_html
,
escape_html
=
escape_html
,
enable_thinking
=
enable_thinking
,
clear_btn
=
clear_btn
,
clear_btn
=
clear_btn
,
),
),
)
)
src/llamafactory/webui/components/export.py
View file @
0722acf1
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
json
from
collections.abc
import
Generator
from
collections.abc
import
Generator
from
typing
import
TYPE_CHECKING
,
Union
from
typing
import
TYPE_CHECKING
,
Union
...
@@ -57,6 +58,7 @@ def save_model(
...
@@ -57,6 +58,7 @@ def save_model(
export_legacy_format
:
bool
,
export_legacy_format
:
bool
,
export_dir
:
str
,
export_dir
:
str
,
export_hub_model_id
:
str
,
export_hub_model_id
:
str
,
extra_args
:
str
,
)
->
Generator
[
str
,
None
,
None
]:
)
->
Generator
[
str
,
None
,
None
]:
user_config
=
load_config
()
user_config
=
load_config
()
error
=
""
error
=
""
...
@@ -73,6 +75,11 @@ def save_model(
...
@@ -73,6 +75,11 @@ def save_model(
elif
export_quantization_bit
in
GPTQ_BITS
and
checkpoint_path
and
isinstance
(
checkpoint_path
,
list
):
elif
export_quantization_bit
in
GPTQ_BITS
and
checkpoint_path
and
isinstance
(
checkpoint_path
,
list
):
error
=
ALERTS
[
"err_gptq_lora"
][
lang
]
error
=
ALERTS
[
"err_gptq_lora"
][
lang
]
try
:
json
.
loads
(
extra_args
)
except
json
.
JSONDecodeError
:
error
=
ALERTS
[
"err_json_schema"
][
lang
]
if
error
:
if
error
:
gr
.
Warning
(
error
)
gr
.
Warning
(
error
)
yield
error
yield
error
...
@@ -92,6 +99,7 @@ def save_model(
...
@@ -92,6 +99,7 @@ def save_model(
export_legacy_format
=
export_legacy_format
,
export_legacy_format
=
export_legacy_format
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
)
)
args
.
update
(
json
.
loads
(
extra_args
))
if
checkpoint_path
:
if
checkpoint_path
:
if
finetuning_type
in
PEFT_METHODS
:
# list
if
finetuning_type
in
PEFT_METHODS
:
# list
...
@@ -118,6 +126,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
...
@@ -118,6 +126,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
with
gr
.
Row
():
with
gr
.
Row
():
export_dir
=
gr
.
Textbox
()
export_dir
=
gr
.
Textbox
()
export_hub_model_id
=
gr
.
Textbox
()
export_hub_model_id
=
gr
.
Textbox
()
extra_args
=
gr
.
Textbox
(
value
=
"{}"
)
checkpoint_path
:
gr
.
Dropdown
=
engine
.
manager
.
get_elem_by_id
(
"top.checkpoint_path"
)
checkpoint_path
:
gr
.
Dropdown
=
engine
.
manager
.
get_elem_by_id
(
"top.checkpoint_path"
)
checkpoint_path
.
change
(
can_quantize
,
[
checkpoint_path
],
[
export_quantization_bit
],
queue
=
False
)
checkpoint_path
.
change
(
can_quantize
,
[
checkpoint_path
],
[
export_quantization_bit
],
queue
=
False
)
...
@@ -141,6 +150,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
...
@@ -141,6 +150,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
export_legacy_format
,
export_legacy_format
,
export_dir
,
export_dir
,
export_hub_model_id
,
export_hub_model_id
,
extra_args
,
],
],
[
info_box
],
[
info_box
],
)
)
...
@@ -153,6 +163,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
...
@@ -153,6 +163,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
export_legacy_format
=
export_legacy_format
,
export_legacy_format
=
export_legacy_format
,
export_dir
=
export_dir
,
export_dir
=
export_dir
,
export_hub_model_id
=
export_hub_model_id
,
export_hub_model_id
=
export_hub_model_id
,
extra_args
=
extra_args
,
export_btn
=
export_btn
,
export_btn
=
export_btn
,
info_box
=
info_box
,
info_box
=
info_box
,
)
)
src/llamafactory/webui/components/infer.py
View file @
0722acf1
...
@@ -36,6 +36,7 @@ def create_infer_tab(engine: "Engine") -> dict[str, "Component"]:
...
@@ -36,6 +36,7 @@ def create_infer_tab(engine: "Engine") -> dict[str, "Component"]:
with
gr
.
Row
():
with
gr
.
Row
():
infer_backend
=
gr
.
Dropdown
(
choices
=
[
"huggingface"
,
"vllm"
,
"sglang"
],
value
=
"huggingface"
)
infer_backend
=
gr
.
Dropdown
(
choices
=
[
"huggingface"
,
"vllm"
,
"sglang"
],
value
=
"huggingface"
)
infer_dtype
=
gr
.
Dropdown
(
choices
=
[
"auto"
,
"float16"
,
"bfloat16"
,
"float32"
],
value
=
"auto"
)
infer_dtype
=
gr
.
Dropdown
(
choices
=
[
"auto"
,
"float16"
,
"bfloat16"
,
"float32"
],
value
=
"auto"
)
extra_args
=
gr
.
Textbox
(
value
=
'{"vllm_enforce_eager": true}'
)
with
gr
.
Row
():
with
gr
.
Row
():
load_btn
=
gr
.
Button
()
load_btn
=
gr
.
Button
()
...
@@ -43,11 +44,12 @@ def create_infer_tab(engine: "Engine") -> dict[str, "Component"]:
...
@@ -43,11 +44,12 @@ def create_infer_tab(engine: "Engine") -> dict[str, "Component"]:
info_box
=
gr
.
Textbox
(
show_label
=
False
,
interactive
=
False
)
info_box
=
gr
.
Textbox
(
show_label
=
False
,
interactive
=
False
)
input_elems
.
update
({
infer_backend
,
infer_dtype
})
input_elems
.
update
({
infer_backend
,
infer_dtype
,
extra_args
})
elem_dict
.
update
(
elem_dict
.
update
(
dict
(
dict
(
infer_backend
=
infer_backend
,
infer_backend
=
infer_backend
,
infer_dtype
=
infer_dtype
,
infer_dtype
=
infer_dtype
,
extra_args
=
extra_args
,
load_btn
=
load_btn
,
load_btn
=
load_btn
,
unload_btn
=
unload_btn
,
unload_btn
=
unload_btn
,
info_box
=
info_box
,
info_box
=
info_box
,
...
...
src/llamafactory/webui/components/top.py
View file @
0722acf1
...
@@ -18,7 +18,7 @@ from ...data import TEMPLATES
...
@@ -18,7 +18,7 @@ from ...data import TEMPLATES
from
...extras.constants
import
METHODS
,
SUPPORTED_MODELS
from
...extras.constants
import
METHODS
,
SUPPORTED_MODELS
from
...extras.packages
import
is_gradio_available
from
...extras.packages
import
is_gradio_available
from
..common
import
save_config
from
..common
import
save_config
from
..control
import
can_quantize
,
can_quantize_to
,
get_model_info
,
list_checkpoints
from
..control
import
can_quantize
,
can_quantize_to
,
check_template
,
get_model_info
,
list_checkpoints
if
is_gradio_available
():
if
is_gradio_available
():
...
@@ -49,7 +49,7 @@ def create_top() -> dict[str, "Component"]:
...
@@ -49,7 +49,7 @@ def create_top() -> dict[str, "Component"]:
model_name
.
change
(
get_model_info
,
[
model_name
],
[
model_path
,
template
],
queue
=
False
).
then
(
model_name
.
change
(
get_model_info
,
[
model_name
],
[
model_path
,
template
],
queue
=
False
).
then
(
list_checkpoints
,
[
model_name
,
finetuning_type
],
[
checkpoint_path
],
queue
=
False
list_checkpoints
,
[
model_name
,
finetuning_type
],
[
checkpoint_path
],
queue
=
False
)
)
.
then
(
check_template
,
[
lang
,
template
])
model_name
.
input
(
save_config
,
inputs
=
[
lang
,
model_name
],
queue
=
False
)
model_name
.
input
(
save_config
,
inputs
=
[
lang
,
model_name
],
queue
=
False
)
model_path
.
input
(
save_config
,
inputs
=
[
lang
,
model_name
,
model_path
],
queue
=
False
)
model_path
.
input
(
save_config
,
inputs
=
[
lang
,
model_name
,
model_path
],
queue
=
False
)
finetuning_type
.
change
(
can_quantize
,
[
finetuning_type
],
[
quantization_bit
],
queue
=
False
).
then
(
finetuning_type
.
change
(
can_quantize
,
[
finetuning_type
],
[
quantization_bit
],
queue
=
False
).
then
(
...
...
src/llamafactory/webui/components/train.py
View file @
0722acf1
...
@@ -106,11 +106,11 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
...
@@ -106,11 +106,11 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
use_llama_pro
=
gr
.
Checkbox
()
use_llama_pro
=
gr
.
Checkbox
()
with
gr
.
Column
():
with
gr
.
Column
():
enable_thinking
=
gr
.
Checkbox
(
value
=
True
)
report_to
=
gr
.
Dropdown
(
report_to
=
gr
.
Dropdown
(
choices
=
[
"none"
,
"all"
,
"wandb"
,
"mlflow"
,
"neptune"
,
"tensorboard"
],
choices
=
[
"none"
,
"wandb"
,
"mlflow"
,
"neptune"
,
"tensorboard"
,
"all"
],
value
=
[
"none"
]
,
value
=
"none"
,
allow_custom_value
=
True
,
allow_custom_value
=
True
,
multiselect
=
True
,
)
)
input_elems
.
update
(
input_elems
.
update
(
...
@@ -126,6 +126,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
...
@@ -126,6 +126,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
mask_history
,
mask_history
,
resize_vocab
,
resize_vocab
,
use_llama_pro
,
use_llama_pro
,
enable_thinking
,
report_to
,
report_to
,
}
}
)
)
...
@@ -143,6 +144,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
...
@@ -143,6 +144,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
mask_history
=
mask_history
,
mask_history
=
mask_history
,
resize_vocab
=
resize_vocab
,
resize_vocab
=
resize_vocab
,
use_llama_pro
=
use_llama_pro
,
use_llama_pro
=
use_llama_pro
,
enable_thinking
=
enable_thinking
,
report_to
=
report_to
,
report_to
=
report_to
,
)
)
)
)
...
@@ -231,6 +233,42 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
...
@@ -231,6 +233,42 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
)
)
)
)
with
gr
.
Accordion
(
open
=
False
)
as
mm_tab
:
with
gr
.
Row
():
freeze_vision_tower
=
gr
.
Checkbox
(
value
=
True
)
freeze_multi_modal_projector
=
gr
.
Checkbox
(
value
=
True
)
freeze_language_model
=
gr
.
Checkbox
(
value
=
False
)
with
gr
.
Row
():
image_max_pixels
=
gr
.
Textbox
(
value
=
"768*768"
)
image_min_pixels
=
gr
.
Textbox
(
value
=
"32*32"
)
video_max_pixels
=
gr
.
Textbox
(
value
=
"256*256"
)
video_min_pixels
=
gr
.
Textbox
(
value
=
"16*16"
)
input_elems
.
update
(
{
freeze_vision_tower
,
freeze_multi_modal_projector
,
freeze_language_model
,
image_max_pixels
,
image_min_pixels
,
video_max_pixels
,
video_min_pixels
,
}
)
elem_dict
.
update
(
dict
(
mm_tab
=
mm_tab
,
freeze_vision_tower
=
freeze_vision_tower
,
freeze_multi_modal_projector
=
freeze_multi_modal_projector
,
freeze_language_model
=
freeze_language_model
,
image_max_pixels
=
image_max_pixels
,
image_min_pixels
=
image_min_pixels
,
video_max_pixels
=
video_max_pixels
,
video_min_pixels
=
video_min_pixels
,
)
)
with
gr
.
Accordion
(
open
=
False
)
as
galore_tab
:
with
gr
.
Accordion
(
open
=
False
)
as
galore_tab
:
with
gr
.
Row
():
with
gr
.
Row
():
use_galore
=
gr
.
Checkbox
()
use_galore
=
gr
.
Checkbox
()
...
...
src/llamafactory/webui/control.py
View file @
0722acf1
...
@@ -84,6 +84,17 @@ def get_model_info(model_name: str) -> tuple[str, str]:
...
@@ -84,6 +84,17 @@ def get_model_info(model_name: str) -> tuple[str, str]:
return
get_model_path
(
model_name
),
get_template
(
model_name
)
return
get_model_path
(
model_name
),
get_template
(
model_name
)
def
check_template
(
lang
:
str
,
template
:
str
)
->
None
:
r
"""Check if an instruct model is used.
Please use queue=True to show the warning message.
Inputs: top.lang, top.template
"""
if
template
==
"default"
:
gr
.
Warning
(
ALERTS
[
"warn_no_instruct"
][
lang
])
def
get_trainer_info
(
lang
:
str
,
output_path
:
os
.
PathLike
,
do_train
:
bool
)
->
tuple
[
str
,
"gr.Slider"
,
dict
[
str
,
Any
]]:
def
get_trainer_info
(
lang
:
str
,
output_path
:
os
.
PathLike
,
do_train
:
bool
)
->
tuple
[
str
,
"gr.Slider"
,
dict
[
str
,
Any
]]:
r
"""Get training infomation for monitor.
r
"""Get training infomation for monitor.
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment