Commit 9b0e3a30 authored by cmx's avatar cmx
Browse files

first commit

parent fe5cd1fc
Pipeline #3450 failed with stages
in 0 seconds
"""
This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
See the original Unsloth repository at https://github.com/unslothai/unsloth.
The following line
https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23
is based on code from Unsloth, located at:
https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
Modifications made by Yanning Chen, 2024.
"""
import functools
import importlib
import operator
from typing import Callable
import torch
import triton
import triton.language as tl
from packaging.version import Version
from liger_kernel.utils import infer_device
def is_hip() -> bool:
return torch.version.hip is not None
def ensure_contiguous(fn):
@functools.wraps(fn)
def wrapper(ctx, *args, **kwargs):
def maybe_to_contiguous(x):
return x.contiguous() if isinstance(x, torch.Tensor) else x
args = [maybe_to_contiguous(arg) for arg in args]
kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()}
return fn(ctx, *args, **kwargs)
return wrapper
def calculate_settings(n):
# reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
MAX_FUSED_SIZE = 65536
BLOCK_SIZE = triton.next_power_of_2(n)
if BLOCK_SIZE > MAX_FUSED_SIZE:
raise RuntimeError(
f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}."
)
num_warps = 4
if BLOCK_SIZE >= 32768:
num_warps = 32 if not is_hip() else 16
elif BLOCK_SIZE >= 8192:
num_warps = 16
elif BLOCK_SIZE >= 2048:
num_warps = 8
return BLOCK_SIZE, num_warps
def compare_version(package: str, operator: Callable, target: str):
try:
pkg = importlib.import_module(package)
except ImportError:
return False
pkg_version = Version(pkg.__version__)
return operator(pkg_version, Version(target))
def get_amp_custom_fwd_bwd() -> Callable:
device = infer_device()
if compare_version("torch", operator.ge, "2.4.0"):
return (
functools.partial(torch.amp.custom_fwd, device_type=device),
functools.partial(torch.amp.custom_bwd, device_type=device),
)
if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None:
return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd
return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd()
torch_to_triton_dtype = {
torch.float32: tl.float32,
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
}
@triton.jit
def element_mul_kernel(
X_ptr,
X_stride,
grad_output_ptr,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""
This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
The multiplication is performed in-place on the tensor pointed by X_ptr.
Parameters:
X_ptr: Pointer to the input tensor.
X_stride (int): The stride of the input tensor.
grad_output_ptr: Pointer to the gradient output value.
n_cols (int): The number of columns in the input tensor.
BLOCK_SIZE (int): The block size for Triton operations.
"""
# Get the program ID and convert it to int64 to avoid overflow
program_id = tl.program_id(0).to(tl.int64)
# Locate the start index
X_ptr += program_id * X_stride
# Load the gradient output value
grad_output = tl.load(grad_output_ptr)
# Perform the element-wise multiplication
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
def get_npu_core_count(default: int = 20) -> int:
"""Return NPU vector core count.
Fallback to `default` if Triton runtime or NPU device is unavailable.
"""
try:
utils = triton.runtime.driver.active.utils
props = utils.get_device_properties(0)
return int(props.get("num_vectorcore", default))
except Exception:
return default
def set_large_grf_mode(kernel_args: dict):
"""Set large GRF mode for XPU devices."""
# On XPU triton installed along with pytorch-xpu will be called `pytorch-triton-xpu`,
# triton XPU installed from source will be called `triton`.
if compare_version("pytorch-triton-xpu", operator.ge, "3.6.0") or compare_version("triton", operator.ge, "3.6.0"):
kernel_args["grf_mode"] = "256"
else:
# API was changed in https://github.com/intel/intel-xpu-backend-for-triton/pull/5430
kernel_args["grf_mode"] = "large"
import importlib
from typing import TYPE_CHECKING
# Always-safe imports (independent of 'transformers')
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
from liger_kernel.transformers.dyt import LigerDyT # noqa: F401
from liger_kernel.transformers.fused_add_rms_norm import LigerFusedAddRMSNorm # noqa: F401
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
from liger_kernel.transformers.kl_div import LigerKLDIVLoss # noqa: F401
from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb # noqa: F401
from liger_kernel.transformers.llama4_rope import liger_llama4_vision_rotary_pos_emb # noqa: F401
from liger_kernel.transformers.mhc import LigerMHC # noqa: F401
from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention # noqa: F401
from liger_kernel.transformers.poly_norm import LigerPolyNorm # noqa: F401
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
from liger_kernel.transformers.softmax import LigerSoftmax # noqa: F401
from liger_kernel.transformers.sparsemax import LigerSparsemax # noqa: F401
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
from liger_kernel.transformers.swiglu import LigerExperts # noqa: F401
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP # noqa: F401
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
from liger_kernel.transformers.tiled_mlp import LigerTiledGEGLUMLP # noqa: F401
from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP # noqa: F401
from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
# Static-only imports for IDEs and type checkers
if TYPE_CHECKING:
from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_exaone4 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_falcon_h1 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gpt_oss # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_dense # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_moe # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_internvl # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo3 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_pixtral # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_5 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_5_moe # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl_moe # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smolvlm # noqa: F401
# Check if 'transformers' is installed
try:
import transformers # noqa: F401
_TRANSFORMERS_AVAILABLE = True
except ImportError:
_TRANSFORMERS_AVAILABLE = False
def is_transformers_available() -> bool:
"""
Returns True if the 'transformers' package is available.
Useful for conditional logic in downstream code.
"""
return _TRANSFORMERS_AVAILABLE
def __getattr__(name: str):
"""
Handles lazy access to transformer-dependent attributes.
If 'transformers' is not installed, raises a user-friendly ImportError.
"""
if not _TRANSFORMERS_AVAILABLE:
raise ImportError(
f"The attribute '{name}' requires the 'transformers' library, which is not installed.\n"
f"Please install it with `pip install transformers` to use this functionality."
)
if name == "AutoLigerKernelForCausalLM":
module = importlib.import_module("liger_kernel.transformers.auto_model")
return getattr(module, name)
monkey_patch_symbols = {
"_apply_liger_kernel",
"_apply_liger_kernel_to_instance",
"apply_liger_kernel_to_falcon_h1",
"apply_liger_kernel_to_gemma",
"apply_liger_kernel_to_gemma2",
"apply_liger_kernel_to_gemma3",
"apply_liger_kernel_to_gemma3_text",
"apply_liger_kernel_to_glm4",
"apply_liger_kernel_to_glm4v",
"apply_liger_kernel_to_glm4v_moe",
"apply_liger_kernel_to_gpt_oss",
"apply_liger_kernel_to_granite",
"apply_liger_kernel_to_internvl",
"apply_liger_kernel_to_llama",
"apply_liger_kernel_to_llava",
"apply_liger_kernel_to_llama4",
"apply_liger_kernel_to_mistral",
"apply_liger_kernel_to_mixtral",
"apply_liger_kernel_to_mllama",
"apply_liger_kernel_to_olmo2",
"apply_liger_kernel_to_olmo3",
"apply_liger_kernel_to_paligemma",
"apply_liger_kernel_to_phi3",
"apply_liger_kernel_to_pixtral",
"apply_liger_kernel_to_qwen2",
"apply_liger_kernel_to_qwen2_5_vl",
"apply_liger_kernel_to_qwen2_vl",
"apply_liger_kernel_to_qwen3",
"apply_liger_kernel_to_qwen3_moe",
"apply_liger_kernel_to_qwen3_5",
"apply_liger_kernel_to_qwen3_5_moe",
"apply_liger_kernel_to_qwen3_next",
"apply_liger_kernel_to_qwen3_vl",
"apply_liger_kernel_to_qwen3_vl_moe",
"apply_liger_kernel_to_smollm3",
"apply_liger_kernel_to_smolvlm",
"apply_liger_kernel_to_hunyuan_v1_dense",
"apply_liger_kernel_to_hunyuan_v1_moe",
"apply_liger_kernel_to_exaone4",
}
if name in monkey_patch_symbols:
module = importlib.import_module("liger_kernel.transformers.monkey_patch")
return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}")
# Shared symbols in all environments
__all__ = [
"is_transformers_available",
"LigerCrossEntropyLoss",
"LigerDyT",
"LigerFusedLinearCrossEntropyLoss",
"LigerFusedLinearJSD",
"LigerGEGLUMLP",
"LigerJSD",
"LigerLayerNorm",
"LigerFusedAddRMSNorm",
"LigerPolyNorm",
"LigerRMSNorm",
"liger_rotary_pos_emb",
"liger_llama4_text_rotary_pos_emb",
"liger_llama4_vision_rotary_pos_emb",
"LigerBlockSparseTop2MLP",
"LigerPhi3SwiGLUMLP",
"LigerQwen3MoeSwiGLUMLP",
"LigerSwiGLUMLP",
"LigerTiledGEGLUMLP",
"LigerTiledSwiGLUMLP",
"LigerTVDLoss",
"LigerKLDIVLoss",
"LigerMHC",
"LigerMultiTokenAttention",
"LigerSoftmax",
"LigerSparsemax",
]
# Add transformer-dependent symbols only if available
if _TRANSFORMERS_AVAILABLE:
__all__.extend(
[
"AutoLigerKernelForCausalLM",
"_apply_liger_kernel",
"_apply_liger_kernel_to_instance",
"apply_liger_kernel_to_falcon_h1",
"apply_liger_kernel_to_gemma",
"apply_liger_kernel_to_gemma2",
"apply_liger_kernel_to_gemma3",
"apply_liger_kernel_to_gemma3_text",
"apply_liger_kernel_to_glm4",
"apply_liger_kernel_to_glm4v",
"apply_liger_kernel_to_glm4v_moe",
"apply_liger_kernel_to_gpt_oss",
"apply_liger_kernel_to_granite",
"apply_liger_kernel_to_internvl",
"apply_liger_kernel_to_llama",
"apply_liger_kernel_to_llava",
"apply_liger_kernel_to_llama4",
"apply_liger_kernel_to_mistral",
"apply_liger_kernel_to_mixtral",
"apply_liger_kernel_to_mllama",
"apply_liger_kernel_to_olmo2",
"apply_liger_kernel_to_olmo3",
"apply_liger_kernel_to_paligemma",
"apply_liger_kernel_to_phi3",
"apply_liger_kernel_to_pixtral",
"apply_liger_kernel_to_qwen2",
"apply_liger_kernel_to_qwen2_5_vl",
"apply_liger_kernel_to_qwen2_vl",
"apply_liger_kernel_to_qwen3",
"apply_liger_kernel_to_qwen3_moe",
"apply_liger_kernel_to_qwen3_5",
"apply_liger_kernel_to_qwen3_5_moe",
"apply_liger_kernel_to_qwen3_next",
"apply_liger_kernel_to_qwen3_vl",
"apply_liger_kernel_to_qwen3_vl_moe",
"apply_liger_kernel_to_smollm3",
"apply_liger_kernel_to_smolvlm",
"apply_liger_kernel_to_hunyuan_v1_dense",
"apply_liger_kernel_to_hunyuan_v1_moe",
"apply_liger_kernel_to_exaone4",
]
)
import inspect
import logging
from transformers import AutoConfig
from transformers import AutoModelForCausalLM
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
logger = logging.getLogger(__name__)
def _get_model_config(model_dir, **model_init_kwargs):
config = AutoConfig.from_pretrained(model_dir, **model_init_kwargs)
return config
class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
"""
This class is a drop-in replacement for AutoModelForCausalLM that applies the Liger Kernel to the model
if applicable.
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
model_config = _get_model_config(pretrained_model_name_or_path, **kwargs)
# Determine the model type and apply the Liger Kernel if applicable
# Note: _apply_liger_kernel will only pass relevant kwargs to the apply_liger_kernel_to_* function
model_type = model_config.model_type
_apply_liger_kernel(model_type, **kwargs)
# Filter out kwargs that were passed to the apply_liger_* function, which will cause
# model initialization errors otherwise
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
apply_fn_signature = inspect.signature(apply_fn)
applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)
@classmethod
def from_config(cls, config, **kwargs):
model_type = getattr(config, "model_type", None)
if not model_type:
logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
return
model_type = config.model_type
_apply_liger_kernel(model_type, **kwargs)
# Filter out kwargs that were passed to the apply_liger_* function, which will cause
# model initialization errors otherwise
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
apply_fn_signature = inspect.signature(apply_fn)
applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
return super().from_config(config, **applicable_kwargs)
from typing import Optional
import torch
from liger_kernel.ops import LigerCrossEntropyFunction
from liger_kernel.transformers.functional import CrossEntropyOutput
class LigerCrossEntropyLoss(torch.nn.Module):
def __init__(
self,
weight: Optional[torch.FloatTensor] = None,
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
reduction: str = "mean",
softcap: Optional[float] = None,
return_z_loss: bool = False,
return_token_accuracy: bool = False,
return_predicted_tokens: bool = False,
):
super().__init__()
assert (label_smoothing >= 0) and (label_smoothing <= 1), (
f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
)
assert reduction in {
"mean",
"sum",
"none",
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
self.weight = weight
self.ignore_index = ignore_index
self.lse_square_scale = lse_square_scale
self.label_smoothing = label_smoothing
self.reduction = reduction
self.softcap = softcap
self.return_z_loss = return_z_loss
self.return_token_accuracy = return_token_accuracy
self.return_predicted_tokens = return_predicted_tokens
def forward(self, _input: torch.Tensor, target: torch.Tensor):
loss, z_loss, token_accuracy, predicted_tokens = LigerCrossEntropyFunction.apply(
_input,
target,
self.weight,
self.ignore_index,
self.lse_square_scale,
self.label_smoothing,
self.reduction,
self.softcap,
self.return_z_loss,
self.return_token_accuracy,
self.return_predicted_tokens,
)
if not self.return_z_loss and not self.return_token_accuracy and not self.return_predicted_tokens:
return loss
return CrossEntropyOutput(
loss=loss, z_loss=z_loss, token_accuracy=token_accuracy, predicted_tokens=predicted_tokens
)
import torch
import torch.nn as nn
from liger_kernel.ops import LigerDyTFunction
class LigerDyT(nn.Module):
def __init__(self, hidden_size, beta=True, init_alpha=0.5):
super().__init__()
self.hidden_size = hidden_size
self.init_alpha = init_alpha
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
self.gamma = nn.Parameter(torch.ones(hidden_size))
self.beta = None
if beta:
self.beta = nn.Parameter(torch.zeros(hidden_size))
def forward(self, x):
return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
def extra_repr(self):
return f"{self.hidden_size}, init_alpha={self.init_alpha}, beta={self.beta}"
from liger_kernel.transformers.experimental.embedding import LigerEmbedding # noqa: F401
__all__ = [
"LigerEmbedding",
]
from typing import Optional
import torch
import torch.nn as nn
from liger_kernel.ops import LigerEmbeddingFunction
class LigerEmbedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weight = nn.Parameter(torch.randn(num_embeddings, embedding_dim))
if padding_idx is not None:
with torch.no_grad():
self.weight[padding_idx].fill_(0)
def forward(self, indices):
embedded = LigerEmbeddingFunction.apply(self.weight, indices)
if self.padding_idx is not None:
embedded = embedded.clone()
embedded[indices == self.padding_idx] = 0
return embedded
from typing import Any
from typing import Callable
from torch.distributed.fsdp import FullyShardedDataParallel
class _FSDPForwardRedirection:
"""
Modified based on
https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648
Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and
post-forward can be properly executed around the method call.
This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of
the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
"""
def __call__(
self,
wrapper_module: FullyShardedDataParallel,
method: Callable,
*args: Any,
**kwargs: Any,
):
"""Reroutes a method call through the `wrapper_module`'s `forward` method.
Args:
wrapper_module: The module that has `original_module` wrapped.
original_module: The module that was wrapped inside `wrapper_module`.
method_name: The name of the method that should be called on the `original_module` after inputs get
redirected through the `wrapper_module`'s `forward` method.
*args: The positional arguments to the method `method_name`. They will get passed to a patched
`forward` method instead.
**kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
`forward` method instead.
"""
assert isinstance(wrapper_module, FullyShardedDataParallel)
original_module = wrapper_module._fsdp_wrapped_module
original_forward = original_module.forward
def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
# Unpatch ourselves immediately before calling the method `method_name`
# because itself may want to call the real `forward`
original_module.forward = original_forward # type: ignore[method-assign]
# Call the actual method e.g. `.training_step(...)`
out = method(*_args, **_kwargs)
return out
# Patch the original_module's forward so we can redirect the arguments back to the real method
original_module.forward = wrapped_forward # type: ignore[method-assign]
wrapper_output = wrapper_module(*args, **kwargs)
return wrapper_output
from dataclasses import dataclass
from typing import Optional
import torch
from liger_kernel.ops import LigerCrossEntropyFunction
from liger_kernel.ops import LigerDyTFunction
from liger_kernel.ops import LigerFusedAddRMSNormFunction
from liger_kernel.ops import LigerFusedLinearCrossEntropyFunction
from liger_kernel.ops import LigerFusedLinearJSDFunction
from liger_kernel.ops import LigerFusedNeighborhoodAttentionFunction
from liger_kernel.ops import LigerGELUMulFunction
from liger_kernel.ops import LigerGroupNormFunction
from liger_kernel.ops import LigerJSDFunction
from liger_kernel.ops import LigerKLDivLossFunction
from liger_kernel.ops import LigerLayerNormFunction
from liger_kernel.ops import LigerMHCCoeffsFunction
from liger_kernel.ops import LigerMHCPostResFunction
from liger_kernel.ops import LigerMHCPreFunction
from liger_kernel.ops import LigerMultiTokenAttentionFunction
from liger_kernel.ops import LigerPolyNormFunction
from liger_kernel.ops import LigerQwen2VLMRopeFunction
from liger_kernel.ops import LigerRMSNormFunction
from liger_kernel.ops import LigerRopeFunction
from liger_kernel.ops import LigerSiLUMulFunction
from liger_kernel.ops import LigerSoftmaxFunction
from liger_kernel.ops import LigerSparsemaxFunction
from liger_kernel.ops import LigerTVDLossFunction
@dataclass
class CrossEntropyOutput:
loss: torch.Tensor
z_loss: Optional[torch.Tensor] = None
token_accuracy: Optional[torch.Tensor] = None
predicted_tokens: Optional[torch.Tensor] = None
# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
# `weight` and `size_average` are placeholders and not implemented yet
def liger_cross_entropy(
input,
target,
weight=None,
size_average=None,
ignore_index: int = -100,
reduce=None,
reduction: str = "mean",
label_smoothing: float = 0.0,
lse_square_scale: float = 0.0,
softcap: Optional[float] = None,
return_z_loss: bool = False,
return_token_accuracy: bool = False,
return_predicted_tokens: bool = False,
):
loss, z_loss, token_accuracy, predicted_tokens = LigerCrossEntropyFunction.apply(
input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
reduction,
softcap,
return_z_loss,
return_token_accuracy,
return_predicted_tokens,
)
if not return_z_loss and not return_token_accuracy and not return_predicted_tokens:
return loss
return CrossEntropyOutput(
loss=loss, z_loss=z_loss, token_accuracy=token_accuracy, predicted_tokens=predicted_tokens
)
def liger_fused_linear_cross_entropy(
input,
weight,
target,
bias=None,
ce_weight=None,
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
reduction: str = "mean",
softcap: Optional[float] = None,
return_z_loss: bool = False,
accum_dtype=None,
use_token_scaling: bool = False,
return_token_accuracy: bool = False,
return_predicted_tokens: bool = False,
):
loss, z_loss, token_accuracy, predicted_tokens = LigerFusedLinearCrossEntropyFunction.apply(
input,
weight,
target,
bias,
ce_weight,
ignore_index,
lse_square_scale,
label_smoothing,
reduction,
softcap,
return_z_loss,
accum_dtype,
use_token_scaling,
return_token_accuracy,
return_predicted_tokens,
)
if not return_z_loss and not return_token_accuracy and not return_predicted_tokens:
return loss
return CrossEntropyOutput(
loss=loss, z_loss=z_loss, token_accuracy=token_accuracy, predicted_tokens=predicted_tokens
)
def liger_fused_linear_jsd(
student_input,
student_weight,
teacher_input,
teacher_weight,
shift_labels=None,
jsd_beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
):
return LigerFusedLinearJSDFunction.apply(
student_input,
student_weight,
teacher_input,
teacher_weight,
shift_labels,
jsd_beta,
ignore_index,
temperature,
)
def liger_geglu(a, b):
return LigerGELUMulFunction.apply(a, b)
def liger_group_norm(
X,
affine_scaling_weight,
affine_shifting_bias,
num_channels,
num_groups,
eps,
):
return LigerGroupNormFunction.apply(
X,
affine_scaling_weight,
affine_shifting_bias,
num_channels,
num_groups,
eps,
)
def liger_jsd(
input,
target,
shift_labels=None,
beta: float = 0.5,
ignore_index: int = -100,
):
return LigerJSDFunction.apply(
input,
target,
shift_labels,
beta,
ignore_index,
)
# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div
# `size_average` and `mean` are being deprecated in torch API and are placeholders here
def liger_kl_div(
input,
target,
size_average: bool = True,
reduce: bool = True,
reduction: str = "mean",
log_target: bool = False,
eps: float = 1e-10,
):
# Note: the default reduction in torch is `mean`, but being `batchmean` in Liger
return LigerKLDivLossFunction.apply(
input,
target,
reduction,
log_target,
eps,
)
def liger_sparsemax(
input,
dim: int = -1,
):
return LigerSparsemaxFunction.apply(input, dim)
def liger_multi_token_attention(
scores,
weight,
bias=None,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
sparse: bool = False,
):
"""
Functional interface for multi-token attention.
Args:
scores: Input tensor of shape (B, C_in, L, L)
weight: Convolution weight tensor of shape (C_out, C_in // groups, K, K)
bias: Optional bias tensor of shape (C_out,)
stride: Stride for the convolution (default: 1)
padding: Padding for the convolution (default: 0)
dilation: Dilation factor for the convolution (default: 1)
groups: Number of groups for the convolution (default: 1)
sparse: Specifies if input tensors are expected to be sparse (default: False)
Returns:
Output tensor after applying multi-token attention.
"""
return LigerMultiTokenAttentionFunction.apply(scores, weight, bias, stride, padding, dilation, groups, sparse)
def liger_fused_neighborhood_attention(
query,
key,
value,
kernel_size: int = 7,
dilation: int = 1,
scale: float = None,
):
"""
Liger fused neighborhood attention.
paper: https://arxiv.org/pdf/2504.16922
Args:
query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim]
key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim]
value: Value tensor of shape [batch_size, num_heads, seq_len, head_dim]
kernel_size: Size of the neighborhood window (default: 7)
dilation: Dilation factor for the neighborhood (default: 1)
scale: Scaling factor for attention scores (default: rsqrt(head_dim))
Returns:
Output tensor of shape [batch_size, num_heads, seq_len, head_dim]
"""
return LigerFusedNeighborhoodAttentionFunction.apply(query, key, value, kernel_size, dilation, scale)
def liger_tvd(
input,
target,
shift_labels=None,
reduction: str = "mean",
ignore_index: int = -100,
):
return LigerTVDLossFunction.apply(
input,
target,
shift_labels,
reduction,
ignore_index,
)
def liger_layer_norm(X, W, B, eps):
return LigerLayerNormFunction.apply(X, W, B, eps)
def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
def liger_poly_norm(X, W, B, eps=1e-6, in_place=True):
return LigerPolyNormFunction.apply(X, W, B, eps, in_place)
def liger_fused_add_rms_norm(X, R, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
return LigerFusedAddRMSNormFunction.apply(X, R, W, eps, offset, casting_mode, in_place)
def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
def liger_swiglu(a, b):
return LigerSiLUMulFunction.apply(a, b)
def liger_softmax(x):
return LigerSoftmaxFunction.apply(x)
def liger_dyt(x, alpha, gamma, beta):
return LigerDyTFunction.apply(x, alpha, gamma, beta)
def liger_mhc_coeffs(
x,
phi,
b,
alpha_pre,
alpha_post,
alpha_res,
*,
allow_fp32: bool = False,
tmax: int = 20,
rms_eps: float = 1e-6,
pre_eps: float = 0.0,
sinkhorn_eps: float = 1e-6,
post_mult: float = 2.0,
):
# Convert config scalars to Python types so they are not included in the
# autograd computation graph (they are not learnable parameters).
return LigerMHCCoeffsFunction.apply(
x,
phi,
b,
alpha_pre,
alpha_post,
alpha_res,
allow_fp32,
int(tmax),
float(rms_eps),
float(pre_eps),
float(sinkhorn_eps),
float(post_mult),
)
def liger_mhc_pre(x, h_pre):
return LigerMHCPreFunction.apply(x, h_pre)
def liger_mhc_post_res(x, f_out, h_post, h_res):
return LigerMHCPostResFunction.apply(x, f_out, h_post, h_res)
def liger_mhc_apply(x, f_out, h_pre, h_post, h_res, *, return_x_in: bool = False):
x_in = liger_mhc_pre(x, h_pre)
x_out = liger_mhc_post_res(x, f_out, h_post, h_res)
if return_x_in:
return x_out, x_in
return x_out
def liger_mhc_forward(
x,
layer,
phi,
b,
alpha_pre,
alpha_post,
alpha_res,
*,
allow_fp32=False,
tmax=20,
rms_eps=1e-6,
pre_eps=0.0,
sinkhorn_eps=1e-6,
post_mult=2.0,
return_coeffs=False,
):
"""High-level helper: compute coeffs, apply pre, run layer, then apply post+res."""
h_pre, h_post, h_res = liger_mhc_coeffs(
x,
phi,
b,
alpha_pre,
alpha_post,
alpha_res,
allow_fp32=allow_fp32,
tmax=tmax,
rms_eps=rms_eps,
pre_eps=pre_eps,
sinkhorn_eps=sinkhorn_eps,
post_mult=post_mult,
)
x_in = liger_mhc_pre(x, h_pre)
layer_dtype = x_in.dtype
if hasattr(layer, "parameters"):
try:
layer_dtype = next(layer.parameters()).dtype
except StopIteration:
layer_dtype = x_in.dtype
if x_in.dtype != layer_dtype:
x_in = x_in.to(layer_dtype)
f_out = layer(x_in)
x_out = liger_mhc_post_res(x, f_out, h_post, h_res)
if return_coeffs:
return x_out, (h_pre, h_post, h_res)
return x_out
import torch
import torch.nn as nn
from liger_kernel.ops import LigerFusedAddRMSNormFunction
class LigerFusedAddRMSNorm(nn.Module):
def __init__(
self,
hidden_size,
eps=1e-6,
offset=0.0,
casting_mode="llama",
init_fn="ones",
in_place=False,
):
super().__init__()
assert init_fn in [
"ones",
"zeros",
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (eps, offset, casting_mode, in_place)
def forward(self, hidden_states, residual):
return LigerFusedAddRMSNormFunction.apply(
hidden_states,
residual,
self.weight,
self.variance_epsilon,
self.offset,
self.casting_mode,
self.in_place,
)
def extra_repr(self):
return (
f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
)
from typing import Optional
import torch
from liger_kernel.ops import LigerFusedLinearCrossEntropyFunction
from liger_kernel.transformers.functional import CrossEntropyOutput
class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
def __init__(
self,
ce_weight: Optional[torch.FloatTensor] = None,
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
reduction: str = "mean",
softcap: Optional[float] = None,
return_z_loss: bool = False,
accum_dtype: Optional[torch.dtype] = None,
use_token_scaling: bool = False,
return_token_accuracy: bool = False,
return_predicted_tokens: bool = False,
):
super().__init__()
assert (label_smoothing >= 0) and (label_smoothing <= 1), (
f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
)
assert reduction in {
"mean",
"sum",
"none",
}, f"reduction must be 'mean' or 'sum' or 'none'. Got: {reduction}"
assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
self.ce_weight = ce_weight
self.ignore_index = ignore_index
self.lse_square_scale = lse_square_scale
self.label_smoothing = label_smoothing
self.reduction = reduction
self.softcap = softcap
self.return_z_loss = return_z_loss
self.accum_dtype = accum_dtype
self.use_token_scaling = use_token_scaling
self.return_token_accuracy = return_token_accuracy
self.return_predicted_tokens = return_predicted_tokens
def forward(self, lin_weight, _input, target, bias=None):
loss, z_loss, token_accuracy, predicted_tokens = LigerFusedLinearCrossEntropyFunction.apply(
_input,
lin_weight,
target,
bias,
self.ce_weight,
self.ignore_index,
self.lse_square_scale,
self.label_smoothing,
self.reduction,
self.softcap,
self.return_z_loss,
self.accum_dtype,
self.use_token_scaling,
self.return_token_accuracy,
self.return_predicted_tokens,
)
if not self.return_z_loss and not self.return_token_accuracy and not self.return_predicted_tokens:
return loss
return CrossEntropyOutput(
loss=loss, z_loss=z_loss, token_accuracy=token_accuracy, predicted_tokens=predicted_tokens
)
from typing import Optional
import torch
from liger_kernel.ops import LigerFusedLinearJSDFunction
class LigerFusedLinearJSD(torch.nn.Module):
r"""Fusing the last linear layer with generalized JSD
Handle the forward and backward pass of the final linear layer via JSD by avoiding
the materialization of the large logits tensor.
Args:
jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
ignore_index (int): The index to ignore in the target. Default: `-100`
temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
Shape:
- student_input: :math:`(BT, H)`, where B is batch size, T is sequence length, H is hidden dimension.
- student_weight: :math:`(V, H)`, where V is vocab size.
- teacher_input: :math:`(BT, H')`, where H' is hidden dimension of the teacher model.
- teacher_weight: :math:`(V, H')`, where hidden size H and H' can be different.
- shift_labels: :math:`(BT,)`
- Output: a scalar.
Examples:
```python
>>> (B, T, H_s, H_t, V) = (2, 2, 3, 5, 10)
>>> fused_jsd = LigerFusedLinearJSD(jsd_beta=0.1, temperature=2.0)
>>> # generate inputs and weights
>>> student_input = torch.rand(B * T, H_s, device="cuda", requires_grad=True)
>>> student_lin = torch.nn.Linear(H_s, V, bias=False, device="cuda")
>>> # teacher input doesn't require grad, hidden_dim can be different from student's
>>> teacher_input = torch.rand(B * T, H_t, device="cuda")
>>> teacher_lin = torch.nn.Linear(H_t, V, bias=False, device="cuda")
>>> output = fused_jsd(student_input, student_lin.weight, teacher_input, teacher_lin.weight)
>>> output.backward()
>>>
>>> # Example with labels for supervised fine-tuning (SFT) context:
>>>
>>> # Assume hidden_states, lm_heads and corresponding labels are given
>>> student_lm_head = torch.nn.Linear(H_s, V, bias=False)
>>> student_hidden_states = torch.randn(B * T, H_s, requires_grad=True).log_softmax(dim=-1)
>>> teacher_lm_head = torch.nn.Linear(H_t, V, bias=False)
>>> teacher_hidden_states = torch.randn(B * T, H_t).log_softmax(dim=-1)
>>> labels = torch.randint(0, V, (B * T,), torch.long)
>>>
>>> # Shift so that tokens < n predict n
>>> shift_student_hidden_states = student_hidden_states[..., :-1, :].contiguous()
>>> shift_teacher_hidden_states = teacher_hidden_states[..., :-1, :].contiguous()
>>> shift_labels = labels[..., 1:].contiguous()
>>>
>>> # Flatten tokens
>>> shift_student_hidden_states = shift_student_hidden_states.view(-1, V)
>>> shift_teacher_hidden_states = shift_teacher_hidden_states.view(-1, V)
>>> shift_labels = shift_labels.view(-1)
>>>
>>> # Calculate loss
>>> loss_fct = LigerJSD(beta=0.1)
>>> loss = loss_fct(
>>> shift_studetn_hidden_states,
>>> student_lm_head.weight,
>>> shift_teacher_hidden_states,
>>> teacher_lm_head.weight,
>>> shift_labels
>>> )
```
"""
def __init__(self, jsd_beta=0.5, ignore_index=-100, temperature=1.0):
super().__init__()
assert temperature != 0, "temperature cannot be 0."
self.jsd_beta = jsd_beta
self.temperature = temperature
self.ignore_index = ignore_index
def forward(
self,
student_input: torch.Tensor,
student_weight: torch.Tensor,
teacher_input: torch.Tensor,
teacher_weight: torch.Tensor,
shift_labels: Optional[torch.LongTensor],
):
return LigerFusedLinearJSDFunction.apply(
student_input,
student_weight,
teacher_input,
teacher_weight,
shift_labels,
self.jsd_beta,
self.ignore_index,
self.temperature,
)
import math
from typing import Optional
import torch
import torch.nn as nn
from liger_kernel.ops import LigerFusedNeighborhoodAttentionFunction
class LigerFusedNeighborhoodAttention(nn.Module):
"""
Liger Fused Neighborhood Attention Module.
Paper: https://arxiv.org/pdf/2504.16922
Fused Neighborhood attention restricts the attention mechanism to a local neighborhood
around each position, reducing computational complexity from O(n²) to O(n*k)
where k is the neighborhood size.
Args:
hidden_size (int): The hidden dimension size
num_heads (int): Number of attention heads
kernel_size (int): Size of the neighborhood window (default: 7)
dilation (int): Dilation factor for the neighborhood (default: 1)
bias (bool): Whether to use bias in linear projections (default: True)
dropout (float): Dropout probability (default: 0.0)
scale (Optional[float]): Scaling factor for attention scores.
If None, uses 1/sqrt(head_dim) (default: None)
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
kernel_size: int = 7,
dilation: int = 1,
bias: bool = True,
dropout: float = 0.0,
scale: Optional[float] = None,
):
super().__init__()
if hidden_size % num_heads != 0:
raise ValueError(f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})")
if kernel_size <= 0:
raise ValueError(f"kernel_size ({kernel_size}) must be positive")
if kernel_size % 2 == 0:
raise ValueError(f"kernel_size ({kernel_size}) must be odd")
if dilation < 1:
raise ValueError(f"dilation ({dilation}) must be positive")
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.kernel_size = kernel_size
self.dilation = dilation
self.scale = scale if scale is not None else 1.0 / math.sqrt(self.head_dim)
self.dropout_p = dropout
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
if dropout > 0.0:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass of the fused neighborhood attention module.
Args:
hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
attention_mask (Optional[torch.Tensor]): Attention mask (currently not supported)
Returns:
torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size]
"""
if attention_mask is not None:
raise NotImplementedError("Attention mask is not yet supported in LigerFusedNeighborhoodAttention")
batch_size, seq_len, hidden_size = hidden_states.shape
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
attn_output = LigerFusedNeighborhoodAttentionFunction.apply(
query, key, value, self.kernel_size, self.dilation, self.scale
)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size)
if self.dropout is not None:
attn_output = self.dropout(attn_output)
output = self.out_proj(attn_output)
return output
def extra_repr(self) -> str:
return (
f"hidden_size={self.hidden_size}, num_heads={self.num_heads}, "
f"head_dim={self.head_dim}, kernel_size={self.kernel_size}, "
f"dilation={self.dilation}, scale={self.scale}, dropout={self.dropout_p}"
)
class LigerFusedNeighborhoodAttentionLayer(nn.Module):
"""
A complete neighborhood attention layer with layer norm and residual connection.
Args:
hidden_size (int): The hidden dimension size
num_heads (int): Number of attention heads
kernel_size (int): Size of the neighborhood window (default: 7)
dilation (int): Dilation factor for the neighborhood (default: 1)
bias (bool): Whether to use bias in linear projections (default: True)
dropout (float): Dropout probability (default: 0.0)
layer_norm_eps (float): Epsilon for layer normalization (default: 1e-5)
scale (Optional[float]): Scaling factor for attention scores (default: None)
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
kernel_size: int = 7,
dilation: int = 1,
bias: bool = True,
dropout: float = 0.0,
layer_norm_eps: float = 1e-5,
scale: Optional[float] = None,
):
super().__init__()
self.attention = LigerFusedNeighborhoodAttention(
hidden_size=hidden_size,
num_heads=num_heads,
kernel_size=kernel_size,
dilation=dilation,
bias=bias,
dropout=dropout,
scale=scale,
)
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
if dropout > 0.0:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass with residual connection and layer normalization.
Args:
hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
attention_mask (Optional[torch.Tensor]): Attention mask (currently not supported)
Returns:
torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size]
"""
normed_hidden_states = self.layer_norm(hidden_states)
attn_output = self.attention(normed_hidden_states, attention_mask)
if self.dropout is not None:
attn_output = self.dropout(attn_output)
output = hidden_states + attn_output
return output
class LigerFusedNeighborhoodAttentionConfig:
"""
Configuration class for Fused Neighborhood Attention.
This can be used to easily configure neighborhood attention parameters
for different model architectures.
"""
def __init__(
self,
hidden_size: int = 768,
num_heads: int = 12,
kernel_size: int = 7,
dilation: int = 1,
bias: bool = True,
dropout: float = 0.0,
layer_norm_eps: float = 1e-5,
scale: Optional[float] = None,
):
self.hidden_size = hidden_size
self.num_heads = num_heads
self.kernel_size = kernel_size
self.dilation = dilation
self.bias = bias
self.dropout = dropout
self.layer_norm_eps = layer_norm_eps
self.scale = scale
def to_dict(self):
return {
"hidden_size": self.hidden_size,
"num_heads": self.num_heads,
"kernel_size": self.kernel_size,
"dilation": self.dilation,
"bias": self.bias,
"dropout": self.dropout,
"layer_norm_eps": self.layer_norm_eps,
"scale": self.scale,
}
import torch.nn as nn
from liger_kernel.ops import LigerGELUMulFunction
class LigerGEGLUMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
# TODO: support exact GELU
# Right now Gemma 1, 1.1 and 2 models are all using `gelu_pytorch_tanh`
# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/gemma/modeling_gemma.py#L175
# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/activations.py#L46
# So we can safely assume we use tanh approximation form all the time
def forward(self, x):
return self.down_proj(LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
import torch
import torch.nn as nn
from liger_kernel.ops import LigerGroupNormFunction
class LigerGroupNorm(nn.Module):
def __init__(self, num_channels, num_groups, eps=1e-6, bias=False, init_fn="ones"):
"""
A Group Normalization layer.
Args:
num_channels (int): Number of channels in the input tensor.
num_groups (int): Number of groups to divide the channels into.
eps (float, optional): A value added to the denominator for numerical stability. Default: 1e-6.
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False``.
init_fn (str, optional): Initialization function for the learnable parameters. Default: "ones".
"""
super().__init__()
assert init_fn in [
"ones",
"zeros",
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
assert num_channels % num_groups == 0, (
f"Number of channels {num_channels} must be divisible by num_groups {num_groups}"
)
self.num_channels = num_channels
self.num_groups = num_groups
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels))
self.bias = nn.Parameter(torch.randn(num_channels) if bias else torch.zeros(num_channels))
self.variance_epsilon = eps
def forward(self, hidden_states):
# hidden_states: (batch_size, num_channels, *)
assert hidden_states.dim() >= 3, f"Input must have atleast 3 dimensions, got {hidden_states.dim()}"
assert hidden_states.size(1) == self.num_channels, (
f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}"
)
return LigerGroupNormFunction.apply(
hidden_states,
self.weight,
self.bias,
self.num_channels,
self.num_groups,
self.variance_epsilon,
)
def extra_repr(self):
return f"{self.hidden_size}, num_channels={self.num_channels}, num_groups={self.num_groups}, eps={self.eps}"
import torch
from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
from liger_kernel.ops import GrpoLossFunction
def triton_grpo_loss(
logits,
old_logp,
ref_logp,
completion_ids,
advantages,
completion_mask=None,
temperature=0.9,
beta=0.04,
eps_low=0.2,
eps_high=0.4,
inplace=True,
loss_type="dapo",
max_completion_length=None,
importance_sampling_level="token",
reduce=False,
sapo_temperature_pos=1.0,
sapo_temperature_neg=1.05,
vllm_is_ratio=None,
delta=None,
use_bias_correction_kl=False,
):
"""
Triton-optimized GRPO loss function.
Args:
logits: Model logits (B, L+1, V)
old_logp: Old policy log probabilities (B, L) or None
ref_logp: Reference model log probabilities (B, L) or None (required if beta != 0)
completion_ids: Token IDs for completions (B, L)
advantages: Per-sequence advantages (B,)
completion_mask: Mask for valid tokens (B, L) or None
temperature: Temperature for log softmax
beta: KL penalty coefficient
eps_low: Lower clipping bound for importance ratio
eps_high: Upper clipping bound for importance ratio
inplace: Whether to modify logits in-place during backward
loss_type: Loss reduction type ("grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo")
max_completion_length: Max completion length for dr_grpo loss type; defaults to sequence length if None
importance_sampling_level: "token" or "sequence" importance sampling
reduce: If True, return reduced loss; if False, return per-token loss
vllm_is_ratio: vLLM importance sampling ratio (B, L) or (B, 1) or None.
Used to correct for distribution mismatch when using vLLM for generation.
Applied to PPO loss BEFORE adding KL penalty.
delta: Upper clamp for two-sided clipping (INTELLECT-2). When set, coef_1 is clamped
to max=delta before computing the PPO loss. Only supported for standard PPO loss
types (grpo, bnpo, dr_grpo, dapo, luspo). None means disabled.
use_bias_correction_kl: If True, multiply KL divergence by coef_1 (importance sampling
ratio) for bias-corrected KL estimation (DeepSeek-V3.2). Default False.
Returns:
If reduce=True: (loss, metrics) where metrics = [kl_mean, clip_ratio] or [clip_ratio]
If reduce=False: (per_token_loss, per_token_kl, is_clipped)
"""
assert logits is not None and completion_ids is not None and advantages is not None, (
"must provide logits, completion_ids and advantages"
)
assert importance_sampling_level in ("token", "sequence"), (
f"importance_sampling_level must be 'token' or 'sequence', got {importance_sampling_level}"
)
result = GrpoLossFunction.apply(
logits,
old_logp,
ref_logp,
completion_ids,
advantages,
completion_mask,
temperature,
beta,
eps_low,
eps_high,
inplace,
loss_type,
max_completion_length,
reduce,
importance_sampling_level,
sapo_temperature_pos,
sapo_temperature_neg,
vllm_is_ratio,
delta,
use_bias_correction_kl,
)
if not reduce:
# Returns (per_token_loss, per_token_kl, is_clipped) - all (B, L) tensors
return result
# reduce=True: Returns (reduced_loss, kl_mean, clip_ratio) - all scalars
reduced_loss, kl_mean, clip_ratio = result
metrics = []
if beta != 0.0 and kl_mean is not None:
metrics.append(kl_mean)
metrics.append(clip_ratio)
return reduced_loss, metrics
def _reduce_grpo_loss(per_token_loss, completion_mask, loss_type, max_completion_length):
mask = completion_mask
if mask is None:
mask = torch.ones_like(per_token_loss, dtype=per_token_loss.dtype, device=per_token_loss.device)
mask = mask.to(per_token_loss.dtype)
if loss_type == "grpo" or loss_type == "sapo":
# SAPO uses the same normalization as GRPO (per-sequence average)
per_seq = (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
return per_seq.mean()
if loss_type == "bnpo":
return (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0)
if loss_type == "dr_grpo":
batch = per_token_loss.shape[0]
max_len = max_completion_length if max_completion_length is not None else per_token_loss.shape[1]
return (per_token_loss * mask).sum() / (batch * max_len)
if loss_type == "dapo" or loss_type == "cispo":
# CISPO uses the same normalization as DAPO
normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(mask)
return (per_token_loss * mask).sum() / normalizer
if loss_type == "luspo":
# LUSPO: scale each sequence's loss by its valid token count, then average across sequences
return (per_token_loss * mask.sum(-1, keepdim=True)).mean()
raise ValueError(f"Unsupported loss_type '{loss_type}' for Triton GRPO loss.")
def _masked_mean(values, mask):
if mask is None:
mask = torch.ones_like(values, dtype=values.dtype, device=values.device)
mask = mask.to(values.dtype)
return (values * mask).sum() / mask.sum().clamp(min=1.0)
# This is a demo how to use grpo_loss in GRPOTrainer. The Trl version must be 0.26.2+
"""
import torch
import trl
from packaging.version import Version
assert Version(trl.__version__) >= Version("0.26.2"), "please pip install trl>=0.26.2"
from trl.extras.profiling import profiling_decorator
@profiling_decorator
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
return fused_selective_log_softmax(logits, input_ids, self.temperature, mask=attention_mask)
@profiling_decorator
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
# Compute the per-token log probabilities for the model
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
ref_per_token_logps = inputs["ref_per_token_logps"]
advantages = inputs["advantages"]
old_per_token_logps = inputs["old_per_token_logps"]
# Get vLLM importance sampling ratio if using vLLM with importance sampling correction
vllm_is_ratio = inputs.get("importance_sampling_ratio", None)
per_token_loss, per_token_kl, is_clipped = triton_grpo_loss(
logits,
old_per_token_logps,
ref_per_token_logps,
completion_ids,
advantages,
completion_mask,
temperature=self.temperature,
beta=self.beta,
eps_low=self.epsilon_low,
eps_high=self.epsilon_high,
importance_sampling_level=self.importance_sampling_level, # "token" or "sequence"
vllm_is_ratio=vllm_is_ratio, # vLLM distribution correction
)
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
# Log the metrics
mode = "eval" if self.control.should_evaluate else "train"
if self.beta != 0.0:
mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
return loss
trl.GRPOTrainer._get_per_token_logps = _get_per_token_logps
trl.GRPOTrainer.compute_loss = compute_loss
trigger = None
"""
# add this line at the first line of grpo.py in open-r1
"""
from liger_kernel.transformers.grpo_loss import trigger
"""
from typing import Optional
import torch
from liger_kernel.ops import LigerJSDFunction
class LigerJSD(torch.nn.Module):
r"""The generalized Jensen-Shannon Divergence.
.. math::
JSD(\beta)(P || Q)
= \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
.. note::
As all the other losses in PyTorch, this function expects the first argument,
:attr:`log_q`, to be the predictions, the output of the student model in log-space,
and the second, :attr:`log_p`, to be the observations, the output of the teacher model in log-space.
This differs from the standard mathematical notation :math:`JSD(P || Q)` where
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
Args:
beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
ignore_index (int): The index to ignore in the target. Default: `-100`
Shape:
- Input: :math:`(BT, V)`, where B is batch size, T is sequence length, V is vocab size.
- Target: :math:`(BT, V)`, same shape as the input.
- shift_labels (Optional): :math:`(BT,)`
- Output: a scalar.
Examples:
```python
>>> (B, T, V) = (2, 2, 5)
>>> jsd = LigerJSD(beta=0.1)
>>> # input should be a distribution in the log space
>>> input = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1)
>>> target = torch.randn(B * T, V).log_softmax(dim=-1)
>>> output = jsd(input, target)
>>>
>>> # Example with labels for supervised fine-tuning (SFT) context
>>> # Assume logits and corresponding labels are given
>>> student_logits = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1)
>>> teacher_logits = torch.randn(B * T, V).log_softmax(dim=-1)
>>> labels = torch.randint(0, V, (B * T,), torch.long)
>>> # Shift so that tokens < n predict n
>>> shift_student_logits = student_logits[..., :-1, :].contiguous()
>>> shift_teacher_logits = teacher_logits[..., :-1, :].contiguous()
>>> shift_labels = labels[..., 1:].contiguous()
>>> # Flatten tokens
>>> shift_student_logits = shift_student_logits.view(-1, V)
>>> shift_teacher_logits = shift_teacher_logits.view(-1, V)
>>> shift_labels = shift_labels.view(-1)
>>> # Calculate loss
>>> loss_fct = LigerJSD(beta=0.1)
>>> loss = loss_fct(shift_studetn_logits, shift_teacher_logits, shift_labels)
```
"""
def __init__(self, beta: float = 0.5, ignore_index: int = -100):
super().__init__()
self.beta = beta
self.ignore_index = ignore_index
def forward(
self,
log_q: torch.Tensor,
log_p: torch.Tensor,
shift_labels: Optional[torch.LongTensor] = None,
):
return LigerJSDFunction.apply(log_q, log_p, shift_labels, self.beta, self.ignore_index)
import torch.nn as nn
from liger_kernel.ops import LigerKLDivLossFunction
class LigerKLDIVLoss(nn.KLDivLoss):
def __init__(self, eps: float = 1e-10, *args, **kwargs):
super(LigerKLDIVLoss, self).__init__(*args, **kwargs)
self.eps = eps
def forward(self, y_pred, y_true):
return LigerKLDivLossFunction.apply(y_pred, y_true, self.reduction, self.log_target, self.eps)
import torch
import torch.nn as nn
from liger_kernel.ops import LigerLayerNormFunction
class LigerLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6, bias=False, init_fn="ones"):
super().__init__()
assert init_fn in [
"ones",
"zeros",
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
self.hidden_size = hidden_size
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
self.bias = nn.Parameter(torch.randn(hidden_size) if bias else torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
return LigerLayerNormFunction.apply(hidden_states, self.weight, self.bias, self.variance_epsilon)
def extra_repr(self):
return f"{self.hidden_size}, eps={self.eps}"
"""
Liger Kernel implementation of Llama4 Rotary Position Embedding (RoPE).
Supports both text and vision RoPE variants with fused operations for optimal performance.
"""
import torch
from liger_kernel.ops import LigerLlama4RopeFunction
def liger_llama4_text_rotary_pos_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Liger-optimized implementation of Llama4 text rotary position embedding.
This implementation uses a fused Triton kernel for complex multiplication,
providing significant performance improvements over the original PyTorch implementation.
Args:
xq (torch.Tensor): Query tensor of shape (batch_size, seq_len, num_heads, head_dim)
xk (torch.Tensor): Key tensor of shape (batch_size, seq_len, num_heads, head_dim)
freqs_cis (torch.Tensor): Complex frequency tensor from Llama4TextRotaryEmbedding
Returns:
Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors
"""
# Use fused Triton kernel for complex RoPE
return LigerLlama4RopeFunction.apply(xq, xk, freqs_cis)
def liger_llama4_vision_rotary_pos_emb(
query: torch.Tensor,
key: torch.Tensor,
freqs_ci: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Liger-optimized implementation of Llama4 vision rotary position embedding.
This implementation uses the same fused Triton kernel as text RoPE,
providing performance improvements for vision transformer attention.
Args:
query (torch.Tensor): Query tensor of shape (batch_size, seq_len, num_heads, head_dim)
key (torch.Tensor): Key tensor of shape (batch_size, seq_len, num_heads, head_dim)
freqs_ci (torch.Tensor): Complex frequency tensor for 2D positions
Returns:
Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors
"""
# Handle broadcasting for vision RoPE
if freqs_ci.dim() == 3:
try:
# Try the regular 3D expansion
freqs_ci = freqs_ci.unsqueeze(0).expand(query.shape[0], -1, -1)
except RuntimeError as e:
if "expand" in str(e) and "4" in str(e):
# The tensor is actually 4D internally, handle it differently
freqs_ci = freqs_ci.squeeze(1) # Remove the middle dimension
freqs_ci = freqs_ci.unsqueeze(0).expand(query.shape[0], -1, -1)
else:
raise e
elif freqs_ci.dim() == 4: # (1, seq_len, 1, head_dim//2) - already properly shaped
# Squeeze the middle dimension to get (1, seq_len, head_dim//2)
freqs_ci = freqs_ci.squeeze(2)
elif freqs_ci.dim() == 2: # (seq_len, head_dim//2) - needs expansion
freqs_ci = freqs_ci.unsqueeze(0).expand(query.shape[0], -1, -1)
else:
raise ValueError(f"Unexpected freqs_ci shape: {freqs_ci.shape}")
# Use the same fused kernel as text RoPE
return LigerLlama4RopeFunction.apply(query, key, freqs_ci)
# Note: We only patch the functions, not the classes
# The original Llama4TextRotaryEmbedding and Llama4VisionRotaryEmbedding classes remain unchanged
# Convenience functions for monkey patching
def apply_liger_llama4_rope_full(modeling_module):
"""
Apply Liger optimizations to Llama4 RoPE functions.
Args:
modeling_module: The transformers modeling module to patch
"""
# Replace the text RoPE function
modeling_module.apply_rotary_emb = liger_llama4_text_rotary_pos_emb
# Replace the vision RoPE function
modeling_module.vision_apply_rotary_emb = liger_llama4_vision_rotary_pos_emb
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment