Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
d1588ee7
Commit
d1588ee7
authored
Jul 18, 2025
by
chenych
Browse files
update 0718
parent
358bd2a0
Changes
43
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
301 additions
and
84 deletions
+301
-84
src/llamafactory/extras/misc.py
src/llamafactory/extras/misc.py
+43
-22
src/llamafactory/hparams/parser.py
src/llamafactory/hparams/parser.py
+8
-4
src/llamafactory/model/adapter.py
src/llamafactory/model/adapter.py
+1
-1
src/llamafactory/model/loader.py
src/llamafactory/model/loader.py
+7
-12
src/llamafactory/model/model_utils/unsloth.py
src/llamafactory/model/model_utils/unsloth.py
+5
-2
src/llamafactory/model/model_utils/valuehead.py
src/llamafactory/model/model_utils/valuehead.py
+1
-1
src/llamafactory/model/model_utils/visual.py
src/llamafactory/model/model_utils/visual.py
+17
-0
src/llamafactory/model/patcher.py
src/llamafactory/model/patcher.py
+3
-0
src/llamafactory/train/callbacks.py
src/llamafactory/train/callbacks.py
+1
-1
src/llamafactory/webui/common.py
src/llamafactory/webui/common.py
+9
-4
src/llamafactory/webui/components/__init__.py
src/llamafactory/webui/components/__init__.py
+2
-0
src/llamafactory/webui/components/chatbot.py
src/llamafactory/webui/components/chatbot.py
+9
-1
src/llamafactory/webui/components/footer.py
src/llamafactory/webui/components/footer.py
+45
-0
src/llamafactory/webui/components/top.py
src/llamafactory/webui/components/top.py
+15
-5
src/llamafactory/webui/control.py
src/llamafactory/webui/control.py
+10
-1
src/llamafactory/webui/engine.py
src/llamafactory/webui/engine.py
+3
-1
src/llamafactory/webui/interface.py
src/llamafactory/webui/interface.py
+6
-6
src/llamafactory/webui/locales.py
src/llamafactory/webui/locales.py
+93
-0
src/llamafactory/webui/runner.py
src/llamafactory/webui/runner.py
+21
-21
tests/data/test_formatter.py
tests/data/test_formatter.py
+2
-2
No files found.
src/llamafactory/extras/misc.py
View file @
d1588ee7
...
...
@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Literal, Union
import
torch
import
torch.distributed
as
dist
import
transformers.dynamic_module_utils
from
huggingface_hub.utils
import
WeakFileLock
from
transformers
import
InfNanRemoveLogitsProcessor
,
LogitsProcessorList
from
transformers.dynamic_module_utils
import
get_relative_imports
from
transformers.utils
import
(
...
...
@@ -35,7 +36,6 @@ from transformers.utils import (
from
transformers.utils.versions
import
require_version
from
.
import
logging
from
.packages
import
is_transformers_version_greater_than
_is_fp16_available
=
is_torch_npu_available
()
or
is_torch_cuda_available
()
...
...
@@ -94,15 +94,11 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def
check_dependencies
()
->
None
:
r
"""Check the version of the required packages."""
check_version
(
"transformers>=4.45.0,<=4.52.4,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0,!=4.52.0"
)
check_version
(
"transformers>=4.49.0,<=4.52.4,!=4.52.0"
)
check_version
(
"datasets>=2.16.0,<=3.6.0"
)
check_version
(
"accelerate>=
0
.3
4
.0,<=1.7.0"
)
check_version
(
"accelerate>=
1
.3.0,<=1.7.0"
)
check_version
(
"peft>=0.14.0,<=0.15.2"
)
check_version
(
"trl>=0.8.6,<=0.9.6"
)
if
is_transformers_version_greater_than
(
"4.46.0"
)
and
not
is_transformers_version_greater_than
(
"4.48.1"
):
logger
.
warning_rank0_once
(
"There are known bugs in transformers v4.46.0-v4.48.0, please use other versions."
)
def
calculate_tps
(
dataset
:
list
[
dict
[
str
,
Any
]],
metrics
:
dict
[
str
,
float
],
stage
:
Literal
[
"sft"
,
"rm"
])
->
float
:
...
...
@@ -182,8 +178,22 @@ def get_logits_processor() -> "LogitsProcessorList":
return
logits_processor
def
get_current_memory
()
->
tuple
[
int
,
int
]:
r
"""Get the available and total memory for the current device (in Bytes)."""
if
is_torch_xpu_available
():
return
torch
.
xpu
.
mem_get_info
()
elif
is_torch_npu_available
():
return
torch
.
npu
.
mem_get_info
()
elif
is_torch_mps_available
():
return
torch
.
mps
.
current_allocated_memory
(),
torch
.
mps
.
recommended_max_memory
()
elif
is_torch_cuda_available
():
return
torch
.
cuda
.
mem_get_info
()
else
:
return
0
,
-
1
def
get_peak_memory
()
->
tuple
[
int
,
int
]:
r
"""Get the peak memory usage for the current device (in Bytes)."""
r
"""Get the peak memory usage
(allocated, reserved)
for the current device (in Bytes)."""
if
is_torch_xpu_available
():
return
torch
.
xpu
.
max_memory_allocated
(),
torch
.
xpu
.
max_memory_reserved
()
elif
is_torch_npu_available
():
...
...
@@ -193,7 +203,7 @@ def get_peak_memory() -> tuple[int, int]:
elif
is_torch_cuda_available
():
return
torch
.
cuda
.
max_memory_allocated
(),
torch
.
cuda
.
max_memory_reserved
()
else
:
return
0
,
0
return
0
,
-
1
def
has_tokenized_data
(
path
:
"os.PathLike"
)
->
bool
:
...
...
@@ -259,25 +269,36 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
return
model_args
.
model_name_or_path
if
use_modelscope
():
check_version
(
"modelscope>=1.1
1
.0"
,
mandatory
=
True
)
check_version
(
"modelscope>=1.1
4
.0"
,
mandatory
=
True
)
from
modelscope
import
snapshot_download
# type: ignore
from
modelscope.hub.api
import
HubApi
# type: ignore
if
model_args
.
ms_hub_token
:
api
=
HubApi
()
api
.
login
(
model_args
.
ms_hub_token
)
revision
=
"master"
if
model_args
.
model_revision
==
"main"
else
model_args
.
model_revision
return
snapshot_download
(
model_args
.
model_name_or_path
,
revision
=
revision
,
cache_dir
=
model_args
.
cache_dir
,
)
with
WeakFileLock
(
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
"~/.cache/llamafactory/modelscope.lock"
))):
model_path
=
snapshot_download
(
model_args
.
model_name_or_path
,
revision
=
revision
,
cache_dir
=
model_args
.
cache_dir
,
)
return
model_path
if
use_openmind
():
check_version
(
"openmind>=0.8.0"
,
mandatory
=
True
)
from
openmind.utils.hub
import
snapshot_download
# type: ignore
return
snapshot_download
(
model_args
.
model_name_or_path
,
revision
=
model_args
.
model_revision
,
cache_dir
=
model_args
.
cache_dir
,
)
with
WeakFileLock
(
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
"~/.cache/llamafactory/openmind.lock"
))):
model_path
=
snapshot_download
(
model_args
.
model_name_or_path
,
revision
=
model_args
.
model_revision
,
cache_dir
=
model_args
.
cache_dir
,
)
return
model_path
def
use_modelscope
()
->
bool
:
...
...
@@ -305,5 +326,5 @@ def fix_proxy(ipv6_enabled: bool = False) -> None:
r
"""Fix proxy settings for gradio ui."""
os
.
environ
[
"no_proxy"
]
=
"localhost,127.0.0.1,0.0.0.0"
if
ipv6_enabled
:
for
name
in
(
"http_proxy"
,
"https_proxy"
,
"HTTP_PROXY"
,
"HTTPS_PROXY"
):
os
.
environ
.
pop
(
name
,
None
)
os
.
environ
.
pop
(
"http_proxy"
,
None
)
os
.
environ
.
pop
(
"HTTP_PROXY"
,
None
)
src/llamafactory/hparams/parser.py
View file @
d1588ee7
...
...
@@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
import
sys
from
pathlib
import
Path
...
...
@@ -23,7 +22,6 @@ from typing import Any, Optional, Union
import
torch
import
transformers
import
yaml
from
omegaconf
import
OmegaConf
from
transformers
import
HfArgumentParser
from
transformers.integrations
import
is_deepspeed_zero3_enabled
...
...
@@ -62,11 +60,11 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
if
sys
.
argv
[
1
].
endswith
(
".yaml"
)
or
sys
.
argv
[
1
].
endswith
(
".yml"
):
override_config
=
OmegaConf
.
from_cli
(
sys
.
argv
[
2
:])
dict_config
=
yaml
.
safe_
load
(
Path
(
sys
.
argv
[
1
]).
absolute
()
.
read_text
()
)
dict_config
=
OmegaConf
.
load
(
Path
(
sys
.
argv
[
1
]).
absolute
())
return
OmegaConf
.
to_container
(
OmegaConf
.
merge
(
dict_config
,
override_config
))
elif
sys
.
argv
[
1
].
endswith
(
".json"
):
override_config
=
OmegaConf
.
from_cli
(
sys
.
argv
[
2
:])
dict_config
=
js
on
.
load
s
(
Path
(
sys
.
argv
[
1
]).
absolute
()
.
read_text
()
)
dict_config
=
OmegaC
on
f
.
load
(
Path
(
sys
.
argv
[
1
]).
absolute
())
return
OmegaConf
.
to_container
(
OmegaConf
.
merge
(
dict_config
,
override_config
))
else
:
return
sys
.
argv
[
1
:]
...
...
@@ -166,6 +164,9 @@ def _check_extra_dependencies(
if
finetuning_args
.
use_adam_mini
:
check_version
(
"adam-mini"
,
mandatory
=
True
)
if
finetuning_args
.
use_swanlab
:
check_version
(
"swanlab"
,
mandatory
=
True
)
if
finetuning_args
.
plot_loss
:
check_version
(
"matplotlib"
,
mandatory
=
True
)
...
...
@@ -348,6 +349,9 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
# https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/trainer.py#L782
training_args
.
label_names
=
training_args
.
label_names
or
[
"labels"
]
if
"swanlab"
in
training_args
.
report_to
and
finetuning_args
.
use_swanlab
:
training_args
.
report_to
.
remove
(
"swanlab"
)
if
(
training_args
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
and
training_args
.
ddp_find_unused_parameters
is
None
...
...
src/llamafactory/model/adapter.py
View file @
d1588ee7
...
...
@@ -188,7 +188,7 @@ def _setup_lora_tuning(
if
adapter_to_resume
is
not
None
:
# resume lora training
if
model_args
.
use_unsloth
:
model
=
load_unsloth_peft_model
(
config
,
model_args
,
is_trainable
=
is_trainable
)
model
=
load_unsloth_peft_model
(
config
,
model_args
,
finetuning_args
,
is_trainable
=
is_trainable
)
else
:
model
=
PeftModel
.
from_pretrained
(
model
,
adapter_to_resume
,
is_trainable
=
is_trainable
,
**
init_kwargs
)
...
...
src/llamafactory/model/loader.py
View file @
d1588ee7
...
...
@@ -19,6 +19,7 @@ import torch
from
transformers
import
(
AutoConfig
,
AutoModelForCausalLM
,
AutoModelForImageTextToText
,
AutoModelForSeq2SeqLM
,
AutoModelForTextToWaveform
,
AutoModelForVision2Seq
,
...
...
@@ -29,7 +30,6 @@ from trl import AutoModelForCausalLMWithValueHead
from
..extras
import
logging
from
..extras.misc
import
count_parameters
,
skip_check_imports
,
try_download_model_from_other_hub
from
..extras.packages
import
is_transformers_version_greater_than
from
.adapter
import
init_adapter
from
.model_utils.liger_kernel
import
apply_liger_kernel
from
.model_utils.misc
import
register_autoclass
...
...
@@ -39,10 +39,6 @@ from .model_utils.valuehead import load_valuehead_params
from
.patcher
import
patch_config
,
patch_model
,
patch_processor
,
patch_tokenizer
,
patch_valuehead_model
if
is_transformers_version_greater_than
(
"4.46.0"
):
from
transformers
import
AutoModelForImageTextToText
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
,
PreTrainedModel
,
PreTrainedTokenizer
,
ProcessorMixin
...
...
@@ -111,9 +107,8 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
**
init_kwargs
,
)
except
Exception
as
e
:
raise
OSError
(
"Failed to load processor."
)
from
e
patch_processor
(
processor
,
tokenizer
,
model_args
)
logger
.
info_rank0
(
f
"Failed to load processor:
{
e
}
."
)
processor
=
None
# Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
...
...
@@ -121,6 +116,9 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
logger
.
debug
(
"The loaded processor is not an instance of Processor. Dropping it."
)
processor
=
None
if
processor
is
not
None
:
patch_processor
(
processor
,
tokenizer
,
model_args
)
return
{
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
...
...
@@ -160,10 +158,7 @@ def load_model(
else
:
if
type
(
config
)
in
AutoModelForVision2Seq
.
_model_mapping
.
keys
():
# image-text
load_class
=
AutoModelForVision2Seq
elif
(
is_transformers_version_greater_than
(
"4.46.0"
)
and
type
(
config
)
in
AutoModelForImageTextToText
.
_model_mapping
.
keys
()
):
# image-text
elif
type
(
config
)
in
AutoModelForImageTextToText
.
_model_mapping
.
keys
():
# image-text
load_class
=
AutoModelForImageTextToText
elif
type
(
config
)
in
AutoModelForSeq2SeqLM
.
_model_mapping
.
keys
():
# audio-text
load_class
=
AutoModelForSeq2SeqLM
...
...
src/llamafactory/model/model_utils/unsloth.py
View file @
d1588ee7
...
...
@@ -80,12 +80,15 @@ def get_unsloth_peft_model(
def
load_unsloth_peft_model
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
finetuning_args
:
"FinetuningArguments"
,
is_trainable
:
bool
,
)
->
"PreTrainedModel"
:
r
"""Load peft model with unsloth. Used in both training and inference."""
from
unsloth
import
FastLanguageModel
# type: ignore
unsloth_kwargs
=
_get_unsloth_kwargs
(
config
,
model_args
.
adapter_name_or_path
[
0
],
model_args
)
unsloth_kwargs
=
_get_unsloth_kwargs
(
config
,
model_args
.
adapter_name_or_path
[
0
],
model_args
,
finetuning_args
)
try
:
if
not
is_trainable
:
unsloth_kwargs
[
"use_gradient_checkpointing"
]
=
False
...
...
src/llamafactory/model/model_utils/valuehead.py
View file @
d1588ee7
...
...
@@ -49,7 +49,7 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
try
:
vhead_file
=
cached_file
(
filename
=
V_HEAD_WEIGHTS_NAME
,
**
kwargs
)
return
torch
.
load
(
vhead_file
,
map_location
=
"cpu"
)
return
torch
.
load
(
vhead_file
,
map_location
=
"cpu"
,
weights_only
=
True
)
except
Exception
as
err
:
err_text
=
str
(
err
)
...
...
src/llamafactory/model/model_utils/visual.py
View file @
d1588ee7
...
...
@@ -204,6 +204,23 @@ _register_composite_model(
)
_register_composite_model
(
model_type
=
"gemma3n"
,
vision_model_keys
=
[
"vision_tower"
,
"audio_tower"
],
lora_conflict_keys
=
[
"timm_model"
,
"subsample_conv_projection"
],
)
# copied from qwen2vl
_register_composite_model
(
model_type
=
"glm4v"
,
projector_key
=
"visual.merger"
,
vision_model_keys
=
[
"visual.patch_embed"
,
"visual.blocks"
],
language_model_keys
=
[
"language_model"
,
"lm_head"
],
lora_conflict_keys
=
[
"patch_embed"
],
)
_register_composite_model
(
model_type
=
"internvl"
,
)
...
...
src/llamafactory/model/patcher.py
View file @
d1588ee7
...
...
@@ -178,6 +178,9 @@ def patch_model(
resize_embedding_layer
(
model
,
tokenizer
)
if
is_trainable
:
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"gemma3n"
:
setattr
(
model_args
,
"disable_gradient_checkpointing"
,
True
)
prepare_model_for_training
(
model
,
model_args
)
autocast_projector_dtype
(
model
,
model_args
)
add_z3_leaf_module
(
model
)
...
...
src/llamafactory/train/callbacks.py
View file @
d1588ee7
...
...
@@ -76,7 +76,7 @@ def fix_valuehead_checkpoint(
state_dict
:
dict
[
str
,
torch
.
Tensor
]
=
{
key
:
f
.
get_tensor
(
key
)
for
key
in
f
.
keys
()}
else
:
path_to_checkpoint
=
os
.
path
.
join
(
output_dir
,
WEIGHTS_NAME
)
state_dict
:
dict
[
str
,
torch
.
Tensor
]
=
torch
.
load
(
path_to_checkpoint
,
map_location
=
"cpu"
)
state_dict
:
dict
[
str
,
torch
.
Tensor
]
=
torch
.
load
(
path_to_checkpoint
,
map_location
=
"cpu"
,
weights_only
=
True
)
os
.
remove
(
path_to_checkpoint
)
decoder_state_dict
,
v_head_state_dict
=
{},
{}
...
...
src/llamafactory/webui/common.py
View file @
d1588ee7
...
...
@@ -77,14 +77,19 @@ def load_config() -> dict[str, Union[str, dict[str, Any]]]:
with
open
(
_get_config_path
(),
encoding
=
"utf-8"
)
as
f
:
return
safe_load
(
f
)
except
Exception
:
return
{
"lang"
:
None
,
"last_model"
:
None
,
"path_dict"
:
{},
"cache_dir"
:
None
}
return
{
"lang"
:
None
,
"hub_name"
:
None
,
"last_model"
:
None
,
"path_dict"
:
{},
"cache_dir"
:
None
}
def
save_config
(
lang
:
str
,
model_name
:
Optional
[
str
]
=
None
,
model_path
:
Optional
[
str
]
=
None
)
->
None
:
def
save_config
(
lang
:
str
,
hub_name
:
Optional
[
str
]
=
None
,
model_name
:
Optional
[
str
]
=
None
,
model_path
:
Optional
[
str
]
=
None
)
->
None
:
r
"""Save user config."""
os
.
makedirs
(
DEFAULT_CACHE_DIR
,
exist_ok
=
True
)
user_config
=
load_config
()
user_config
[
"lang"
]
=
lang
or
user_config
[
"lang"
]
if
hub_name
:
user_config
[
"hub_name"
]
=
hub_name
if
model_name
:
user_config
[
"last_model"
]
=
model_name
...
...
@@ -247,7 +252,7 @@ def create_ds_config() -> None:
"stage"
:
2
,
"allgather_partitions"
:
True
,
"allgather_bucket_size"
:
5e8
,
"overlap_comm"
:
Tru
e
,
"overlap_comm"
:
Fals
e
,
"reduce_scatter"
:
True
,
"reduce_bucket_size"
:
5e8
,
"contiguous_gradients"
:
True
,
...
...
@@ -262,7 +267,7 @@ def create_ds_config() -> None:
ds_config
[
"zero_optimization"
]
=
{
"stage"
:
3
,
"overlap_comm"
:
Tru
e
,
"overlap_comm"
:
Fals
e
,
"contiguous_gradients"
:
True
,
"sub_group_size"
:
1e9
,
"reduce_bucket_size"
:
"auto"
,
...
...
src/llamafactory/webui/components/__init__.py
View file @
d1588ee7
...
...
@@ -15,6 +15,7 @@
from
.chatbot
import
create_chat_box
from
.eval
import
create_eval_tab
from
.export
import
create_export_tab
from
.footer
import
create_footer
from
.infer
import
create_infer_tab
from
.top
import
create_top
from
.train
import
create_train_tab
...
...
@@ -24,6 +25,7 @@ __all__ = [
"create_chat_box"
,
"create_eval_tab"
,
"create_export_tab"
,
"create_footer"
,
"create_infer_tab"
,
"create_top"
,
"create_train_tab"
,
...
...
src/llamafactory/webui/components/chatbot.py
View file @
d1588ee7
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
inspect
import
json
from
typing
import
TYPE_CHECKING
...
...
@@ -50,7 +51,14 @@ def create_chat_box(
)
->
tuple
[
"Component"
,
"Component"
,
dict
[
str
,
"Component"
]]:
lang
=
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)
with
gr
.
Column
(
visible
=
visible
)
as
chat_box
:
chatbot
=
gr
.
Chatbot
(
type
=
"messages"
,
show_copy_button
=
True
)
kwargs
=
{}
if
"show_copy_button"
in
inspect
.
signature
(
gr
.
Chatbot
.
__init__
).
parameters
:
kwargs
[
"show_copy_button"
]
=
True
if
"resizable"
in
inspect
.
signature
(
gr
.
Chatbot
.
__init__
).
parameters
:
kwargs
[
"resizable"
]
=
True
chatbot
=
gr
.
Chatbot
(
type
=
"messages"
,
**
kwargs
)
messages
=
gr
.
State
([])
with
gr
.
Row
():
with
gr
.
Column
(
scale
=
4
):
...
...
src/llamafactory/webui/components/footer.py
0 → 100644
View file @
d1588ee7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
from
...extras.misc
import
get_current_memory
from
...extras.packages
import
is_gradio_available
if
is_gradio_available
():
import
gradio
as
gr
if
TYPE_CHECKING
:
from
gradio.components
import
Component
def
get_device_memory
()
->
"gr.Slider"
:
free
,
total
=
get_current_memory
()
if
total
!=
-
1
:
used
=
round
((
total
-
free
)
/
(
1024
**
3
),
2
)
total
=
round
(
total
/
(
1024
**
3
),
2
)
return
gr
.
Slider
(
minimum
=
0
,
maximum
=
total
,
value
=
used
,
step
=
0.01
,
visible
=
True
)
else
:
return
gr
.
Slider
(
visible
=
False
)
def
create_footer
()
->
dict
[
str
,
"Component"
]:
with
gr
.
Row
():
device_memory
=
gr
.
Slider
(
visible
=
False
,
interactive
=
False
)
timer
=
gr
.
Timer
(
value
=
5
)
timer
.
tick
(
get_device_memory
,
outputs
=
[
device_memory
],
queue
=
False
)
return
dict
(
device_memory
=
device_memory
)
src/llamafactory/webui/components/top.py
View file @
d1588ee7
...
...
@@ -16,9 +16,10 @@ from typing import TYPE_CHECKING
from
...data
import
TEMPLATES
from
...extras.constants
import
METHODS
,
SUPPORTED_MODELS
from
...extras.misc
import
use_modelscope
,
use_openmind
from
...extras.packages
import
is_gradio_available
from
..common
import
save_config
from
..control
import
can_quantize
,
can_quantize_to
,
check_template
,
get_model_info
,
list_checkpoints
from
..control
import
can_quantize
,
can_quantize_to
,
check_template
,
get_model_info
,
list_checkpoints
,
switch_hub
if
is_gradio_available
():
...
...
@@ -33,8 +34,10 @@ def create_top() -> dict[str, "Component"]:
with
gr
.
Row
():
lang
=
gr
.
Dropdown
(
choices
=
[
"en"
,
"ru"
,
"zh"
,
"ko"
,
"ja"
],
value
=
None
,
scale
=
1
)
available_models
=
list
(
SUPPORTED_MODELS
.
keys
())
+
[
"Custom"
]
model_name
=
gr
.
Dropdown
(
choices
=
available_models
,
value
=
None
,
scale
=
3
)
model_path
=
gr
.
Textbox
(
scale
=
3
)
model_name
=
gr
.
Dropdown
(
choices
=
available_models
,
value
=
None
,
scale
=
2
)
model_path
=
gr
.
Textbox
(
scale
=
2
)
default_hub
=
"modelscope"
if
use_modelscope
()
else
"openmind"
if
use_openmind
()
else
"huggingface"
hub_name
=
gr
.
Dropdown
(
choices
=
[
"huggingface"
,
"modelscope"
,
"openmind"
],
value
=
default_hub
,
scale
=
2
)
with
gr
.
Row
():
finetuning_type
=
gr
.
Dropdown
(
choices
=
METHODS
,
value
=
"lora"
,
scale
=
1
)
...
...
@@ -50,18 +53,25 @@ def create_top() -> dict[str, "Component"]:
model_name
.
change
(
get_model_info
,
[
model_name
],
[
model_path
,
template
],
queue
=
False
).
then
(
list_checkpoints
,
[
model_name
,
finetuning_type
],
[
checkpoint_path
],
queue
=
False
).
then
(
check_template
,
[
lang
,
template
])
model_name
.
input
(
save_config
,
inputs
=
[
lang
,
model_name
],
queue
=
False
)
model_path
.
input
(
save_config
,
inputs
=
[
lang
,
model_name
,
model_path
],
queue
=
False
)
model_name
.
input
(
save_config
,
inputs
=
[
lang
,
hub_name
,
model_name
],
queue
=
False
)
model_path
.
input
(
save_config
,
inputs
=
[
lang
,
hub_name
,
model_name
,
model_path
],
queue
=
False
)
finetuning_type
.
change
(
can_quantize
,
[
finetuning_type
],
[
quantization_bit
],
queue
=
False
).
then
(
list_checkpoints
,
[
model_name
,
finetuning_type
],
[
checkpoint_path
],
queue
=
False
)
checkpoint_path
.
focus
(
list_checkpoints
,
[
model_name
,
finetuning_type
],
[
checkpoint_path
],
queue
=
False
)
quantization_method
.
change
(
can_quantize_to
,
[
quantization_method
],
[
quantization_bit
],
queue
=
False
)
hub_name
.
change
(
switch_hub
,
inputs
=
[
hub_name
],
queue
=
False
).
then
(
get_model_info
,
[
model_name
],
[
model_path
,
template
],
queue
=
False
).
then
(
list_checkpoints
,
[
model_name
,
finetuning_type
],
[
checkpoint_path
],
queue
=
False
).
then
(
check_template
,
[
lang
,
template
]
)
hub_name
.
input
(
save_config
,
inputs
=
[
lang
,
hub_name
],
queue
=
False
)
return
dict
(
lang
=
lang
,
model_name
=
model_name
,
model_path
=
model_path
,
hub_name
=
hub_name
,
finetuning_type
=
finetuning_type
,
checkpoint_path
=
checkpoint_path
,
quantization_bit
=
quantization_bit
,
...
...
src/llamafactory/webui/control.py
View file @
d1588ee7
...
...
@@ -38,6 +38,15 @@ if is_gradio_available():
import
gradio
as
gr
def
switch_hub
(
hub_name
:
str
)
->
None
:
r
"""Switch model hub.
Inputs: top.hub_name
"""
os
.
environ
[
"USE_MODELSCOPE_HUB"
]
=
"1"
if
hub_name
==
"modelscope"
else
"0"
os
.
environ
[
"USE_OPENMIND_HUB"
]
=
"1"
if
hub_name
==
"openmind"
else
"0"
def
can_quantize
(
finetuning_type
:
str
)
->
"gr.Dropdown"
:
r
"""Judge if the quantization is available in this finetuning type.
...
...
@@ -112,7 +121,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tup
running_log_path
=
os
.
path
.
join
(
output_path
,
RUNNING_LOG
)
if
os
.
path
.
isfile
(
running_log_path
):
with
open
(
running_log_path
,
encoding
=
"utf-8"
)
as
f
:
running_log
=
f
.
read
()[
-
20000
:]
# avoid lengthy log
running_log
=
"```
\n
"
+
f
.
read
()[
-
20000
:]
+
"
\n
```
\n
"
# avoid lengthy log
trainer_log_path
=
os
.
path
.
join
(
output_path
,
TRAINER_LOG
)
if
os
.
path
.
isfile
(
trainer_log_path
):
...
...
src/llamafactory/webui/engine.py
View file @
d1588ee7
...
...
@@ -49,11 +49,13 @@ class Engine:
def
resume
(
self
):
r
"""Get the initial value of gradio components and restores training status if necessary."""
user_config
=
load_config
()
if
not
self
.
demo_mode
else
{}
# do not use config in demo mode
lang
=
user_config
.
get
(
"lang"
,
None
)
or
"en"
lang
=
user_config
.
get
(
"lang"
)
or
"en"
init_dict
=
{
"top.lang"
:
{
"value"
:
lang
},
"infer.chat_box"
:
{
"visible"
:
self
.
chatter
.
loaded
}}
if
not
self
.
pure_chat
:
current_time
=
get_time
()
hub_name
=
user_config
.
get
(
"hub_name"
)
or
"huggingface"
init_dict
[
"top.hub_name"
]
=
{
"value"
:
hub_name
}
init_dict
[
"train.current_time"
]
=
{
"value"
:
current_time
}
init_dict
[
"train.output_dir"
]
=
{
"value"
:
f
"train_
{
current_time
}
"
}
init_dict
[
"train.config_path"
]
=
{
"value"
:
f
"
{
current_time
}
.yaml"
}
...
...
src/llamafactory/webui/interface.py
View file @
d1588ee7
...
...
@@ -22,6 +22,7 @@ from .components import (
create_chat_box
,
create_eval_tab
,
create_export_tab
,
create_footer
,
create_infer_tab
,
create_top
,
create_train_tab
,
...
...
@@ -38,15 +39,13 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
engine
=
Engine
(
demo_mode
=
demo_mode
,
pure_chat
=
False
)
hostname
=
os
.
getenv
(
"HOSTNAME"
,
os
.
getenv
(
"COMPUTERNAME"
,
platform
.
node
())).
split
(
"."
)[
0
]
with
gr
.
Blocks
(
title
=
f
"LLaMA Board (
{
hostname
}
)"
,
css
=
CSS
)
as
demo
:
with
gr
.
Blocks
(
title
=
f
"LLaMA Factory (
{
hostname
}
)"
,
css
=
CSS
)
as
demo
:
title
=
gr
.
HTML
()
subtitle
=
gr
.
HTML
()
if
demo_mode
:
gr
.
HTML
(
"<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>"
)
gr
.
HTML
(
'<h3><center>Visit <a href="https://github.com/hiyouga/LLaMA-Factory" target="_blank">'
"LLaMA Factory</a> for details.</center></h3>"
)
gr
.
DuplicateButton
(
value
=
"Duplicate Space for private use"
,
elem_classes
=
"duplicate-button"
)
engine
.
manager
.
add_elems
(
"head"
,
{
"title"
:
title
,
"subtitle"
:
subtitle
})
engine
.
manager
.
add_elems
(
"top"
,
create_top
())
lang
:
gr
.
Dropdown
=
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)
...
...
@@ -63,6 +62,7 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
with
gr
.
Tab
(
"Export"
):
engine
.
manager
.
add_elems
(
"export"
,
create_export_tab
(
engine
))
engine
.
manager
.
add_elems
(
"footer"
,
create_footer
())
demo
.
load
(
engine
.
resume
,
outputs
=
engine
.
manager
.
get_elem_list
(),
concurrency_limit
=
None
)
lang
.
change
(
engine
.
change_lang
,
[
lang
],
engine
.
manager
.
get_elem_list
(),
queue
=
False
)
lang
.
input
(
save_config
,
inputs
=
[
lang
],
queue
=
False
)
...
...
src/llamafactory/webui/locales.py
View file @
d1588ee7
...
...
@@ -13,6 +13,55 @@
# limitations under the License.
LOCALES
=
{
"title"
:
{
"en"
:
{
"value"
:
"<h1><center>🦙🏭LLaMA Factory: Unified Efficient Fine-Tuning of 100+ LLMs</center></h1>"
,
},
"ru"
:
{
"value"
:
"<h1><center>🦙🏭LLaMA Factory: Унифицированная эффективная тонкая настройка 100+ LLMs</center></h1>"
,
},
"zh"
:
{
"value"
:
"<h1><center>🦙🏭LLaMA Factory: 一站式大模型高效微调平台</center></h1>"
,
},
"ko"
:
{
"value"
:
"<h1><center>🦙🏭LLaMA Factory: 100+ LLMs를 위한 통합 효율적인 튜닝</center></h1>"
,
},
"ja"
:
{
"value"
:
"<h1><center>🦙🏭LLaMA Factory: 100+ LLMs の統合効率的なチューニング</center></h1>"
,
},
},
"subtitle"
:
{
"en"
:
{
"value"
:
(
"<h3><center>Visit <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub Page</a></center></h3>"
),
},
"ru"
:
{
"value"
:
(
"<h3><center>Посетить <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"страницу GitHub</a></center></h3>"
),
},
"zh"
:
{
"value"
:
(
"<h3><center>访问 <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub 主页</a></center></h3>"
),
},
"ko"
:
{
"value"
:
(
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub 페이지</a>를 방문하세요.</center></h3>"
),
},
"ja"
:
{
"value"
:
(
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub ページ</a>にアクセスする</center></h3>"
),
},
},
"lang"
:
{
"en"
:
{
"label"
:
"Language"
,
...
...
@@ -74,6 +123,28 @@ LOCALES = {
"info"
:
"事前学習済みモデルへのパス、または Hugging Face のモデル識別子。"
,
},
},
"hub_name"
:
{
"en"
:
{
"label"
:
"Hub name"
,
"info"
:
"Choose the model download source."
,
},
"ru"
:
{
"label"
:
"Имя хаба"
,
"info"
:
"Выберите источник загрузки модели."
,
},
"zh"
:
{
"label"
:
"模型下载源"
,
"info"
:
"选择模型下载源。(网络受限环境推荐使用 ModelScope)"
,
},
"ko"
:
{
"label"
:
"모델 다운로드 소스"
,
"info"
:
"모델 다운로드 소스를 선택하세요."
,
},
"ja"
:
{
"label"
:
"モデルダウンロードソース"
,
"info"
:
"モデルをダウンロードするためのソースを選択してください。"
,
},
},
"finetuning_type"
:
{
"en"
:
{
"label"
:
"Finetuning method"
,
...
...
@@ -2849,6 +2920,28 @@ LOCALES = {
"value"
:
"エクスポート"
,
},
},
"device_memory"
:
{
"en"
:
{
"label"
:
"Device memory"
,
"info"
:
"Current memory usage of the device (GB)."
,
},
"ru"
:
{
"label"
:
"Память устройства"
,
"info"
:
"Текущая память на устройстве (GB)."
,
},
"zh"
:
{
"label"
:
"设备显存"
,
"info"
:
"当前设备的显存(GB)。"
,
},
"ko"
:
{
"label"
:
"디바이스 메모리"
,
"info"
:
"지금 사용 중인 기기 메모리 (GB)."
,
},
"ja"
:
{
"label"
:
"デバイスメモリ"
,
"info"
:
"現在のデバイスのメモリ(GB)。"
,
},
},
}
...
...
src/llamafactory/webui/runner.py
View file @
d1588ee7
...
...
@@ -16,14 +16,13 @@ import json
import
os
from
collections.abc
import
Generator
from
copy
import
deepcopy
from
subprocess
import
Popen
,
TimeoutExpired
from
subprocess
import
PIPE
,
Popen
,
TimeoutExpired
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
transformers.trainer
import
TRAINING_ARGS_NAME
from
transformers.utils
import
is_torch_npu_available
from
..extras.constants
import
LLAMABOARD_CONFIG
,
MULTIMODAL_SUPPORTED_MODELS
,
PEFT_METHODS
,
TRAINING_STAGES
from
..extras.misc
import
is_accelerator_available
,
torch_gc
,
use_ray
from
..extras.misc
import
is_accelerator_available
,
torch_gc
from
..extras.packages
import
is_gradio_available
from
.common
import
(
DEFAULT_CACHE_DIR
,
...
...
@@ -114,7 +113,7 @@ class Runner:
return
""
def
_finalize
(
self
,
lang
:
str
,
finish_info
:
str
)
->
str
:
def
_finalize
(
self
,
lang
:
str
,
finish_info
:
str
)
->
None
:
r
"""Clean the cached memory and resets the runner."""
finish_info
=
ALERTS
[
"info_aborted"
][
lang
]
if
self
.
aborted
else
finish_info
gr
.
Info
(
finish_info
)
...
...
@@ -123,7 +122,6 @@ class Runner:
self
.
running
=
False
self
.
running_data
=
None
torch_gc
()
return
finish_info
def
_parse_train_args
(
self
,
data
:
dict
[
"Component"
,
Any
])
->
dict
[
str
,
Any
]:
r
"""Build and validate the training arguments."""
...
...
@@ -314,11 +312,13 @@ class Runner:
max_samples
=
int
(
get
(
"eval.max_samples"
)),
per_device_eval_batch_size
=
get
(
"eval.batch_size"
),
predict_with_generate
=
True
,
report_to
=
"none"
,
max_new_tokens
=
get
(
"eval.max_new_tokens"
),
top_p
=
get
(
"eval.top_p"
),
temperature
=
get
(
"eval.temperature"
),
output_dir
=
get_save_dir
(
model_name
,
finetuning_type
,
get
(
"eval.output_dir"
)),
trust_remote_code
=
True
,
ddp_timeout
=
180000000
,
)
if
get
(
"eval.predict"
):
...
...
@@ -375,7 +375,7 @@ class Runner:
env
[
"FORCE_TORCHRUN"
]
=
"1"
# NOTE: DO NOT USE shell=True to avoid security risk
self
.
trainer
=
Popen
([
"llamafactory-cli"
,
"train"
,
save_cmd
(
args
)],
env
=
env
)
self
.
trainer
=
Popen
([
"llamafactory-cli"
,
"train"
,
save_cmd
(
args
)],
env
=
env
,
stderr
=
PIPE
,
text
=
True
)
yield
from
self
.
monitor
()
def
_build_config_dict
(
self
,
data
:
dict
[
"Component"
,
Any
])
->
dict
[
str
,
Any
]:
...
...
@@ -417,7 +417,8 @@ class Runner:
swanlab_link
=
self
.
manager
.
get_elem_by_id
(
"train.swanlab_link"
)
if
self
.
do_train
else
None
running_log
=
""
while
self
.
trainer
is
not
None
:
return_code
=
-
1
while
return_code
==
-
1
:
if
self
.
aborted
:
yield
{
output_box
:
ALERTS
[
"info_aborting"
][
lang
],
...
...
@@ -436,27 +437,26 @@ class Runner:
return_dict
[
swanlab_link
]
=
running_info
[
"swanlab_link"
]
yield
return_dict
try
:
self
.
trainer
.
wait
(
2
)
self
.
trainer
=
Non
e
stderr
=
self
.
trainer
.
communicate
(
timeout
=
2
)[
1
]
return_code
=
self
.
trainer
.
returncod
e
except
TimeoutExpired
:
continue
if
self
.
do_train
:
if
os
.
path
.
exists
(
os
.
path
.
join
(
output_path
,
TRAINING_ARGS_NAME
))
or
use_ray
():
finish_info
=
ALERTS
[
"info_finished"
][
lang
]
if
return_code
==
0
or
self
.
aborted
:
finish_info
=
ALERTS
[
"info_finished"
][
lang
]
if
self
.
do_train
:
finish_log
=
ALERTS
[
"info_finished"
][
lang
]
+
"
\n\n
"
+
running_log
else
:
finish_
info
=
ALERTS
[
"err_failed"
][
lang
]
finish_
log
=
load_eval_results
(
os
.
path
.
join
(
output_path
,
"all_results.json"
))
+
"
\n\n
"
+
running_log
else
:
if
os
.
path
.
exists
(
os
.
path
.
join
(
output_path
,
"all_results.json"
))
or
use_ray
():
finish_info
=
load_eval_results
(
os
.
path
.
join
(
output_path
,
"all_results.json"
))
else
:
finish_info
=
ALERTS
[
"err_failed"
][
lang
]
print
(
stderr
)
finish_info
=
ALERTS
[
"err_failed"
][
lang
]
finish_log
=
ALERTS
[
"err_failed"
][
lang
]
+
f
" Exit code:
{
return_code
}
\n\n
```
\n
{
stderr
}
\n
```
\n
"
return_dict
=
{
output_box
:
self
.
_finalize
(
lang
,
finish_info
)
+
"
\n\n
"
+
running_log
,
progress_bar
:
gr
.
Slider
(
visible
=
False
),
}
self
.
_finalize
(
lang
,
finish_info
)
return_dict
=
{
output_box
:
finish_log
,
progress_bar
:
gr
.
Slider
(
visible
=
False
)}
yield
return_dict
def
save_args
(
self
,
data
):
...
...
tests/data/test_formatter.py
View file @
d1588ee7
...
...
@@ -110,8 +110,8 @@ def test_glm4_function_formatter():
def
test_glm4_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"glm4"
)
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具
\n\n
"
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱
AI
公司
训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。
\n\n
# 可用工具
\n\n
"
f
"## test_tool
\n\n
{
json
.
dumps
(
TOOLS
[
0
],
indent
=
4
,
ensure_ascii
=
False
)
}
\n
"
"在调用上述函数时,请使用 Json 格式表示调用的参数。"
]
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment