Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
7ea81099
Commit
7ea81099
authored
Apr 07, 2025
by
chenych
Browse files
update llama4
parent
84987715
Changes
139
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
236 additions
and
151 deletions
+236
-151
src/llamafactory/webui/components/top.py
src/llamafactory/webui/components/top.py
+2
-2
src/llamafactory/webui/components/train.py
src/llamafactory/webui/components/train.py
+4
-4
src/llamafactory/webui/control.py
src/llamafactory/webui/control.py
+14
-23
src/llamafactory/webui/engine.py
src/llamafactory/webui/engine.py
+7
-15
src/llamafactory/webui/interface.py
src/llamafactory/webui/interface.py
+8
-3
src/llamafactory/webui/manager.py
src/llamafactory/webui/manager.py
+15
-27
src/llamafactory/webui/runner.py
src/llamafactory/webui/runner.py
+25
-47
src/webui.py
src/webui.py
+3
-1
tests/check_license.py
tests/check_license.py
+38
-0
tests/data/processor/test_pairwise.py
tests/data/processor/test_pairwise.py
+1
-2
tests/data/processor/test_processor_utils.py
tests/data/processor/test_processor_utils.py
+1
-2
tests/data/test_formatter.py
tests/data/test_formatter.py
+4
-2
tests/data/test_mm_plugin.py
tests/data/test_mm_plugin.py
+31
-10
tests/data/test_template.py
tests/data/test_template.py
+5
-6
tests/e2e/test_sglang.py
tests/e2e/test_sglang.py
+71
-0
tests/model/model_utils/test_checkpointing.py
tests/model/model_utils/test_checkpointing.py
+1
-1
tests/model/test_pissa.py
tests/model/test_pissa.py
+1
-3
tests/train/test_sft_trainer.py
tests/train/test_sft_trainer.py
+3
-3
tests/version.txt
tests/version.txt
+2
-0
No files found.
src/llamafactory/webui/components/top.py
View file @
7ea81099
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Dict
from
typing
import
TYPE_CHECKING
from
...data
import
TEMPLATES
from
...data
import
TEMPLATES
from
...extras.constants
import
METHODS
,
SUPPORTED_MODELS
from
...extras.constants
import
METHODS
,
SUPPORTED_MODELS
...
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
...
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
from
gradio.components
import
Component
from
gradio.components
import
Component
def
create_top
()
->
D
ict
[
str
,
"Component"
]:
def
create_top
()
->
d
ict
[
str
,
"Component"
]:
with
gr
.
Row
():
with
gr
.
Row
():
lang
=
gr
.
Dropdown
(
choices
=
[
"en"
,
"ru"
,
"zh"
,
"ko"
,
"ja"
],
value
=
None
,
scale
=
1
)
lang
=
gr
.
Dropdown
(
choices
=
[
"en"
,
"ru"
,
"zh"
,
"ko"
,
"ja"
],
value
=
None
,
scale
=
1
)
available_models
=
list
(
SUPPORTED_MODELS
.
keys
())
+
[
"Custom"
]
available_models
=
list
(
SUPPORTED_MODELS
.
keys
())
+
[
"Custom"
]
...
...
src/llamafactory/webui/components/train.py
View file @
7ea81099
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Dict
from
typing
import
TYPE_CHECKING
from
transformers.trainer_utils
import
SchedulerType
from
transformers.trainer_utils
import
SchedulerType
...
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
...
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
from
..engine
import
Engine
from
..engine
import
Engine
def
create_train_tab
(
engine
:
"Engine"
)
->
D
ict
[
str
,
"Component"
]:
def
create_train_tab
(
engine
:
"Engine"
)
->
d
ict
[
str
,
"Component"
]:
input_elems
=
engine
.
manager
.
get_base_elems
()
input_elems
=
engine
.
manager
.
get_base_elems
()
elem_dict
=
dict
()
elem_dict
=
dict
()
...
@@ -382,8 +382,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
...
@@ -382,8 +382,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
resume_btn
.
change
(
engine
.
runner
.
monitor
,
outputs
=
output_elems
,
concurrency_limit
=
None
)
resume_btn
.
change
(
engine
.
runner
.
monitor
,
outputs
=
output_elems
,
concurrency_limit
=
None
)
lang
=
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)
lang
=
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)
model_name
:
"
gr.Dropdown
"
=
engine
.
manager
.
get_elem_by_id
(
"top.model_name"
)
model_name
:
gr
.
Dropdown
=
engine
.
manager
.
get_elem_by_id
(
"top.model_name"
)
finetuning_type
:
"
gr.Dropdown
"
=
engine
.
manager
.
get_elem_by_id
(
"top.finetuning_type"
)
finetuning_type
:
gr
.
Dropdown
=
engine
.
manager
.
get_elem_by_id
(
"top.finetuning_type"
)
arg_save_btn
.
click
(
engine
.
runner
.
save_args
,
input_elems
,
output_elems
,
concurrency_limit
=
None
)
arg_save_btn
.
click
(
engine
.
runner
.
save_args
,
input_elems
,
output_elems
,
concurrency_limit
=
None
)
arg_load_btn
.
click
(
arg_load_btn
.
click
(
...
...
src/llamafactory/webui/control.py
View file @
7ea81099
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
json
import
json
import
os
import
os
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Optional
from
transformers.trainer_utils
import
get_last_checkpoint
from
transformers.trainer_utils
import
get_last_checkpoint
...
@@ -39,8 +39,7 @@ if is_gradio_available():
...
@@ -39,8 +39,7 @@ if is_gradio_available():
def
can_quantize
(
finetuning_type
:
str
)
->
"gr.Dropdown"
:
def
can_quantize
(
finetuning_type
:
str
)
->
"gr.Dropdown"
:
r
"""
r
"""Judge if the quantization is available in this finetuning type.
Judges if the quantization is available in this finetuning type.
Inputs: top.finetuning_type
Inputs: top.finetuning_type
Outputs: top.quantization_bit
Outputs: top.quantization_bit
...
@@ -52,8 +51,7 @@ def can_quantize(finetuning_type: str) -> "gr.Dropdown":
...
@@ -52,8 +51,7 @@ def can_quantize(finetuning_type: str) -> "gr.Dropdown":
def
can_quantize_to
(
quantization_method
:
str
)
->
"gr.Dropdown"
:
def
can_quantize_to
(
quantization_method
:
str
)
->
"gr.Dropdown"
:
r
"""
r
"""Get the available quantization bits.
Gets the available quantization bits.
Inputs: top.quantization_method
Inputs: top.quantization_method
Outputs: top.quantization_bit
Outputs: top.quantization_bit
...
@@ -68,9 +66,8 @@ def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
...
@@ -68,9 +66,8 @@ def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
return
gr
.
Dropdown
(
choices
=
available_bits
)
return
gr
.
Dropdown
(
choices
=
available_bits
)
def
change_stage
(
training_stage
:
str
=
list
(
TRAINING_STAGES
.
keys
())[
0
])
->
Tuple
[
List
[
str
],
bool
]:
def
change_stage
(
training_stage
:
str
=
list
(
TRAINING_STAGES
.
keys
())[
0
])
->
tuple
[
list
[
str
],
bool
]:
r
"""
r
"""Modify states after changing the training stage.
Modifys states after changing the training stage.
Inputs: train.training_stage
Inputs: train.training_stage
Outputs: train.dataset, train.packing
Outputs: train.dataset, train.packing
...
@@ -78,9 +75,8 @@ def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple
...
@@ -78,9 +75,8 @@ def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple
return
[],
TRAINING_STAGES
[
training_stage
]
==
"pt"
return
[],
TRAINING_STAGES
[
training_stage
]
==
"pt"
def
get_model_info
(
model_name
:
str
)
->
Tuple
[
str
,
str
]:
def
get_model_info
(
model_name
:
str
)
->
tuple
[
str
,
str
]:
r
"""
r
"""Get the necessary information of this model.
Gets the necessary information of this model.
Inputs: top.model_name
Inputs: top.model_name
Outputs: top.model_path, top.template
Outputs: top.model_path, top.template
...
@@ -88,9 +84,8 @@ def get_model_info(model_name: str) -> Tuple[str, str]:
...
@@ -88,9 +84,8 @@ def get_model_info(model_name: str) -> Tuple[str, str]:
return
get_model_path
(
model_name
),
get_template
(
model_name
)
return
get_model_path
(
model_name
),
get_template
(
model_name
)
def
get_trainer_info
(
lang
:
str
,
output_path
:
os
.
PathLike
,
do_train
:
bool
)
->
Tuple
[
str
,
"gr.Slider"
,
Dict
[
str
,
Any
]]:
def
get_trainer_info
(
lang
:
str
,
output_path
:
os
.
PathLike
,
do_train
:
bool
)
->
tuple
[
str
,
"gr.Slider"
,
dict
[
str
,
Any
]]:
r
"""
r
"""Get training infomation for monitor.
Gets training infomation for monitor.
If do_train is True:
If do_train is True:
Inputs: top.lang, train.output_path
Inputs: top.lang, train.output_path
...
@@ -110,7 +105,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tup
...
@@ -110,7 +105,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tup
trainer_log_path
=
os
.
path
.
join
(
output_path
,
TRAINER_LOG
)
trainer_log_path
=
os
.
path
.
join
(
output_path
,
TRAINER_LOG
)
if
os
.
path
.
isfile
(
trainer_log_path
):
if
os
.
path
.
isfile
(
trainer_log_path
):
trainer_log
:
L
ist
[
D
ict
[
str
,
Any
]]
=
[]
trainer_log
:
l
ist
[
d
ict
[
str
,
Any
]]
=
[]
with
open
(
trainer_log_path
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
trainer_log_path
,
encoding
=
"utf-8"
)
as
f
:
for
line
in
f
:
for
line
in
f
:
trainer_log
.
append
(
json
.
loads
(
line
))
trainer_log
.
append
(
json
.
loads
(
line
))
...
@@ -143,8 +138,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tup
...
@@ -143,8 +138,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tup
def
list_checkpoints
(
model_name
:
str
,
finetuning_type
:
str
)
->
"gr.Dropdown"
:
def
list_checkpoints
(
model_name
:
str
,
finetuning_type
:
str
)
->
"gr.Dropdown"
:
r
"""
r
"""List all available checkpoints.
Lists all available checkpoints.
Inputs: top.model_name, top.finetuning_type
Inputs: top.model_name, top.finetuning_type
Outputs: top.checkpoint_path
Outputs: top.checkpoint_path
...
@@ -166,8 +160,7 @@ def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
...
@@ -166,8 +160,7 @@ def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
def
list_config_paths
(
current_time
:
str
)
->
"gr.Dropdown"
:
def
list_config_paths
(
current_time
:
str
)
->
"gr.Dropdown"
:
r
"""
r
"""List all the saved configuration files.
Lists all the saved configuration files.
Inputs: train.current_time
Inputs: train.current_time
Outputs: train.config_path
Outputs: train.config_path
...
@@ -182,8 +175,7 @@ def list_config_paths(current_time: str) -> "gr.Dropdown":
...
@@ -182,8 +175,7 @@ def list_config_paths(current_time: str) -> "gr.Dropdown":
def
list_datasets
(
dataset_dir
:
str
=
None
,
training_stage
:
str
=
list
(
TRAINING_STAGES
.
keys
())[
0
])
->
"gr.Dropdown"
:
def
list_datasets
(
dataset_dir
:
str
=
None
,
training_stage
:
str
=
list
(
TRAINING_STAGES
.
keys
())[
0
])
->
"gr.Dropdown"
:
r
"""
r
"""List all available datasets in the dataset dir for the training stage.
Lists all available datasets in the dataset dir for the training stage.
Inputs: *.dataset_dir, *.training_stage
Inputs: *.dataset_dir, *.training_stage
Outputs: *.dataset
Outputs: *.dataset
...
@@ -195,8 +187,7 @@ def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_S
...
@@ -195,8 +187,7 @@ def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_S
def
list_output_dirs
(
model_name
:
Optional
[
str
],
finetuning_type
:
str
,
current_time
:
str
)
->
"gr.Dropdown"
:
def
list_output_dirs
(
model_name
:
Optional
[
str
],
finetuning_type
:
str
,
current_time
:
str
)
->
"gr.Dropdown"
:
r
"""
r
"""List all the directories that can resume from.
Lists all the directories that can resume from.
Inputs: top.model_name, top.finetuning_type, train.current_time
Inputs: top.model_name, top.finetuning_type, train.current_time
Outputs: train.output_dir
Outputs: train.output_dir
...
...
src/llamafactory/webui/engine.py
View file @
7ea81099
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
from
typing
import
TYPE_CHECKING
,
Any
from
.chatter
import
WebChatModel
from
.chatter
import
WebChatModel
from
.common
import
create_ds_config
,
get_time
,
load_config
from
.common
import
create_ds_config
,
get_time
,
load_config
...
@@ -26,9 +26,7 @@ if TYPE_CHECKING:
...
@@ -26,9 +26,7 @@ if TYPE_CHECKING:
class
Engine
:
class
Engine
:
r
"""
r
"""A general engine to control the behaviors of Web UI."""
A general engine to control the behaviors of Web UI.
"""
def
__init__
(
self
,
demo_mode
:
bool
=
False
,
pure_chat
:
bool
=
False
)
->
None
:
def
__init__
(
self
,
demo_mode
:
bool
=
False
,
pure_chat
:
bool
=
False
)
->
None
:
self
.
demo_mode
=
demo_mode
self
.
demo_mode
=
demo_mode
...
@@ -39,11 +37,9 @@ class Engine:
...
@@ -39,11 +37,9 @@ class Engine:
if
not
demo_mode
:
if
not
demo_mode
:
create_ds_config
()
create_ds_config
()
def
_update_component
(
self
,
input_dict
:
Dict
[
str
,
Dict
[
str
,
Any
]])
->
Dict
[
"Component"
,
"Component"
]:
def
_update_component
(
self
,
input_dict
:
dict
[
str
,
dict
[
str
,
Any
]])
->
dict
[
"Component"
,
"Component"
]:
r
"""
r
"""Update gradio components according to the (elem_id, properties) mapping."""
Updates gradio components according to the (elem_id, properties) mapping.
output_dict
:
dict
[
Component
,
Component
]
=
{}
"""
output_dict
:
Dict
[
"Component"
,
"Component"
]
=
{}
for
elem_id
,
elem_attr
in
input_dict
.
items
():
for
elem_id
,
elem_attr
in
input_dict
.
items
():
elem
=
self
.
manager
.
get_elem_by_id
(
elem_id
)
elem
=
self
.
manager
.
get_elem_by_id
(
elem_id
)
output_dict
[
elem
]
=
elem
.
__class__
(
**
elem_attr
)
output_dict
[
elem
]
=
elem
.
__class__
(
**
elem_attr
)
...
@@ -51,9 +47,7 @@ class Engine:
...
@@ -51,9 +47,7 @@ class Engine:
return
output_dict
return
output_dict
def
resume
(
self
):
def
resume
(
self
):
r
"""
r
"""Get the initial value of gradio components and restores training status if necessary."""
Gets the initial value of gradio components and restores training status if necessary.
"""
user_config
=
load_config
()
if
not
self
.
demo_mode
else
{}
# do not use config in demo mode
user_config
=
load_config
()
if
not
self
.
demo_mode
else
{}
# do not use config in demo mode
lang
=
user_config
.
get
(
"lang"
,
None
)
or
"en"
lang
=
user_config
.
get
(
"lang"
,
None
)
or
"en"
init_dict
=
{
"top.lang"
:
{
"value"
:
lang
},
"infer.chat_box"
:
{
"visible"
:
self
.
chatter
.
loaded
}}
init_dict
=
{
"top.lang"
:
{
"value"
:
lang
},
"infer.chat_box"
:
{
"visible"
:
self
.
chatter
.
loaded
}}
...
@@ -79,9 +73,7 @@ class Engine:
...
@@ -79,9 +73,7 @@ class Engine:
yield
self
.
_update_component
({
"eval.resume_btn"
:
{
"value"
:
True
}})
yield
self
.
_update_component
({
"eval.resume_btn"
:
{
"value"
:
True
}})
def
change_lang
(
self
,
lang
:
str
):
def
change_lang
(
self
,
lang
:
str
):
r
"""
r
"""Update the displayed language of gradio components."""
Updates the displayed language of gradio components.
"""
return
{
return
{
elem
:
elem
.
__class__
(
**
LOCALES
[
elem_name
][
lang
])
elem
:
elem
.
__class__
(
**
LOCALES
[
elem_name
][
lang
])
for
elem_name
,
elem
in
self
.
manager
.
get_elem_iter
()
for
elem_name
,
elem
in
self
.
manager
.
get_elem_iter
()
...
...
src/llamafactory/webui/interface.py
View file @
7ea81099
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
import
os
import
os
import
platform
import
platform
from
..extras.misc
import
is_env_enabled
from
..extras.misc
import
fix_proxy
,
is_env_enabled
from
..extras.packages
import
is_gradio_available
from
..extras.packages
import
is_gradio_available
from
.common
import
save_config
from
.common
import
save_config
from
.components
import
(
from
.components
import
(
...
@@ -48,7 +48,7 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
...
@@ -48,7 +48,7 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
gr
.
DuplicateButton
(
value
=
"Duplicate Space for private use"
,
elem_classes
=
"duplicate-button"
)
gr
.
DuplicateButton
(
value
=
"Duplicate Space for private use"
,
elem_classes
=
"duplicate-button"
)
engine
.
manager
.
add_elems
(
"top"
,
create_top
())
engine
.
manager
.
add_elems
(
"top"
,
create_top
())
lang
:
"
gr.Dropdown
"
=
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)
lang
:
gr
.
Dropdown
=
engine
.
manager
.
get_elem_by_id
(
"top.lang"
)
with
gr
.
Tab
(
"Train"
):
with
gr
.
Tab
(
"Train"
):
engine
.
manager
.
add_elems
(
"train"
,
create_train_tab
(
engine
))
engine
.
manager
.
add_elems
(
"train"
,
create_train_tab
(
engine
))
...
@@ -72,8 +72,9 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
...
@@ -72,8 +72,9 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
def
create_web_demo
()
->
"gr.Blocks"
:
def
create_web_demo
()
->
"gr.Blocks"
:
engine
=
Engine
(
pure_chat
=
True
)
engine
=
Engine
(
pure_chat
=
True
)
hostname
=
os
.
getenv
(
"HOSTNAME"
,
os
.
getenv
(
"COMPUTERNAME"
,
platform
.
node
())).
split
(
"."
)[
0
]
with
gr
.
Blocks
(
title
=
"Web Demo
"
,
css
=
CSS
)
as
demo
:
with
gr
.
Blocks
(
title
=
f
"LLaMA Factory Web Demo (
{
hostname
}
)
"
,
css
=
CSS
)
as
demo
:
lang
=
gr
.
Dropdown
(
choices
=
[
"en"
,
"ru"
,
"zh"
,
"ko"
,
"ja"
],
scale
=
1
)
lang
=
gr
.
Dropdown
(
choices
=
[
"en"
,
"ru"
,
"zh"
,
"ko"
,
"ja"
],
scale
=
1
)
engine
.
manager
.
add_elems
(
"top"
,
dict
(
lang
=
lang
))
engine
.
manager
.
add_elems
(
"top"
,
dict
(
lang
=
lang
))
...
@@ -91,6 +92,8 @@ def run_web_ui() -> None:
...
@@ -91,6 +92,8 @@ def run_web_ui() -> None:
gradio_ipv6
=
is_env_enabled
(
"GRADIO_IPV6"
)
gradio_ipv6
=
is_env_enabled
(
"GRADIO_IPV6"
)
gradio_share
=
is_env_enabled
(
"GRADIO_SHARE"
)
gradio_share
=
is_env_enabled
(
"GRADIO_SHARE"
)
server_name
=
os
.
getenv
(
"GRADIO_SERVER_NAME"
,
"[::]"
if
gradio_ipv6
else
"0.0.0.0"
)
server_name
=
os
.
getenv
(
"GRADIO_SERVER_NAME"
,
"[::]"
if
gradio_ipv6
else
"0.0.0.0"
)
print
(
"Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860"
)
fix_proxy
(
ipv6_enabled
=
gradio_ipv6
)
create_ui
().
queue
().
launch
(
share
=
gradio_share
,
server_name
=
server_name
,
inbrowser
=
True
)
create_ui
().
queue
().
launch
(
share
=
gradio_share
,
server_name
=
server_name
,
inbrowser
=
True
)
...
@@ -98,4 +101,6 @@ def run_web_demo() -> None:
...
@@ -98,4 +101,6 @@ def run_web_demo() -> None:
gradio_ipv6
=
is_env_enabled
(
"GRADIO_IPV6"
)
gradio_ipv6
=
is_env_enabled
(
"GRADIO_IPV6"
)
gradio_share
=
is_env_enabled
(
"GRADIO_SHARE"
)
gradio_share
=
is_env_enabled
(
"GRADIO_SHARE"
)
server_name
=
os
.
getenv
(
"GRADIO_SERVER_NAME"
,
"[::]"
if
gradio_ipv6
else
"0.0.0.0"
)
server_name
=
os
.
getenv
(
"GRADIO_SERVER_NAME"
,
"[::]"
if
gradio_ipv6
else
"0.0.0.0"
)
print
(
"Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860"
)
fix_proxy
(
ipv6_enabled
=
gradio_ipv6
)
create_web_demo
().
queue
().
launch
(
share
=
gradio_share
,
server_name
=
server_name
,
inbrowser
=
True
)
create_web_demo
().
queue
().
launch
(
share
=
gradio_share
,
server_name
=
server_name
,
inbrowser
=
True
)
src/llamafactory/webui/manager.py
View file @
7ea81099
...
@@ -12,7 +12,8 @@
...
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Dict
,
Generator
,
List
,
Set
,
Tuple
from
collections.abc
import
Generator
from
typing
import
TYPE_CHECKING
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -20,54 +21,41 @@ if TYPE_CHECKING:
...
@@ -20,54 +21,41 @@ if TYPE_CHECKING:
class
Manager
:
class
Manager
:
r
"""
r
"""A class to manage all the gradio components in Web UI."""
A class to manage all the gradio components in Web UI.
"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
self
.
_id_to_elem
:
D
ict
[
str
,
"
Component
"
]
=
{}
self
.
_id_to_elem
:
d
ict
[
str
,
Component
]
=
{}
self
.
_elem_to_id
:
D
ict
[
"
Component
"
,
str
]
=
{}
self
.
_elem_to_id
:
d
ict
[
Component
,
str
]
=
{}
def
add_elems
(
self
,
tab_name
:
str
,
elem_dict
:
Dict
[
str
,
"Component"
])
->
None
:
def
add_elems
(
self
,
tab_name
:
str
,
elem_dict
:
dict
[
str
,
"Component"
])
->
None
:
r
"""
r
"""Add elements to manager."""
Adds elements to manager.
"""
for
elem_name
,
elem
in
elem_dict
.
items
():
for
elem_name
,
elem
in
elem_dict
.
items
():
elem_id
=
f
"
{
tab_name
}
.
{
elem_name
}
"
elem_id
=
f
"
{
tab_name
}
.
{
elem_name
}
"
self
.
_id_to_elem
[
elem_id
]
=
elem
self
.
_id_to_elem
[
elem_id
]
=
elem
self
.
_elem_to_id
[
elem
]
=
elem_id
self
.
_elem_to_id
[
elem
]
=
elem_id
def
get_elem_list
(
self
)
->
List
[
"Component"
]:
def
get_elem_list
(
self
)
->
list
[
"Component"
]:
r
"""
r
"""Return the list of all elements."""
Returns the list of all elements.
"""
return
list
(
self
.
_id_to_elem
.
values
())
return
list
(
self
.
_id_to_elem
.
values
())
def
get_elem_iter
(
self
)
->
Generator
[
Tuple
[
str
,
"Component"
],
None
,
None
]:
def
get_elem_iter
(
self
)
->
Generator
[
tuple
[
str
,
"Component"
],
None
,
None
]:
r
"""
r
"""Return an iterator over all elements with their names."""
Returns an iterator over all elements with their names.
"""
for
elem_id
,
elem
in
self
.
_id_to_elem
.
items
():
for
elem_id
,
elem
in
self
.
_id_to_elem
.
items
():
yield
elem_id
.
split
(
"."
)[
-
1
],
elem
yield
elem_id
.
split
(
"."
)[
-
1
],
elem
def
get_elem_by_id
(
self
,
elem_id
:
str
)
->
"Component"
:
def
get_elem_by_id
(
self
,
elem_id
:
str
)
->
"Component"
:
r
"""
r
"""Get element by id.
Gets element by id.
Example: top.lang, train.dataset
Example: top.lang, train.dataset
"""
"""
return
self
.
_id_to_elem
[
elem_id
]
return
self
.
_id_to_elem
[
elem_id
]
def
get_id_by_elem
(
self
,
elem
:
"Component"
)
->
str
:
def
get_id_by_elem
(
self
,
elem
:
"Component"
)
->
str
:
r
"""
r
"""Get id by element."""
Gets id by element.
"""
return
self
.
_elem_to_id
[
elem
]
return
self
.
_elem_to_id
[
elem
]
def
get_base_elems
(
self
)
->
Set
[
"Component"
]:
def
get_base_elems
(
self
)
->
set
[
"Component"
]:
r
"""
r
"""Get the base elements that are commonly used."""
Gets the base elements that are commonly used.
"""
return
{
return
{
self
.
_id_to_elem
[
"top.lang"
],
self
.
_id_to_elem
[
"top.lang"
],
self
.
_id_to_elem
[
"top.model_name"
],
self
.
_id_to_elem
[
"top.model_name"
],
...
...
src/llamafactory/webui/runner.py
View file @
7ea81099
...
@@ -14,9 +14,10 @@
...
@@ -14,9 +14,10 @@
import
json
import
json
import
os
import
os
from
collections.abc
import
Generator
from
copy
import
deepcopy
from
copy
import
deepcopy
from
subprocess
import
Popen
,
TimeoutExpired
from
subprocess
import
Popen
,
TimeoutExpired
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Generator
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
transformers.trainer
import
TRAINING_ARGS_NAME
from
transformers.trainer
import
TRAINING_ARGS_NAME
from
transformers.utils
import
is_torch_npu_available
from
transformers.utils
import
is_torch_npu_available
...
@@ -51,17 +52,16 @@ if TYPE_CHECKING:
...
@@ -51,17 +52,16 @@ if TYPE_CHECKING:
class
Runner
:
class
Runner
:
r
"""
r
"""A class to manage the running status of the trainers."""
A class to manage the running status of the trainers.
"""
def
__init__
(
self
,
manager
:
"Manager"
,
demo_mode
:
bool
=
False
)
->
None
:
def
__init__
(
self
,
manager
:
"Manager"
,
demo_mode
:
bool
=
False
)
->
None
:
r
"""Init a runner."""
self
.
manager
=
manager
self
.
manager
=
manager
self
.
demo_mode
=
demo_mode
self
.
demo_mode
=
demo_mode
""" Resume """
""" Resume """
self
.
trainer
:
Optional
[
"
Popen
"
]
=
None
self
.
trainer
:
Optional
[
Popen
]
=
None
self
.
do_train
=
True
self
.
do_train
=
True
self
.
running_data
:
D
ict
[
"
Component
"
,
Any
]
=
None
self
.
running_data
:
d
ict
[
Component
,
Any
]
=
None
""" State """
""" State """
self
.
aborted
=
False
self
.
aborted
=
False
self
.
running
=
False
self
.
running
=
False
...
@@ -71,10 +71,8 @@ class Runner:
...
@@ -71,10 +71,8 @@ class Runner:
if
self
.
trainer
is
not
None
:
if
self
.
trainer
is
not
None
:
abort_process
(
self
.
trainer
.
pid
)
abort_process
(
self
.
trainer
.
pid
)
def
_initialize
(
self
,
data
:
Dict
[
"Component"
,
Any
],
do_train
:
bool
,
from_preview
:
bool
)
->
str
:
def
_initialize
(
self
,
data
:
dict
[
"Component"
,
Any
],
do_train
:
bool
,
from_preview
:
bool
)
->
str
:
r
"""
r
"""Validate the configuration."""
Validates the configuration.
"""
get
=
lambda
elem_id
:
data
[
self
.
manager
.
get_elem_by_id
(
elem_id
)]
get
=
lambda
elem_id
:
data
[
self
.
manager
.
get_elem_by_id
(
elem_id
)]
lang
,
model_name
,
model_path
=
get
(
"top.lang"
),
get
(
"top.model_name"
),
get
(
"top.model_path"
)
lang
,
model_name
,
model_path
=
get
(
"top.lang"
),
get
(
"top.model_name"
),
get
(
"top.model_path"
)
dataset
=
get
(
"train.dataset"
)
if
do_train
else
get
(
"eval.dataset"
)
dataset
=
get
(
"train.dataset"
)
if
do_train
else
get
(
"eval.dataset"
)
...
@@ -116,9 +114,7 @@ class Runner:
...
@@ -116,9 +114,7 @@ class Runner:
return
""
return
""
def
_finalize
(
self
,
lang
:
str
,
finish_info
:
str
)
->
str
:
def
_finalize
(
self
,
lang
:
str
,
finish_info
:
str
)
->
str
:
r
"""
r
"""Clean the cached memory and resets the runner."""
Cleans the cached memory and resets the runner.
"""
finish_info
=
ALERTS
[
"info_aborted"
][
lang
]
if
self
.
aborted
else
finish_info
finish_info
=
ALERTS
[
"info_aborted"
][
lang
]
if
self
.
aborted
else
finish_info
gr
.
Info
(
finish_info
)
gr
.
Info
(
finish_info
)
self
.
trainer
=
None
self
.
trainer
=
None
...
@@ -128,10 +124,8 @@ class Runner:
...
@@ -128,10 +124,8 @@ class Runner:
torch_gc
()
torch_gc
()
return
finish_info
return
finish_info
def
_parse_train_args
(
self
,
data
:
Dict
[
"Component"
,
Any
])
->
Dict
[
str
,
Any
]:
def
_parse_train_args
(
self
,
data
:
dict
[
"Component"
,
Any
])
->
dict
[
str
,
Any
]:
r
"""
r
"""Build and validate the training arguments."""
Builds and validates the training arguments.
"""
get
=
lambda
elem_id
:
data
[
self
.
manager
.
get_elem_by_id
(
elem_id
)]
get
=
lambda
elem_id
:
data
[
self
.
manager
.
get_elem_by_id
(
elem_id
)]
model_name
,
finetuning_type
=
get
(
"top.model_name"
),
get
(
"top.finetuning_type"
)
model_name
,
finetuning_type
=
get
(
"top.model_name"
),
get
(
"top.finetuning_type"
)
user_config
=
load_config
()
user_config
=
load_config
()
...
@@ -291,10 +285,8 @@ class Runner:
...
@@ -291,10 +285,8 @@ class Runner:
return
args
return
args
def
_parse_eval_args
(
self
,
data
:
Dict
[
"Component"
,
Any
])
->
Dict
[
str
,
Any
]:
def
_parse_eval_args
(
self
,
data
:
dict
[
"Component"
,
Any
])
->
dict
[
str
,
Any
]:
r
"""
r
"""Build and validate the evaluation arguments."""
Builds and validates the evaluation arguments.
"""
get
=
lambda
elem_id
:
data
[
self
.
manager
.
get_elem_by_id
(
elem_id
)]
get
=
lambda
elem_id
:
data
[
self
.
manager
.
get_elem_by_id
(
elem_id
)]
model_name
,
finetuning_type
=
get
(
"top.model_name"
),
get
(
"top.finetuning_type"
)
model_name
,
finetuning_type
=
get
(
"top.model_name"
),
get
(
"top.finetuning_type"
)
user_config
=
load_config
()
user_config
=
load_config
()
...
@@ -345,10 +337,8 @@ class Runner:
...
@@ -345,10 +337,8 @@ class Runner:
return
args
return
args
def
_preview
(
self
,
data
:
Dict
[
"Component"
,
Any
],
do_train
:
bool
)
->
Generator
[
Dict
[
"Component"
,
str
],
None
,
None
]:
def
_preview
(
self
,
data
:
dict
[
"Component"
,
Any
],
do_train
:
bool
)
->
Generator
[
dict
[
"Component"
,
str
],
None
,
None
]:
r
"""
r
"""Preview the training commands."""
Previews the training commands.
"""
output_box
=
self
.
manager
.
get_elem_by_id
(
"{}.output_box"
.
format
(
"train"
if
do_train
else
"eval"
))
output_box
=
self
.
manager
.
get_elem_by_id
(
"{}.output_box"
.
format
(
"train"
if
do_train
else
"eval"
))
error
=
self
.
_initialize
(
data
,
do_train
,
from_preview
=
True
)
error
=
self
.
_initialize
(
data
,
do_train
,
from_preview
=
True
)
if
error
:
if
error
:
...
@@ -358,10 +348,8 @@ class Runner:
...
@@ -358,10 +348,8 @@ class Runner:
args
=
self
.
_parse_train_args
(
data
)
if
do_train
else
self
.
_parse_eval_args
(
data
)
args
=
self
.
_parse_train_args
(
data
)
if
do_train
else
self
.
_parse_eval_args
(
data
)
yield
{
output_box
:
gen_cmd
(
args
)}
yield
{
output_box
:
gen_cmd
(
args
)}
def
_launch
(
self
,
data
:
Dict
[
"Component"
,
Any
],
do_train
:
bool
)
->
Generator
[
Dict
[
"Component"
,
Any
],
None
,
None
]:
def
_launch
(
self
,
data
:
dict
[
"Component"
,
Any
],
do_train
:
bool
)
->
Generator
[
dict
[
"Component"
,
Any
],
None
,
None
]:
r
"""
r
"""Start the training process."""
Starts the training process.
"""
output_box
=
self
.
manager
.
get_elem_by_id
(
"{}.output_box"
.
format
(
"train"
if
do_train
else
"eval"
))
output_box
=
self
.
manager
.
get_elem_by_id
(
"{}.output_box"
.
format
(
"train"
if
do_train
else
"eval"
))
error
=
self
.
_initialize
(
data
,
do_train
,
from_preview
=
False
)
error
=
self
.
_initialize
(
data
,
do_train
,
from_preview
=
False
)
if
error
:
if
error
:
...
@@ -383,10 +371,8 @@ class Runner:
...
@@ -383,10 +371,8 @@ class Runner:
self
.
trainer
=
Popen
([
"llamafactory-cli"
,
"train"
,
save_cmd
(
args
)],
env
=
env
)
self
.
trainer
=
Popen
([
"llamafactory-cli"
,
"train"
,
save_cmd
(
args
)],
env
=
env
)
yield
from
self
.
monitor
()
yield
from
self
.
monitor
()
def
_build_config_dict
(
self
,
data
:
Dict
[
"Component"
,
Any
])
->
Dict
[
str
,
Any
]:
def
_build_config_dict
(
self
,
data
:
dict
[
"Component"
,
Any
])
->
dict
[
str
,
Any
]:
r
"""
r
"""Build a dictionary containing the current training configuration."""
Builds a dictionary containing the current training configuration.
"""
config_dict
=
{}
config_dict
=
{}
skip_ids
=
[
"top.lang"
,
"top.model_path"
,
"train.output_dir"
,
"train.config_path"
]
skip_ids
=
[
"top.lang"
,
"top.model_path"
,
"train.output_dir"
,
"train.config_path"
]
for
elem
,
value
in
data
.
items
():
for
elem
,
value
in
data
.
items
():
...
@@ -409,9 +395,7 @@ class Runner:
...
@@ -409,9 +395,7 @@ class Runner:
yield
from
self
.
_launch
(
data
,
do_train
=
False
)
yield
from
self
.
_launch
(
data
,
do_train
=
False
)
def
monitor
(
self
):
def
monitor
(
self
):
r
"""
r
"""Monitorgit the training progress and logs."""
Monitors the training progress and logs.
"""
self
.
aborted
=
False
self
.
aborted
=
False
self
.
running
=
True
self
.
running
=
True
...
@@ -469,9 +453,7 @@ class Runner:
...
@@ -469,9 +453,7 @@ class Runner:
yield
return_dict
yield
return_dict
def
save_args
(
self
,
data
):
def
save_args
(
self
,
data
):
r
"""
r
"""Save the training configuration to config path."""
Saves the training configuration to config path.
"""
output_box
=
self
.
manager
.
get_elem_by_id
(
"train.output_box"
)
output_box
=
self
.
manager
.
get_elem_by_id
(
"train.output_box"
)
error
=
self
.
_initialize
(
data
,
do_train
=
True
,
from_preview
=
True
)
error
=
self
.
_initialize
(
data
,
do_train
=
True
,
from_preview
=
True
)
if
error
:
if
error
:
...
@@ -487,27 +469,23 @@ class Runner:
...
@@ -487,27 +469,23 @@ class Runner:
return
{
output_box
:
ALERTS
[
"info_config_saved"
][
lang
]
+
save_path
}
return
{
output_box
:
ALERTS
[
"info_config_saved"
][
lang
]
+
save_path
}
def
load_args
(
self
,
lang
:
str
,
config_path
:
str
):
def
load_args
(
self
,
lang
:
str
,
config_path
:
str
):
r
"""
r
"""Load the training configuration from config path."""
Loads the training configuration from config path.
"""
output_box
=
self
.
manager
.
get_elem_by_id
(
"train.output_box"
)
output_box
=
self
.
manager
.
get_elem_by_id
(
"train.output_box"
)
config_dict
=
load_args
(
os
.
path
.
join
(
DEFAULT_CONFIG_DIR
,
config_path
))
config_dict
=
load_args
(
os
.
path
.
join
(
DEFAULT_CONFIG_DIR
,
config_path
))
if
config_dict
is
None
:
if
config_dict
is
None
:
gr
.
Warning
(
ALERTS
[
"err_config_not_found"
][
lang
])
gr
.
Warning
(
ALERTS
[
"err_config_not_found"
][
lang
])
return
{
output_box
:
ALERTS
[
"err_config_not_found"
][
lang
]}
return
{
output_box
:
ALERTS
[
"err_config_not_found"
][
lang
]}
output_dict
:
D
ict
[
"
Component
"
,
Any
]
=
{
output_box
:
ALERTS
[
"info_config_loaded"
][
lang
]}
output_dict
:
d
ict
[
Component
,
Any
]
=
{
output_box
:
ALERTS
[
"info_config_loaded"
][
lang
]}
for
elem_id
,
value
in
config_dict
.
items
():
for
elem_id
,
value
in
config_dict
.
items
():
output_dict
[
self
.
manager
.
get_elem_by_id
(
elem_id
)]
=
value
output_dict
[
self
.
manager
.
get_elem_by_id
(
elem_id
)]
=
value
return
output_dict
return
output_dict
def
check_output_dir
(
self
,
lang
:
str
,
model_name
:
str
,
finetuning_type
:
str
,
output_dir
:
str
):
def
check_output_dir
(
self
,
lang
:
str
,
model_name
:
str
,
finetuning_type
:
str
,
output_dir
:
str
):
r
"""
r
"""Restore the training status if output_dir exists."""
Restore the training status if output_dir exists.
"""
output_box
=
self
.
manager
.
get_elem_by_id
(
"train.output_box"
)
output_box
=
self
.
manager
.
get_elem_by_id
(
"train.output_box"
)
output_dict
:
D
ict
[
"
Component
"
,
Any
]
=
{
output_box
:
LOCALES
[
"output_box"
][
lang
][
"value"
]}
output_dict
:
d
ict
[
Component
,
Any
]
=
{
output_box
:
LOCALES
[
"output_box"
][
lang
][
"value"
]}
if
model_name
and
output_dir
and
os
.
path
.
isdir
(
get_save_dir
(
model_name
,
finetuning_type
,
output_dir
)):
if
model_name
and
output_dir
and
os
.
path
.
isdir
(
get_save_dir
(
model_name
,
finetuning_type
,
output_dir
)):
gr
.
Warning
(
ALERTS
[
"warn_output_dir_exists"
][
lang
])
gr
.
Warning
(
ALERTS
[
"warn_output_dir_exists"
][
lang
])
output_dict
[
output_box
]
=
ALERTS
[
"warn_output_dir_exists"
][
lang
]
output_dict
[
output_box
]
=
ALERTS
[
"warn_output_dir_exists"
][
lang
]
...
...
src/webui.py
View file @
7ea81099
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
os
import
os
from
llamafactory.extras.misc
import
is_env_enabled
from
llamafactory.extras.misc
import
fix_proxy
,
is_env_enabled
from
llamafactory.webui.interface
import
create_ui
from
llamafactory.webui.interface
import
create_ui
...
@@ -22,6 +22,8 @@ def main():
...
@@ -22,6 +22,8 @@ def main():
gradio_ipv6
=
is_env_enabled
(
"GRADIO_IPV6"
)
gradio_ipv6
=
is_env_enabled
(
"GRADIO_IPV6"
)
gradio_share
=
is_env_enabled
(
"GRADIO_SHARE"
)
gradio_share
=
is_env_enabled
(
"GRADIO_SHARE"
)
server_name
=
os
.
getenv
(
"GRADIO_SERVER_NAME"
,
"[::]"
if
gradio_ipv6
else
"0.0.0.0"
)
server_name
=
os
.
getenv
(
"GRADIO_SERVER_NAME"
,
"[::]"
if
gradio_ipv6
else
"0.0.0.0"
)
print
(
"Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860"
)
fix_proxy
(
ipv6_enabled
=
gradio_ipv6
)
create_ui
().
queue
().
launch
(
share
=
gradio_share
,
server_name
=
server_name
,
inbrowser
=
True
)
create_ui
().
queue
().
launch
(
share
=
gradio_share
,
server_name
=
server_name
,
inbrowser
=
True
)
...
...
tests/check_license.py
0 → 100644
View file @
7ea81099
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
from
pathlib
import
Path
KEYWORDS
=
(
"Copyright"
,
"2025"
,
"LlamaFactory"
)
def
main
():
path_list
:
list
[
Path
]
=
[]
for
check_dir
in
sys
.
argv
[
1
:]:
path_list
.
extend
(
Path
(
check_dir
).
glob
(
"**/*.py"
))
for
path
in
path_list
:
with
open
(
path
.
absolute
(),
encoding
=
"utf-8"
)
as
f
:
file_content
=
f
.
read
().
strip
().
split
(
"
\n
"
)
if
not
file_content
[
0
]:
continue
print
(
f
"Check license:
{
path
}
"
)
assert
all
(
keyword
in
file_content
[
0
]
for
keyword
in
KEYWORDS
),
f
"File
{
path
}
does not contain license."
if
__name__
==
"__main__"
:
main
()
tests/data/processor/test_pairwise.py
View file @
7ea81099
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
import
os
import
os
import
random
import
random
from
typing
import
Dict
,
List
import
pytest
import
pytest
from
datasets
import
load_dataset
from
datasets
import
load_dataset
...
@@ -43,7 +42,7 @@ TRAIN_ARGS = {
...
@@ -43,7 +42,7 @@ TRAIN_ARGS = {
}
}
def
_convert_sharegpt_to_openai
(
messages
:
L
ist
[
D
ict
[
str
,
str
]])
->
L
ist
[
D
ict
[
str
,
str
]]:
def
_convert_sharegpt_to_openai
(
messages
:
l
ist
[
d
ict
[
str
,
str
]])
->
l
ist
[
d
ict
[
str
,
str
]]:
role_mapping
=
{
"human"
:
"user"
,
"gpt"
:
"assistant"
,
"system"
:
"system"
}
role_mapping
=
{
"human"
:
"user"
,
"gpt"
:
"assistant"
,
"system"
:
"system"
}
new_messages
=
[]
new_messages
=
[]
for
message
in
messages
:
for
message
in
messages
:
...
...
tests/data/processor/test_processor_utils.py
View file @
7ea81099
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
Tuple
import
pytest
import
pytest
...
@@ -31,5 +30,5 @@ from llamafactory.data.processor.processor_utils import infer_seqlen
...
@@ -31,5 +30,5 @@ from llamafactory.data.processor.processor_utils import infer_seqlen
((
10
,
10
,
1000
),
(
10
,
10
)),
((
10
,
10
,
1000
),
(
10
,
10
)),
],
],
)
)
def
test_infer_seqlen
(
test_input
:
T
uple
[
int
,
int
,
int
],
test_output
:
T
uple
[
int
,
int
]):
def
test_infer_seqlen
(
test_input
:
t
uple
[
int
,
int
,
int
],
test_output
:
t
uple
[
int
,
int
]):
assert
test_output
==
infer_seqlen
(
*
test_input
)
assert
test_output
==
infer_seqlen
(
*
test_input
)
tests/data/test_formatter.py
View file @
7ea81099
...
@@ -112,7 +112,8 @@ def test_glm4_tool_formatter():
...
@@ -112,7 +112,8 @@ def test_glm4_tool_formatter():
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具
\n\n
"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具
\n\n
"
f
"## test_tool
\n\n
{
json
.
dumps
(
TOOLS
[
0
],
indent
=
4
,
ensure_ascii
=
False
)
}
\n
在调用上述函数时,请使用 Json 格式表示调用的参数。"
f
"## test_tool
\n\n
{
json
.
dumps
(
TOOLS
[
0
],
indent
=
4
,
ensure_ascii
=
False
)
}
\n
"
"在调用上述函数时,请使用 Json 格式表示调用的参数。"
]
]
...
@@ -136,7 +137,8 @@ def test_llama3_tool_formatter():
...
@@ -136,7 +137,8 @@ def test_llama3_tool_formatter():
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
TOOLS
[
0
]}
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
TOOLS
[
0
]}
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
f
"Cutting Knowledge Date: December 2023
\n
Today Date:
{
date
}
\n\n
"
f
"Cutting Knowledge Date: December 2023
\n
Today Date:
{
date
}
\n\n
"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
"You have access to the following functions. "
"To call a function, please respond with JSON for a function call. "
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
f
"Do not use variables.
\n\n
{
json
.
dumps
(
wrapped_tool
,
indent
=
4
,
ensure_ascii
=
False
)
}
\n\n
"
f
"Do not use variables.
\n\n
{
json
.
dumps
(
wrapped_tool
,
indent
=
4
,
ensure_ascii
=
False
)
}
\n\n
"
]
]
...
...
tests/data/test_mm_plugin.py
View file @
7ea81099
...
@@ -13,13 +13,14 @@
...
@@ -13,13 +13,14 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Sequence
from
typing
import
TYPE_CHECKING
,
Any
import
pytest
import
pytest
import
torch
import
torch
from
PIL
import
Image
from
PIL
import
Image
from
llamafactory.data.mm_plugin
import
get_mm_plugin
from
llamafactory.data.mm_plugin
import
get_mm_plugin
from
llamafactory.extras.packages
import
is_transformers_version_greater_than
from
llamafactory.hparams
import
get_infer_args
from
llamafactory.hparams
import
get_infer_args
from
llamafactory.model
import
load_tokenizer
from
llamafactory.model
import
load_tokenizer
...
@@ -69,12 +70,12 @@ LABELS = [0, 1, 2, 3, 4]
...
@@ -69,12 +70,12 @@ LABELS = [0, 1, 2, 3, 4]
BATCH_IDS
=
[[
1
]
*
1024
]
BATCH_IDS
=
[[
1
]
*
1024
]
def
_get_mm_inputs
(
processor
:
"ProcessorMixin"
)
->
D
ict
[
str
,
"torch.Tensor"
]:
def
_get_mm_inputs
(
processor
:
"ProcessorMixin"
)
->
d
ict
[
str
,
"torch.Tensor"
]:
image_processor
:
"
BaseImageProcessor
"
=
getattr
(
processor
,
"image_processor"
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
return
image_processor
(
images
=
IMAGES
,
return_tensors
=
"pt"
)
return
image_processor
(
images
=
IMAGES
,
return_tensors
=
"pt"
)
def
_is_close
(
batch_a
:
D
ict
[
str
,
Any
],
batch_b
:
D
ict
[
str
,
Any
])
->
None
:
def
_is_close
(
batch_a
:
d
ict
[
str
,
Any
],
batch_b
:
d
ict
[
str
,
Any
])
->
None
:
assert
batch_a
.
keys
()
==
batch_b
.
keys
()
assert
batch_a
.
keys
()
==
batch_b
.
keys
()
for
key
in
batch_a
.
keys
():
for
key
in
batch_a
.
keys
():
if
isinstance
(
batch_a
[
key
],
torch
.
Tensor
):
if
isinstance
(
batch_a
[
key
],
torch
.
Tensor
):
...
@@ -96,11 +97,11 @@ def _check_plugin(
...
@@ -96,11 +97,11 @@ def _check_plugin(
plugin
:
"BasePlugin"
,
plugin
:
"BasePlugin"
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
"ProcessorMixin"
,
processor
:
"ProcessorMixin"
,
expected_mm_messages
:
Sequence
[
D
ict
[
str
,
str
]]
=
MM_MESSAGES
,
expected_mm_messages
:
list
[
d
ict
[
str
,
str
]]
=
MM_MESSAGES
,
expected_input_ids
:
L
ist
[
int
]
=
INPUT_IDS
,
expected_input_ids
:
l
ist
[
int
]
=
INPUT_IDS
,
expected_labels
:
L
ist
[
int
]
=
LABELS
,
expected_labels
:
l
ist
[
int
]
=
LABELS
,
expected_mm_inputs
:
D
ict
[
str
,
Any
]
=
{},
expected_mm_inputs
:
d
ict
[
str
,
Any
]
=
{},
expected_no_mm_inputs
:
D
ict
[
str
,
Any
]
=
{},
expected_no_mm_inputs
:
d
ict
[
str
,
Any
]
=
{},
)
->
None
:
)
->
None
:
# test mm_messages
# test mm_messages
if
plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
if
plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
...
@@ -135,6 +136,27 @@ def test_base_plugin():
...
@@ -135,6 +136,27 @@ def test_base_plugin():
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
@
pytest
.
mark
.
skipif
(
not
HF_TOKEN
or
not
is_transformers_version_greater_than
(
"4.50.0"
),
reason
=
"Gated model."
)
def
test_gemma3_plugin
():
image_seqlen
=
256
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"google/gemma-3-4b-it"
)
gemma3_plugin
=
get_mm_plugin
(
name
=
"gemma3"
,
image_token
=
"<image_soft_token>"
)
image_tokens_expanded
=
"<image_soft_token>"
*
image_seqlen
check_inputs
=
{
"plugin"
:
gemma3_plugin
,
**
tokenizer_module
}
check_inputs
[
"expected_mm_messages"
]
=
[
{
key
:
value
.
replace
(
"<image>"
,
f
"
\n\n
<start_of_image>
{
image_tokens_expanded
}
<end_of_image>
\n\n
"
)
for
key
,
value
in
message
.
items
()
}
for
message
in
MM_MESSAGES
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"processor"
])
check_inputs
[
"expected_mm_inputs"
].
pop
(
"num_crops"
)
check_inputs
[
"expected_mm_inputs"
][
"token_type_ids"
]
=
[[
0
]
*
1024
]
check_inputs
[
"expected_no_mm_inputs"
]
=
{
"token_type_ids"
:
[[
0
]
*
1024
]}
_check_plugin
(
**
check_inputs
)
def
test_llava_plugin
():
def
test_llava_plugin
():
image_seqlen
=
576
image_seqlen
=
576
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-1.5-7b-hf"
)
tokenizer_module
=
_load_tokenizer_module
(
model_name_or_path
=
"llava-hf/llava-1.5-7b-hf"
)
...
@@ -210,7 +232,6 @@ def test_pixtral_plugin():
...
@@ -210,7 +232,6 @@ def test_pixtral_plugin():
for
message
in
MM_MESSAGES
for
message
in
MM_MESSAGES
]
]
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"processor"
])
check_inputs
[
"expected_mm_inputs"
]
=
_get_mm_inputs
(
tokenizer_module
[
"processor"
])
check_inputs
[
"expected_mm_inputs"
].
pop
(
"image_sizes"
)
check_inputs
[
"expected_mm_inputs"
][
"pixel_values"
]
=
check_inputs
[
"expected_mm_inputs"
][
"pixel_values"
][
0
]
check_inputs
[
"expected_mm_inputs"
][
"pixel_values"
]
=
check_inputs
[
"expected_mm_inputs"
][
"pixel_values"
][
0
]
_check_plugin
(
**
check_inputs
)
_check_plugin
(
**
check_inputs
)
...
...
tests/data/test_template.py
View file @
7ea81099
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
from
typing
import
TYPE_CHECKING
,
Sequence
from
typing
import
TYPE_CHECKING
import
pytest
import
pytest
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
...
@@ -40,10 +40,9 @@ MESSAGES = [
...
@@ -40,10 +40,9 @@ MESSAGES = [
def
_check_tokenization
(
def
_check_tokenization
(
tokenizer
:
"PreTrainedTokenizer"
,
batch_input_ids
:
Sequence
[
Sequence
[
int
]],
batch_text
:
Sequence
[
str
]
tokenizer
:
"PreTrainedTokenizer"
,
batch_input_ids
:
list
[
list
[
int
]],
batch_text
:
list
[
str
]
)
->
None
:
)
->
None
:
r
"""
r
"""Check token ids and texts.
Checks token ids and texts.
encode(text) == token_ids
encode(text) == token_ids
decode(token_ids) == text
decode(token_ids) == text
...
@@ -54,8 +53,7 @@ def _check_tokenization(
...
@@ -54,8 +53,7 @@ def _check_tokenization(
def
_check_template
(
model_id
:
str
,
template_name
:
str
,
prompt_str
:
str
,
answer_str
:
str
,
use_fast
:
bool
)
->
None
:
def
_check_template
(
model_id
:
str
,
template_name
:
str
,
prompt_str
:
str
,
answer_str
:
str
,
use_fast
:
bool
)
->
None
:
r
"""
r
"""Check template.
Checks template.
Args:
Args:
model_id: the model id on hugging face hub.
model_id: the model id on hugging face hub.
...
@@ -63,6 +61,7 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s
...
@@ -63,6 +61,7 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s
prompt_str: the string corresponding to the prompt part.
prompt_str: the string corresponding to the prompt part.
answer_str: the string corresponding to the answer part.
answer_str: the string corresponding to the answer part.
use_fast: whether to use fast tokenizer.
use_fast: whether to use fast tokenizer.
"""
"""
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
use_fast
=
use_fast
,
token
=
HF_TOKEN
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
use_fast
=
use_fast
,
token
=
HF_TOKEN
)
content_str
=
tokenizer
.
apply_chat_template
(
MESSAGES
,
tokenize
=
False
)
content_str
=
tokenizer
.
apply_chat_template
(
MESSAGES
,
tokenize
=
False
)
...
...
tests/e2e/test_sglang.py
0 → 100644
View file @
7ea81099
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
pytest
from
llamafactory.chat
import
ChatModel
from
llamafactory.extras.packages
import
is_sglang_available
MODEL_NAME
=
"meta-llama/Llama-3.2-1B-Instruct"
INFER_ARGS
=
{
"model_name_or_path"
:
MODEL_NAME
,
"finetuning_type"
:
"lora"
,
"template"
:
"llama3"
,
"infer_dtype"
:
"float16"
,
"infer_backend"
:
"sglang"
,
"do_sample"
:
False
,
"max_new_tokens"
:
1
,
}
MESSAGES
=
[
{
"role"
:
"user"
,
"content"
:
"Hi"
},
]
@
pytest
.
mark
.
skipif
(
not
is_sglang_available
(),
reason
=
"SGLang is not installed"
)
def
test_chat
():
r
"""Test the SGLang engine's basic chat functionality."""
chat_model
=
ChatModel
(
INFER_ARGS
)
response
=
chat_model
.
chat
(
MESSAGES
)[
0
]
# TODO: Change to EXPECTED_RESPONSE
print
(
response
.
response_text
)
@
pytest
.
mark
.
skipif
(
not
is_sglang_available
(),
reason
=
"SGLang is not installed"
)
def
test_stream_chat
():
r
"""Test the SGLang engine's streaming chat functionality."""
chat_model
=
ChatModel
(
INFER_ARGS
)
response
=
""
for
token
in
chat_model
.
stream_chat
(
MESSAGES
):
response
+=
token
print
(
"Complete response:"
,
response
)
assert
response
,
"Should receive a non-empty response"
# Run tests if executed directly
if
__name__
==
"__main__"
:
if
not
is_sglang_available
():
print
(
"SGLang is not available. Please install it."
)
sys
.
exit
(
1
)
test_chat
()
test_stream_chat
()
tests/model/model_utils/test_checkpointing.py
View file @
7ea81099
...
@@ -62,5 +62,5 @@ def test_upcast_layernorm():
...
@@ -62,5 +62,5 @@ def test_upcast_layernorm():
def
test_upcast_lmhead_output
():
def
test_upcast_lmhead_output
():
model
=
load_train_model
(
upcast_lmhead_output
=
True
,
**
TRAIN_ARGS
)
model
=
load_train_model
(
upcast_lmhead_output
=
True
,
**
TRAIN_ARGS
)
inputs
=
torch
.
randn
((
1
,
16
),
dtype
=
torch
.
float16
,
device
=
get_current_device
())
inputs
=
torch
.
randn
((
1
,
16
),
dtype
=
torch
.
float16
,
device
=
get_current_device
())
outputs
:
"
torch.Tensor
"
=
model
.
get_output_embeddings
()(
inputs
)
outputs
:
torch
.
Tensor
=
model
.
get_output_embeddings
()(
inputs
)
assert
outputs
.
dtype
==
torch
.
float32
assert
outputs
.
dtype
==
torch
.
float32
tests/model/test_pissa.py
View file @
7ea81099
...
@@ -48,8 +48,6 @@ INFER_ARGS = {
...
@@ -48,8 +48,6 @@ INFER_ARGS = {
"infer_dtype"
:
"float16"
,
"infer_dtype"
:
"float16"
,
}
}
OS_NAME
=
os
.
getenv
(
"OS_NAME"
,
""
)
@
pytest
.
mark
.
xfail
(
reason
=
"PiSSA initialization is not stable in different platform."
)
@
pytest
.
mark
.
xfail
(
reason
=
"PiSSA initialization is not stable in different platform."
)
def
test_pissa_train
():
def
test_pissa_train
():
...
@@ -58,7 +56,7 @@ def test_pissa_train():
...
@@ -58,7 +56,7 @@ def test_pissa_train():
compare_model
(
model
,
ref_model
)
compare_model
(
model
,
ref_model
)
@
pytest
.
mark
.
xfail
(
OS_NAME
.
startswith
(
"windows"
),
reason
=
"Known connection error
on Windows
."
)
@
pytest
.
mark
.
xfail
(
reason
=
"Known connection error."
)
def
test_pissa_inference
():
def
test_pissa_inference
():
model
=
load_infer_model
(
**
INFER_ARGS
)
model
=
load_infer_model
(
**
INFER_ARGS
)
ref_model
=
load_reference_model
(
TINY_LLAMA_PISSA
,
TINY_LLAMA_PISSA
,
use_pissa
=
True
,
is_trainable
=
False
)
ref_model
=
load_reference_model
(
TINY_LLAMA_PISSA
,
TINY_LLAMA_PISSA
,
use_pissa
=
True
,
is_trainable
=
False
)
...
...
tests/train/test_sft_trainer.py
View file @
7ea81099
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
os
import
os
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Dict
,
List
from
typing
import
Any
import
pytest
import
pytest
from
transformers
import
DataCollatorWithPadding
from
transformers
import
DataCollatorWithPadding
...
@@ -46,9 +46,9 @@ TRAIN_ARGS = {
...
@@ -46,9 +46,9 @@ TRAIN_ARGS = {
@
dataclass
@
dataclass
class
DataCollatorWithVerbose
(
DataCollatorWithPadding
):
class
DataCollatorWithVerbose
(
DataCollatorWithPadding
):
verbose_list
:
L
ist
[
D
ict
[
str
,
Any
]]
=
field
(
default_factory
=
list
)
verbose_list
:
l
ist
[
d
ict
[
str
,
Any
]]
=
field
(
default_factory
=
list
)
def
__call__
(
self
,
features
:
L
ist
[
D
ict
[
str
,
Any
]])
->
D
ict
[
str
,
Any
]:
def
__call__
(
self
,
features
:
l
ist
[
d
ict
[
str
,
Any
]])
->
d
ict
[
str
,
Any
]:
self
.
verbose_list
.
extend
(
features
)
self
.
verbose_list
.
extend
(
features
)
batch
=
super
().
__call__
(
features
)
batch
=
super
().
__call__
(
features
)
return
{
k
:
v
[:,
:
1
]
for
k
,
v
in
batch
.
items
()}
# truncate input length
return
{
k
:
v
[:,
:
1
]
for
k
,
v
in
batch
.
items
()}
# truncate input length
...
...
tests/version.txt
0 → 100644
View file @
7ea81099
# change if test fails
0.9.3.101
Prev
1
…
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment