Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
d1588ee7
Commit
d1588ee7
authored
Jul 18, 2025
by
chenych
Browse files
update 0718
parent
358bd2a0
Changes
43
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
301 additions
and
84 deletions
+301
-84
src/llamafactory/extras/misc.py
src/llamafactory/extras/misc.py
+43
-22
src/llamafactory/hparams/parser.py
src/llamafactory/hparams/parser.py
+8
-4
src/llamafactory/model/adapter.py
src/llamafactory/model/adapter.py
+1
-1
src/llamafactory/model/loader.py
src/llamafactory/model/loader.py
+7
-12
src/llamafactory/model/model_utils/unsloth.py
src/llamafactory/model/model_utils/unsloth.py
+5
-2
src/llamafactory/model/model_utils/valuehead.py
src/llamafactory/model/model_utils/valuehead.py
+1
-1
src/llamafactory/model/model_utils/visual.py
src/llamafactory/model/model_utils/visual.py
+17
-0
src/llamafactory/model/patcher.py
src/llamafactory/model/patcher.py
+3
-0
src/llamafactory/train/callbacks.py
src/llamafactory/train/callbacks.py
+1
-1
src/llamafactory/webui/common.py
src/llamafactory/webui/common.py
+9
-4
src/llamafactory/webui/components/__init__.py
src/llamafactory/webui/components/__init__.py
+2
-0
src/llamafactory/webui/components/chatbot.py
src/llamafactory/webui/components/chatbot.py
+9
-1
src/llamafactory/webui/components/footer.py
src/llamafactory/webui/components/footer.py
+45
-0
src/llamafactory/webui/components/top.py
src/llamafactory/webui/components/top.py
+15
-5
src/llamafactory/webui/control.py
src/llamafactory/webui/control.py
+10
-1
src/llamafactory/webui/engine.py
src/llamafactory/webui/engine.py
+3
-1
src/llamafactory/webui/interface.py
src/llamafactory/webui/interface.py
+6
-6
src/llamafactory/webui/locales.py
src/llamafactory/webui/locales.py
+93
-0
src/llamafactory/webui/runner.py
src/llamafactory/webui/runner.py
+21
-21
tests/data/test_formatter.py
tests/data/test_formatter.py
+2
-2
No files found.
src/llamafactory/extras/misc.py
View file @
d1588ee7
...
@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Literal, Union
...
@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Literal, Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
transformers.dynamic_module_utils
import
transformers.dynamic_module_utils
from
huggingface_hub.utils
import
WeakFileLock
from
transformers
import
InfNanRemoveLogitsProcessor
,
LogitsProcessorList
from
transformers
import
InfNanRemoveLogitsProcessor
,
LogitsProcessorList
from
transformers.dynamic_module_utils
import
get_relative_imports
from
transformers.dynamic_module_utils
import
get_relative_imports
from
transformers.utils
import
(
from
transformers.utils
import
(
...
@@ -35,7 +36,6 @@ from transformers.utils import (
...
@@ -35,7 +36,6 @@ from transformers.utils import (
from
transformers.utils.versions
import
require_version
from
transformers.utils.versions
import
require_version
from
.
import
logging
from
.
import
logging
from
.packages
import
is_transformers_version_greater_than
_is_fp16_available
=
is_torch_npu_available
()
or
is_torch_cuda_available
()
_is_fp16_available
=
is_torch_npu_available
()
or
is_torch_cuda_available
()
...
@@ -94,15 +94,11 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
...
@@ -94,15 +94,11 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def
check_dependencies
()
->
None
:
def
check_dependencies
()
->
None
:
r
"""Check the version of the required packages."""
r
"""Check the version of the required packages."""
check_version
(
check_version
(
"transformers>=4.49.0,<=4.52.4,!=4.52.0"
)
"transformers>=4.45.0,<=4.52.4,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0,!=4.52.0"
)
check_version
(
"datasets>=2.16.0,<=3.6.0"
)
check_version
(
"datasets>=2.16.0,<=3.6.0"
)
check_version
(
"accelerate>=
0
.3
4
.0,<=1.7.0"
)
check_version
(
"accelerate>=
1
.3.0,<=1.7.0"
)
check_version
(
"peft>=0.14.0,<=0.15.2"
)
check_version
(
"peft>=0.14.0,<=0.15.2"
)
check_version
(
"trl>=0.8.6,<=0.9.6"
)
check_version
(
"trl>=0.8.6,<=0.9.6"
)
if
is_transformers_version_greater_than
(
"4.46.0"
)
and
not
is_transformers_version_greater_than
(
"4.48.1"
):
logger
.
warning_rank0_once
(
"There are known bugs in transformers v4.46.0-v4.48.0, please use other versions."
)
def
calculate_tps
(
dataset
:
list
[
dict
[
str
,
Any
]],
metrics
:
dict
[
str
,
float
],
stage
:
Literal
[
"sft"
,
"rm"
])
->
float
:
def
calculate_tps
(
dataset
:
list
[
dict
[
str
,
Any
]],
metrics
:
dict
[
str
,
float
],
stage
:
Literal
[
"sft"
,
"rm"
])
->
float
:
...
@@ -182,8 +178,22 @@ def get_logits_processor() -> "LogitsProcessorList":
...
@@ -182,8 +178,22 @@ def get_logits_processor() -> "LogitsProcessorList":
return
logits_processor
return
logits_processor
def
get_current_memory
()
->
tuple
[
int
,
int
]:
r
"""Get the available and total memory for the current device (in Bytes)."""
if
is_torch_xpu_available
():
return
torch
.
xpu
.
mem_get_info
()
elif
is_torch_npu_available
():
return
torch
.
npu
.
mem_get_info
()
elif
is_torch_mps_available
():
return
torch
.
mps
.
current_allocated_memory
(),
torch
.
mps
.
recommended_max_memory
()
elif
is_torch_cuda_available
():
return
torch
.
cuda
.
mem_get_info
()
else
:
return
0
,
-
1
def
get_peak_memory
()
->
tuple
[
int
,
int
]:
def
get_peak_memory
()
->
tuple
[
int
,
int
]:
r
"""Get the peak memory usage for the current device (in Bytes)."""
r
"""Get the peak memory usage
(allocated, reserved)
for the current device (in Bytes)."""
if
is_torch_xpu_available
():
if
is_torch_xpu_available
():
return
torch
.
xpu
.
max_memory_allocated
(),
torch
.
xpu
.
max_memory_reserved
()
return
torch
.
xpu
.
max_memory_allocated
(),
torch
.
xpu
.
max_memory_reserved
()
elif
is_torch_npu_available
():
elif
is_torch_npu_available
():
...
@@ -193,7 +203,7 @@ def get_peak_memory() -> tuple[int, int]:
...
@@ -193,7 +203,7 @@ def get_peak_memory() -> tuple[int, int]:
elif
is_torch_cuda_available
():
elif
is_torch_cuda_available
():
return
torch
.
cuda
.
max_memory_allocated
(),
torch
.
cuda
.
max_memory_reserved
()
return
torch
.
cuda
.
max_memory_allocated
(),
torch
.
cuda
.
max_memory_reserved
()
else
:
else
:
return
0
,
0
return
0
,
-
1
def
has_tokenized_data
(
path
:
"os.PathLike"
)
->
bool
:
def
has_tokenized_data
(
path
:
"os.PathLike"
)
->
bool
:
...
@@ -259,26 +269,37 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
...
@@ -259,26 +269,37 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
return
model_args
.
model_name_or_path
return
model_args
.
model_name_or_path
if
use_modelscope
():
if
use_modelscope
():
check_version
(
"modelscope>=1.1
1
.0"
,
mandatory
=
True
)
check_version
(
"modelscope>=1.1
4
.0"
,
mandatory
=
True
)
from
modelscope
import
snapshot_download
# type: ignore
from
modelscope
import
snapshot_download
# type: ignore
from
modelscope.hub.api
import
HubApi
# type: ignore
if
model_args
.
ms_hub_token
:
api
=
HubApi
()
api
.
login
(
model_args
.
ms_hub_token
)
revision
=
"master"
if
model_args
.
model_revision
==
"main"
else
model_args
.
model_revision
revision
=
"master"
if
model_args
.
model_revision
==
"main"
else
model_args
.
model_revision
return
snapshot_download
(
with
WeakFileLock
(
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
"~/.cache/llamafactory/modelscope.lock"
))):
model_path
=
snapshot_download
(
model_args
.
model_name_or_path
,
model_args
.
model_name_or_path
,
revision
=
revision
,
revision
=
revision
,
cache_dir
=
model_args
.
cache_dir
,
cache_dir
=
model_args
.
cache_dir
,
)
)
return
model_path
if
use_openmind
():
if
use_openmind
():
check_version
(
"openmind>=0.8.0"
,
mandatory
=
True
)
check_version
(
"openmind>=0.8.0"
,
mandatory
=
True
)
from
openmind.utils.hub
import
snapshot_download
# type: ignore
from
openmind.utils.hub
import
snapshot_download
# type: ignore
return
snapshot_download
(
with
WeakFileLock
(
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
"~/.cache/llamafactory/openmind.lock"
))):
model_path
=
snapshot_download
(
model_args
.
model_name_or_path
,
model_args
.
model_name_or_path
,
revision
=
model_args
.
model_revision
,
revision
=
model_args
.
model_revision
,
cache_dir
=
model_args
.
cache_dir
,
cache_dir
=
model_args
.
cache_dir
,
)
)
return
model_path
def
use_modelscope
()
->
bool
:
def
use_modelscope
()
->
bool
:
return
is_env_enabled
(
"USE_MODELSCOPE_HUB"
)
return
is_env_enabled
(
"USE_MODELSCOPE_HUB"
)
...
@@ -305,5 +326,5 @@ def fix_proxy(ipv6_enabled: bool = False) -> None:
...
@@ -305,5 +326,5 @@ def fix_proxy(ipv6_enabled: bool = False) -> None:
r
"""Fix proxy settings for gradio ui."""
r
"""Fix proxy settings for gradio ui."""
os
.
environ
[
"no_proxy"
]
=
"localhost,127.0.0.1,0.0.0.0"
os
.
environ
[
"no_proxy"
]
=
"localhost,127.0.0.1,0.0.0.0"
if
ipv6_enabled
:
if
ipv6_enabled
:
for
name
in
(
"http_proxy"
,
"https_proxy"
,
"HTTP_PROXY"
,
"HTTPS_PROXY"
):
os
.
environ
.
pop
(
"http_proxy"
,
None
)
os
.
environ
.
pop
(
name
,
None
)
os
.
environ
.
pop
(
"HTTP_PROXY"
,
None
)
src/llamafactory/hparams/parser.py
View file @
d1588ee7
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
json
import
os
import
os
import
sys
import
sys
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -23,7 +22,6 @@ from typing import Any, Optional, Union
...
@@ -23,7 +22,6 @@ from typing import Any, Optional, Union
import
torch
import
torch
import
transformers
import
transformers
import
yaml
from
omegaconf
import
OmegaConf
from
omegaconf
import
OmegaConf
from
transformers
import
HfArgumentParser
from
transformers
import
HfArgumentParser
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.integrations
import
is_deepspeed_zero3_enabled
...
@@ -62,11 +60,11 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
...
@@ -62,11 +60,11 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
if
sys
.
argv
[
1
].
endswith
(
".yaml"
)
or
sys
.
argv
[
1
].
endswith
(
".yml"
):
if
sys
.
argv
[
1
].
endswith
(
".yaml"
)
or
sys
.
argv
[
1
].
endswith
(
".yml"
):
override_config
=
OmegaConf
.
from_cli
(
sys
.
argv
[
2
:])
override_config
=
OmegaConf
.
from_cli
(
sys
.
argv
[
2
:])
dict_config
=
yaml
.
safe_
load
(
Path
(
sys
.
argv
[
1
]).
absolute
()
.
read_text
()
)
dict_config
=
OmegaConf
.
load
(
Path
(
sys
.
argv
[
1
]).
absolute
())
return
OmegaConf
.
to_container
(
OmegaConf
.
merge
(
dict_config
,
override_config
))
return
OmegaConf
.
to_container
(
OmegaConf
.
merge
(
dict_config
,
override_config
))
elif
sys
.
argv
[
1
].
endswith
(
".json"
):
elif
sys
.
argv
[
1
].
endswith
(
".json"
):
override_config
=
OmegaConf
.
from_cli
(
sys
.
argv
[
2
:])
override_config
=
OmegaConf
.
from_cli
(
sys
.
argv
[
2
:])
dict_config
=
js
on
.
load
s
(
Path
(
sys
.
argv
[
1
]).
absolute
()
.
read_text
()
)
dict_config
=
OmegaC
on
f
.
load
(
Path
(
sys
.
argv
[
1
]).
absolute
())
return
OmegaConf
.
to_container
(
OmegaConf
.
merge
(
dict_config
,
override_config
))
return
OmegaConf
.
to_container
(
OmegaConf
.
merge
(
dict_config
,
override_config
))
else
:
else
:
return
sys
.
argv
[
1
:]
return
sys
.
argv
[
1
:]
...
@@ -166,6 +164,9 @@ def _check_extra_dependencies(
...
@@ -166,6 +164,9 @@ def _check_extra_dependencies(
if
finetuning_args
.
use_adam_mini
:
if
finetuning_args
.
use_adam_mini
:
check_version
(
"adam-mini"
,
mandatory
=
True
)
check_version
(
"adam-mini"
,
mandatory
=
True
)
if
finetuning_args
.
use_swanlab
:
check_version
(
"swanlab"
,
mandatory
=
True
)
if
finetuning_args
.
plot_loss
:
if
finetuning_args
.
plot_loss
:
check_version
(
"matplotlib"
,
mandatory
=
True
)
check_version
(
"matplotlib"
,
mandatory
=
True
)
...
@@ -348,6 +349,9 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
...
@@ -348,6 +349,9 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
# https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/trainer.py#L782
# https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/trainer.py#L782
training_args
.
label_names
=
training_args
.
label_names
or
[
"labels"
]
training_args
.
label_names
=
training_args
.
label_names
or
[
"labels"
]
if
"swanlab"
in
training_args
.
report_to
and
finetuning_args
.
use_swanlab
:
training_args
.
report_to
.
remove
(
"swanlab"
)
if
(
if
(
training_args
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
training_args
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
and
training_args
.
ddp_find_unused_parameters
is
None
and
training_args
.
ddp_find_unused_parameters
is
None
...
...
src/llamafactory/model/adapter.py
View file @
d1588ee7
...
@@ -188,7 +188,7 @@ def _setup_lora_tuning(
...
@@ -188,7 +188,7 @@ def _setup_lora_tuning(
if
adapter_to_resume
is
not
None
:
# resume lora training
if
adapter_to_resume
is
not
None
:
# resume lora training
if
model_args
.
use_unsloth
:
if
model_args
.
use_unsloth
:
model
=
load_unsloth_peft_model
(
config
,
model_args
,
is_trainable
=
is_trainable
)
model
=
load_unsloth_peft_model
(
config
,
model_args
,
finetuning_args
,
is_trainable
=
is_trainable
)
else
:
else
:
model
=
PeftModel
.
from_pretrained
(
model
,
adapter_to_resume
,
is_trainable
=
is_trainable
,
**
init_kwargs
)
model
=
PeftModel
.
from_pretrained
(
model
,
adapter_to_resume
,
is_trainable
=
is_trainable
,
**
init_kwargs
)
...
...
src/llamafactory/model/loader.py
View file @
d1588ee7
...
@@ -19,6 +19,7 @@ import torch
...
@@ -19,6 +19,7 @@ import torch
from
transformers
import
(
from
transformers
import
(
AutoConfig
,
AutoConfig
,
AutoModelForCausalLM
,
AutoModelForCausalLM
,
AutoModelForImageTextToText
,
AutoModelForSeq2SeqLM
,
AutoModelForSeq2SeqLM
,
AutoModelForTextToWaveform
,
AutoModelForTextToWaveform
,
AutoModelForVision2Seq
,
AutoModelForVision2Seq
,
...
@@ -29,7 +30,6 @@ from trl import AutoModelForCausalLMWithValueHead
...
@@ -29,7 +30,6 @@ from trl import AutoModelForCausalLMWithValueHead
from
..extras
import
logging
from
..extras
import
logging
from
..extras.misc
import
count_parameters
,
skip_check_imports
,
try_download_model_from_other_hub
from
..extras.misc
import
count_parameters
,
skip_check_imports
,
try_download_model_from_other_hub
from
..extras.packages
import
is_transformers_version_greater_than
from
.adapter
import
init_adapter
from
.adapter
import
init_adapter
from
.model_utils.liger_kernel
import
apply_liger_kernel
from
.model_utils.liger_kernel
import
apply_liger_kernel
from
.model_utils.misc
import
register_autoclass
from
.model_utils.misc
import
register_autoclass
...
@@ -39,10 +39,6 @@ from .model_utils.valuehead import load_valuehead_params
...
@@ -39,10 +39,6 @@ from .model_utils.valuehead import load_valuehead_params
from
.patcher
import
patch_config
,
patch_model
,
patch_processor
,
patch_tokenizer
,
patch_valuehead_model
from
.patcher
import
patch_config
,
patch_model
,
patch_processor
,
patch_tokenizer
,
patch_valuehead_model
if
is_transformers_version_greater_than
(
"4.46.0"
):
from
transformers
import
AutoModelForImageTextToText
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
,
PreTrainedModel
,
PreTrainedTokenizer
,
ProcessorMixin
from
transformers
import
PretrainedConfig
,
PreTrainedModel
,
PreTrainedTokenizer
,
ProcessorMixin
...
@@ -111,9 +107,8 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
...
@@ -111,9 +107,8 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
**
init_kwargs
,
**
init_kwargs
,
)
)
except
Exception
as
e
:
except
Exception
as
e
:
raise
OSError
(
"Failed to load processor."
)
from
e
logger
.
info_rank0
(
f
"Failed to load processor:
{
e
}
."
)
processor
=
None
patch_processor
(
processor
,
tokenizer
,
model_args
)
# Avoid load tokenizer, see:
# Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
...
@@ -121,6 +116,9 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
...
@@ -121,6 +116,9 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
logger
.
debug
(
"The loaded processor is not an instance of Processor. Dropping it."
)
logger
.
debug
(
"The loaded processor is not an instance of Processor. Dropping it."
)
processor
=
None
processor
=
None
if
processor
is
not
None
:
patch_processor
(
processor
,
tokenizer
,
model_args
)
return
{
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
return
{
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
...
@@ -160,10 +158,7 @@ def load_model(
...
@@ -160,10 +158,7 @@ def load_model(
else
:
else
:
if
type
(
config
)
in
AutoModelForVision2Seq
.
_model_mapping
.
keys
():
# image-text
if
type
(
config
)
in
AutoModelForVision2Seq
.
_model_mapping
.
keys
():
# image-text
load_class
=
AutoModelForVision2Seq
load_class
=
AutoModelForVision2Seq
elif
(
elif
type
(
config
)
in
AutoModelForImageTextToText
.
_model_mapping
.
keys
():
# image-text
is_transformers_version_greater_than
(
"4.46.0"
)
and
type
(
config
)
in
AutoModelForImageTextToText
.
_model_mapping
.
keys
()
):
# image-text
load_class
=
AutoModelForImageTextToText
load_class
=
AutoModelForImageTextToText
elif
type
(
config
)
in
AutoModelForSeq2SeqLM
.
_model_mapping
.
keys
():
# audio-text
elif
type
(
config
)
in
AutoModelForSeq2SeqLM
.
_model_mapping
.
keys
():
# audio-text
load_class
=
AutoModelForSeq2SeqLM
load_class
=
AutoModelForSeq2SeqLM
...
...
src/llamafactory/model/model_utils/unsloth.py
View file @
d1588ee7
...
@@ -80,12 +80,15 @@ def get_unsloth_peft_model(
...
@@ -80,12 +80,15 @@ def get_unsloth_peft_model(
def
load_unsloth_peft_model
(
def
load_unsloth_peft_model
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
finetuning_args
:
"FinetuningArguments"
,
is_trainable
:
bool
,
)
->
"PreTrainedModel"
:
)
->
"PreTrainedModel"
:
r
"""Load peft model with unsloth. Used in both training and inference."""
r
"""Load peft model with unsloth. Used in both training and inference."""
from
unsloth
import
FastLanguageModel
# type: ignore
from
unsloth
import
FastLanguageModel
# type: ignore
unsloth_kwargs
=
_get_unsloth_kwargs
(
config
,
model_args
.
adapter_name_or_path
[
0
],
model_args
)
unsloth_kwargs
=
_get_unsloth_kwargs
(
config
,
model_args
.
adapter_name_or_path
[
0
],
model_args
,
finetuning_args
)
try
:
try
:
if
not
is_trainable
:
if
not
is_trainable
:
unsloth_kwargs
[
"use_gradient_checkpointing"
]
=
False
unsloth_kwargs
[
"use_gradient_checkpointing"
]
=
False
...
...
src/llamafactory/model/model_utils/valuehead.py
View file @
d1588ee7
...
@@ -49,7 +49,7 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
...
@@ -49,7 +49,7 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
try
:
try
:
vhead_file
=
cached_file
(
filename
=
V_HEAD_WEIGHTS_NAME
,
**
kwargs
)
vhead_file
=
cached_file
(
filename
=
V_HEAD_WEIGHTS_NAME
,
**
kwargs
)
return
torch
.
load
(
vhead_file
,
map_location
=
"cpu"
)
return
torch
.
load
(
vhead_file
,
map_location
=
"cpu"
,
weights_only
=
True
)
except
Exception
as
err
:
except
Exception
as
err
:
err_text
=
str
(
err
)
err_text
=
str
(
err
)
...
...
src/llamafactory/model/model_utils/visual.py
View file @
d1588ee7
...
@@ -204,6 +204,23 @@ _register_composite_model(
...
@@ -204,6 +204,23 @@ _register_composite_model(
)
)
_register_composite_model
(
model_type
=
"gemma3n"
,
vision_model_keys
=
[
"vision_tower"
,
"audio_tower"
],
lora_conflict_keys
=
[
"timm_model"
,
"subsample_conv_projection"
],
)
# copied from qwen2vl
_register_composite_model
(
model_type
=
"glm4v"
,
projector_key
=
"visual.merger"
,
vision_model_keys
=
[
"visual.patch_embed"
,
"visual.blocks"
],
language_model_keys
=
[
"language_model"
,
"lm_head"
],
lora_conflict_keys
=
[
"patch_embed"
],
)
_register_composite_model
(
_register_composite_model
(
model_type
=
"internvl"
,
model_type
=
"internvl"
,
)
)
...
...
src/llamafactory/model/patcher.py
View file @
d1588ee7
...
@@ -178,6 +178,9 @@ def patch_model(
...
@@ -178,6 +178,9 @@ def patch_model(
resize_embedding_layer
(
model
,
tokenizer
)
resize_embedding_layer
(
model
,
tokenizer
)
if
is_trainable
:
if
is_trainable
:
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"gemma3n"
:
setattr
(
model_args
,
"disable_gradient_checkpointing"
,
True
)
prepare_model_for_training
(
model
,
model_args
)
prepare_model_for_training
(
model
,
model_args
)
autocast_projector_dtype
(
model
,
model_args
)
autocast_projector_dtype
(
model
,
model_args
)
add_z3_leaf_module
(
model
)
add_z3_leaf_module
(
model
)
...
...
src/llamafactory/train/callbacks.py
View file @
d1588ee7
...
@@ -76,7 +76,7 @@ def fix_valuehead_checkpoint(
...
@@ -76,7 +76,7 @@ def fix_valuehead_checkpoint(
state_dict
:
dict
[
str
,
torch
.
Tensor
]
=
{
key
:
f
.
get_tensor
(
key
)
for
key
in
f
.
keys
()}
state_dict
:
dict
[
str
,
torch
.
Tensor
]
=
{
key
:
f
.
get_tensor
(
key
)
for
key
in
f
.
keys
()}
else
:
else
:
path_to_checkpoint
=
os
.
path
.
join
(
output_dir
,
WEIGHTS_NAME
)
path_to_checkpoint
=
os
.
path
.
join
(
output_dir
,
WEIGHTS_NAME
)
state_dict
:
dict
[
str
,
torch
.
Tensor
]
=
torch
.
load
(
path_to_checkpoint
,
map_location
=
"cpu"
)
state_dict
:
dict
[
str
,
torch
.
Tensor
]
=
torch
.
load
(
path_to_checkpoint
,
map_location
=
"cpu"
,
weights_only
=
True
)
os
.
remove
(
path_to_checkpoint
)
os
.
remove
(
path_to_checkpoint
)
decoder_state_dict
,
v_head_state_dict
=
{},
{}
decoder_state_dict
,
v_head_state_dict
=
{},
{}
...
...
src/llamafactory/webui/common.py
View file @
d1588ee7
...
@@ -77,14 +77,19 @@ def load_config() -> dict[str, Union[str, dict[str, Any]]]:
...
@@ -77,14 +77,19 @@ def load_config() -> dict[str, Union[str, dict[str, Any]]]:
with
open
(
_get_config_path
(),
encoding
=
"utf-8"
)
as
f
:
with
open
(
_get_config_path
(),
encoding
=
"utf-8"
)
as
f
:
return
safe_load
(
f
)
return
safe_load
(
f
)
except
Exception
:
except
Exception
:
return
{
"lang"
:
None
,
"last_model"
:
None
,
"path_dict"
:
{},
"cache_dir"
:
None
}
return
{
"lang"
:
None
,
"hub_name"
:
None
,
"last_model"
:
None
,
"path_dict"
:
{},
"cache_dir"
:
None
}
def
save_config
(
lang
:
str
,
model_name
:
Optional
[
str
]
=
None
,
model_path
:
Optional
[
str
]
=
None
)
->
None
:
def
save_config
(
lang
:
str
,
hub_name
:
Optional
[
str
]
=
None
,
model_name
:
Optional
[
str
]
=
None
,
model_path
:
Optional
[
str
]
=
None
)
->
None
:
r
"""Save user config."""
r
"""Save user config."""
os
.
makedirs
(
DEFAULT_CACHE_DIR
,
exist_ok
=
True
)
os
.
makedirs
(
DEFAULT_CACHE_DIR
,
exist_ok
=
True
)
user_config
=
load_config
()
user_config
=
load_config
()
user_config
[
"lang"
]
=
lang
or
user_config
[
"lang"
]
user_config
[
"lang"
]
=
lang
or
user_config
[
"lang"
]
if
hub_name
:
user_config
[
"hub_name"
]
=
hub_name
if
model_name
:
if
model_name
:
user_config
[
"last_model"
]
=
model_name
user_config
[
"last_model"
]
=
model_name
...
@@ -247,7 +252,7 @@ def create_ds_config() -> None:
...
@@ -247,7 +252,7 @@ def create_ds_config() -> None:
"stage"
:
2
,
"stage"
:
2
,
"allgather_partitions"
:
True
,
"allgather_partitions"
:
True
,
"allgather_bucket_size"
:
5e8
,
"allgather_bucket_size"
:
5e8
,
"overlap_comm"
:
Tru
e
,
"overlap_comm"
:
Fals
e
,
"reduce_scatter"
:
True
,
"reduce_scatter"
:
True
,
"reduce_bucket_size"
:
5e8
,
"reduce_bucket_size"
:
5e8
,
"contiguous_gradients"
:
True
,
"contiguous_gradients"
:
True
,
...
@@ -262,7 +267,7 @@ def create_ds_config() -> None:
...
@@ -262,7 +267,7 @@ def create_ds_config() -> None:
ds_config
[
"zero_optimization"
]
=
{
ds_config
[
"zero_optimization"
]
=
{
"stage"
:
3
,
"stage"
:
3
,
"overlap_comm"
:
Tru
e
,
"overlap_comm"
:
Fals
e
,
"contiguous_gradients"
:
True
,
"contiguous_gradients"
:
True
,
"sub_group_size"
:
1e9
,
"sub_group_size"
:
1e9
,
"reduce_bucket_size"
:
"auto"
,
"reduce_bucket_size"
:
"auto"
,
...
...
src/llamafactory/webui/components/__init__.py
View file @
d1588ee7
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
from
.chatbot
import
create_chat_box
from
.chatbot
import
create_chat_box
from
.eval
import
create_eval_tab
from
.eval
import
create_eval_tab
from
.export
import
create_export_tab
from
.export
import
create_export_tab
from
.footer
import
create_footer
from
.infer
import
create_infer_tab
from
.infer
import
create_infer_tab
from
.top
import
create_top
from
.top
import
create_top
from
.train
import
create_train_tab
from
.train
import
create_train_tab
...
@@ -24,6 +25,7 @@ __all__ = [
...
@@ -24,6 +25,7 @@ __all__ = [
"create_chat_box"
,
"create_chat_box"
,
"create_eval_tab"
,
"create_eval_tab"
,
"create_export_tab"
,
"create_export_tab"
,
"create_footer"
,
"create_infer_tab"
,
"create_infer_tab"
,
"create_top"
,
"create_top"
,
"create_train_tab"
,
"create_train_tab"
,
...
...
src/llamafactory/webui/components/chatbot.py
View file @
d1588ee7
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
inspect
import
json
import
json
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
...
@@ -50,7 +51,14 @@ def create_chat_box(
...
@@ -50,7 +51,14 @@ def create_chat_box(
)
->
tuple
[
"Component"
,
"Component"
,
dict
[
str
,
"Component"
]]:
)
->
tuple
[
"Component"
,
"Component"
,
dict
[
str
,
"Component"
]]:
lang
=
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)
lang
=
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)
with
gr
.
Column
(
visible
=
visible
)
as
chat_box
:
with
gr
.
Column
(
visible
=
visible
)
as
chat_box
:
chatbot
=
gr
.
Chatbot
(
type
=
"messages"
,
show_copy_button
=
True
)
kwargs
=
{}
if
"show_copy_button"
in
inspect
.
signature
(
gr
.
Chatbot
.
__init__
).
parameters
:
kwargs
[
"show_copy_button"
]
=
True
if
"resizable"
in
inspect
.
signature
(
gr
.
Chatbot
.
__init__
).
parameters
:
kwargs
[
"resizable"
]
=
True
chatbot
=
gr
.
Chatbot
(
type
=
"messages"
,
**
kwargs
)
messages
=
gr
.
State
([])
messages
=
gr
.
State
([])
with
gr
.
Row
():
with
gr
.
Row
():
with
gr
.
Column
(
scale
=
4
):
with
gr
.
Column
(
scale
=
4
):
...
...
src/llamafactory/webui/components/footer.py
0 → 100644
View file @
d1588ee7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
from
...extras.misc
import
get_current_memory
from
...extras.packages
import
is_gradio_available
if
is_gradio_available
():
import
gradio
as
gr
if
TYPE_CHECKING
:
from
gradio.components
import
Component
def
get_device_memory
()
->
"gr.Slider"
:
free
,
total
=
get_current_memory
()
if
total
!=
-
1
:
used
=
round
((
total
-
free
)
/
(
1024
**
3
),
2
)
total
=
round
(
total
/
(
1024
**
3
),
2
)
return
gr
.
Slider
(
minimum
=
0
,
maximum
=
total
,
value
=
used
,
step
=
0.01
,
visible
=
True
)
else
:
return
gr
.
Slider
(
visible
=
False
)
def
create_footer
()
->
dict
[
str
,
"Component"
]:
with
gr
.
Row
():
device_memory
=
gr
.
Slider
(
visible
=
False
,
interactive
=
False
)
timer
=
gr
.
Timer
(
value
=
5
)
timer
.
tick
(
get_device_memory
,
outputs
=
[
device_memory
],
queue
=
False
)
return
dict
(
device_memory
=
device_memory
)
src/llamafactory/webui/components/top.py
View file @
d1588ee7
...
@@ -16,9 +16,10 @@ from typing import TYPE_CHECKING
...
@@ -16,9 +16,10 @@ from typing import TYPE_CHECKING
from
...data
import
TEMPLATES
from
...data
import
TEMPLATES
from
...extras.constants
import
METHODS
,
SUPPORTED_MODELS
from
...extras.constants
import
METHODS
,
SUPPORTED_MODELS
from
...extras.misc
import
use_modelscope
,
use_openmind
from
...extras.packages
import
is_gradio_available
from
...extras.packages
import
is_gradio_available
from
..common
import
save_config
from
..common
import
save_config
from
..control
import
can_quantize
,
can_quantize_to
,
check_template
,
get_model_info
,
list_checkpoints
from
..control
import
can_quantize
,
can_quantize_to
,
check_template
,
get_model_info
,
list_checkpoints
,
switch_hub
if
is_gradio_available
():
if
is_gradio_available
():
...
@@ -33,8 +34,10 @@ def create_top() -> dict[str, "Component"]:
...
@@ -33,8 +34,10 @@ def create_top() -> dict[str, "Component"]:
with
gr
.
Row
():
with
gr
.
Row
():
lang
=
gr
.
Dropdown
(
choices
=
[
"en"
,
"ru"
,
"zh"
,
"ko"
,
"ja"
],
value
=
None
,
scale
=
1
)
lang
=
gr
.
Dropdown
(
choices
=
[
"en"
,
"ru"
,
"zh"
,
"ko"
,
"ja"
],
value
=
None
,
scale
=
1
)
available_models
=
list
(
SUPPORTED_MODELS
.
keys
())
+
[
"Custom"
]
available_models
=
list
(
SUPPORTED_MODELS
.
keys
())
+
[
"Custom"
]
model_name
=
gr
.
Dropdown
(
choices
=
available_models
,
value
=
None
,
scale
=
3
)
model_name
=
gr
.
Dropdown
(
choices
=
available_models
,
value
=
None
,
scale
=
2
)
model_path
=
gr
.
Textbox
(
scale
=
3
)
model_path
=
gr
.
Textbox
(
scale
=
2
)
default_hub
=
"modelscope"
if
use_modelscope
()
else
"openmind"
if
use_openmind
()
else
"huggingface"
hub_name
=
gr
.
Dropdown
(
choices
=
[
"huggingface"
,
"modelscope"
,
"openmind"
],
value
=
default_hub
,
scale
=
2
)
with
gr
.
Row
():
with
gr
.
Row
():
finetuning_type
=
gr
.
Dropdown
(
choices
=
METHODS
,
value
=
"lora"
,
scale
=
1
)
finetuning_type
=
gr
.
Dropdown
(
choices
=
METHODS
,
value
=
"lora"
,
scale
=
1
)
...
@@ -50,18 +53,25 @@ def create_top() -> dict[str, "Component"]:
...
@@ -50,18 +53,25 @@ def create_top() -> dict[str, "Component"]:
model_name
.
change
(
get_model_info
,
[
model_name
],
[
model_path
,
template
],
queue
=
False
).
then
(
model_name
.
change
(
get_model_info
,
[
model_name
],
[
model_path
,
template
],
queue
=
False
).
then
(
list_checkpoints
,
[
model_name
,
finetuning_type
],
[
checkpoint_path
],
queue
=
False
list_checkpoints
,
[
model_name
,
finetuning_type
],
[
checkpoint_path
],
queue
=
False
).
then
(
check_template
,
[
lang
,
template
])
).
then
(
check_template
,
[
lang
,
template
])
model_name
.
input
(
save_config
,
inputs
=
[
lang
,
model_name
],
queue
=
False
)
model_name
.
input
(
save_config
,
inputs
=
[
lang
,
hub_name
,
model_name
],
queue
=
False
)
model_path
.
input
(
save_config
,
inputs
=
[
lang
,
model_name
,
model_path
],
queue
=
False
)
model_path
.
input
(
save_config
,
inputs
=
[
lang
,
hub_name
,
model_name
,
model_path
],
queue
=
False
)
finetuning_type
.
change
(
can_quantize
,
[
finetuning_type
],
[
quantization_bit
],
queue
=
False
).
then
(
finetuning_type
.
change
(
can_quantize
,
[
finetuning_type
],
[
quantization_bit
],
queue
=
False
).
then
(
list_checkpoints
,
[
model_name
,
finetuning_type
],
[
checkpoint_path
],
queue
=
False
list_checkpoints
,
[
model_name
,
finetuning_type
],
[
checkpoint_path
],
queue
=
False
)
)
checkpoint_path
.
focus
(
list_checkpoints
,
[
model_name
,
finetuning_type
],
[
checkpoint_path
],
queue
=
False
)
checkpoint_path
.
focus
(
list_checkpoints
,
[
model_name
,
finetuning_type
],
[
checkpoint_path
],
queue
=
False
)
quantization_method
.
change
(
can_quantize_to
,
[
quantization_method
],
[
quantization_bit
],
queue
=
False
)
quantization_method
.
change
(
can_quantize_to
,
[
quantization_method
],
[
quantization_bit
],
queue
=
False
)
hub_name
.
change
(
switch_hub
,
inputs
=
[
hub_name
],
queue
=
False
).
then
(
get_model_info
,
[
model_name
],
[
model_path
,
template
],
queue
=
False
).
then
(
list_checkpoints
,
[
model_name
,
finetuning_type
],
[
checkpoint_path
],
queue
=
False
).
then
(
check_template
,
[
lang
,
template
]
)
hub_name
.
input
(
save_config
,
inputs
=
[
lang
,
hub_name
],
queue
=
False
)
return
dict
(
return
dict
(
lang
=
lang
,
lang
=
lang
,
model_name
=
model_name
,
model_name
=
model_name
,
model_path
=
model_path
,
model_path
=
model_path
,
hub_name
=
hub_name
,
finetuning_type
=
finetuning_type
,
finetuning_type
=
finetuning_type
,
checkpoint_path
=
checkpoint_path
,
checkpoint_path
=
checkpoint_path
,
quantization_bit
=
quantization_bit
,
quantization_bit
=
quantization_bit
,
...
...
src/llamafactory/webui/control.py
View file @
d1588ee7
...
@@ -38,6 +38,15 @@ if is_gradio_available():
...
@@ -38,6 +38,15 @@ if is_gradio_available():
import
gradio
as
gr
import
gradio
as
gr
def
switch_hub
(
hub_name
:
str
)
->
None
:
r
"""Switch model hub.
Inputs: top.hub_name
"""
os
.
environ
[
"USE_MODELSCOPE_HUB"
]
=
"1"
if
hub_name
==
"modelscope"
else
"0"
os
.
environ
[
"USE_OPENMIND_HUB"
]
=
"1"
if
hub_name
==
"openmind"
else
"0"
def
can_quantize
(
finetuning_type
:
str
)
->
"gr.Dropdown"
:
def
can_quantize
(
finetuning_type
:
str
)
->
"gr.Dropdown"
:
r
"""Judge if the quantization is available in this finetuning type.
r
"""Judge if the quantization is available in this finetuning type.
...
@@ -112,7 +121,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tup
...
@@ -112,7 +121,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tup
running_log_path
=
os
.
path
.
join
(
output_path
,
RUNNING_LOG
)
running_log_path
=
os
.
path
.
join
(
output_path
,
RUNNING_LOG
)
if
os
.
path
.
isfile
(
running_log_path
):
if
os
.
path
.
isfile
(
running_log_path
):
with
open
(
running_log_path
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
running_log_path
,
encoding
=
"utf-8"
)
as
f
:
running_log
=
f
.
read
()[
-
20000
:]
# avoid lengthy log
running_log
=
"```
\n
"
+
f
.
read
()[
-
20000
:]
+
"
\n
```
\n
"
# avoid lengthy log
trainer_log_path
=
os
.
path
.
join
(
output_path
,
TRAINER_LOG
)
trainer_log_path
=
os
.
path
.
join
(
output_path
,
TRAINER_LOG
)
if
os
.
path
.
isfile
(
trainer_log_path
):
if
os
.
path
.
isfile
(
trainer_log_path
):
...
...
src/llamafactory/webui/engine.py
View file @
d1588ee7
...
@@ -49,11 +49,13 @@ class Engine:
...
@@ -49,11 +49,13 @@ class Engine:
def
resume
(
self
):
def
resume
(
self
):
r
"""Get the initial value of gradio components and restores training status if necessary."""
r
"""Get the initial value of gradio components and restores training status if necessary."""
user_config
=
load_config
()
if
not
self
.
demo_mode
else
{}
# do not use config in demo mode
user_config
=
load_config
()
if
not
self
.
demo_mode
else
{}
# do not use config in demo mode
lang
=
user_config
.
get
(
"lang"
,
None
)
or
"en"
lang
=
user_config
.
get
(
"lang"
)
or
"en"
init_dict
=
{
"top.lang"
:
{
"value"
:
lang
},
"infer.chat_box"
:
{
"visible"
:
self
.
chatter
.
loaded
}}
init_dict
=
{
"top.lang"
:
{
"value"
:
lang
},
"infer.chat_box"
:
{
"visible"
:
self
.
chatter
.
loaded
}}
if
not
self
.
pure_chat
:
if
not
self
.
pure_chat
:
current_time
=
get_time
()
current_time
=
get_time
()
hub_name
=
user_config
.
get
(
"hub_name"
)
or
"huggingface"
init_dict
[
"top.hub_name"
]
=
{
"value"
:
hub_name
}
init_dict
[
"train.current_time"
]
=
{
"value"
:
current_time
}
init_dict
[
"train.current_time"
]
=
{
"value"
:
current_time
}
init_dict
[
"train.output_dir"
]
=
{
"value"
:
f
"train_
{
current_time
}
"
}
init_dict
[
"train.output_dir"
]
=
{
"value"
:
f
"train_
{
current_time
}
"
}
init_dict
[
"train.config_path"
]
=
{
"value"
:
f
"
{
current_time
}
.yaml"
}
init_dict
[
"train.config_path"
]
=
{
"value"
:
f
"
{
current_time
}
.yaml"
}
...
...
src/llamafactory/webui/interface.py
View file @
d1588ee7
...
@@ -22,6 +22,7 @@ from .components import (
...
@@ -22,6 +22,7 @@ from .components import (
create_chat_box
,
create_chat_box
,
create_eval_tab
,
create_eval_tab
,
create_export_tab
,
create_export_tab
,
create_footer
,
create_infer_tab
,
create_infer_tab
,
create_top
,
create_top
,
create_train_tab
,
create_train_tab
,
...
@@ -38,15 +39,13 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
...
@@ -38,15 +39,13 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
engine
=
Engine
(
demo_mode
=
demo_mode
,
pure_chat
=
False
)
engine
=
Engine
(
demo_mode
=
demo_mode
,
pure_chat
=
False
)
hostname
=
os
.
getenv
(
"HOSTNAME"
,
os
.
getenv
(
"COMPUTERNAME"
,
platform
.
node
())).
split
(
"."
)[
0
]
hostname
=
os
.
getenv
(
"HOSTNAME"
,
os
.
getenv
(
"COMPUTERNAME"
,
platform
.
node
())).
split
(
"."
)[
0
]
with
gr
.
Blocks
(
title
=
f
"LLaMA Board (
{
hostname
}
)"
,
css
=
CSS
)
as
demo
:
with
gr
.
Blocks
(
title
=
f
"LLaMA Factory (
{
hostname
}
)"
,
css
=
CSS
)
as
demo
:
title
=
gr
.
HTML
()
subtitle
=
gr
.
HTML
()
if
demo_mode
:
if
demo_mode
:
gr
.
HTML
(
"<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>"
)
gr
.
HTML
(
'<h3><center>Visit <a href="https://github.com/hiyouga/LLaMA-Factory" target="_blank">'
"LLaMA Factory</a> for details.</center></h3>"
)
gr
.
DuplicateButton
(
value
=
"Duplicate Space for private use"
,
elem_classes
=
"duplicate-button"
)
gr
.
DuplicateButton
(
value
=
"Duplicate Space for private use"
,
elem_classes
=
"duplicate-button"
)
engine
.
manager
.
add_elems
(
"head"
,
{
"title"
:
title
,
"subtitle"
:
subtitle
})
engine
.
manager
.
add_elems
(
"top"
,
create_top
())
engine
.
manager
.
add_elems
(
"top"
,
create_top
())
lang
:
gr
.
Dropdown
=
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)
lang
:
gr
.
Dropdown
=
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)
...
@@ -63,6 +62,7 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
...
@@ -63,6 +62,7 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
with
gr
.
Tab
(
"Export"
):
with
gr
.
Tab
(
"Export"
):
engine
.
manager
.
add_elems
(
"export"
,
create_export_tab
(
engine
))
engine
.
manager
.
add_elems
(
"export"
,
create_export_tab
(
engine
))
engine
.
manager
.
add_elems
(
"footer"
,
create_footer
())
demo
.
load
(
engine
.
resume
,
outputs
=
engine
.
manager
.
get_elem_list
(),
concurrency_limit
=
None
)
demo
.
load
(
engine
.
resume
,
outputs
=
engine
.
manager
.
get_elem_list
(),
concurrency_limit
=
None
)
lang
.
change
(
engine
.
change_lang
,
[
lang
],
engine
.
manager
.
get_elem_list
(),
queue
=
False
)
lang
.
change
(
engine
.
change_lang
,
[
lang
],
engine
.
manager
.
get_elem_list
(),
queue
=
False
)
lang
.
input
(
save_config
,
inputs
=
[
lang
],
queue
=
False
)
lang
.
input
(
save_config
,
inputs
=
[
lang
],
queue
=
False
)
...
...
src/llamafactory/webui/locales.py
View file @
d1588ee7
...
@@ -13,6 +13,55 @@
...
@@ -13,6 +13,55 @@
# limitations under the License.
# limitations under the License.
LOCALES
=
{
LOCALES
=
{
"title"
:
{
"en"
:
{
"value"
:
"<h1><center>🦙🏭LLaMA Factory: Unified Efficient Fine-Tuning of 100+ LLMs</center></h1>"
,
},
"ru"
:
{
"value"
:
"<h1><center>🦙🏭LLaMA Factory: Унифицированная эффективная тонкая настройка 100+ LLMs</center></h1>"
,
},
"zh"
:
{
"value"
:
"<h1><center>🦙🏭LLaMA Factory: 一站式大模型高效微调平台</center></h1>"
,
},
"ko"
:
{
"value"
:
"<h1><center>🦙🏭LLaMA Factory: 100+ LLMs를 위한 통합 효율적인 튜닝</center></h1>"
,
},
"ja"
:
{
"value"
:
"<h1><center>🦙🏭LLaMA Factory: 100+ LLMs の統合効率的なチューニング</center></h1>"
,
},
},
"subtitle"
:
{
"en"
:
{
"value"
:
(
"<h3><center>Visit <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub Page</a></center></h3>"
),
},
"ru"
:
{
"value"
:
(
"<h3><center>Посетить <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"страницу GitHub</a></center></h3>"
),
},
"zh"
:
{
"value"
:
(
"<h3><center>访问 <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub 主页</a></center></h3>"
),
},
"ko"
:
{
"value"
:
(
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub 페이지</a>를 방문하세요.</center></h3>"
),
},
"ja"
:
{
"value"
:
(
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub ページ</a>にアクセスする</center></h3>"
),
},
},
"lang"
:
{
"lang"
:
{
"en"
:
{
"en"
:
{
"label"
:
"Language"
,
"label"
:
"Language"
,
...
@@ -74,6 +123,28 @@ LOCALES = {
...
@@ -74,6 +123,28 @@ LOCALES = {
"info"
:
"事前学習済みモデルへのパス、または Hugging Face のモデル識別子。"
,
"info"
:
"事前学習済みモデルへのパス、または Hugging Face のモデル識別子。"
,
},
},
},
},
"hub_name"
:
{
"en"
:
{
"label"
:
"Hub name"
,
"info"
:
"Choose the model download source."
,
},
"ru"
:
{
"label"
:
"Имя хаба"
,
"info"
:
"Выберите источник загрузки модели."
,
},
"zh"
:
{
"label"
:
"模型下载源"
,
"info"
:
"选择模型下载源。(网络受限环境推荐使用 ModelScope)"
,
},
"ko"
:
{
"label"
:
"모델 다운로드 소스"
,
"info"
:
"모델 다운로드 소스를 선택하세요."
,
},
"ja"
:
{
"label"
:
"モデルダウンロードソース"
,
"info"
:
"モデルをダウンロードするためのソースを選択してください。"
,
},
},
"finetuning_type"
:
{
"finetuning_type"
:
{
"en"
:
{
"en"
:
{
"label"
:
"Finetuning method"
,
"label"
:
"Finetuning method"
,
...
@@ -2849,6 +2920,28 @@ LOCALES = {
...
@@ -2849,6 +2920,28 @@ LOCALES = {
"value"
:
"エクスポート"
,
"value"
:
"エクスポート"
,
},
},
},
},
"device_memory"
:
{
"en"
:
{
"label"
:
"Device memory"
,
"info"
:
"Current memory usage of the device (GB)."
,
},
"ru"
:
{
"label"
:
"Память устройства"
,
"info"
:
"Текущая память на устройстве (GB)."
,
},
"zh"
:
{
"label"
:
"设备显存"
,
"info"
:
"当前设备的显存(GB)。"
,
},
"ko"
:
{
"label"
:
"디바이스 메모리"
,
"info"
:
"지금 사용 중인 기기 메모리 (GB)."
,
},
"ja"
:
{
"label"
:
"デバイスメモリ"
,
"info"
:
"現在のデバイスのメモリ(GB)。"
,
},
},
}
}
...
...
src/llamafactory/webui/runner.py
View file @
d1588ee7
...
@@ -16,14 +16,13 @@ import json
...
@@ -16,14 +16,13 @@ import json
import
os
import
os
from
collections.abc
import
Generator
from
collections.abc
import
Generator
from
copy
import
deepcopy
from
copy
import
deepcopy
from
subprocess
import
Popen
,
TimeoutExpired
from
subprocess
import
PIPE
,
Popen
,
TimeoutExpired
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
transformers.trainer
import
TRAINING_ARGS_NAME
from
transformers.utils
import
is_torch_npu_available
from
transformers.utils
import
is_torch_npu_available
from
..extras.constants
import
LLAMABOARD_CONFIG
,
MULTIMODAL_SUPPORTED_MODELS
,
PEFT_METHODS
,
TRAINING_STAGES
from
..extras.constants
import
LLAMABOARD_CONFIG
,
MULTIMODAL_SUPPORTED_MODELS
,
PEFT_METHODS
,
TRAINING_STAGES
from
..extras.misc
import
is_accelerator_available
,
torch_gc
,
use_ray
from
..extras.misc
import
is_accelerator_available
,
torch_gc
from
..extras.packages
import
is_gradio_available
from
..extras.packages
import
is_gradio_available
from
.common
import
(
from
.common
import
(
DEFAULT_CACHE_DIR
,
DEFAULT_CACHE_DIR
,
...
@@ -114,7 +113,7 @@ class Runner:
...
@@ -114,7 +113,7 @@ class Runner:
return
""
return
""
def
_finalize
(
self
,
lang
:
str
,
finish_info
:
str
)
->
str
:
def
_finalize
(
self
,
lang
:
str
,
finish_info
:
str
)
->
None
:
r
"""Clean the cached memory and resets the runner."""
r
"""Clean the cached memory and resets the runner."""
finish_info
=
ALERTS
[
"info_aborted"
][
lang
]
if
self
.
aborted
else
finish_info
finish_info
=
ALERTS
[
"info_aborted"
][
lang
]
if
self
.
aborted
else
finish_info
gr
.
Info
(
finish_info
)
gr
.
Info
(
finish_info
)
...
@@ -123,7 +122,6 @@ class Runner:
...
@@ -123,7 +122,6 @@ class Runner:
self
.
running
=
False
self
.
running
=
False
self
.
running_data
=
None
self
.
running_data
=
None
torch_gc
()
torch_gc
()
return
finish_info
def
_parse_train_args
(
self
,
data
:
dict
[
"Component"
,
Any
])
->
dict
[
str
,
Any
]:
def
_parse_train_args
(
self
,
data
:
dict
[
"Component"
,
Any
])
->
dict
[
str
,
Any
]:
r
"""Build and validate the training arguments."""
r
"""Build and validate the training arguments."""
...
@@ -314,11 +312,13 @@ class Runner:
...
@@ -314,11 +312,13 @@ class Runner:
max_samples
=
int
(
get
(
"eval.max_samples"
)),
max_samples
=
int
(
get
(
"eval.max_samples"
)),
per_device_eval_batch_size
=
get
(
"eval.batch_size"
),
per_device_eval_batch_size
=
get
(
"eval.batch_size"
),
predict_with_generate
=
True
,
predict_with_generate
=
True
,
report_to
=
"none"
,
max_new_tokens
=
get
(
"eval.max_new_tokens"
),
max_new_tokens
=
get
(
"eval.max_new_tokens"
),
top_p
=
get
(
"eval.top_p"
),
top_p
=
get
(
"eval.top_p"
),
temperature
=
get
(
"eval.temperature"
),
temperature
=
get
(
"eval.temperature"
),
output_dir
=
get_save_dir
(
model_name
,
finetuning_type
,
get
(
"eval.output_dir"
)),
output_dir
=
get_save_dir
(
model_name
,
finetuning_type
,
get
(
"eval.output_dir"
)),
trust_remote_code
=
True
,
trust_remote_code
=
True
,
ddp_timeout
=
180000000
,
)
)
if
get
(
"eval.predict"
):
if
get
(
"eval.predict"
):
...
@@ -375,7 +375,7 @@ class Runner:
...
@@ -375,7 +375,7 @@ class Runner:
env
[
"FORCE_TORCHRUN"
]
=
"1"
env
[
"FORCE_TORCHRUN"
]
=
"1"
# NOTE: DO NOT USE shell=True to avoid security risk
# NOTE: DO NOT USE shell=True to avoid security risk
self
.
trainer
=
Popen
([
"llamafactory-cli"
,
"train"
,
save_cmd
(
args
)],
env
=
env
)
self
.
trainer
=
Popen
([
"llamafactory-cli"
,
"train"
,
save_cmd
(
args
)],
env
=
env
,
stderr
=
PIPE
,
text
=
True
)
yield
from
self
.
monitor
()
yield
from
self
.
monitor
()
def
_build_config_dict
(
self
,
data
:
dict
[
"Component"
,
Any
])
->
dict
[
str
,
Any
]:
def
_build_config_dict
(
self
,
data
:
dict
[
"Component"
,
Any
])
->
dict
[
str
,
Any
]:
...
@@ -417,7 +417,8 @@ class Runner:
...
@@ -417,7 +417,8 @@ class Runner:
swanlab_link
=
self
.
manager
.
get_elem_by_id
(
"train.swanlab_link"
)
if
self
.
do_train
else
None
swanlab_link
=
self
.
manager
.
get_elem_by_id
(
"train.swanlab_link"
)
if
self
.
do_train
else
None
running_log
=
""
running_log
=
""
while
self
.
trainer
is
not
None
:
return_code
=
-
1
while
return_code
==
-
1
:
if
self
.
aborted
:
if
self
.
aborted
:
yield
{
yield
{
output_box
:
ALERTS
[
"info_aborting"
][
lang
],
output_box
:
ALERTS
[
"info_aborting"
][
lang
],
...
@@ -436,27 +437,26 @@ class Runner:
...
@@ -436,27 +437,26 @@ class Runner:
return_dict
[
swanlab_link
]
=
running_info
[
"swanlab_link"
]
return_dict
[
swanlab_link
]
=
running_info
[
"swanlab_link"
]
yield
return_dict
yield
return_dict
try
:
try
:
self
.
trainer
.
wait
(
2
)
stderr
=
self
.
trainer
.
communicate
(
timeout
=
2
)[
1
]
self
.
trainer
=
Non
e
return_code
=
self
.
trainer
.
returncod
e
except
TimeoutExpired
:
except
TimeoutExpired
:
continue
continue
if
self
.
do_train
:
if
return_code
==
0
or
self
.
aborted
:
if
os
.
path
.
exists
(
os
.
path
.
join
(
output_path
,
TRAINING_ARGS_NAME
))
or
use_ray
():
finish_info
=
ALERTS
[
"info_finished"
][
lang
]
finish_info
=
ALERTS
[
"info_finished"
][
lang
]
if
self
.
do_train
:
finish_log
=
ALERTS
[
"info_finished"
][
lang
]
+
"
\n\n
"
+
running_log
else
:
else
:
finish_info
=
ALERTS
[
"err_failed"
][
lang
]
finish_log
=
load_eval_results
(
os
.
path
.
join
(
output_path
,
"all_results.json"
))
+
"
\n\n
"
+
running_log
else
:
if
os
.
path
.
exists
(
os
.
path
.
join
(
output_path
,
"all_results.json"
))
or
use_ray
():
finish_info
=
load_eval_results
(
os
.
path
.
join
(
output_path
,
"all_results.json"
))
else
:
else
:
print
(
stderr
)
finish_info
=
ALERTS
[
"err_failed"
][
lang
]
finish_info
=
ALERTS
[
"err_failed"
][
lang
]
finish_log
=
ALERTS
[
"err_failed"
][
lang
]
+
f
" Exit code:
{
return_code
}
\n\n
```
\n
{
stderr
}
\n
```
\n
"
return_dict
=
{
self
.
_finalize
(
lang
,
finish_info
)
output_box
:
self
.
_finalize
(
lang
,
finish_info
)
+
"
\n\n
"
+
running_log
,
return_dict
=
{
output_box
:
finish_log
,
progress_bar
:
gr
.
Slider
(
visible
=
False
)}
progress_bar
:
gr
.
Slider
(
visible
=
False
),
}
yield
return_dict
yield
return_dict
def
save_args
(
self
,
data
):
def
save_args
(
self
,
data
):
...
...
tests/data/test_formatter.py
View file @
d1588ee7
...
@@ -110,8 +110,8 @@ def test_glm4_function_formatter():
...
@@ -110,8 +110,8 @@ def test_glm4_function_formatter():
def
test_glm4_tool_formatter
():
def
test_glm4_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"glm4"
)
formatter
=
ToolFormatter
(
tool_format
=
"glm4"
)
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱
AI
公司
训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具
\n\n
"
"你的任务是针对用户的问题和要求提供适当的答复和支持。
\n\n
# 可用工具
\n\n
"
f
"## test_tool
\n\n
{
json
.
dumps
(
TOOLS
[
0
],
indent
=
4
,
ensure_ascii
=
False
)
}
\n
"
f
"## test_tool
\n\n
{
json
.
dumps
(
TOOLS
[
0
],
indent
=
4
,
ensure_ascii
=
False
)
}
\n
"
"在调用上述函数时,请使用 Json 格式表示调用的参数。"
"在调用上述函数时,请使用 Json 格式表示调用的参数。"
]
]
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment