Commit ca625f43 authored by shihm's avatar shihm
Browse files

uodata

parent 7164651d
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.metadata
import importlib.util
from functools import lru_cache
from typing import TYPE_CHECKING
from packaging import version
if TYPE_CHECKING:
from packaging.version import Version
def _is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None
def _get_package_version(name: str) -> "Version":
try:
return version.parse(importlib.metadata.version(name))
except Exception:
return version.parse("0.0.0")
@lru_cache
def is_transformers_version_greater_than(content: str):
return _get_package_version("transformers") >= version.parse(content)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable
from . import logging
logger = logging.get_logger(__name__)
class BasePlugin:
"""Base class for plugins.
A plugin is a callable object that can be registered and called by name.
"""
_registry: dict[str, Callable] = {}
def __init__(self, name: str | None = None):
"""Initialize the plugin with a name.
Args:
name (str): The name of the plugin.
"""
self.name = name
@property
def register(self):
"""Decorator to register a function as a plugin.
Example usage:
```python
@PrintPlugin("hello").register()
def print_hello():
print("Hello world!")
```
"""
if self.name is None:
raise ValueError("Plugin name is not specified.")
if self.name in self._registry:
logger.warning_rank0_once(f"Plugin {self.name} is already registered.")
def decorator(func: Callable) -> Callable:
self._registry[self.name] = func
return func
return decorator
def __call__(self, *args, **kwargs):
"""Call the registered function with the given arguments.
Example usage:
```python
PrintPlugin("hello")()
```
"""
if self.name not in self._registry:
raise ValueError(f"Plugin {self.name} is not registered.")
return self._registry[self.name](*args, **kwargs)
if __name__ == "__main__":
"""
python -m llamafactory.v1.utils.plugin
"""
class PrintPlugin(BasePlugin):
pass
@PrintPlugin("hello").register
def print_hello():
print("Hello world!")
PrintPlugin("hello")()
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager
@contextmanager
def dist_env(local_rank: int = 0, world_size: int = 1, master_port: int = 25595):
"""Set distributed environment variables."""
env_vars = {
"MASTER_ADDR": "127.0.0.1",
"MASTER_PORT": str(master_port),
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"LOCAL_WORLD_SIZE": str(world_size),
}
os.environ.update(env_vars)
try:
yield
finally:
for key in env_vars.keys():
os.environ.pop(key, None)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
if TYPE_CHECKING:
import datasets
import numpy as np
import torch
import torch.utils.data
import transformers
from torch.distributed import ProcessGroup
from torch.distributed.fsdp import FullyShardedDataParallel
Tensor = torch.Tensor
TensorLike = Union[int, float, list[int], list[float], np.ndarray, Tensor]
TorchDataset = Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset]
HFDataset = Union[datasets.Dataset, datasets.IterableDataset]
DataCollator = transformers.DataCollator
DataLoader = torch.utils.data.DataLoader
HFConfig = transformers.PretrainedConfig
HFModel = transformers.PreTrainedModel
DistModel = Union[torch.nn.parallel.DistributedDataParallel, FullyShardedDataParallel]
Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]
Optimizer = torch.optim.Optimizer
Scheduler = torch.optim.lr_scheduler.LRScheduler
ProcessGroup = ProcessGroup
else:
Tensor = None
TensorLike = None
TorchDataset = None
HFDataset = None
DataCollator = None
DataLoader = None
HFConfig = None
HFModel = None
DistModel = None
Processor = None
Optimizer = None
Scheduler = None
ProcessGroup = None
class DatasetInfo(TypedDict, total=False):
path: str
"""Local file path."""
source: NotRequired[Literal["hf_hub", "ms_hub", "local"]]
"""Dataset source, default to "hf_hub"."""
split: NotRequired[str]
"""Dataset split, default to "train"."""
converter: NotRequired[str]
"""Dataset converter, default to None."""
size: NotRequired[int]
"""Number of samples, default to all samples."""
weight: NotRequired[float]
"""Dataset weight, default to 1.0."""
streaming: NotRequired[bool]
"""Is streaming dataset, default to False."""
class DistributedConfig(TypedDict, total=False):
mp_replicate_size: NotRequired[int]
"""Model parallel replicate size, default to 1."""
mp_shard_size: NotRequired[int]
"""Model parallel shard size, default to world_size // mp_replicate_size."""
dp_size: NotRequired[int]
"""Data parallel size, default to world_size // cp_size."""
cp_size: NotRequired[int]
"""Context parallel size, default to 1."""
timeout: NotRequired[int]
"""Timeout for distributed communication, default to 600."""
class Content(TypedDict):
type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"]
value: str
class Message(TypedDict):
role: Literal["system", "user", "assistant", "tool"]
content: list[Content]
loss_weight: float
class SFTSample(TypedDict):
messages: list[Message]
extra_info: NotRequired[str]
_dataset_name: NotRequired[str]
class DPOSample(TypedDict):
chosen_messages: list[Message]
rejected_messages: list[Message]
extra_info: NotRequired[str]
_dataset_name: NotRequired[str]
Sample = Union[SFTSample, DPOSample]
......@@ -16,7 +16,7 @@ import json
import os
from collections.abc import Generator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from transformers.utils import is_torch_npu_available
......@@ -81,7 +81,7 @@ class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager
self.demo_mode = demo_mode
self.engine: Optional[BaseEngine] = None
self.engine: BaseEngine | None = None
if not lazy_init: # read arguments from command line
super().__init__()
......@@ -197,9 +197,9 @@ class WebChatModel(ChatModel):
lang: str,
system: str,
tools: str,
image: Optional[Any],
video: Optional[Any],
audio: Optional[Any],
image: Any | None,
video: Any | None,
audio: Any | None,
max_new_tokens: int,
top_p: float,
temperature: float,
......
......@@ -17,7 +17,7 @@ import os
import signal
from collections import defaultdict
from datetime import datetime
from typing import Any, Optional, Union
from typing import Any
from psutil import Process
from yaml import safe_dump, safe_load
......@@ -36,8 +36,8 @@ from ..extras.misc import use_modelscope, use_openmind
logger = logging.get_logger(__name__)
DEFAULT_CACHE_DIR = "cache"
DEFAULT_CONFIG_DIR = "config"
DEFAULT_CACHE_DIR = "llamaboard_cache"
DEFAULT_CONFIG_DIR = "llamaboard_config"
DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user_config.yaml"
......@@ -71,7 +71,7 @@ def _get_config_path() -> os.PathLike:
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> dict[str, Union[str, dict[str, Any]]]:
def load_config() -> dict[str, str | dict[str, Any]]:
r"""Load user config if exists."""
try:
with open(_get_config_path(), encoding="utf-8") as f:
......@@ -81,7 +81,7 @@ def load_config() -> dict[str, Union[str, dict[str, Any]]]:
def save_config(
lang: str, hub_name: Optional[str] = None, model_name: Optional[str] = None, model_path: Optional[str] = None
lang: str, hub_name: str | None = None, model_name: str | None = None, model_path: str | None = None
) -> None:
r"""Save user config."""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
......@@ -151,7 +151,7 @@ def load_dataset_info(dataset_dir: str) -> dict[str, dict[str, Any]]:
return {}
def load_args(config_path: str) -> Optional[dict[str, Any]]:
def load_args(config_path: str) -> dict[str, Any] | None:
r"""Load the training configuration from config path."""
try:
with open(config_path, encoding="utf-8") as f:
......
......@@ -14,7 +14,7 @@
import json
from collections.abc import Generator
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING
from ...extras.constants import PEFT_METHODS
from ...extras.misc import torch_gc
......@@ -37,7 +37,7 @@ if TYPE_CHECKING:
GPTQ_BITS = ["8", "4", "3", "2"]
def can_quantize(checkpoint_path: Union[str, list[str]]) -> "gr.Dropdown":
def can_quantize(checkpoint_path: str | list[str]) -> "gr.Dropdown":
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
return gr.Dropdown(value="none", interactive=False)
else:
......@@ -49,7 +49,7 @@ def save_model(
model_name: str,
model_path: str,
finetuning_type: str,
checkpoint_path: Union[str, list[str]],
checkpoint_path: str | list[str],
template: str,
export_size: int,
export_quantization_bit: str,
......
......@@ -14,7 +14,7 @@
import json
import os
from typing import Any, Optional
from typing import Any
from transformers.trainer_utils import get_last_checkpoint
......@@ -206,7 +206,7 @@ def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_S
return gr.Dropdown(choices=datasets)
def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown":
def list_output_dirs(model_name: str | None, finetuning_type: str, current_time: str) -> "gr.Dropdown":
r"""List all the directories that can resume from.
Inputs: top.model_name, top.finetuning_type, train.current_time
......
......@@ -34,31 +34,41 @@ LOCALES = {
"en": {
"value": (
"<h3><center>Visit <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub Page</a></center></h3>"
"GitHub Page</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
"Documentation</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
"Blog</a></center></h3>"
),
},
"ru": {
"value": (
"<h3><center>Посетить <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"страницу GitHub</a></center></h3>"
"страницу GitHub</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
"Документацию</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
"Блог</a></center></h3>"
),
},
"zh": {
"value": (
"<h3><center>访问 <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub 主页</a></center></h3>"
"GitHub 主页</a> <a href='https://llamafactory.readthedocs.io/zh-cn/latest/' target='_blank'>"
"官方文档</a> <a href='https://blog.llamafactory.net/' target='_blank'>"
"博客</a></center></h3>"
),
},
"ko": {
"value": (
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub 페이지</a>를 방문하세요.</center></h3>"
"GitHub 페이지</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
"공식 문서</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
"블로그</a>를 방문하세요.</center></h3>"
),
},
"ja": {
"value": (
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub ページ</a>にアクセスする</center></h3>"
"GitHub ページ</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
"ドキュメント</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
"ブログ</a>にアクセスする</center></h3>"
),
},
},
......
......@@ -17,7 +17,7 @@ import os
from collections.abc import Generator
from copy import deepcopy
from subprocess import PIPE, Popen, TimeoutExpired
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from transformers.utils import is_torch_npu_available
......@@ -59,7 +59,7 @@ class Runner:
self.manager = manager
self.demo_mode = demo_mode
""" Resume """
self.trainer: Optional[Popen] = None
self.trainer: Popen | None = None
self.do_train = True
self.running_data: dict[Component, Any] = None
""" State """
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LLaMA-Factory test configuration.
Contains shared fixtures, pytest configuration, and custom markers.
"""
import os
from typing import Optional
import pytest
import torch
import torch.distributed as dist
from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.extras.misc import get_current_device, get_device_count, is_env_enabled
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.train.test_utils import patch_valuehead_model
CURRENT_DEVICE = get_current_device().type
def pytest_configure(config: Config):
"""Register custom pytest markers."""
config.addinivalue_line(
"markers",
"slow: marks tests as slow (deselect with '-m \"not slow\"' or set RUN_SLOW=1 to run)",
)
config.addinivalue_line(
"markers",
"runs_on: test requires specific device type, e.g., @pytest.mark.runs_on(['cuda'])",
)
config.addinivalue_line(
"markers",
"require_distributed(num_devices): allow multi-device execution (default: 2)",
)
def _handle_runs_on(items: list[Item]):
"""Skip tests on specified device TYPES (cpu/cuda/npu)."""
for item in items:
marker = item.get_closest_marker("runs_on")
if not marker:
continue
devices = marker.args[0]
if isinstance(devices, str):
devices = [devices]
if CURRENT_DEVICE not in devices:
item.add_marker(pytest.mark.skip(reason=f"test requires one of {devices} (current: {CURRENT_DEVICE})"))
def _handle_slow_tests(items: list[Item]):
"""Skip slow tests unless RUN_SLOW is enabled."""
if not is_env_enabled("RUN_SLOW"):
skip_slow = pytest.mark.skip(reason="slow test (set RUN_SLOW=1 to run)")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
def _get_visible_devices_env() -> Optional[str]:
"""Return device visibility env var name."""
if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES"
elif CURRENT_DEVICE == "npu":
return "ASCEND_RT_VISIBLE_DEVICES"
else:
return None
def _handle_device_visibility(items: list[Item]):
"""Handle device visibility based on test markers."""
env_key = _get_visible_devices_env()
if env_key is None or CURRENT_DEVICE in ("cpu", "mps"):
return
# Parse visible devices
visible_devices_env = os.environ.get(env_key)
if visible_devices_env is None:
available = get_device_count()
else:
visible_devices = [v for v in visible_devices_env.split(",") if v != ""]
available = len(visible_devices)
for item in items:
marker = item.get_closest_marker("require_distributed")
if not marker:
continue
required = marker.args[0] if marker.args else 2
if available < required:
item.add_marker(pytest.mark.skip(reason=f"test requires {required} devices, but only {available} visible"))
def pytest_collection_modifyitems(config: Config, items: list[Item]):
"""Modify test collection based on markers and environment."""
# Handle version compatibility (from HEAD)
if not is_transformers_version_greater_than("4.57.0"):
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
for item in items:
if "tests_v1" in str(item.fspath):
item.add_marker(skip_bc)
_handle_slow_tests(items)
_handle_runs_on(items)
_handle_device_visibility(items)
@pytest.fixture(autouse=True)
def _cleanup_distributed_state():
"""Cleanup distributed state after each test."""
yield
if dist.is_initialized():
dist.destroy_process_group()
@pytest.fixture(autouse=True)
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
"""Set environment variables for distributed tests if specific devices are requested."""
env_key = _get_visible_devices_env()
if not env_key:
return
# Save old environment for logic checks, monkeypatch handles restoration
old_value = os.environ.get(env_key)
marker = request.node.get_closest_marker("require_distributed")
if marker: # distributed test
required = marker.args[0] if marker.args else 2
specific_devices = marker.args[1] if len(marker.args) > 1 else None
if specific_devices:
devices_str = ",".join(map(str, specific_devices))
else:
devices_str = ",".join(str(i) for i in range(required))
monkeypatch.setenv(env_key, devices_str)
else: # non-distributed test
if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""]
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else:
monkeypatch.setenv(env_key, "0")
if CURRENT_DEVICE == "cuda":
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu":
monkeypatch.setattr(torch.npu, "device_count", lambda: 1)
@pytest.fixture
def fix_valuehead_cpu_loading():
"""Fix valuehead model loading."""
patch_valuehead_model()
......@@ -42,6 +42,7 @@ TRAIN_ARGS = {
}
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [16])
def test_feedback_data(num_samples: int):
train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"]
......
......@@ -51,6 +51,7 @@ def _convert_sharegpt_to_openai(messages: list[dict[str, str]]) -> list[dict[str
return new_messages
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [16])
def test_pairwise_data(num_samples: int):
train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"]
......
......@@ -18,6 +18,7 @@ import pytest
from llamafactory.data.processor.processor_utils import infer_seqlen
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize(
"test_input,test_output",
[
......
......@@ -42,6 +42,7 @@ TRAIN_ARGS = {
}
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [16])
def test_supervised_single_turn(num_samples: int):
train_dataset = load_dataset_module(dataset_dir="ONLINE", dataset=TINY_DATA, **TRAIN_ARGS)["train_dataset"]
......@@ -61,6 +62,7 @@ def test_supervised_single_turn(num_samples: int):
assert train_dataset["input_ids"][index] == ref_input_ids
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [8])
def test_supervised_multi_turn(num_samples: int):
train_dataset = load_dataset_module(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS)[
......@@ -74,6 +76,7 @@ def test_supervised_multi_turn(num_samples: int):
assert train_dataset["input_ids"][index] == ref_input_ids
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [4])
def test_supervised_train_on_prompt(num_samples: int):
train_dataset = load_dataset_module(
......@@ -88,6 +91,7 @@ def test_supervised_train_on_prompt(num_samples: int):
assert train_dataset["labels"][index] == ref_ids
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [4])
def test_supervised_mask_history(num_samples: int):
train_dataset = load_dataset_module(
......
......@@ -42,9 +42,11 @@ TRAIN_ARGS = {
"output_dir": "dummy_dir",
"overwrite_output_dir": True,
"fp16": True,
"report_to": "none", # transfromers compatibility
}
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [16])
def test_unsupervised_data(num_samples: int):
train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"]
......
......@@ -14,6 +14,7 @@
import os
import pytest
import torch
from PIL import Image
from transformers import AutoConfig, AutoModelForVision2Seq
......@@ -28,6 +29,7 @@ from llamafactory.model import load_tokenizer
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
@pytest.mark.runs_on(["cpu", "mps"])
def test_base_collator():
model_args, data_args, *_ = get_infer_args({"model_name_or_path": TINY_LLAMA3, "template": "default"})
tokenizer_module = load_tokenizer(model_args)
......@@ -71,6 +73,7 @@ def test_base_collator():
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
@pytest.mark.runs_on(["cpu", "mps"])
def test_multimodal_collator():
model_args, data_args, *_ = get_infer_args(
{"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"}
......@@ -126,6 +129,7 @@ def test_multimodal_collator():
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
@pytest.mark.runs_on(["cpu"])
def test_4d_attention_mask():
o = 0.0
x = torch.finfo(torch.float16).min
......
......@@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from llamafactory.data import Role
from llamafactory.data.converter import get_dataset_converter
from llamafactory.data.parser import DatasetAttr
from llamafactory.hparams import DataArguments
@pytest.mark.runs_on(["cpu", "mps"])
def test_alpaca_converter():
dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset")
data_args = DataArguments()
......@@ -38,6 +41,7 @@ def test_alpaca_converter():
}
@pytest.mark.runs_on(["cpu", "mps"])
def test_sharegpt_converter():
dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset")
data_args = DataArguments()
......
......@@ -15,6 +15,8 @@
import json
from datetime import datetime
import pytest
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
......@@ -36,16 +38,19 @@ TOOLS = [
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_empty_formatter():
formatter = EmptyFormatter(slots=["\n"])
assert formatter.apply() == ["\n"]
@pytest.mark.runs_on(["cpu", "mps"])
def test_string_formatter():
formatter = StringFormatter(slots=["<s>", "Human: {{content}}\nAssistant:"])
assert formatter.apply(content="Hi") == ["<s>", "Human: Hi\nAssistant:"]
@pytest.mark.runs_on(["cpu", "mps"])
def test_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps(FUNCTION)
......@@ -55,6 +60,7 @@ def test_function_formatter():
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps([FUNCTION] * 2)
......@@ -65,6 +71,7 @@ def test_multi_function_formatter():
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_default_tool_formatter():
formatter = ToolFormatter(tool_format="default")
assert formatter.apply(content=json.dumps(TOOLS)) == [
......@@ -83,12 +90,14 @@ def test_default_tool_formatter():
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_default_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu", "mps"])
def test_default_multi_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = (
......@@ -101,12 +110,14 @@ def test_default_multi_tool_extractor():
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_glm4_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""]
@pytest.mark.runs_on(["cpu", "mps"])
def test_glm4_tool_formatter():
formatter = ToolFormatter(tool_format="glm4")
assert formatter.apply(content=json.dumps(TOOLS)) == [
......@@ -117,12 +128,14 @@ def test_glm4_tool_formatter():
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_glm4_tool_extractor():
formatter = ToolFormatter(tool_format="glm4")
result = """test_tool\n{"foo": "bar", "size": 10}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps(FUNCTION)
......@@ -131,6 +144,7 @@ def test_llama3_function_formatter():
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps([FUNCTION] * 2)
......@@ -141,6 +155,7 @@ def test_llama3_multi_function_formatter():
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_tool_formatter():
formatter = ToolFormatter(tool_format="llama3")
date = datetime.now().strftime("%d %b %Y")
......@@ -154,12 +169,14 @@ def test_llama3_tool_formatter():
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_tool_extractor():
formatter = ToolFormatter(tool_format="llama3")
result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_multi_tool_extractor():
formatter = ToolFormatter(tool_format="llama3")
result = (
......@@ -172,6 +189,7 @@ def test_llama3_multi_tool_extractor():
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps(FUNCTION)
......@@ -181,6 +199,7 @@ def test_mistral_function_formatter():
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_multi_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps([FUNCTION] * 2)
......@@ -192,6 +211,7 @@ def test_mistral_multi_function_formatter():
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_tool_formatter():
formatter = ToolFormatter(tool_format="mistral")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
......@@ -200,12 +220,14 @@ def test_mistral_tool_formatter():
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_tool_extractor():
formatter = ToolFormatter(tool_format="mistral")
result = """{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_multi_tool_extractor():
formatter = ToolFormatter(tool_format="mistral")
result = (
......@@ -218,6 +240,7 @@ def test_mistral_multi_tool_extractor():
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps(FUNCTION)
......@@ -226,6 +249,7 @@ def test_qwen_function_formatter():
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps([FUNCTION] * 2)
......@@ -236,6 +260,7 @@ def test_qwen_multi_function_formatter():
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_tool_formatter():
formatter = ToolFormatter(tool_format="qwen")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
......@@ -249,12 +274,14 @@ def test_qwen_tool_formatter():
]
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_tool_extractor():
formatter = ToolFormatter(tool_format="qwen")
result = """<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_multi_tool_extractor():
formatter = ToolFormatter(tool_format="qwen")
result = (
......
......@@ -14,6 +14,8 @@
import os
import pytest
from llamafactory.train.test_utils import load_dataset_module
......@@ -38,18 +40,21 @@ TRAIN_ARGS = {
}
@pytest.mark.runs_on(["cpu", "mps"])
def test_load_train_only():
dataset_module = load_dataset_module(**TRAIN_ARGS)
assert dataset_module.get("train_dataset") is not None
assert dataset_module.get("eval_dataset") is None
@pytest.mark.runs_on(["cpu", "mps"])
def test_load_val_size():
dataset_module = load_dataset_module(val_size=0.1, **TRAIN_ARGS)
assert dataset_module.get("train_dataset") is not None
assert dataset_module.get("eval_dataset") is not None
@pytest.mark.runs_on(["cpu", "mps"])
def test_load_eval_data():
dataset_module = load_dataset_module(eval_dataset=TINY_DATA, **TRAIN_ARGS)
assert dataset_module.get("train_dataset") is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment