Commit ca625f43 authored by shihm's avatar shihm
Browse files

uodata

parent 7164651d
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's VeOmni library.
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/utils/dist_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions used by the distributed interface.
Including:
- Environment info (rank, world_size, local_rank, etc.)
- Accelerator info (device type, device count, etc.)
- Collective communication operations (all_gather, all_reduce, broadcast)
- Synchronize processes and ensure main-process-first execution order
"""
import os
from collections.abc import Callable
from contextlib import contextmanager
from enum import Enum, unique
from functools import lru_cache, wraps
from typing import Optional
import numpy as np
import torch
import torch.distributed as dist
from ..utils.types import ProcessGroup, Tensor, TensorLike
@unique
class DeviceType(str, Enum):
CPU = "cpu"
CUDA = "cuda"
META = "meta"
MPS = "mps"
NPU = "npu"
XPU = "xpu"
@unique
class ReduceOp(str, Enum):
SUM = "sum"
MEAN = "mean"
MAX = "max"
MIN = "min"
def requires_accelerator(fn):
"""Decorator to check if torch.accelerator is available.
Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError
"""
@wraps(fn)
def wrapper(*args, **kwargs):
if not hasattr(torch, "accelerator"):
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
return fn(*args, **kwargs)
return wrapper
def is_distributed() -> bool:
"""Check if distributed environment is available."""
return os.getenv("RANK") is not None
def get_rank() -> int:
"""Get rank."""
return int(os.getenv("RANK", "0"))
def get_world_size() -> int:
"""Get world size."""
return int(os.getenv("WORLD_SIZE", "1"))
def get_local_rank() -> int:
"""Get local rank."""
return int(os.getenv("LOCAL_RANK", "0"))
def get_local_world_size() -> int:
"""Get local world size."""
return int(os.getenv("LOCAL_WORLD_SIZE", "1"))
@lru_cache
@requires_accelerator
def get_current_accelerator(check_available: bool = True) -> torch.device:
"""Get current accelerator."""
accelerator = torch.accelerator.current_accelerator(check_available=check_available)
return accelerator or torch.device(DeviceType.CPU.value)
@lru_cache
@requires_accelerator
def get_device_count() -> int:
"""Get the number of available devices."""
return torch.accelerator.device_count()
@requires_accelerator
def synchronize() -> None:
"""Synchronize all processes."""
torch.accelerator.synchronize()
@requires_accelerator
def set_device() -> None:
"""Set current accelerator."""
torch.accelerator.set_device_index(get_local_rank())
def is_torch_cuda_available():
"""Check if CUDA is available."""
return get_current_accelerator().type == DeviceType.CUDA
def is_torch_mps_available():
"""Check if MPS is available."""
return get_current_accelerator().type == DeviceType.MPS
def is_torch_npu_available():
"""Check if NPU is available."""
return get_current_accelerator().type == DeviceType.NPU
def is_torch_xpu_available():
"""Check if XPU is available."""
return get_current_accelerator().type == DeviceType.XPU
def operate_tensorlike(fn: Callable[[...], Tensor], data: TensorLike, **kwargs) -> TensorLike:
"""Operate tensorlike data on current accelerator."""
device = get_current_accelerator()
is_tensor = isinstance(data, torch.Tensor)
is_ndarray = isinstance(data, np.ndarray)
if is_tensor:
orig_device = data.device
data = data.to(device=device)
elif is_ndarray:
data = torch.from_numpy(data).to(device=device, dtype=torch.float)
else:
data = torch.tensor(data, dtype=torch.float, device=device)
result = fn(data, **kwargs)
if is_tensor:
return result.to(orig_device)
elif is_ndarray:
return result.cpu().numpy()
elif result.numel() == 1:
return result.item()
else:
return result.tolist()
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
"""Gathers the tensor from all ranks and stacks them at the first dim."""
world_size = get_world_size()
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=tensor.device)
dist.all_gather_into_tensor(output_tensor, tensor, group=group)
return output_tensor.view(-1, *tensor.size())
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> Tensor:
"""Performs all reduce in the given process group."""
reduce_ops = {
ReduceOp.MEAN: dist.ReduceOp.SUM,
ReduceOp.SUM: dist.ReduceOp.SUM,
ReduceOp.MAX: dist.ReduceOp.MAX,
ReduceOp.MIN: dist.ReduceOp.MIN,
}
dist.all_reduce(tensor, op=reduce_ops[op], group=group)
if op == ReduceOp.MEAN: # ReduceOp.AVG is not supported by the NPU backend
tensor /= dist.get_world_size(group=group)
return tensor
def broadcast(tensor: Tensor, src: int = 0, group: Optional[ProcessGroup] = None) -> Tensor:
"""Broadcasts the tensor from the src process to all other processes."""
dist.broadcast(tensor, src=src, group=group)
return tensor
@contextmanager
def main_process_first(local_only: bool = True) -> None:
"""A context manager for torch distributed environment to do something on the main process firstly."""
if get_world_size() > 1:
is_main_process = get_local_rank() == 0 if local_only else get_rank() == 0
try:
if not is_main_process:
dist.barrier()
yield
finally:
if is_main_process:
dist.barrier()
else:
yield
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's VeOmni library.
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/distributed/parallel_state.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A unified interface for model parallelism and data parallelism.
Supports model parallelism types:
- mp_replicate: Replicate model across multiple devices.
- mp_shard: Shard model across multiple devices.
And data parallelism types:
- dp: Data parallelism.
- cp: Context parallelism.
"""
from dataclasses import dataclass
from datetime import timedelta
from enum import Enum
from typing import Any, Optional
from torch.distributed import barrier, destroy_process_group, init_process_group
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
from . import helper
class Dim(str, Enum):
"""Dimension names."""
MP_REPLICATE = "mp_replicate"
MP_SHARD = "mp_shard"
DP = "dp"
CP = "cp"
@dataclass
class DistributedStrategy:
"""Distributed strategy."""
mp_replicate_size: int = 1
"""Model parallel replicate size, default to 1."""
mp_shard_size: int | None = None
"""Model parallel shard size, default to world_size // mp_replicate_size."""
dp_size: int | None = None
"""Data parallel size, default to world_size // cp_size."""
cp_size: int = 1
"""Context parallel size, default to 1."""
def __post_init__(self) -> None:
if not helper.is_distributed():
self.mp_shard_size = 1
elif self.mp_shard_size is None:
self.mp_shard_size = helper.get_world_size() // self.mp_replicate_size
elif self.mp_replicate_size * self.mp_shard_size != helper.get_world_size():
raise ValueError(
f"mp_replicate_size * mp_shard_size must equal to world_size, "
f"got {self.mp_replicate_size} * {self.mp_shard_size} != {helper.get_world_size()}."
)
if not helper.is_distributed():
self.dp_size = 1
elif self.dp_size is None:
self.dp_size = helper.get_world_size() // self.cp_size
elif self.dp_size * self.cp_size != helper.get_world_size():
raise ValueError(
f"dp_size * cp_size must equal to world_size, "
f"got {self.dp_size} * {self.cp_size} != {helper.get_world_size()}."
)
@property
def model_mesh_shape(self) -> tuple[int, int]:
"""Model parallel mesh shape."""
return (self.mp_replicate_size, self.mp_shard_size)
@property
def model_mesh_dim_names(self) -> tuple[str, str]:
"""Model parallel mesh dimension names."""
return (Dim.MP_REPLICATE.value, Dim.MP_SHARD.value)
@property
def data_mesh_shape(self) -> tuple[int, int]:
"""Data parallel mesh shape."""
return (self.dp_size, self.cp_size)
@property
def data_mesh_dim_names(self) -> tuple[str, str]:
"""Data parallel mesh dimension names."""
return (Dim.DP.value, Dim.CP.value)
class DistributedInterface:
"""Distributed interface."""
_instance: Optional["DistributedInterface"] = None
_initialized: bool = False
def __new__(cls, *args: Any, **kwargs: Any) -> "DistributedInterface":
"""Singleton pattern."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config: DistributedConfig | None = None) -> None:
if self._initialized:
return
self._is_distributed = helper.is_distributed()
self._rank = helper.get_rank()
self._world_size = helper.get_world_size()
self._local_rank = helper.get_local_rank()
self._local_world_size = helper.get_local_world_size()
self.current_accelerator = helper.get_current_accelerator()
self.device_count = helper.get_device_count()
if config is None:
self.strategy = DistributedStrategy()
timeout = 18000
else:
self.strategy = DistributedStrategy(
mp_replicate_size=config.get("mp_replicate_size", 1),
mp_shard_size=config.get("mp_shard_size", None),
dp_size=config.get("dp_size", None),
cp_size=config.get("cp_size", 1),
)
timeout = config.get("timeout", 18000)
if self._is_distributed:
helper.set_device()
init_process_group(timeout=timedelta(seconds=timeout))
self.model_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
mesh_shape=self.strategy.model_mesh_shape,
mesh_dim_names=self.strategy.model_mesh_dim_names,
)
self.data_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
mesh_shape=self.strategy.data_mesh_shape,
mesh_dim_names=self.strategy.data_mesh_dim_names,
)
else:
self.model_device_mesh = None
self.data_device_mesh = None
self._initialized = True
def __str__(self) -> str:
return (
f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, "
f"current_accelerator={self.current_accelerator}, rank={self._rank}, world_size={self._world_size}, "
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
)
def get_device_mesh(self, dim: Dim | None = None) -> DeviceMesh | None:
"""Get device mesh for specified dimension."""
if dim is None:
raise ValueError("dim must be specified.")
elif self.model_device_mesh is None:
return None
elif dim in self.strategy.data_mesh_dim_names:
return self.data_device_mesh[dim.value]
else:
return self.model_device_mesh[dim.value]
def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]:
"""Get process group for specified dimension."""
if self.model_device_mesh is None or dim is None:
return None
else:
return self.get_device_mesh(dim).get_group()
def get_rank(self, dim: Dim | None = None) -> int:
"""Get parallel rank for specified dimension."""
if self.model_device_mesh is None:
return 0
elif dim is None:
return self._rank
else:
return self.get_device_mesh(dim).get_local_rank()
def get_world_size(self, dim: Dim | None = None) -> int:
"""Get parallel size for specified dimension."""
if self.model_device_mesh is None:
return 1
elif dim is None:
return self._world_size
else:
return self.get_device_mesh(dim).size()
def get_local_rank(self) -> int:
"""Get parallel local rank."""
return self._local_rank
def get_local_world_size(self) -> int:
"""Get parallel local world size."""
return self._local_world_size
def all_gather(self, data: Tensor, dim: Dim | None = Dim.DP) -> Tensor:
"""Gather tensor across specified parallel group."""
if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
else:
return data
def all_reduce(
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP
) -> TensorLike:
"""Reduce tensor across specified parallel group."""
if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim))
else:
return data
def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike:
"""Broadcast tensor across specified parallel group."""
if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
else:
return data
def sync(self) -> None:
"""Synchronize all processes."""
helper.synchronize()
def barrier(self) -> None:
"""Barrier all processes."""
barrier()
def destroy(self) -> None:
"""Destroy all processes."""
destroy_process_group()
if __name__ == "__main__":
print(DistributedInterface(DistributedStrategy()))
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import sys
from pathlib import Path
from typing import Any
from omegaconf import OmegaConf
from transformers import HfArgumentParser
from ...extras.misc import is_env_enabled
from .data_args import DataArguments
from .model_args import ModelArguments
from .sample_args import SampleArguments
from .training_args import TrainingArguments
InputArgument = dict[str, Any] | list[str] | None
def validate_args(
data_args: DataArguments,
model_args: ModelArguments,
training_args: TrainingArguments,
sample_args: SampleArguments,
):
"""Validate arguments."""
if (
model_args.quant_config is not None
and training_args.dist_config is not None
and training_args.dist_config.name == "deepspeed"
):
raise ValueError("Quantization is not supported with deepspeed backend.")
def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
"""Parse arguments from command line or config file."""
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS")
if args is None:
if len(sys.argv) > 1 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
override_config = OmegaConf.from_cli(sys.argv[2:])
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
args = OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
elif len(sys.argv) > 1 and sys.argv[1].endswith(".json"):
override_config = OmegaConf.from_cli(sys.argv[2:])
dict_config = OmegaConf.create(json.load(Path(sys.argv[1]).absolute()))
args = OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
else: # list of strings
args = sys.argv[1:]
if isinstance(args, dict):
(*parsed_args,) = parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
else:
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args, return_remaining_strings=True)
if unknown_args and not allow_extra_keys:
print(parser.format_help())
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
validate_args(*parsed_args)
return tuple(parsed_args)
if __name__ == "__main__":
print(get_args())
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/training_args.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from enum import Enum, unique
class PluginConfig(dict):
"""Dictionary that allows attribute access."""
@property
def name(self) -> str:
"""Plugin name."""
if "name" not in self:
raise ValueError("Plugin configuration must have a 'name' field.")
return self["name"]
PluginArgument = PluginConfig | dict | str | None
@unique
class ModelClass(str, Enum):
"""Auto class for model config."""
LLM = "llm"
CLS = "cls"
OTHER = "other"
@unique
class SampleBackend(str, Enum):
HF = "hf"
VLLM = "vllm"
def _convert_str_dict(data: dict) -> dict:
"""Parse string representation inside the dictionary.
Args:
data: The string or dictionary to convert.
Returns:
The converted dictionary.
"""
for key, value in data.items():
if isinstance(value, dict):
data[key] = _convert_str_dict(value)
elif isinstance(value, str):
if value.lower() in ("true", "false"):
data[key] = value.lower() == "true"
elif value.isdigit():
data[key] = int(value)
elif value.replace(".", "", 1).isdigit():
data[key] = float(value)
return data
def get_plugin_config(config: PluginArgument) -> PluginConfig | None:
"""Get the plugin configuration from the argument value.
Args:
config: The argument value to get the plugin configuration from.
Returns:
The plugin configuration.
"""
if config is None:
return None
if isinstance(config, str) and config.startswith("{"):
config = json.loads(config)
config = _convert_str_dict(config)
if "name" not in config:
raise ValueError("Plugin configuration must have a 'name' field.")
return PluginConfig(config)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
@dataclass
class DataArguments:
dataset: str | None = field(
default=None,
metadata={"help": "Path to the dataset."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Cutoff length for the dataset."},
)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@dataclass
class ModelArguments:
model: str = field(
metadata={"help": "Path to the model or model identifier from Hugging Face."},
)
trust_remote_code: bool = field(
default=False,
metadata={"help": "Trust remote code from Hugging Face."},
)
use_fast_processor: bool = field(
default=True,
metadata={"help": "Use fast processor from Hugging Face."},
)
model_class: ModelClass = field(
default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."},
)
peft_config: PluginConfig | None = field(
default=None,
metadata={"help": "PEFT configuration for the model."},
)
kernel_config: PluginConfig | None = field(
default=None,
metadata={"help": "Kernel configuration for the model."},
)
quant_config: PluginConfig | None = field(
default=None,
metadata={"help": "Quantization configuration for the model."},
)
def __post_init__(self) -> None:
self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config)
self.quant_config = get_plugin_config(self.quant_config)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from .arg_utils import SampleBackend
@dataclass
class SampleArguments:
sample_backend: SampleBackend = field(
default=SampleBackend.HF,
metadata={"help": "Sampling backend, default to 'hf'."},
)
max_new_tokens: int = field(
default=128,
metadata={"help": "Maximum number of new tokens to generate."},
)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass, field
from uuid import uuid4
from .arg_utils import PluginConfig, get_plugin_config
@dataclass
class TrainingArguments:
output_dir: str = field(
default=os.path.join("outputs", str(uuid4())),
metadata={"help": "Path to the output directory."},
)
micro_batch_size: int = field(
default=1,
metadata={"help": "Micro batch size for training."},
)
global_batch_size: int = field(
default=1,
metadata={"help": "Global batch size for training."},
)
learning_rate: float = field(
default=1e-4,
metadata={"help": "Learning rate for training."},
)
bf16: bool = field(
default=False,
metadata={"help": "Use bf16 for training."},
)
dist_config: PluginConfig | None = field(
default=None,
metadata={"help": "Distribution configuration for training."},
)
def __post_init__(self) -> None:
self.dist_config = get_plugin_config(self.dist_config)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import AsyncGenerator
from ..config import ModelArguments, SampleArguments, SampleBackend
from ..utils.types import HFModel, Message, Sample, TorchDataset
from .utils.inference_engine import HuggingFaceEngine
from .utils.rendering import Renderer
class BaseSampler:
"""Base sampler.
Args:
args: Sample arguments.
model_args: Model arguments.
model: Model.
renderer: Renderer.
"""
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
renderer: Renderer,
) -> None:
if args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(args, model_args, model, renderer)
else:
raise ValueError(f"Unknown sample backend: {args.sample_backend}")
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
"""Generate tokens asynchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
async for token in self.engine.generate(messages, tools):
yield token
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
return await self.engine.batch_infer(dataset)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of trainer.
Init Phase:
1. Init dataloader.
2. Init optimizer (deepspeed).
3. Shard model.
4. Init optimizer (fsdp).
5. Init scheduler.
Train Phase:
1. Train Loop
"""
from ..config.training_args import TrainingArguments
from ..utils.types import HFModel, Processor, TorchDataset
from .trainer_utils.data_collator import DataCollator
class BaseTrainer:
def __init__(
self,
args: TrainingArguments,
model: HFModel,
processor: Processor,
dataset: TorchDataset,
) -> None:
self.args = args
self.model = model
self.processor = processor
self.dataset = dataset
self.data_collator = DataCollator()
self.optimizer = None
self.lr_scheduler = None
def init_model_and_optimizer(self) -> None:
pass
def create_dataloader(self) -> None:
pass
def fit(self) -> None:
pass
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from ..config.sample_args import SampleArguments, SampleBackend
from .model_loader import ModelLoader
class BaseEngine(ABC):
@abstractmethod
def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ...
@abstractmethod
async def generate(self):
pass
@abstractmethod
async def batch_infer(self):
pass
class HuggingFaceEngine(BaseEngine):
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
self.args = sample_args
class ChatSampler:
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
if sample_args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(model_loader, sample_args)
else:
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of data engine.
Init Data engine:
1. Parse dataset info from arguments.
2. Load datasets according to dataset info.
3. Build data index (and reweight samples if necessary).
Get Data Sample:
1. Get sample from data index.
2. Convert sample to standard format.
3. Return sample.
"""
import os
from collections.abc import Iterable
from typing import Any
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from torch.utils.data import Dataset
from ..config.data_args import DataArguments
from ..utils.types import DatasetInfo, HFDataset, Sample
class DataEngine(Dataset):
"""Data engine.
Args:
data_args: Data arguments.
"""
def __init__(self, data_args: DataArguments) -> None:
self.args = data_args
"""Data arguments."""
self.datasets: dict[str, HFDataset] = {}
"""Dict of (dataset_name, dataset)"""
self.dataset_infos: dict[str, DatasetInfo] = {}
"""Dict of (dataset_name, dataset_info)"""
self.data_index: list[tuple[str, int]] = []
"""List of (dataset_name, sample_index)"""
self.streaming: bool = False
"""Whether dataset is streaming."""
self._get_dataset_info()
self._load_dataset()
self._build_data_index()
def _get_dataset_info(self) -> None:
"""Get dataset info from data arguments."""
if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file
self.dataset_infos = OmegaConf.load(self.args.dataset)
elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
repo_id, filename = os.path.split(self.args.dataset)
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
self.dataset_infos = OmegaConf.load(filepath)
elif os.path.exists(self.args.dataset): # local file(s)
self.dataset_infos = {"default": {"path": self.args.dataset, "source": "local"}}
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
self.dataset_infos = {"default": {"path": self.args.dataset}}
def _load_dataset(self) -> None:
"""Load datasets according to dataset info."""
for dataset_name, dataset_info in self.dataset_infos.items():
split = dataset_info.get("split", "train")
streaming = dataset_info.get("streaming", False)
self.streaming |= streaming
if dataset_info.get("source", "hf_hub") == "hf_hub":
from datasets import load_dataset
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=streaming)
else: # data loader plugin
from ..plugins.data_plugins.loader import DataLoaderPlugin
self.datasets[dataset_name] = DataLoaderPlugin(dataset_info["source"]).load(dataset_info)
def _build_data_index(self) -> None:
"""Build dataset index."""
for dataset_name, dataset in self.datasets.items():
streaming = self.dataset_infos[dataset_name].get("streaming", False)
if streaming:
data_index = [(dataset_name, -1) for _ in range(1000)]
else:
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
size = self.dataset_infos[dataset_name].get("size")
weight = self.dataset_infos[dataset_name].get("weight")
if size or weight: # data index plugin
from ..plugins.data_plugins.loader import DataIndexPlugin
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
self.data_index.extend(data_index)
def _convert_data_sample(self, raw_sample: dict[str, Any], dataset_name: str) -> Sample:
"""Convert dataset sample.
Args:
raw_sample (dict[str, Any]): Raw dataset sample.
dataset_name (str): Dataset name.
Returns:
Sample: Dataset sample.
"""
converter = self.dataset_infos[dataset_name].get("converter")
if converter is not None:
from ..plugins.data_plugins.converter import DataConverterPlugin
return {"_dataset_name": dataset_name, **DataConverterPlugin(converter)(raw_sample)}
else:
return {"_dataset_name": dataset_name, **raw_sample}
def __len__(self) -> int:
"""Get dataset length.
Returns:
int: Dataset length.
"""
if self.streaming:
return -1
else:
return len(self.data_index)
def __getitem__(self, index: int | Any) -> Sample | list[Sample]:
"""Get dataset item.
Args:
index (int): Dataset index.
Returns:
Sample: Dataset item.
"""
if self.streaming:
raise ValueError("Streaming dataset does not support index access.")
if isinstance(index, int):
dataset_name, sample_index = self.data_index[index]
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
else: # data selector plugin
from ..plugins.data_plugins.loader import DataSelectorPlugin
selected_index = DataSelectorPlugin().select(self.data_index, index)
if isinstance(selected_index, list):
return [
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
for dataset_name, sample_index in selected_index
]
else:
dataset_name, sample_index = selected_index
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
def __iter__(self) -> Iterable[Sample]:
"""Get dataset iterator.
Returns:
Iterable[Sample]: Dataset iterator.
"""
# NOTE: hf iterable dataset uses worker ids while map dataset does not
# NOTE: add worker id and shuffle to the map dataset
# https://github.com/huggingface/datasets/blob/4.0.0/src/datasets/iterable_dataset.py#L2214
raise NotImplementedError()
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_sft_demo.yaml
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_dpo_demo.yaml
"""
from ..config.arg_parser import get_args
data_args, *_ = get_args()
data_engine = DataEngine(data_args=data_args)
print(data_engine[0])
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of model engine.
How to use:
model_engine = ModelEngine(model_args, is_train=True)
model_engine.processor: Get the tokenizer or multi-modal processor.
model_engine.renderer: Get the renderer.
model_engine.model_config: Get the model configuration.
model_engine.model: Get the HF model.
Init workflow:
1. Init processor.
2. Init render.
2. Init model config.
3. Init model.
4. Init adapter.
"""
import torch
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoProcessor
from ..accelerator.helper import DeviceType
from ..accelerator.interface import DistributedInterface
from ..config.model_args import ModelArguments, ModelClass
from ..utils import logging
from ..utils.types import HFConfig, HFModel, Processor
from .utils.rendering import Renderer
logger = logging.get_logger(__name__)
class ModelEngine:
"""Model engine.
Args:
model_args: Model arguments.
is_train: Whether to train the model.
"""
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
self.args = model_args
"""Model arguments."""
self.is_train = is_train
"""Whether to train the model."""
self.processor = self._init_processor()
"""Tokenizer or multi-modal processor."""
self.renderer = Renderer(self.args.template, self.processor)
"""Renderer."""
self.model_config = self._init_model_config()
"""Model configuration."""
self.model = self._init_model()
"""HF model."""
def _init_processor(self) -> Processor:
"""Init processor.
NOTE: Transformers v5 always use fast tokenizer.
https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/auto/tokenization_auto.py#L642
"""
return AutoProcessor.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
)
def _init_model_config(self) -> HFConfig:
"""Init model config."""
return AutoConfig.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
)
def _init_model(self) -> HFModel:
"""Init model.
Transformers can choose the proper model init context.
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
"""
if self.args.init_config is not None:
from ..plugins.model_plugins.initialization import InitPlugin
init_device = InitPlugin(self.args.init_config.name)()
else:
init_device = DistributedInterface().current_device
init_kwargs = {"device_map": init_device}
if self.args.quant_config is not None:
from ..plugins.model_plugins.quantization import QuantizationPlugin
init_kwargs = QuantizationPlugin(self.args.quant_config.name)(
init_kwargs=init_kwargs,
config=self.model_config,
tokenizer=self.processor,
model_args=self.args,
is_trainable=self.is_train,
)
if self.args.model_class == ModelClass.LLM:
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
AutoClass = AutoModelForImageTextToText
else:
AutoClass = AutoModelForCausalLM
elif self.args.model_class == ModelClass.CLS:
from transformers import AutoModelForTokenClassification
AutoClass = AutoModelForTokenClassification
else:
from transformers import AutoModel
AutoClass = AutoModel
if init_device.type == DeviceType.META:
assert self.args.quant_config is None, "Quantization is not supported with meta device."
with init_empty_weights():
model = AutoClass.from_config(self.model_config)
else:
model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
trust_remote_code=self.args.trust_remote_code,
**init_kwargs,
)
if self.args.peft_config is None:
if self.is_train:
logger.info_rank0("Fine-tuning mode: full tuning")
model = model.to(torch.float32)
else:
logger.info_rank0("Inference the original model")
else:
from ..plugins.model_plugins.peft import PeftPlugin
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)
if self.args.kernel_config is not None:
from ..plugins.model_plugins.kernels.interface import KernelPlugin
model = KernelPlugin(self.args.kernel_config.name)(
model, include_kernels=self.args.kernel_config.get("include_kernels")
)
return model
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.model_engine --model llamafactory/tiny-random-qwen2.5
"""
from ..config.arg_parser import get_args
model_args, *_ = get_args()
model_engine = ModelEngine(model_args=model_args)
print(model_engine.processor)
print(model_engine.model_config)
print(model_engine.model)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of model loader.
Init Phase:
1. Init processor.
2. Init model config.
3. Init model.
4. Init adapter.
"""
import torch
from transformers import AutoConfig, AutoProcessor
from ..accelerator.interface import DistributedInterface
from ..config.model_args import ModelArguments, ModelClass
from ..utils import logging
from ..utils.types import HFConfig, HFModel, Processor
logger = logging.get_logger(__name__)
class ModelLoader:
"""Model loader.
Args:
model_args: Model arguments.
is_trainable: Whether to train the model.
"""
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
self.args = model_args
"""Model arguments."""
self.is_train = is_train
"""Whether to train the model."""
self.processor = self._init_processor()
"""Tokenizer or multi-modal processor."""
self.model_config = self._init_model_config()
"""Model configuration."""
self.model = self._init_model()
"""HF model."""
def _init_processor(self) -> Processor:
"""Init processor."""
return AutoProcessor.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
use_fast=self.args.use_fast_processor,
)
def _init_model_config(self) -> HFConfig:
"""Init model config."""
return AutoConfig.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
)
def _init_model(self) -> HFModel:
"""Init model.
Let transformers handle the model init context.
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
"""
if self.args.model_class == ModelClass.LLM:
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
AutoClass = AutoModelForImageTextToText
else:
AutoClass = AutoModelForCausalLM
elif self.args.model_class == ModelClass.CLS:
from transformers import AutoModelForTokenClassification
AutoClass = AutoModelForTokenClassification
else:
from transformers import AutoModel
AutoClass = AutoModel
# map the entire model to the current accelerator
model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map=DistributedInterface().current_accelerator,
trust_remote_code=self.args.trust_remote_code,
)
if self.args.peft_config is None:
if self.is_train:
logger.info_rank0("Fine-tuning mode: full tuning")
model = model.to(torch.float32)
else:
logger.info_rank0("Inference the original model")
else:
from ..plugins.model_plugins.peft import PeftPlugin
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)
if self.args.kernel_config is not None:
from ..plugins.model_plugins.kernels.interface import KernelPlugin
model = KernelPlugin(self.args.kernel_config.name)(
model=model, include_kernels=self.args.kernel_config.get("include_kernels")
)
return model
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5
"""
from ..config.arg_parser import get_args
_, model_args, *_ = get_args()
model_loader = ModelLoader(model_args=model_args)
print(model_loader.processor)
print(model_loader.model_config)
print(model_loader.model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment