Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
ca625f43
Commit
ca625f43
authored
Mar 30, 2026
by
shihm
Browse files
uodata
parent
7164651d
Changes
327
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
527 additions
and
23 deletions
+527
-23
src/llamafactory/v1/utils/packages.py
src/llamafactory/v1/utils/packages.py
+43
-0
src/llamafactory/v1/utils/plugin.py
src/llamafactory/v1/utils/plugin.py
+89
-0
src/llamafactory/v1/utils/pytest.py
src/llamafactory/v1/utils/pytest.py
+35
-0
src/llamafactory/v1/utils/types.py
src/llamafactory/v1/utils/types.py
+110
-0
src/llamafactory/webui/chatter.py
src/llamafactory/webui/chatter.py
+5
-5
src/llamafactory/webui/common.py
src/llamafactory/webui/common.py
+6
-6
src/llamafactory/webui/components/export.py
src/llamafactory/webui/components/export.py
+3
-3
src/llamafactory/webui/control.py
src/llamafactory/webui/control.py
+2
-2
src/llamafactory/webui/locales.py
src/llamafactory/webui/locales.py
+15
-5
src/llamafactory/webui/runner.py
src/llamafactory/webui/runner.py
+2
-2
tests/conftest.py
tests/conftest.py
+168
-0
tests/data/processor/test_feedback.py
tests/data/processor/test_feedback.py
+1
-0
tests/data/processor/test_pairwise.py
tests/data/processor/test_pairwise.py
+1
-0
tests/data/processor/test_processor_utils.py
tests/data/processor/test_processor_utils.py
+1
-0
tests/data/processor/test_supervised.py
tests/data/processor/test_supervised.py
+4
-0
tests/data/processor/test_unsupervised.py
tests/data/processor/test_unsupervised.py
+2
-0
tests/data/test_collator.py
tests/data/test_collator.py
+4
-0
tests/data/test_converter.py
tests/data/test_converter.py
+4
-0
tests/data/test_formatter.py
tests/data/test_formatter.py
+27
-0
tests/data/test_loader.py
tests/data/test_loader.py
+5
-0
No files found.
src/llamafactory/v1/utils/packages.py
0 → 100644
View file @
ca625f43
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
importlib.metadata
import
importlib.util
from
functools
import
lru_cache
from
typing
import
TYPE_CHECKING
from
packaging
import
version
if
TYPE_CHECKING
:
from
packaging.version
import
Version
def
_is_package_available
(
name
:
str
)
->
bool
:
return
importlib
.
util
.
find_spec
(
name
)
is
not
None
def
_get_package_version
(
name
:
str
)
->
"Version"
:
try
:
return
version
.
parse
(
importlib
.
metadata
.
version
(
name
))
except
Exception
:
return
version
.
parse
(
"0.0.0"
)
@
lru_cache
def
is_transformers_version_greater_than
(
content
:
str
):
return
_get_package_version
(
"transformers"
)
>=
version
.
parse
(
content
)
src/llamafactory/v1/utils/plugin.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections.abc
import
Callable
from
.
import
logging
logger
=
logging
.
get_logger
(
__name__
)
class
BasePlugin
:
"""Base class for plugins.
A plugin is a callable object that can be registered and called by name.
"""
_registry
:
dict
[
str
,
Callable
]
=
{}
def
__init__
(
self
,
name
:
str
|
None
=
None
):
"""Initialize the plugin with a name.
Args:
name (str): The name of the plugin.
"""
self
.
name
=
name
@
property
def
register
(
self
):
"""Decorator to register a function as a plugin.
Example usage:
```python
@PrintPlugin("hello").register()
def print_hello():
print("Hello world!")
```
"""
if
self
.
name
is
None
:
raise
ValueError
(
"Plugin name is not specified."
)
if
self
.
name
in
self
.
_registry
:
logger
.
warning_rank0_once
(
f
"Plugin
{
self
.
name
}
is already registered."
)
def
decorator
(
func
:
Callable
)
->
Callable
:
self
.
_registry
[
self
.
name
]
=
func
return
func
return
decorator
def
__call__
(
self
,
*
args
,
**
kwargs
):
"""Call the registered function with the given arguments.
Example usage:
```python
PrintPlugin("hello")()
```
"""
if
self
.
name
not
in
self
.
_registry
:
raise
ValueError
(
f
"Plugin
{
self
.
name
}
is not registered."
)
return
self
.
_registry
[
self
.
name
](
*
args
,
**
kwargs
)
if
__name__
==
"__main__"
:
"""
python -m llamafactory.v1.utils.plugin
"""
class
PrintPlugin
(
BasePlugin
):
pass
@
PrintPlugin
(
"hello"
).
register
def
print_hello
():
print
(
"Hello world!"
)
PrintPlugin
(
"hello"
)()
src/llamafactory/v1/utils/pytest.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
contextlib
import
contextmanager
@
contextmanager
def
dist_env
(
local_rank
:
int
=
0
,
world_size
:
int
=
1
,
master_port
:
int
=
25595
):
"""Set distributed environment variables."""
env_vars
=
{
"MASTER_ADDR"
:
"127.0.0.1"
,
"MASTER_PORT"
:
str
(
master_port
),
"RANK"
:
str
(
local_rank
),
"LOCAL_RANK"
:
str
(
local_rank
),
"WORLD_SIZE"
:
str
(
world_size
),
"LOCAL_WORLD_SIZE"
:
str
(
world_size
),
}
os
.
environ
.
update
(
env_vars
)
try
:
yield
finally
:
for
key
in
env_vars
.
keys
():
os
.
environ
.
pop
(
key
,
None
)
src/llamafactory/v1/utils/types.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Literal
,
NotRequired
,
TypedDict
,
Union
if
TYPE_CHECKING
:
import
datasets
import
numpy
as
np
import
torch
import
torch.utils.data
import
transformers
from
torch.distributed
import
ProcessGroup
from
torch.distributed.fsdp
import
FullyShardedDataParallel
Tensor
=
torch
.
Tensor
TensorLike
=
Union
[
int
,
float
,
list
[
int
],
list
[
float
],
np
.
ndarray
,
Tensor
]
TorchDataset
=
Union
[
torch
.
utils
.
data
.
Dataset
,
torch
.
utils
.
data
.
IterableDataset
]
HFDataset
=
Union
[
datasets
.
Dataset
,
datasets
.
IterableDataset
]
DataCollator
=
transformers
.
DataCollator
DataLoader
=
torch
.
utils
.
data
.
DataLoader
HFConfig
=
transformers
.
PretrainedConfig
HFModel
=
transformers
.
PreTrainedModel
DistModel
=
Union
[
torch
.
nn
.
parallel
.
DistributedDataParallel
,
FullyShardedDataParallel
]
Processor
=
Union
[
transformers
.
PreTrainedTokenizer
,
transformers
.
ProcessorMixin
]
Optimizer
=
torch
.
optim
.
Optimizer
Scheduler
=
torch
.
optim
.
lr_scheduler
.
LRScheduler
ProcessGroup
=
ProcessGroup
else
:
Tensor
=
None
TensorLike
=
None
TorchDataset
=
None
HFDataset
=
None
DataCollator
=
None
DataLoader
=
None
HFConfig
=
None
HFModel
=
None
DistModel
=
None
Processor
=
None
Optimizer
=
None
Scheduler
=
None
ProcessGroup
=
None
class
DatasetInfo
(
TypedDict
,
total
=
False
):
path
:
str
"""Local file path."""
source
:
NotRequired
[
Literal
[
"hf_hub"
,
"ms_hub"
,
"local"
]]
"""Dataset source, default to "hf_hub"."""
split
:
NotRequired
[
str
]
"""Dataset split, default to "train"."""
converter
:
NotRequired
[
str
]
"""Dataset converter, default to None."""
size
:
NotRequired
[
int
]
"""Number of samples, default to all samples."""
weight
:
NotRequired
[
float
]
"""Dataset weight, default to 1.0."""
streaming
:
NotRequired
[
bool
]
"""Is streaming dataset, default to False."""
class
DistributedConfig
(
TypedDict
,
total
=
False
):
mp_replicate_size
:
NotRequired
[
int
]
"""Model parallel replicate size, default to 1."""
mp_shard_size
:
NotRequired
[
int
]
"""Model parallel shard size, default to world_size // mp_replicate_size."""
dp_size
:
NotRequired
[
int
]
"""Data parallel size, default to world_size // cp_size."""
cp_size
:
NotRequired
[
int
]
"""Context parallel size, default to 1."""
timeout
:
NotRequired
[
int
]
"""Timeout for distributed communication, default to 600."""
class
Content
(
TypedDict
):
type
:
Literal
[
"text"
,
"reasoning"
,
"tools"
,
"tool_calls"
,
"image_url"
]
value
:
str
class
Message
(
TypedDict
):
role
:
Literal
[
"system"
,
"user"
,
"assistant"
,
"tool"
]
content
:
list
[
Content
]
loss_weight
:
float
class
SFTSample
(
TypedDict
):
messages
:
list
[
Message
]
extra_info
:
NotRequired
[
str
]
_dataset_name
:
NotRequired
[
str
]
class
DPOSample
(
TypedDict
):
chosen_messages
:
list
[
Message
]
rejected_messages
:
list
[
Message
]
extra_info
:
NotRequired
[
str
]
_dataset_name
:
NotRequired
[
str
]
Sample
=
Union
[
SFTSample
,
DPOSample
]
src/llamafactory/webui/chatter.py
View file @
ca625f43
...
...
@@ -16,7 +16,7 @@ import json
import
os
from
collections.abc
import
Generator
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
from
transformers.utils
import
is_torch_npu_available
...
...
@@ -81,7 +81,7 @@ class WebChatModel(ChatModel):
def
__init__
(
self
,
manager
:
"Manager"
,
demo_mode
:
bool
=
False
,
lazy_init
:
bool
=
True
)
->
None
:
self
.
manager
=
manager
self
.
demo_mode
=
demo_mode
self
.
engine
:
Optional
[
BaseEngine
]
=
None
self
.
engine
:
BaseEngine
|
None
=
None
if
not
lazy_init
:
# read arguments from command line
super
().
__init__
()
...
...
@@ -197,9 +197,9 @@ class WebChatModel(ChatModel):
lang
:
str
,
system
:
str
,
tools
:
str
,
image
:
Optional
[
Any
]
,
video
:
Optional
[
Any
]
,
audio
:
Optional
[
Any
]
,
image
:
Any
|
None
,
video
:
Any
|
None
,
audio
:
Any
|
None
,
max_new_tokens
:
int
,
top_p
:
float
,
temperature
:
float
,
...
...
src/llamafactory/webui/common.py
View file @
ca625f43
...
...
@@ -17,7 +17,7 @@ import os
import
signal
from
collections
import
defaultdict
from
datetime
import
datetime
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
from
psutil
import
Process
from
yaml
import
safe_dump
,
safe_load
...
...
@@ -36,8 +36,8 @@ from ..extras.misc import use_modelscope, use_openmind
logger
=
logging
.
get_logger
(
__name__
)
DEFAULT_CACHE_DIR
=
"cache"
DEFAULT_CONFIG_DIR
=
"config"
DEFAULT_CACHE_DIR
=
"
llamaboard_
cache"
DEFAULT_CONFIG_DIR
=
"
llamaboard_
config"
DEFAULT_DATA_DIR
=
"data"
DEFAULT_SAVE_DIR
=
"saves"
USER_CONFIG
=
"user_config.yaml"
...
...
@@ -71,7 +71,7 @@ def _get_config_path() -> os.PathLike:
return
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
USER_CONFIG
)
def
load_config
()
->
dict
[
str
,
Union
[
str
,
dict
[
str
,
Any
]]
]
:
def
load_config
()
->
dict
[
str
,
str
|
dict
[
str
,
Any
]]:
r
"""Load user config if exists."""
try
:
with
open
(
_get_config_path
(),
encoding
=
"utf-8"
)
as
f
:
...
...
@@ -81,7 +81,7 @@ def load_config() -> dict[str, Union[str, dict[str, Any]]]:
def
save_config
(
lang
:
str
,
hub_name
:
Optional
[
str
]
=
None
,
model_name
:
Optional
[
str
]
=
None
,
model_path
:
Optional
[
str
]
=
None
lang
:
str
,
hub_name
:
str
|
None
=
None
,
model_name
:
str
|
None
=
None
,
model_path
:
str
|
None
=
None
)
->
None
:
r
"""Save user config."""
os
.
makedirs
(
DEFAULT_CACHE_DIR
,
exist_ok
=
True
)
...
...
@@ -151,7 +151,7 @@ def load_dataset_info(dataset_dir: str) -> dict[str, dict[str, Any]]:
return
{}
def
load_args
(
config_path
:
str
)
->
Optional
[
dict
[
str
,
Any
]
]
:
def
load_args
(
config_path
:
str
)
->
dict
[
str
,
Any
]
|
None
:
r
"""Load the training configuration from config path."""
try
:
with
open
(
config_path
,
encoding
=
"utf-8"
)
as
f
:
...
...
src/llamafactory/webui/components/export.py
View file @
ca625f43
...
...
@@ -14,7 +14,7 @@
import
json
from
collections.abc
import
Generator
from
typing
import
TYPE_CHECKING
,
Union
from
typing
import
TYPE_CHECKING
from
...extras.constants
import
PEFT_METHODS
from
...extras.misc
import
torch_gc
...
...
@@ -37,7 +37,7 @@ if TYPE_CHECKING:
GPTQ_BITS
=
[
"8"
,
"4"
,
"3"
,
"2"
]
def
can_quantize
(
checkpoint_path
:
Union
[
str
,
list
[
str
]
]
)
->
"gr.Dropdown"
:
def
can_quantize
(
checkpoint_path
:
str
|
list
[
str
])
->
"gr.Dropdown"
:
if
isinstance
(
checkpoint_path
,
list
)
and
len
(
checkpoint_path
)
!=
0
:
return
gr
.
Dropdown
(
value
=
"none"
,
interactive
=
False
)
else
:
...
...
@@ -49,7 +49,7 @@ def save_model(
model_name
:
str
,
model_path
:
str
,
finetuning_type
:
str
,
checkpoint_path
:
Union
[
str
,
list
[
str
]
]
,
checkpoint_path
:
str
|
list
[
str
],
template
:
str
,
export_size
:
int
,
export_quantization_bit
:
str
,
...
...
src/llamafactory/webui/control.py
View file @
ca625f43
...
...
@@ -14,7 +14,7 @@
import
json
import
os
from
typing
import
Any
,
Optional
from
typing
import
Any
from
transformers.trainer_utils
import
get_last_checkpoint
...
...
@@ -206,7 +206,7 @@ def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_S
return
gr
.
Dropdown
(
choices
=
datasets
)
def
list_output_dirs
(
model_name
:
Optional
[
str
]
,
finetuning_type
:
str
,
current_time
:
str
)
->
"gr.Dropdown"
:
def
list_output_dirs
(
model_name
:
str
|
None
,
finetuning_type
:
str
,
current_time
:
str
)
->
"gr.Dropdown"
:
r
"""List all the directories that can resume from.
Inputs: top.model_name, top.finetuning_type, train.current_time
...
...
src/llamafactory/webui/locales.py
View file @
ca625f43
...
...
@@ -34,31 +34,41 @@ LOCALES = {
"en"
:
{
"value"
:
(
"<h3><center>Visit <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub Page</a></center></h3>"
"GitHub Page</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
"Documentation</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
"Blog</a></center></h3>"
),
},
"ru"
:
{
"value"
:
(
"<h3><center>Посетить <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"страницу GitHub</a></center></h3>"
"страницу GitHub</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
"Документацию</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
"Блог</a></center></h3>"
),
},
"zh"
:
{
"value"
:
(
"<h3><center>访问 <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub 主页</a></center></h3>"
"GitHub 主页</a> <a href='https://llamafactory.readthedocs.io/zh-cn/latest/' target='_blank'>"
"官方文档</a> <a href='https://blog.llamafactory.net/' target='_blank'>"
"博客</a></center></h3>"
),
},
"ko"
:
{
"value"
:
(
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub 페이지</a>를 방문하세요.</center></h3>"
"GitHub 페이지</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
"공식 문서</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
"블로그</a>를 방문하세요.</center></h3>"
),
},
"ja"
:
{
"value"
:
(
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub ページ</a>にアクセスする</center></h3>"
"GitHub ページ</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
"ドキュメント</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
"ブログ</a>にアクセスする</center></h3>"
),
},
},
...
...
src/llamafactory/webui/runner.py
View file @
ca625f43
...
...
@@ -17,7 +17,7 @@ import os
from
collections.abc
import
Generator
from
copy
import
deepcopy
from
subprocess
import
PIPE
,
Popen
,
TimeoutExpired
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
from
transformers.utils
import
is_torch_npu_available
...
...
@@ -59,7 +59,7 @@ class Runner:
self
.
manager
=
manager
self
.
demo_mode
=
demo_mode
""" Resume """
self
.
trainer
:
Optional
[
Popen
]
=
None
self
.
trainer
:
Popen
|
None
=
None
self
.
do_train
=
True
self
.
running_data
:
dict
[
Component
,
Any
]
=
None
""" State """
...
...
tests/conftest.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LLaMA-Factory test configuration.
Contains shared fixtures, pytest configuration, and custom markers.
"""
import
os
from
typing
import
Optional
import
pytest
import
torch
import
torch.distributed
as
dist
from
pytest
import
Config
,
FixtureRequest
,
Item
,
MonkeyPatch
from
llamafactory.extras.misc
import
get_current_device
,
get_device_count
,
is_env_enabled
from
llamafactory.extras.packages
import
is_transformers_version_greater_than
from
llamafactory.train.test_utils
import
patch_valuehead_model
CURRENT_DEVICE
=
get_current_device
().
type
def
pytest_configure
(
config
:
Config
):
"""Register custom pytest markers."""
config
.
addinivalue_line
(
"markers"
,
"slow: marks tests as slow (deselect with '-m
\"
not slow
\"
' or set RUN_SLOW=1 to run)"
,
)
config
.
addinivalue_line
(
"markers"
,
"runs_on: test requires specific device type, e.g., @pytest.mark.runs_on(['cuda'])"
,
)
config
.
addinivalue_line
(
"markers"
,
"require_distributed(num_devices): allow multi-device execution (default: 2)"
,
)
def
_handle_runs_on
(
items
:
list
[
Item
]):
"""Skip tests on specified device TYPES (cpu/cuda/npu)."""
for
item
in
items
:
marker
=
item
.
get_closest_marker
(
"runs_on"
)
if
not
marker
:
continue
devices
=
marker
.
args
[
0
]
if
isinstance
(
devices
,
str
):
devices
=
[
devices
]
if
CURRENT_DEVICE
not
in
devices
:
item
.
add_marker
(
pytest
.
mark
.
skip
(
reason
=
f
"test requires one of
{
devices
}
(current:
{
CURRENT_DEVICE
}
)"
))
def
_handle_slow_tests
(
items
:
list
[
Item
]):
"""Skip slow tests unless RUN_SLOW is enabled."""
if
not
is_env_enabled
(
"RUN_SLOW"
):
skip_slow
=
pytest
.
mark
.
skip
(
reason
=
"slow test (set RUN_SLOW=1 to run)"
)
for
item
in
items
:
if
"slow"
in
item
.
keywords
:
item
.
add_marker
(
skip_slow
)
def
_get_visible_devices_env
()
->
Optional
[
str
]:
"""Return device visibility env var name."""
if
CURRENT_DEVICE
==
"cuda"
:
return
"CUDA_VISIBLE_DEVICES"
elif
CURRENT_DEVICE
==
"npu"
:
return
"ASCEND_RT_VISIBLE_DEVICES"
else
:
return
None
def
_handle_device_visibility
(
items
:
list
[
Item
]):
"""Handle device visibility based on test markers."""
env_key
=
_get_visible_devices_env
()
if
env_key
is
None
or
CURRENT_DEVICE
in
(
"cpu"
,
"mps"
):
return
# Parse visible devices
visible_devices_env
=
os
.
environ
.
get
(
env_key
)
if
visible_devices_env
is
None
:
available
=
get_device_count
()
else
:
visible_devices
=
[
v
for
v
in
visible_devices_env
.
split
(
","
)
if
v
!=
""
]
available
=
len
(
visible_devices
)
for
item
in
items
:
marker
=
item
.
get_closest_marker
(
"require_distributed"
)
if
not
marker
:
continue
required
=
marker
.
args
[
0
]
if
marker
.
args
else
2
if
available
<
required
:
item
.
add_marker
(
pytest
.
mark
.
skip
(
reason
=
f
"test requires
{
required
}
devices, but only
{
available
}
visible"
))
def
pytest_collection_modifyitems
(
config
:
Config
,
items
:
list
[
Item
]):
"""Modify test collection based on markers and environment."""
# Handle version compatibility (from HEAD)
if
not
is_transformers_version_greater_than
(
"4.57.0"
):
skip_bc
=
pytest
.
mark
.
skip
(
reason
=
"Skip backward compatibility tests"
)
for
item
in
items
:
if
"tests_v1"
in
str
(
item
.
fspath
):
item
.
add_marker
(
skip_bc
)
_handle_slow_tests
(
items
)
_handle_runs_on
(
items
)
_handle_device_visibility
(
items
)
@
pytest
.
fixture
(
autouse
=
True
)
def
_cleanup_distributed_state
():
"""Cleanup distributed state after each test."""
yield
if
dist
.
is_initialized
():
dist
.
destroy_process_group
()
@
pytest
.
fixture
(
autouse
=
True
)
def
_manage_distributed_env
(
request
:
FixtureRequest
,
monkeypatch
:
MonkeyPatch
)
->
None
:
"""Set environment variables for distributed tests if specific devices are requested."""
env_key
=
_get_visible_devices_env
()
if
not
env_key
:
return
# Save old environment for logic checks, monkeypatch handles restoration
old_value
=
os
.
environ
.
get
(
env_key
)
marker
=
request
.
node
.
get_closest_marker
(
"require_distributed"
)
if
marker
:
# distributed test
required
=
marker
.
args
[
0
]
if
marker
.
args
else
2
specific_devices
=
marker
.
args
[
1
]
if
len
(
marker
.
args
)
>
1
else
None
if
specific_devices
:
devices_str
=
","
.
join
(
map
(
str
,
specific_devices
))
else
:
devices_str
=
","
.
join
(
str
(
i
)
for
i
in
range
(
required
))
monkeypatch
.
setenv
(
env_key
,
devices_str
)
else
:
# non-distributed test
if
old_value
:
visible_devices
=
[
v
for
v
in
old_value
.
split
(
","
)
if
v
!=
""
]
monkeypatch
.
setenv
(
env_key
,
visible_devices
[
0
]
if
visible_devices
else
"0"
)
else
:
monkeypatch
.
setenv
(
env_key
,
"0"
)
if
CURRENT_DEVICE
==
"cuda"
:
monkeypatch
.
setattr
(
torch
.
cuda
,
"device_count"
,
lambda
:
1
)
elif
CURRENT_DEVICE
==
"npu"
:
monkeypatch
.
setattr
(
torch
.
npu
,
"device_count"
,
lambda
:
1
)
@
pytest
.
fixture
def
fix_valuehead_cpu_loading
():
"""Fix valuehead model loading."""
patch_valuehead_model
()
tests/data/processor/test_feedback.py
View file @
ca625f43
...
...
@@ -42,6 +42,7 @@ TRAIN_ARGS = {
}
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
parametrize
(
"num_samples"
,
[
16
])
def
test_feedback_data
(
num_samples
:
int
):
train_dataset
=
load_dataset_module
(
**
TRAIN_ARGS
)[
"train_dataset"
]
...
...
tests/data/processor/test_pairwise.py
View file @
ca625f43
...
...
@@ -51,6 +51,7 @@ def _convert_sharegpt_to_openai(messages: list[dict[str, str]]) -> list[dict[str
return
new_messages
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
parametrize
(
"num_samples"
,
[
16
])
def
test_pairwise_data
(
num_samples
:
int
):
train_dataset
=
load_dataset_module
(
**
TRAIN_ARGS
)[
"train_dataset"
]
...
...
tests/data/processor/test_processor_utils.py
View file @
ca625f43
...
...
@@ -18,6 +18,7 @@ import pytest
from
llamafactory.data.processor.processor_utils
import
infer_seqlen
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
parametrize
(
"test_input,test_output"
,
[
...
...
tests/data/processor/test_supervised.py
View file @
ca625f43
...
...
@@ -42,6 +42,7 @@ TRAIN_ARGS = {
}
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
parametrize
(
"num_samples"
,
[
16
])
def
test_supervised_single_turn
(
num_samples
:
int
):
train_dataset
=
load_dataset_module
(
dataset_dir
=
"ONLINE"
,
dataset
=
TINY_DATA
,
**
TRAIN_ARGS
)[
"train_dataset"
]
...
...
@@ -61,6 +62,7 @@ def test_supervised_single_turn(num_samples: int):
assert
train_dataset
[
"input_ids"
][
index
]
==
ref_input_ids
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
parametrize
(
"num_samples"
,
[
8
])
def
test_supervised_multi_turn
(
num_samples
:
int
):
train_dataset
=
load_dataset_module
(
dataset_dir
=
"REMOTE:"
+
DEMO_DATA
,
dataset
=
"system_chat"
,
**
TRAIN_ARGS
)[
...
...
@@ -74,6 +76,7 @@ def test_supervised_multi_turn(num_samples: int):
assert
train_dataset
[
"input_ids"
][
index
]
==
ref_input_ids
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
parametrize
(
"num_samples"
,
[
4
])
def
test_supervised_train_on_prompt
(
num_samples
:
int
):
train_dataset
=
load_dataset_module
(
...
...
@@ -88,6 +91,7 @@ def test_supervised_train_on_prompt(num_samples: int):
assert
train_dataset
[
"labels"
][
index
]
==
ref_ids
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
parametrize
(
"num_samples"
,
[
4
])
def
test_supervised_mask_history
(
num_samples
:
int
):
train_dataset
=
load_dataset_module
(
...
...
tests/data/processor/test_unsupervised.py
View file @
ca625f43
...
...
@@ -42,9 +42,11 @@ TRAIN_ARGS = {
"output_dir"
:
"dummy_dir"
,
"overwrite_output_dir"
:
True
,
"fp16"
:
True
,
"report_to"
:
"none"
,
# transfromers compatibility
}
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
@
pytest
.
mark
.
parametrize
(
"num_samples"
,
[
16
])
def
test_unsupervised_data
(
num_samples
:
int
):
train_dataset
=
load_dataset_module
(
**
TRAIN_ARGS
)[
"train_dataset"
]
...
...
tests/data/test_collator.py
View file @
ca625f43
...
...
@@ -14,6 +14,7 @@
import
os
import
pytest
import
torch
from
PIL
import
Image
from
transformers
import
AutoConfig
,
AutoModelForVision2Seq
...
...
@@ -28,6 +29,7 @@ from llamafactory.model import load_tokenizer
TINY_LLAMA3
=
os
.
getenv
(
"TINY_LLAMA3"
,
"llamafactory/tiny-random-Llama-3"
)
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_base_collator
():
model_args
,
data_args
,
*
_
=
get_infer_args
({
"model_name_or_path"
:
TINY_LLAMA3
,
"template"
:
"default"
})
tokenizer_module
=
load_tokenizer
(
model_args
)
...
...
@@ -71,6 +73,7 @@ def test_base_collator():
assert
batch_input
[
k
].
eq
(
torch
.
tensor
(
expected_input
[
k
])).
all
()
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_multimodal_collator
():
model_args
,
data_args
,
*
_
=
get_infer_args
(
{
"model_name_or_path"
:
"Qwen/Qwen2-VL-2B-Instruct"
,
"template"
:
"qwen2_vl"
}
...
...
@@ -126,6 +129,7 @@ def test_multimodal_collator():
assert
batch_input
[
k
].
eq
(
torch
.
tensor
(
expected_input
[
k
])).
all
()
@
pytest
.
mark
.
runs_on
([
"cpu"
])
def
test_4d_attention_mask
():
o
=
0.0
x
=
torch
.
finfo
(
torch
.
float16
).
min
...
...
tests/data/test_converter.py
View file @
ca625f43
...
...
@@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
pytest
from
llamafactory.data
import
Role
from
llamafactory.data.converter
import
get_dataset_converter
from
llamafactory.data.parser
import
DatasetAttr
from
llamafactory.hparams
import
DataArguments
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_alpaca_converter
():
dataset_attr
=
DatasetAttr
(
"hf_hub"
,
"llamafactory/tiny-supervised-dataset"
)
data_args
=
DataArguments
()
...
...
@@ -38,6 +41,7 @@ def test_alpaca_converter():
}
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_sharegpt_converter
():
dataset_attr
=
DatasetAttr
(
"hf_hub"
,
"llamafactory/tiny-supervised-dataset"
)
data_args
=
DataArguments
()
...
...
tests/data/test_formatter.py
View file @
ca625f43
...
...
@@ -15,6 +15,8 @@
import
json
from
datetime
import
datetime
import
pytest
from
llamafactory.data.formatter
import
EmptyFormatter
,
FunctionFormatter
,
StringFormatter
,
ToolFormatter
...
...
@@ -36,16 +38,19 @@ TOOLS = [
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_empty_formatter
():
formatter
=
EmptyFormatter
(
slots
=
[
"
\n
"
])
assert
formatter
.
apply
()
==
[
"
\n
"
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_string_formatter
():
formatter
=
StringFormatter
(
slots
=
[
"<s>"
,
"Human: {{content}}
\n
Assistant:"
])
assert
formatter
.
apply
(
content
=
"Hi"
)
==
[
"<s>"
,
"Human: Hi
\n
Assistant:"
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
,
"</s>"
],
tool_format
=
"default"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
...
...
@@ -55,6 +60,7 @@ def test_function_formatter():
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_multi_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
,
"</s>"
],
tool_format
=
"default"
)
tool_calls
=
json
.
dumps
([
FUNCTION
]
*
2
)
...
...
@@ -65,6 +71,7 @@ def test_multi_function_formatter():
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_default_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"default"
)
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
...
...
@@ -83,12 +90,14 @@ def test_default_tool_formatter():
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_default_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"default"
)
result
=
"""Action: test_tool
\n
Action Input: {"foo": "bar", "size": 10}"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_default_multi_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"default"
)
result
=
(
...
...
@@ -101,12 +110,14 @@ def test_default_multi_tool_extractor():
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_glm4_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}"
],
tool_format
=
"glm4"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
assert
formatter
.
apply
(
content
=
tool_calls
)
==
[
"""tool_name
\n
{"foo": "bar", "size": 10}"""
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_glm4_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"glm4"
)
assert
formatter
.
apply
(
content
=
json
.
dumps
(
TOOLS
))
==
[
...
...
@@ -117,12 +128,14 @@ def test_glm4_tool_formatter():
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_glm4_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"glm4"
)
result
=
"""test_tool
\n
{"foo": "bar", "size": 10}
\n
"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_llama3_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
],
tool_format
=
"llama3"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
...
...
@@ -131,6 +144,7 @@ def test_llama3_function_formatter():
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_llama3_multi_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
],
tool_format
=
"llama3"
)
tool_calls
=
json
.
dumps
([
FUNCTION
]
*
2
)
...
...
@@ -141,6 +155,7 @@ def test_llama3_multi_function_formatter():
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_llama3_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"llama3"
)
date
=
datetime
.
now
().
strftime
(
"%d %b %Y"
)
...
...
@@ -154,12 +169,14 @@ def test_llama3_tool_formatter():
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_llama3_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"llama3"
)
result
=
"""{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}
\n
"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_llama3_multi_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"llama3"
)
result
=
(
...
...
@@ -172,6 +189,7 @@ def test_llama3_multi_tool_extractor():
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_mistral_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] {{content}}"
,
"</s>"
],
tool_format
=
"mistral"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
...
...
@@ -181,6 +199,7 @@ def test_mistral_function_formatter():
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_mistral_multi_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] {{content}}"
,
"</s>"
],
tool_format
=
"mistral"
)
tool_calls
=
json
.
dumps
([
FUNCTION
]
*
2
)
...
...
@@ -192,6 +211,7 @@ def test_mistral_multi_function_formatter():
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_mistral_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"mistral"
)
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
TOOLS
[
0
]}
...
...
@@ -200,12 +220,14 @@ def test_mistral_tool_formatter():
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_mistral_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"mistral"
)
result
=
"""{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_mistral_multi_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"mistral"
)
result
=
(
...
...
@@ -218,6 +240,7 @@ def test_mistral_multi_tool_extractor():
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_qwen_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
)
tool_calls
=
json
.
dumps
(
FUNCTION
)
...
...
@@ -226,6 +249,7 @@ def test_qwen_function_formatter():
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_qwen_multi_function_formatter
():
formatter
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
)
tool_calls
=
json
.
dumps
([
FUNCTION
]
*
2
)
...
...
@@ -236,6 +260,7 @@ def test_qwen_multi_function_formatter():
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_qwen_tool_formatter
():
formatter
=
ToolFormatter
(
tool_format
=
"qwen"
)
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
TOOLS
[
0
]}
...
...
@@ -249,12 +274,14 @@ def test_qwen_tool_formatter():
]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_qwen_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"qwen"
)
result
=
"""<tool_call>
\n
{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}
\n
</tool_call>"""
assert
formatter
.
extract
(
result
)
==
[(
"test_tool"
,
"""{"foo": "bar", "size": 10}"""
)]
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_qwen_multi_tool_extractor
():
formatter
=
ToolFormatter
(
tool_format
=
"qwen"
)
result
=
(
...
...
tests/data/test_loader.py
View file @
ca625f43
...
...
@@ -14,6 +14,8 @@
import
os
import
pytest
from
llamafactory.train.test_utils
import
load_dataset_module
...
...
@@ -38,18 +40,21 @@ TRAIN_ARGS = {
}
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_load_train_only
():
dataset_module
=
load_dataset_module
(
**
TRAIN_ARGS
)
assert
dataset_module
.
get
(
"train_dataset"
)
is
not
None
assert
dataset_module
.
get
(
"eval_dataset"
)
is
None
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_load_val_size
():
dataset_module
=
load_dataset_module
(
val_size
=
0.1
,
**
TRAIN_ARGS
)
assert
dataset_module
.
get
(
"train_dataset"
)
is
not
None
assert
dataset_module
.
get
(
"eval_dataset"
)
is
not
None
@
pytest
.
mark
.
runs_on
([
"cpu"
,
"mps"
])
def
test_load_eval_data
():
dataset_module
=
load_dataset_module
(
eval_dataset
=
TINY_DATA
,
**
TRAIN_ARGS
)
assert
dataset_module
.
get
(
"train_dataset"
)
is
not
None
...
...
Prev
1
…
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment