Commit ca625f43 authored by shihm's avatar shihm
Browse files

uodata

parent 7164651d
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import os
import torch
import torch.nn as nn
from peft.tuners.lora import LoraLayer
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
from torch.distributed.fsdp import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
fully_shard,
)
from ....accelerator.helper import get_current_accelerator
from ....accelerator.interface import DistributedInterface
from ....utils.logging import get_logger
from ....utils.types import HFModel, Processor
logger = get_logger(__name__)
def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None:
no_split_modules = getattr(model, "_no_split_modules", None)
if no_split_modules:
if isinstance(no_split_modules, (list, tuple)):
for name, module in model.named_modules():
for cls_name in no_split_modules:
if module.__class__.__name__ == cls_name:
return module.__class__
if hasattr(model, "model") and hasattr(model.model, "layers"):
return type(model.model.layers[0])
if hasattr(model, "layers"):
return type(model.layers[0])
return None
def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
if DistributedInterface().get_rank() == 0:
logger.info("Gathering state dict for saving...")
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
state_dict = get_model_state_dict(model, options=options)
if DistributedInterface().get_rank() == 0:
model_to_save = model.module if hasattr(model, "module") else model
model_to_save.save_pretrained(output_dir, state_dict=state_dict, max_shard_size="4GB")
processor.save_pretrained(output_dir, max_shard_size="4GB")
logger.info(f"Model saved to {output_dir}")
class FSDP2Engine:
def __init__(self, dist_config: dict):
self.dist_interface = DistributedInterface()
self.rank = self.dist_interface.get_rank()
self.local_rank = self.dist_interface.get_local_rank()
self.world_size = self.dist_interface.get_world_size()
self.mixed_precision = dist_config.get("mixed_precision", "bf16")
self.reshard_after_forward = dist_config.get("reshard_after_forward", True)
self.offload_params = dist_config.get("offload_params", False)
self.pin_memory = dist_config.get("pin_memory", True)
self.dcp_path = dist_config.get("dcp_path", None)
self.device_mesh = self.dist_interface.data_device_mesh
if self.device_mesh is None:
logger.warning(
"Device Mesh not found in DistributedInterface. FSDP2 might fail if not running in distributed mode."
)
if self.device_mesh is not None:
try:
self.fsdp_mesh = self.device_mesh["dp"]
except Exception:
self.fsdp_mesh = self.device_mesh
logger.info(f"Using Device Mesh: {self.fsdp_mesh}")
else:
self.fsdp_mesh = None
def get_mp_policy(self) -> MixedPrecisionPolicy:
if self.mixed_precision == "bf16":
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
elif self.mixed_precision == "fp16":
param_dtype = torch.float16
reduce_dtype = torch.float32
else:
param_dtype = torch.float32
reduce_dtype = torch.float32
return MixedPrecisionPolicy(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
cast_forward_inputs=True,
)
def is_lora_module_wrap(self, model) -> bool:
return any(isinstance(module, LoraLayer) for module in model.modules())
def prepare_model(self, model: HFModel) -> HFModel:
if self.fsdp_mesh is None:
logger.warning("No FSDP Mesh available, skipping FSDP wrapping.")
return model
mp_policy = self.get_mp_policy()
layer_cls = get_transformer_layer_cls(model)
if layer_cls is None:
logger.warning(
"Could not identify Transformer Layer class, applying FSDP to the whole model structure only."
)
transformer_layer_cls_to_wrap = set()
else:
logger.info(f"Applying per-layer FSDP to {layer_cls.__name__}")
transformer_layer_cls_to_wrap = {layer_cls}
if self.is_lora_module_wrap(model):
lora_modules = []
for module in model.modules():
if len(list(module.children())) != 0:
continue
if any(param.requires_grad for param in module.parameters(recurse=False)):
lora_modules.append(module)
for module in lora_modules:
fully_shard(
module,
mesh=self.fsdp_mesh,
reshard_after_forward=self.reshard_after_forward,
mp_policy=mp_policy,
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
)
logger.info("Applying FSDP wrap for LoRA layer separately.")
for name, module in model.named_modules():
should_wrap = False
if type(module) in transformer_layer_cls_to_wrap:
should_wrap = True
elif isinstance(module, nn.Embedding):
if not getattr(model.config, "tie_word_embeddings", True):
should_wrap = True
if should_wrap:
fully_shard(
module,
mesh=self.fsdp_mesh,
reshard_after_forward=self.reshard_after_forward,
mp_policy=mp_policy,
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
)
# BaseTrainer is the single source of truth for gradient checkpointing.
# FSDP2 only applies the input-grad compatibility hook when checkpointing is already enabled.
if getattr(model, "is_gradient_checkpointing", False):
if self.rank == 0:
logger.info("Gradient checkpointing is enabled. Applying FSDP2 input grad preparation.")
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
fully_shard(
model,
mesh=self.fsdp_mesh,
reshard_after_forward=self.reshard_after_forward,
mp_policy=mp_policy,
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
)
return model
@torch.no_grad()
def materialize_and_load(self, model: HFModel, hf_model_path: str, dcp_path: str = None):
if self.rank == 0:
logger.info("Materializing sharded model params...")
device = get_current_accelerator()
model.to_empty(device=device)
if dcp_path and os.path.exists(dcp_path):
if self.rank == 0:
logger.info(f"DCP path found at {dcp_path}. Using efficient Sharded Loading (DCP Load).")
self._load_from_dcp(model, dcp_path)
else:
if self.rank == 0:
if dcp_path:
logger.warning(f"DCP path {dcp_path} not found.")
logger.info("Using HF Meta Loading (Chunk Load).")
self._load_weights_from_hf_checkpoint(model, hf_model_path)
return model
def _save_non_persistent_buffers(self, model: HFModel) -> dict:
"""Save non-persistent buffers, such as inv_freq."""
saved = {}
for mod_name, module in model.named_modules():
for buf_name in module._non_persistent_buffers_set:
fqn = f"{mod_name}.{buf_name}" if mod_name else buf_name
buf = getattr(module, buf_name, None)
if buf is not None:
saved[fqn] = copy.deepcopy(buf)
if self.rank == 0 and saved:
logger.info(f"Saved {len(saved)} non-persistent buffers")
return saved
def _restore_non_persistent_buffers(self, model: HFModel, saved_buffers: dict):
"""Register saved non-persistent buffers to model."""
if not saved_buffers:
return
device = get_current_accelerator()
for fqn, buf in saved_buffers.items():
buf = buf.to(device)
if "." in fqn:
parent_fqn, buf_name = fqn.rsplit(".", 1)
parent_module = model.get_submodule(parent_fqn)
else:
buf_name = fqn
parent_module = model
parent_module.register_buffer(buf_name, buf, persistent=False)
if self.rank == 0:
logger.info(f"Restored {len(saved_buffers)} non-persistent buffers")
def shard_model(self, model: HFModel) -> HFModel:
if model.device.type == "meta":
non_persistent_buffers = self._save_non_persistent_buffers(model)
if getattr(model.config, "tie_word_embeddings", None):
model.tie_weights()
model = self.prepare_model(model)
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
# fix tied broken for no-fsdp-wrap case
if getattr(model.config, "tie_word_embeddings", None):
model.tie_weights()
self._restore_non_persistent_buffers(model, non_persistent_buffers)
else:
model = self.prepare_model(model)
return model
def _load_from_dcp(self, model: HFModel, dcp_path: str):
import torch.distributed.checkpoint as dcp
try:
if self.rank == 0:
logger.info(f"Loading distributed checkpoint from {dcp_path} ...")
options = StateDictOptions(full_state_dict=False, cpu_offload=True)
local_state_dict = get_model_state_dict(model, options=options)
dcp.load(state_dict=local_state_dict, checkpoint_id=dcp_path)
set_model_state_dict(model, local_state_dict, options=options)
if self.rank == 0:
logger.info("DCP weights loaded successfully.")
except Exception as e:
logger.error(f"Failed to load from DCP: {e}")
raise e
def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str):
import glob
import json
hf_model_path = self._resolve_hf_checkpoint_dir(hf_model_path)
if self.rank == 0:
logger.info(f"Loading weights from {hf_model_path} ...")
index_file = os.path.join(hf_model_path, "model.safetensors.index.json")
is_safetensors = True
checkpoint_files = []
if os.path.exists(index_file):
with open(index_file) as f:
index = json.load(f)
checkpoint_files = sorted(set(index["weight_map"].values()))
checkpoint_files = [os.path.join(hf_model_path, f) for f in checkpoint_files]
elif os.path.exists(os.path.join(hf_model_path, "model.safetensors")):
checkpoint_files = [os.path.join(hf_model_path, "model.safetensors")]
else:
is_safetensors = False
index_file = os.path.join(hf_model_path, "pytorch_model.bin.index.json")
if os.path.exists(index_file):
with open(index_file) as f:
index = json.load(f)
checkpoint_files = sorted(set(index["weight_map"].values()))
checkpoint_files = [os.path.join(hf_model_path, f) for f in checkpoint_files]
elif os.path.exists(os.path.join(hf_model_path, "pytorch_model.bin")):
checkpoint_files = [os.path.join(hf_model_path, "pytorch_model.bin")]
else:
checkpoint_files = sorted(glob.glob(os.path.join(hf_model_path, "*.safetensors")))
if checkpoint_files:
is_safetensors = True
else:
checkpoint_files = sorted(glob.glob(os.path.join(hf_model_path, "*.bin")))
if not checkpoint_files:
raise ValueError(f"No checkpoint files found in {hf_model_path}")
param_map = dict(model.named_parameters())
total_files = len(checkpoint_files)
for i, ckpt_file in enumerate(checkpoint_files):
if self.rank == 0:
logger.info(f"[{i + 1}/{total_files}] Loading {os.path.basename(ckpt_file)} ...")
if is_safetensors:
from safetensors import safe_open
with safe_open(ckpt_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key in param_map:
tensor = f.get_tensor(key)
self._copy_weights(param_map[key], tensor)
else:
state_dict = torch.load(ckpt_file, map_location="cpu")
for key, tensor in state_dict.items():
if key in param_map:
self._copy_weights(param_map[key], tensor)
del state_dict
gc.collect()
def _resolve_hf_checkpoint_dir(self, hf_model_path: str) -> str:
"""Resolve a HF model identifier or local path to a local directory containing checkpoint files.
- If `hf_model_path` is an existing directory, return it.
- If it's a file path, return its parent directory.
- Otherwise treat it as a Hugging Face Hub repo id and download/resolve to the local cache dir.
"""
if not hf_model_path:
return hf_model_path
# Local directory or file path.
if os.path.isdir(hf_model_path):
return hf_model_path
if os.path.isfile(hf_model_path):
return os.path.dirname(hf_model_path)
# HuggingFace Hub repo id: snapshot to local cache so we can glob/index files.
try:
from huggingface_hub import snapshot_download
except ImportError as e:
raise ValueError(
f"hf_model_path='{hf_model_path}' does not exist locally and huggingface_hub is not available "
f"to download it. Please provide a local model directory or install huggingface_hub. Error: {e}"
) from e
revision = os.getenv("HF_REVISION")
offline = os.getenv("HF_HUB_OFFLINE") == "1" or os.getenv("TRANSFORMERS_OFFLINE") == "1"
# In distributed runs, let rank0 download first to avoid N-way concurrent downloads.
if torch.distributed.is_available() and torch.distributed.is_initialized():
if self.rank == 0:
local_dir = snapshot_download(
repo_id=hf_model_path,
revision=revision,
local_files_only=offline,
allow_patterns=[
"*.safetensors",
"*.bin",
"*.index.json",
"model.safetensors",
"model.safetensors.index.json",
"pytorch_model.bin",
"pytorch_model.bin.index.json",
"config.json",
],
)
logger.info(f"Resolved HF repo id '{hf_model_path}' to local dir: {local_dir}")
torch.distributed.barrier()
if self.rank != 0:
local_dir = snapshot_download(
repo_id=hf_model_path,
revision=revision,
local_files_only=True,
allow_patterns=[
"*.safetensors",
"*.bin",
"*.index.json",
"model.safetensors",
"model.safetensors.index.json",
"pytorch_model.bin",
"pytorch_model.bin.index.json",
"config.json",
],
)
return local_dir
local_dir = snapshot_download(
repo_id=hf_model_path,
revision=revision,
local_files_only=offline,
allow_patterns=[
"*.safetensors",
"*.bin",
"*.index.json",
"model.safetensors",
"model.safetensors.index.json",
"pytorch_model.bin",
"pytorch_model.bin.index.json",
"config.json",
],
)
if self.rank == 0:
logger.info(f"Resolved HF repo id '{hf_model_path}' to local dir: {local_dir}")
return local_dir
def _copy_weights(self, param, loaded_tensor):
from torch.distributed._tensor import DTensor, Shard
if loaded_tensor.dtype != param.dtype:
loaded_tensor = loaded_tensor.to(param.dtype)
if isinstance(param, DTensor):
shard_placement = None
mesh_dim = -1
for i, placement in enumerate(param.placements):
if isinstance(placement, Shard):
shard_placement = placement
mesh_dim = i
break
local_tensor = param.to_local()
if shard_placement is None:
local_tensor.copy_(loaded_tensor)
else:
dim = shard_placement.dim
mesh = param.device_mesh
my_coordinate = mesh.get_coordinate()
if my_coordinate is None:
return
rank_in_dim = my_coordinate[mesh_dim]
world_size_in_dim = mesh.size(mesh_dim)
full_size = param.shape[dim]
chunk_size = (full_size + world_size_in_dim - 1) // world_size_in_dim
start = rank_in_dim * chunk_size
end = min(start + chunk_size, full_size)
if start >= full_size:
return
sliced_tensor = loaded_tensor.narrow(dim, start, end - start)
slices = [slice(None)] * local_tensor.ndim
slices[dim] = slice(0, sliced_tensor.shape[dim])
local_tensor[tuple(slices)].copy_(sliced_tensor)
else:
param.data.copy_(loaded_tensor)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from ....config.arg_utils import PluginConfig
from ....utils.plugin import BasePlugin
if TYPE_CHECKING:
from ....utils.types import HFModel, Processor
class DistributedPlugin(BasePlugin):
def __call__(self, model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
return super().__call__(model, dist_config, **kwargs)
@DistributedPlugin("fsdp2").register()
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
from .fsdp2 import FSDP2Engine
return FSDP2Engine(dist_config).shard_model(model)
@DistributedPlugin("fsdp2").register("save_model")
def save_model_fsdp2(model: HFModel, output_dir: str, processor: Processor) -> None:
from .fsdp2 import save_model
return save_model(model, output_dir, processor)
@DistributedPlugin("deepspeed").register()
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
from .deepspeed import DeepSpeedEngine
return DeepSpeedEngine(
dist_config,
num_micro_batch=kwargs.get("num_micro_batch"),
micro_batch_size=kwargs.get("micro_batch_size"),
).shard_model(model)
@DistributedPlugin("deepspeed").register("save_model")
def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor) -> None:
from .deepspeed import save_model
return save_model(model, output_dir, processor)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils.plugin import BasePlugin
class LRSchedulerPlugin(BasePlugin):
pass
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils.plugin import BasePlugin
class OptimizerPlugin(BasePlugin):
pass
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from collections.abc import Generator
from threading import Thread
from ..config import InputArgument, ModelArguments, SampleArguments, SampleBackend, get_args
from ..core.base_sampler import BaseSampler
from ..core.data_engine import DataEngine
from ..core.model_engine import ModelEngine
from ..core.utils.rendering import Renderer
from ..utils.types import HFModel, Message, Sample, TorchDataset
class SyncSampler(BaseSampler):
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
renderer: Renderer,
) -> None:
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
asyncio.set_event_loop(loop)
loop.run_forever()
super().__init__(args, model_args, model, renderer)
self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
self._thread.start()
def generate(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
"""Generate tokens synchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
generator = super().generate(messages, tools)
while True:
try:
token = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop).result()
yield token
except StopAsyncIteration:
break
def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples synchronously.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
return asyncio.run_coroutine_threadsafe(super().batch_infer(dataset), self._loop).result()
def run_chat(args: InputArgument = None):
model_args, data_args, _, sample_args = get_args(args)
if sample_args.sample_backend != SampleBackend.HF:
model_args.init_plugin = {"name": "init_on_meta"}
model_engine = ModelEngine(model_args)
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
if data_args.train_dataset is not None:
dataset = DataEngine(data_args.train_dataset)
sampler.batch_infer(dataset)
else:
if os.name != "nt":
try:
import readline # noqa: F401
except ImportError:
print("Install `readline` for a better experience.")
messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True:
try:
query = input("\nUser: ")
except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue
except Exception:
raise
if query.strip() == "exit":
break
if query.strip() == "clear":
messages = []
print("History has been removed.")
continue
messages.append({"role": "user", "content": [{"type": "text", "value": query}]})
print("Assistant: ", end="", flush=True)
response = ""
for new_text in sampler.generate(messages):
print(new_text, end="", flush=True)
response += new_text
print()
messages.append(model_engine.renderer.parse_message(response))
if __name__ == "__main__":
run_chat()
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..accelerator.interface import DistributedInterface
from ..config.arg_parser import get_args
from ..core.base_trainer import BaseTrainer
from ..core.data_engine import DataEngine
from ..core.model_loader import ModelLoader
class SFTTrainer(BaseTrainer):
pass
def run_sft(user_args):
model_args, data_args, training_args, _ = get_args(user_args)
DistributedInterface(training_args.dist_config)
data_engine = DataEngine(data_args)
model_loader = ModelLoader(model_args)
trainer = SFTTrainer(
args=training_args,
model=model_loader.model,
processor=model_loader.processor,
dataset=data_engine,
)
trainer.fit()
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's VeOmni library.
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/data/dynamic_batching.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
class DynamicBatchSizeBuffer:
"""A buffer to store samples for dynamic batch size."""
def __init__(self):
self._buffer: list[dict[str, any]] = []
self._buffer_sample_lengths: list[int] = []
self._deleted_indices: set[int] = set()
self._current_index: int = 0
self._total_token_count: int = 0
def append(self, item: dict[str, any]) -> None:
"""Append a sample to the buffer.
Args:
item: A sample to append to the buffer.
The sample should be a dict with the following keys:
- input_ids: torch.Tensor of shape (seq_len, )
- attention_mask: torch.Tensor of shape (seq_len, )
"""
self._buffer.append(item)
sample_length = int(item["attention_mask"].sum().item())
self._buffer_sample_lengths.append(sample_length)
self._total_token_count += sample_length
def get_samples(self, max_tokens_per_iteration: int, force: bool = True) -> list[dict[str, any]]:
"""Get samples from the buffer that fit within the token budget.
Args:
max_tokens_per_iteration: Maximum number of tokens to retrieve.
force: If True, the first available sample will be returned even
if it exceeds the token budget.
Returns:
A list of samples that fit within the token budget.
Raises:
AssertionError: If no samples are found (should not happen in normal operation).
"""
cum_seq_len = 0
samples = []
while self._current_index < len(self._buffer) and cum_seq_len < max_tokens_per_iteration:
if self._current_index in self._deleted_indices:
self._current_index += 1
continue
seq_len = self._buffer_sample_lengths[self._current_index]
remaining_tokens = max_tokens_per_iteration - cum_seq_len
# Check if we can add this sample
can_add = (force and cum_seq_len == 0) or (seq_len <= remaining_tokens)
if can_add:
cum_seq_len += seq_len
samples.append(self._buffer[self._current_index])
self._deleted_indices.add(self._current_index)
self._current_index += 1
assert len(samples) > 0, "No samples found in buffer"
return samples
def __len__(self) -> int:
"""Return the number of samples in the buffer."""
return len(self._buffer)
@property
def total_token_count(self) -> int:
"""Return the total number of tokens in the buffer."""
return self._total_token_count
def flush(self) -> None:
tokens_to_remove = sum(self._buffer_sample_lengths[idx] for idx in self._deleted_indices)
self._total_token_count -= tokens_to_remove
buffer_length = len(self._buffer)
self._buffer = [self._buffer[idx] for idx in range(buffer_length) if idx not in self._deleted_indices]
self._buffer_sample_lengths = [
self._buffer_sample_lengths[idx] for idx in range(buffer_length) if idx not in self._deleted_indices
]
self._current_index = 0
self._deleted_indices.clear()
class BaseBatchingQueue(ABC):
"""Base class for batching queue."""
@abstractmethod
def is_full_filled(self) -> bool:
raise NotImplementedError("Subclasses must implement `is_full_filled`")
@abstractmethod
def put_item(self, item: dict[str, any]) -> None:
raise NotImplementedError("Subclasses must implement `put_item`")
@abstractmethod
def get_micro_batch(self, step: int) -> list[dict[str, any]]:
raise NotImplementedError("Subclasses must implement `get_micro_batch`")
@abstractmethod
def empty(self) -> bool:
raise NotImplementedError("Subclasses must implement `empty`")
class IdentityPacker:
def __init__(self, token_micro_bsz, bsz_warmup_steps, bsz_warmup_init_mbtoken):
self.token_micro_bsz = token_micro_bsz
self.bsz_warmup_steps = bsz_warmup_steps
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken
def __call__(self, samples):
return samples
def get_token_num_to_request(self, cur_step, warmup):
return (
(self.token_micro_bsz - self.bsz_warmup_init_mbtoken) * cur_step // self.bsz_warmup_steps
+ self.bsz_warmup_init_mbtoken
if warmup
else self.token_micro_bsz
)
class TextBatchingQueue(BaseBatchingQueue):
"""Batching text queue for text data."""
def __init__(
self,
token_micro_bsz,
buffer_size: int = 500,
bsz_warmup_steps: int = -1,
bsz_warmup_init_mbtoken: int = 200,
) -> None:
super().__init__()
self._step = 0
self.token_micro_bsz = token_micro_bsz
self.bsz_warmup_steps = bsz_warmup_steps
self.buffer_size = buffer_size # minimum samples in buffer
self.buffer = DynamicBatchSizeBuffer()
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken # training warmup args
assert self.bsz_warmup_init_mbtoken >= 0
self.packer = IdentityPacker(
token_micro_bsz=token_micro_bsz,
bsz_warmup_steps=bsz_warmup_steps,
bsz_warmup_init_mbtoken=bsz_warmup_init_mbtoken,
)
def is_full_filled(self) -> bool:
return len(self.buffer) >= self.buffer_size and self.buffer.total_token_count >= self.token_micro_bsz
def put_item(self, item: dict[str, any]):
if len(item["input_ids"]) == 1:
print("WARNING: EMPTY STRING.")
return
self.buffer.append(item)
def get_token_num_to_request(self):
if self.packer is not None:
warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
return self.packer.get_token_num_to_request(self._step, warmup=warmup)
else:
return self.get_cur_token_micro_bsz()
def get_cur_token_micro_bsz(self):
warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
if warmup:
return (
self.token_micro_bsz - self.bsz_warmup_init_mbtoken
) * self._step // self.bsz_warmup_steps + self.bsz_warmup_init_mbtoken
else:
return self.token_micro_bsz
def get_micro_batch(self, step) -> any:
"""Get a micro batch from the buffer according to the current step.
Args:
step: the current step.
Returns:
data: a list of samples.
"""
self._step = step
n_token_per_iter = self.get_token_num_to_request()
cur_token_micro_bsz = self.get_cur_token_micro_bsz()
assert cur_token_micro_bsz % n_token_per_iter == 0, (
"The token num to get for each request should be divisible by token micro bsz."
)
n_iter = int(cur_token_micro_bsz // n_token_per_iter)
data = []
for _ in range(n_iter):
samples = self.buffer.get_samples(n_token_per_iter)
if self.packer:
samples = self.packer(samples) # maybe packed into one sample, but wrapped in list.
data.extend(samples)
self.buffer.flush() # remove the selected samples.
return data
def empty(self) -> bool:
return len(self.buffer) == 0
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's verl library.
# https://github.com/volcengine/verl/blob/v0.6.1/verl/utils/torch_dtypes.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
import torch
from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
from ..accelerator.interface import DistributedInterface
class DtypeRegistry:
HALF_LIST = ["fp16", "float16", "half", torch.float16]
FLOAT_LIST = ["fp32", "float32", "float", torch.float32]
BFLOAT_LIST = ["bf16", "bfloat16", torch.bfloat16]
class DtypeInterface:
"""Type of precision used."""
_is_fp16_available = is_torch_fp16_available_on_device(DistributedInterface.current_accelerator)
_is_bf16_available = is_torch_bf16_available_on_device(DistributedInterface.current_accelerator)
_is_fp32_available = True
@staticmethod
def is_available(precision: str | torch.dtype) -> bool:
if precision in DtypeRegistry.HALF_LIST:
return DtypeInterface._is_fp16_available
elif precision in DtypeRegistry.FLOAT_LIST:
return DtypeInterface._is_fp32_available
elif precision in DtypeRegistry.BFLOAT_LIST:
return DtypeInterface._is_bf16_available
else:
raise RuntimeError(f"Unexpected precision: {precision}")
@staticmethod
def is_fp16(precision: str | torch.dtype) -> bool:
return precision in DtypeRegistry.HALF_LIST
@staticmethod
def is_fp32(precision: str | torch.dtype) -> bool:
return precision in DtypeRegistry.FLOAT_LIST
@staticmethod
def is_bf16(precision: str | torch.dtype) -> bool:
return precision in DtypeRegistry.BFLOAT_LIST
@staticmethod
def to_dtype(precision: str | torch.dtype) -> torch.dtype:
if precision in DtypeRegistry.HALF_LIST:
return torch.float16
elif precision in DtypeRegistry.FLOAT_LIST:
return torch.float32
elif precision in DtypeRegistry.BFLOAT_LIST:
return torch.bfloat16
else:
raise RuntimeError(f"Unexpected precision: {precision}")
@staticmethod
def to_str(precision: torch.dtype) -> str:
if precision == torch.float16:
return "float16"
elif precision == torch.float32:
return "float32"
elif precision == torch.bfloat16:
return "bfloat16"
else:
raise RuntimeError(f"Unexpected precision: {precision}")
@contextmanager
def set_dtype(self, precision: str | torch.dtype):
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(self.to_dtype(precision))
try:
yield
finally:
torch.set_default_dtype(original_dtype)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import socket
def find_available_port() -> int:
"""Find an available port on the local machine."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
return port
def is_env_enabled(env_var: str, default: str = "0") -> bool:
"""Check if the environment variable is enabled."""
return os.getenv(env_var, default).lower() in ["true", "yes", "on", "t", "y", "1"]
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers import PreTrainedTokenizer
from transformers import set_seed as hf_set_seed
from ..accelerator.interface import DistributedInterface
from .constants import IGNORE_INDEX
from .types import BatchInput, ModelInput, Processor, Tensor
def set_seed(seed: int) -> None:
"""Set seed for reproducibility.
Args:
seed: Random seed.
"""
hf_set_seed(seed)
def is_tokenizer(processor: Processor) -> bool:
"""Check if processor is tokenizer.
Args:
processor: Processor.
Returns:
Whether processor is tokenizer.
"""
return not hasattr(processor, "tokenizer")
def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
"""Get tokenizer from processor.
Args:
processor: Processor.
Returns:
Tokenizer.
"""
return processor.tokenizer if hasattr(processor, "tokenizer") else processor
def _pad_and_truncate(tensor: Tensor, max_seqlen: int, pad_value: int = 0) -> Tensor:
if tensor.shape[-1] >= max_seqlen:
return tensor[..., :max_seqlen]
pad_shape = list(tensor.shape)
pad_shape[-1] = max_seqlen - tensor.shape[-1]
pad_tensor = torch.full(pad_shape, pad_value, dtype=tensor.dtype, device=tensor.device)
return torch.cat([tensor, pad_tensor], dim=-1)
def pad_and_truncate(samples: list[ModelInput], max_seqlen: int) -> list[BatchInput]:
max_length = min(max(len(sample["input_ids"]) for sample in samples), max_seqlen)
padded_samples = []
for sample in samples:
padded_sample = {}
for key, value in sample.items():
if "label" in key:
pad_value = IGNORE_INDEX
else:
pad_value = 0
if not isinstance(value, str):
padded_sample[key] = _pad_and_truncate(torch.tensor(value), max_length, pad_value)
else:
padded_sample[key] = value
padded_samples.append(padded_sample)
return padded_samples
def compute_valid_tokens(batches: list[BatchInput]) -> int:
"""Compute valid tokens in batches.
Args:
batches: Batches.
Returns:
Number of valid tokens.
"""
device = DistributedInterface().current_device
return sum(
(batch["labels"].to(device, non_blocking=True) != IGNORE_INDEX).sum().item()
for batch in batches
if "labels" in batch
)
# Copyright 2025 Optuna, HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/utils/logging.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import threading
from functools import lru_cache
from typing import Optional
_thread_lock = threading.RLock()
_default_handler: Optional["logging.Handler"] = None
_default_log_level: "logging._Level" = logging.INFO
class _Logger(logging.Logger):
"""A logger that supports rank0 logging."""
def info_rank0(self, *args, **kwargs) -> None:
self.info(*args, **kwargs)
def warning_rank0(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def warning_rank0_once(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def _get_default_logging_level() -> "logging._Level":
"""Return the default logging level."""
env_level_str = os.getenv("LLAMAFACTORY_VERBOSITY", None)
if env_level_str:
if env_level_str.upper() in logging._nameToLevel:
return logging._nameToLevel[env_level_str.upper()]
else:
raise ValueError(f"Unknown logging level: {env_level_str}.")
return _default_log_level
def _get_library_name() -> str:
return __name__.split(".")[0]
def _get_library_root_logger() -> "_Logger":
return logging.getLogger(_get_library_name())
def _configure_library_root_logger() -> None:
"""Configure root logger using a stdout stream handler with an explicit format."""
global _default_handler
with _thread_lock:
if _default_handler: # already configured
return
formatter = logging.Formatter(
fmt="[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.setFormatter(formatter)
library_root_logger = _get_library_root_logger()
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(_get_default_logging_level())
library_root_logger.propagate = False
def get_logger(name: str | None = None) -> "_Logger":
"""Return a logger with the specified name. It it not supposed to be accessed externally."""
if name is None:
name = _get_library_name()
_configure_library_root_logger()
return logging.getLogger(name)
def add_handler(handler: "logging.Handler") -> None:
"""Add a handler to the root logger."""
_configure_library_root_logger()
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
"""Remove a handler to the root logger."""
_configure_library_root_logger()
_get_library_root_logger().removeHandler(handler)
def info_rank0(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.info(*args, **kwargs)
def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
@lru_cache(None)
def warning_rank0_once(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
logging.Logger.info_rank0 = info_rank0
logging.Logger.warning_rank0 = warning_rank0
logging.Logger.warning_rank0_once = warning_rank0_once
# Copyright 2025 Optuna, HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/utils/logging.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .types import ModelInput
class StatefulBuffer:
"""A buffer that stores model inputs."""
def __init__(self, max_buffer_size: int = 1_000_000_000) -> None:
self._buffer: list[ModelInput] = []
self._buffer_size: int = 0
self._max_buffer_size: int = max_buffer_size
def __len__(self) -> int:
return len(self._buffer)
@property
def size(self) -> int:
return self._buffer_size
def put(self, samples: list[ModelInput]) -> None:
"""Add samples to the buffer."""
num_tokens = sum(len(sample["input_ids"]) for sample in samples)
if self._buffer_size + num_tokens > self._max_buffer_size:
raise ValueError(f"Buffer size exceeds max buffer size {self._max_buffer_size}.")
self._buffer.extend(samples)
self._buffer_size += num_tokens
def get(self, value: int) -> list[ModelInput]:
"""Get samples from the buffer and remove them."""
samples = self._buffer[:value]
self._buffer_size -= sum(len(sample["input_ids"]) for sample in samples)
del self._buffer[:value]
return samples
def clear(self) -> None:
"""Clear the buffer."""
self._buffer = []
self._buffer_size = 0
def state_dict(self) -> dict:
"""Returns the state of the buffer."""
return {
"buffer": self._buffer,
"buffer_size": self._buffer_size,
}
def load_state_dict(self, state_dict: dict) -> None:
"""Loads the state into the buffer."""
self._buffer = state_dict["buffer"]
self._buffer_size = state_dict["buffer_size"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment