Commit ca625f43 authored by shihm's avatar shihm
Browse files

uodata

parent 7164651d
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of kernel interface.
Init Phase:
1. Scan all kernels.
2. Register default kernels.
3. Define kernel plugin.
"""
import importlib
from pathlib import Path
from ....utils.logging import get_logger
from ....utils.plugin import BasePlugin
from .registry import Registry
logger = get_logger(__name__)
def scan_all_kernels():
r"""Scan all kernels in the ``ops`` directory.
Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
Returns:
dict[str, type[BaseKernel]]: A dictionary of registered kernels.
.. note::
This function assumes that the ``ops`` directory is located in the same directory as this file.
It recursively searches for ``.py`` files and constructs the module path for import.
"""
ops_path = Path(__file__).parent / "ops"
if not ops_path.exists():
return
base_package = __package__
for file_path in ops_path.rglob("*.py"):
if file_path.name == "__init__.py":
continue
# calculate the relative path:
# file_path = .../kernels_v2/ops/mlp/npu_swiglu.py
# rel_path = ops/mlp/npu_swiglu.py
rel_path = file_path.relative_to(Path(__file__).parent)
# build module path:
module_name = ".".join(rel_path.parts)[:-3]
full_module_name = f"{base_package}.{module_name}"
try:
importlib.import_module(full_module_name)
except Exception as e:
logger.warning(f"[Kernel Registry] Failed to import {full_module_name} when loading kernels: {e}")
return Registry.get_registered_kernels()
default_kernels = scan_all_kernels()
def get_default_kernels():
r"""Get a list of default registered kernel IDs.
Returns:
list[str]: List of kernel IDs.
"""
return list(default_kernels.keys())
def apply_kernel(kernel_id: str, **kwargs):
r"""Applies a specific kernel to the model.
Args:
kernel_id (str): The ID of the kernel to apply.
**kwargs: Keyword arguments passed to the kernel application function.
Typically includes the model instance.
Returns:
HFModel: The model with applied kernel.
"""
kernel = default_kernels.get(kernel_id)
if kernel is None:
raise ValueError(f"Kernel {kernel_id} not found")
kernel.apply(**kwargs)
class KernelPlugin(BasePlugin):
r"""Plugin for managing kernel optimizations."""
pass
@KernelPlugin("auto").register
def apply_default_kernels(**kwargs):
r"""Applies all default registered kernels to the model.
Args:
**kwargs: Keyword arguments passed to the kernel application function.
Typically includes the model instance and the include_kernels configuration.
Returns:
HFModel: The model with applied kernels.
"""
if not kwargs.get("include_kernels"): # None/False/empty string
return kwargs.get("model")
elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto
use_kernels = default_kernels.keys()
else:
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
for kernel in use_kernels:
if kernel not in default_kernels:
raise ValueError(f"Kernel {kernel} not found")
apply_kernel(kernel, **kwargs)
return kwargs.get("model")
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of NPU fused MoE kernels.
Init Phase:
1. Define GMM functions.
2. Define NPU fused MoE functions.
3. Register NPU fused MoE kernel.
"""
import types
import torch
import torch.nn.functional as F
try:
import torch_npu
except ImportError:
pass
from ......accelerator.helper import DeviceType
from ......utils.packages import is_transformers_version_greater_than
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
class GmmFunction(torch.autograd.Function):
r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
@staticmethod
def forward(ctx, x, weight, group_list):
r"""Performs the forward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object to save tensors for backward pass.
x (Tensor): Input tensor.
weight (Tensor): Weight tensor.
group_list (list): List of group sizes.
Returns:
Tensor: The result of the grouped matrix multiplication.
"""
ctx.save_for_backward(x, weight)
ctx.group_list = group_list
fwd_output = torch_npu.npu_grouped_matmul(
[x], [weight], bias=None, group_list=group_list, split_item=2, group_type=0, group_list_type=1
)[0]
return fwd_output
@staticmethod
def backward(ctx, grad_output):
r"""Performs the backward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object containing saved tensors.
grad_output (Tensor): Gradient with respect to the output.
Returns:
tuple: Gradients with respect to input, weight, and None for group_list.
"""
input_tensor, weight = ctx.saved_tensors
group_list = ctx.group_list
weight = torch.transpose(weight, 1, 2)
grad_input = torch_npu.npu_grouped_matmul(
[grad_output], [weight], bias=None, group_list=group_list, split_item=2, group_type=0, group_list_type=1
)[0]
grad_weight = torch_npu.npu_grouped_matmul(
[input_tensor.T],
[grad_output],
bias=None,
group_list=group_list,
split_item=3,
group_type=2,
group_list_type=1,
)[0]
return grad_input, grad_weight, None
class HybridGmmFunction(torch.autograd.Function):
r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
@staticmethod
def forward(ctx, num_experts, *args):
r"""Performs the forward pass of Hybrid GMM.
Args:
ctx: Context object to save tensors.
num_experts (int): Number of experts.
*args: Variable length argument list containing inputs and weights.
Returns:
tuple: The outputs of the grouped matrix multiplication.
"""
x_list = list(args[:num_experts])
weight_list = list(args[num_experts:])
split_sizes = [x.shape[0] for x in x_list]
ctx.split_sizes = split_sizes
ctx.num_experts = num_experts
ctx.save_for_backward(*args)
outputs = torch_npu.npu_grouped_matmul(
x_list, weight_list, bias=None, group_list=None, split_item=0, group_type=-1
)
return tuple(outputs)
@staticmethod
def backward(ctx, *grad_outputs):
r"""Performs the backward pass of Hybrid GMM.
Args:
ctx: Context object containing saved tensors.
*grad_outputs: Gradients with respect to the outputs.
Returns:
tuple: Gradients with respect to inputs and weights.
"""
saved_tensors = ctx.saved_tensors
num_experts = ctx.num_experts
split_sizes = ctx.split_sizes
x_list = list(saved_tensors[:num_experts])
weight_list = list(saved_tensors[num_experts:])
grad_outputs_contiguous = [g.contiguous() for g in grad_outputs]
w_t_list = [w.t() for w in weight_list]
grad_x_list = torch_npu.npu_grouped_matmul(
grad_outputs_contiguous, # List[Tensor], 每个 [M_i, N]
w_t_list, # List[Tensor], 每个 [N, K] (view)
bias=None,
group_list=None,
split_item=0,
group_type=-1,
)
x_concat = torch.cat(x_list, dim=0)
dy_concat = torch.cat(grad_outputs_contiguous, dim=0) # [Total_M, N]
group_list = torch.tensor(split_sizes, device=x_concat.device, dtype=torch.int64)
grad_w_stack = torch_npu.npu_grouped_matmul(
[x_concat.t()],
[dy_concat],
bias=None,
group_list=group_list,
split_item=3,
group_type=2,
group_list_type=1,
)[0]
if grad_w_stack.dim() == 3:
grad_w_list = list(torch.unbind(grad_w_stack, dim=0))
else:
raise RuntimeError(f"Unexpected grad_w_stack shape: {grad_w_stack.shape}")
return (None, *grad_x_list, *grad_w_list)
class NpuMoeFused:
r"""Container for NPU fused MoE forward functions."""
@staticmethod
def npu_moe_experts_forward(
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
) -> torch.Tensor:
r"""Forward pass for MoE experts using NPU fused operations.
Args:
self: The MoE layer instance.
hidden_states (Tensor): Input hidden states.
routing_weights (Tensor): Routing weights.
router_indices (Tensor): Router indices.
Returns:
Tensor: Output tensor after expert computation.
"""
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
hidden_states, router_indices.to(torch.int32)
)
tokens_per_expert = torch.histc(router_indices, bins=self.num_experts, min=0, max=self.num_experts)
intermediate_hidden_states = GmmFunction.apply(permuted_hidden_states, self.gate_up_proj, tokens_per_expert)
intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1)
output = GmmFunction.apply(intermediate_activations, self.down_proj, tokens_per_expert)
next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights)
next_states = next_states.view(batch_size, -1, self.hidden_size)
return next_states
@staticmethod
def npu_moe_sparse_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""Forward pass for sparse MoE block using NPU optimization.
Args:
self: The MoE sparse block instance.
hidden_states (Tensor): Input hidden states.
Returns:
Tensor: The routed output.
"""
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
router_logits = self.gate(hidden_states)
routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size)
routed_out = self.experts(hidden_states, routing_weights, router_indices)
return routed_out
class Qwen3NpuMoeFused:
r"""Container for Qwen3 NPU fused MoE forward functions."""
@staticmethod
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
Args:
self: The Qwen3 MoE block instance.
hidden_states (Tensor): Input hidden states.
Returns:
tuple: A tuple containing the next states and router logits.
"""
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(hidden_states, selected_experts.int())
tokens_per_expert = torch.histc(
selected_experts.float(), bins=self.num_experts, min=0, max=self.num_experts
).long()
split_sizes = tokens_per_expert.tolist()
input_list = list(torch.split(permuted_hidden_states, split_sizes, dim=0))
gate_weights = [e.gate_proj.weight.t() for e in self.experts]
up_weights = [e.up_proj.weight.t() for e in self.experts]
down_weights = [e.down_proj.weight.t() for e in self.experts]
gate_out_tuple = HybridGmmFunction.apply(len(input_list), *input_list, *gate_weights)
up_out_tuple = HybridGmmFunction.apply(len(input_list), *input_list, *up_weights)
inter_list = [F.silu(g) * u for g, u in zip(gate_out_tuple, up_out_tuple)]
down_out_tuple = HybridGmmFunction.apply(len(inter_list), *inter_list, *down_weights)
grouped_output = torch.cat(down_out_tuple, dim=0)
next_states = torch_npu.npu_moe_token_unpermute(grouped_output, row_ids_map, probs=routing_weights)
next_states = next_states.view(batch_size, sequence_length, -1)
return next_states, router_logits
# moe patch config mapping
kernel_moe_mapping = {
"Qwen3VLMoeForConditionalGeneration": {
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
}
}
if not is_transformers_version_greater_than("5.0.0"):
kernel_moe_mapping["Qwen3MoeForCausalLM"] = {
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward
}
@register_kernel
class NpuFusedMoEKernel(BaseKernel):
r"""NPU Fused MoE Kernel implementation."""
_kernel_id = "npu_fused_moe"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> HFModel:
r"""Applies the NPU fused MoE kernel to the model.
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with patched MoE forward functions.
Raises:
ValueError: If the model is not provided.
RuntimeError: If dependencies are not met.
"""
model = kwargs.get("model", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
raise RuntimeError("torch_npu is not available but NpuMoEFusedMoEKernel was called.")
archs = getattr(model.config, "architectures", [])
target_moe_mapping = None
for arch in archs:
if arch in kernel_moe_mapping:
target_moe_mapping = kernel_moe_mapping[arch]
break
if target_moe_mapping is None:
return model
for module in model.modules():
class_name = module.__class__.__name__
if class_name in target_moe_mapping:
new_forward_func = target_moe_mapping[class_name]
module.forward = types.MethodType(new_forward_func, module)
return model
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of NPU fused SwiGLU kernels.
Init Phase:
1. Define SwiGLU forward functions.
2. Register NPU fused SwiGLU kernel.
"""
import re
import types
import torch
from ......accelerator.helper import DeviceType
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
try:
import torch_npu
except ImportError:
pass
def npu_swiglu_forward(self, hidden_state):
r"""SwiGLU forward pass for NPU.
Args:
self: The MLP layer instance.
hidden_state (Tensor): Input hidden state.
Returns:
Tensor: Output of SwiGLU.
"""
return self.down_proj(
torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
)
def _npu_swiglu_glm4_forward(self, hidden_states):
r"""SwiGLU forward pass for GLM4 on NPU.
Args:
self: The GLM4 MLP layer instance.
hidden_states (Tensor): Input hidden states.
Returns:
Tensor: Output of SwiGLU.
"""
up_states = self.gate_up_proj(hidden_states)
gate, up_states = up_states.chunk(2, dim=-1)
return self.down_proj(torch_npu.npu_swiglu(torch.cat((gate, up_states), dim=-1), dim=-1))
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
r"""SwiGLU forward pass for Gemma3nText on NPU.
Args:
self: The Gemma3nText MLP layer instance.
hidden_states (Tensor): Input hidden states.
Returns:
Tensor: Output of SwiGLU.
"""
gate_proj = self.gate_proj(hidden_states)
if self.activation_sparsity > 0.0:
gate_proj = self._gaussian_topk(gate_proj)
down_proj = self.down_proj(
torch_npu.npu_swiglu(torch.cat((gate_proj, self.up_proj(hidden_states)), dim=-1), dim=-1)
)
return down_proj
@register_kernel
class NpuSwiGluKernel(BaseKernel):
r"""NPU Kernel for fused SwiGLU activation."""
# just support apply to the following module layers
expect_modules = frozenset(
{
"Qwen3VLMoeTextMLP",
"Qwen3VLTextMLP",
"Qwen3OmniMoeThinkerTextMLP",
"Qwen3OmniMoeMLP",
"Qwen3OmniMoeTalkerTextMLP",
"Qwen3OmniMoeCode2WavMlp",
"Qwen3NextMLP",
"Qwen3MoeMLP",
"Qwen3MLP",
"Qwen2MLP",
"Qwen2MoeMLP",
"Qwen2_5_VLMLP",
"Qwen2_5OmniMLP",
"Llama4TextMLP",
"LlamaMLP",
"Glm4MLP",
"Glm4MoeMLP",
"Glm4vMoeTextMLP",
"Gemma3MLP",
"Gemma2MLP",
"Gemma3nTextMLP",
"Phi3MLP",
"DeepseekV2MLP",
"DeepseekV3MLP",
"SeedOssMLP",
}
)
_kernel_id = "npu_fused_swiglu"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> "HFModel":
r"""Applies the NPU fused SwiGLU kernel to the model.
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with patched SwiGLU forward functions.
Raises:
ValueError: If the model is not provided.
RuntimeError: If dependencies are not met.
"""
model = kwargs.get("model", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
raise RuntimeError("torch_npu is not available but NpuSwiGluKernel was called.")
# Mapping of specific mlp modules to their corresponding kernel implementations
kernel_mapping = {
"Glm4MLP": _npu_swiglu_glm4_forward,
"Glm4vTextMLP": _npu_swiglu_glm4_forward,
"Phi3MLP": _npu_swiglu_glm4_forward,
"Gemma3nTextMLP": _npu_swiglu_gemma3ntext_forward,
}
swiglu_pattern = re.compile("MLP", re.IGNORECASE)
for name, module in model.named_modules():
# Match any module whose class name contains "MLP"
if (
re.search(swiglu_pattern, module.__class__.__name__)
and module.__class__.__name__ in cls.expect_modules
):
# Bind function as an instance method to preserve `self` semantics
# and replace the original forward
kernel_func = kernel_mapping.get(module.__class__.__name__, npu_swiglu_forward)
module.forward = types.MethodType(kernel_func, module)
return model
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of NPU fused RMSNorm kernels.
Init Phase:
1. Define RMSNorm forward function.
2. Register NPU fused RMSNorm kernel.
"""
import re
import types
from ......accelerator.helper import DeviceType
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
def npu_rms_norm_forward(self, hidden_states):
r"""NPU forward implementation for RMSNorm.
Args:
self: RMSNorm module instance with `weight` and `variance_epsilon`.
hidden_states (Tensor): Input hidden states tensor, same shape as the baseline.
Returns:
Tensor: Normalized tensor consistent with the baseline RMSNorm behavior.
"""
import torch_npu
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
@register_kernel
class NpuRMSNormKernel(BaseKernel):
r"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
_kernel_id = "npu_fused_rmsnorm"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> "HFModel":
r"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
Key points:
- Match modules whose class name contains "RMSNorm" (case-insensitive).
- Bind `_npu_rms_forward` as an instance method via `types.MethodType` to
replace the original `forward`.
- Do not modify weights, hyperparameters, or module structure to ensure
numerical behavior and interface consistency.
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with NPU fused RMSNorm.
Raises:
RuntimeError: If torch_npu is not available.
ValueError: If the model is not provided.
"""
model = kwargs.get("model")
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
for name, module in model.named_modules():
# Match any module whose class name contains "RMSNorm"
if re.search(rms_norm_pattern, module.__class__.__name__):
# Bind function as an instance method to preserve `self` semantics
# and replace the original forward
module.forward = types.MethodType(npu_rms_norm_forward, module)
return model
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of NPU fused RoPE kernels.
Init Phase:
1. Define RoPE forward functions.
2. Register NPU fused RoPE kernel.
"""
import sys
import torch
from ......accelerator.helper import DeviceType
from ......utils.logging import get_logger
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
logger = get_logger(__name__)
try:
import torch_npu
except ImportError:
pass
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
r"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
Args:
q (Tensor): Query tensor.
k (Tensor): Key tensor.
cos (Tensor): Cosine part of embedding.
sin (Tensor): Sine part of embedding.
position_ids (Tensor, optional): Position IDs. Default: ``None``.
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
Returns:
tuple: (q_embed, k_embed) The embedded query and key tensors.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
r"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
Args:
q (Tensor): Query tensor.
k (Tensor): Key tensor.
cos (Tensor): Cosine part of embedding.
sin (Tensor): Sine part of embedding.
mrope_section (Tensor): Multimodal RoPE section.
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
Returns:
tuple: (q_embed, k_embed) The embedded query and key tensors.
"""
mrope_section = mrope_section * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
@register_kernel
class NpuRoPEKernel(BaseKernel):
r"""NPU Kernel for Rotary Position Embedding."""
_kernel_id = "npu_fused_rope"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> "HFModel":
r"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original
`apply_rotary_pos_emb` function in that module's namespace with the
NPU-accelerated version from this file.
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with patched RoPE functions.
Raises:
RuntimeError: If dependencies are not met.
ValueError: If the model is not provided.
"""
if not cls.check_deps():
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
model = kwargs.get("model", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
_modules = set()
for module in model.modules():
if "Attention" in module.__class__.__name__:
module_name = module.__class__.__module__
if module_name in _modules:
continue
try:
target_module = sys.modules[module_name]
if hasattr(target_module, "apply_rotary_pos_emb"):
if getattr(target_module, "apply_rotary_pos_emb") is not _apply_rotary_pos_emb:
setattr(target_module, "apply_rotary_pos_emb", _apply_rotary_pos_emb)
_modules.add(module_name)
if hasattr(target_module, "apply_multimodal_rotary_pos_emb"):
if (
getattr(target_module, "apply_multimodal_rotary_pos_emb")
is not _apply_multimodal_rotary_pos_emb_qwen25_vl
):
setattr(
target_module,
"apply_multimodal_rotary_pos_emb",
_apply_multimodal_rotary_pos_emb_qwen25_vl,
)
_modules.add(module_name)
except Exception as e:
logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}")
return model
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of kernel registry.
Init Phase:
1. Define kernel registry.
2. Register kernels.
"""
from typing import Optional
from ....accelerator.helper import get_current_accelerator
from .base import BaseKernel
__all__ = ["Registry", "register_kernel"]
class Registry:
r"""Registry for managing kernel implementations.
Storage structure: ``{ "kernel_id": Class }``
"""
_kernels: dict[str, type[BaseKernel]] = {}
@classmethod
def register(cls, kernel_cls: type[BaseKernel]):
r"""Decorator to register a kernel class.
The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.
Args:
kernel_cls (type[BaseKernel]): The kernel class to register.
Returns:
type[BaseKernel]: The registered kernel class.
Raises:
TypeError: If the class does not inherit from :class:`BaseKernel`.
ValueError: If the kernel ID is missing or already registered.
"""
if not issubclass(kernel_cls, BaseKernel):
raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel")
kernel_id = kernel_cls.get_kernel_id()
device = kernel_cls.get_device()
# The device type of the current accelerator does not match the device type required by the kernel, skip registration
if device != get_current_accelerator().type:
return
if not kernel_id:
raise ValueError(f"Kernel ID (_kernel_id) is needed for {kernel_cls} to register")
if kernel_id in cls._kernels:
raise ValueError(f"{kernel_id} already registered! The registered kernel is {cls._kernels[kernel_id]}")
cls._kernels[kernel_id] = kernel_cls
return kernel_cls
@classmethod
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
r"""Retrieves a registered kernel implementation by its ID.
Args:
kernel_id (str): The ID of the kernel to retrieve.
Returns:
Optional[type[BaseKernel]]: The kernel class if found, else ``None``.
"""
return cls._kernels.get(kernel_id)
@classmethod
def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]:
r"""Returns a dictionary of all registered kernels.
Returns:
dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.
"""
return cls._kernels
# export decorator alias
register_kernel = Registry.register
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal, TypedDict
from peft import LoraConfig, PeftModel, get_peft_model
from ...utils.plugin import BasePlugin
from ...utils.types import HFModel
class LoraConfigDict(TypedDict, total=False):
name: Literal["lora"]
"""Plugin name."""
r: int
"""Lora rank."""
lora_alpha: int
"""Lora alpha."""
target_modules: list[str]
"""Target modules."""
class FreezeConfigDict(TypedDict, total=False):
name: Literal["freeze"]
"""Plugin name."""
freeze_trainable_layers: int
"""Freeze trainable layers."""
freeze_trainable_modules: list[str] | None
"""Freeze trainable modules."""
class PeftPlugin(BasePlugin):
def __call__(self, model: HFModel, config: dict, is_train: bool) -> HFModel:
return super().__call__(model, config)
@PeftPlugin("lora").register
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
peft_config = LoraConfig(**config)
model = get_peft_model(model, peft_config)
return model
@PeftPlugin("freeze").register
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
raise NotImplementedError()
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from ...utils import logging
from ...utils.plugin import BasePlugin
from ...utils.types import Message, ModelInput, Processor
logger = logging.get_logger(__name__)
class RenderingPlugin(BasePlugin):
_attempted_template_imports: set[str] = set()
def _ensure_template_imported(self) -> None:
if self.name is None or self.name in self._attempted_template_imports:
return
full_module_name = f"{__package__}.templates.{self.name}"
self._attempted_template_imports.add(self.name)
try:
importlib.import_module(full_module_name)
except Exception as exc:
logger.warning(f"[Template Registry] Failed to import {full_module_name}: {exc}")
def __getitem__(self, method_name: str):
self._ensure_template_imported()
return super().__getitem__(method_name)
def render_messages(
self,
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput:
"""Render messages in the template format."""
return self["render_messages"](processor, messages, tools, is_generate, enable_thinking)
def parse_messages(self, generated_text: str) -> Message:
"""Parse messages in the template format."""
return self["parse_messages"](generated_text)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from ....utils.constants import IGNORE_INDEX
from ....utils.helper import get_tokenizer
from ....utils.types import Message, ModelInput, Processor, ToolCall
from ..rendering import RenderingPlugin
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
def _concat_text_content(message: Message) -> str:
"""Concatenate text fields in a message."""
message_text = ""
for content in message["content"]:
if content["type"] == "text":
message_text += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
return message_text
def _get_last_query_index(messages: list[Message]) -> int:
"""Find the last user query index, excluding wrapped tool responses."""
last_query_index = len(messages) - 1
for idx in range(len(messages) - 1, -1, -1):
message = messages[idx]
if message["role"] != "user":
continue
user_text = ""
is_plain_text = True
for content in message["content"]:
if content["type"] != "text":
is_plain_text = False
break
user_text += content["value"]
if not is_plain_text:
continue
if not (user_text.startswith("<tool_response>") and user_text.endswith("</tool_response>")):
last_query_index = idx
break
return last_query_index
def _split_assistant_content(message: Message) -> tuple[str, str, list[ToolCall]]:
"""Split assistant message into text, reasoning and tool calls."""
text_content = ""
reasoning_content = ""
tool_calls: list[ToolCall] = []
for content in message["content"]:
if content["type"] == "text":
text_content += content["value"]
elif content["type"] == "reasoning":
reasoning_content += content["value"]
elif content["type"] == "tool_call":
try:
tool_call: ToolCall = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
tool_calls.append(tool_call)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
return text_content, reasoning_content, tool_calls
@RenderingPlugin("qwen3").register("render_messages")
def render_qwen3_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput:
"""Render messages in the Qwen3 template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-8B
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
temp_str += _concat_text_content(messages[0]) + "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n" + _concat_text_content(messages[0]) + "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
last_query_index = _get_last_query_index(messages)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n" + _concat_text_content(message) + "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
text_content, reasoning_content, tool_calls = _split_assistant_content(message)
if turn_idx > last_query_index and (turn_idx == len(messages) - 1 or reasoning_content):
temp_str += "<think>\n" + reasoning_content.strip("\n") + "\n</think>\n\n" + text_content.lstrip("\n")
else:
temp_str += text_content
for tool_call_idx, tool_call in enumerate(tool_calls):
if (tool_call_idx == 0 and text_content) or tool_call_idx > 0:
temp_str += "\n"
arguments = tool_call.get("arguments")
if isinstance(arguments, str):
arguments_str = arguments
else:
arguments_str = json.dumps(arguments, ensure_ascii=False)
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ arguments_str
+ "}\n</tool_call>"
)
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n" + _concat_text_content(message) + "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
if enable_thinking is False:
temp_str += "<think>\n\n</think>\n\n"
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
@RenderingPlugin("qwen3").register("parse_message")
def parse_qwen3_message(generated_text: str) -> Message:
"""Parse a message in the Qwen3 template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 template format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(think|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "think":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from ....utils.constants import IGNORE_INDEX
from ....utils.helper import get_tokenizer
from ....utils.types import Message, ModelInput, Processor, ToolCall
from ..rendering import RenderingPlugin
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
def _concat_text_content(message: Message) -> str:
"""Concatenate text fields in a message."""
message_text = ""
for content in message["content"]:
if content["type"] == "text":
message_text += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
return message_text
@RenderingPlugin("qwen3_nothink").register("render_messages")
def render_qwen3_nothink_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput:
"""Render messages in the Qwen3 nothink template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
temp_str += _concat_text_content(messages[0]) + "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n" + _concat_text_content(messages[0]) + "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n" + _concat_text_content(message) + "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
for val_idx, content in enumerate(message["content"]):
if content["type"] == "text":
temp_str += content["value"]
elif content["type"] == "reasoning":
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
elif content["type"] == "tool_call":
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
temp_str += "\n"
try:
tool_call: ToolCall = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
+ "}\n</tool_call>"
)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n" + _concat_text_content(message) + "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
if enable_thinking:
raise ValueError("The qwen3_nothink template does not support thinking mode.")
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
@RenderingPlugin("qwen3_nothink").register("parse_message")
def parse_qwen3_nothink_message(generated_text: str) -> Message:
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 nothink template format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "thinking":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils.objects import StatefulBuffer
from ...utils.plugin import BasePlugin
from ...utils.types import BatchInfo, BatchInput, DataLoader
class BatchingPlugin(BasePlugin):
def compute_length(self, data_provider: DataLoader) -> int:
"""Compute the length of the batch generator.
The approximate length is used to calculate the lr schedule.
"""
raise NotImplementedError()
def fill_buffer(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> None:
"""Fill the buffer with data."""
raise NotImplementedError()
def generate_batch(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
"""Generate a batch from the buffer."""
raise NotImplementedError()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment