Commit 9b0e3a30 authored by cmx's avatar cmx
Browse files

first commit

parent fe5cd1fc
Pipeline #3450 failed with stages
in 0 seconds
from typing import Optional
import torch
import triton
import triton.language as tl
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
@triton.jit
def _jsd_kernel(
X_ptr, # input in logspace, X = log Q
X_stride,
Y_ptr, # ground truth in logspace, Y = log P
Y_stride,
loss_ptr,
loss_stride,
dX_ptr,
dX_stride,
label_ptr,
beta: tl.constexpr,
n_non_ignore: int,
ignore_index: tl.constexpr,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HAS_LABEL: tl.constexpr,
):
# JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X)
# = sum(P * log P + Q * log Q - 2 * M * log M) / 2
# = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2
# grad_x_i = 0.5 * Q * (X - log_M)
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
# Grid-Stride Loop - each kernel processes multiple rows
for row_idx in range(pid, n_rows, num_progs):
X_row_ptr = X_ptr + row_idx * X_stride
Y_row_ptr = Y_ptr + row_idx * Y_stride
loss_row_ptr = loss_ptr + row_idx * loss_stride
dX_row_ptr = dX_ptr + row_idx * dX_stride
should_skip = False
if HAS_LABEL:
label = tl.load(label_ptr + row_idx)
should_skip = label == ignore_index
if should_skip:
for i in range(0, n_cols, BLOCK_SIZE):
offsets = i + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_cols
tl.store(dX_row_ptr + offsets, 0.0, mask=mask)
tl.store(loss_row_ptr + offsets, 0.0, mask=mask)
else:
for i in range(0, n_cols, BLOCK_SIZE):
offsets = i + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_cols
X = tl.load(X_row_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
Y = tl.load(Y_row_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
if beta == 0.0: # forward KL
Y_max = tl.max(Y, axis=0)
Y_shifted = Y - Y_max
Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
loss = Y_prob * (Y - X)
dX = -Y_prob
elif beta == 1.0: # reverse KL
X_max = tl.max(X, axis=0)
X_shifted = X - X_max
X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
loss = X_prob * (X - Y)
dX = loss + X_prob
else:
max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
X_shifted = X - max_val
Y_shifted = Y - max_val
# Pre-compute exp(max_val) since it's used twice
exp_max = tl.exp(max_val)
# Compute exp terms with compensation
Q = tl.exp(X_shifted) * exp_max # = exp(X)
P = tl.exp(Y_shifted) * exp_max # = exp(Y)
# Pre-compute common terms
beta_P = beta * P
one_minus_beta_Q = (1 - beta) * Q
M = beta_P + one_minus_beta_Q
log_M = tl.log(M)
loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
dX = one_minus_beta_Q * (X - log_M)
# Pre-compute scaling factor
scale = 1.0 / n_non_ignore
loss = loss * scale
dX = dX * scale
tl.store(loss_row_ptr + offsets, loss, mask=mask)
tl.store(dX_row_ptr + offsets, dX, mask=mask)
def get_optimal_block_size(total_elements):
"""
Calculate optimal Block Size using compute_default_tiling_strategy
"""
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9, dtype_size=4, memory_multiplier=8.0, shapes=((total_elements,),), tiling_dims=(0,)
)
if tile_shapes and len(tile_shapes) > 0:
block_size = tile_shapes[0][0]
return block_size
else:
return 2048
def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = get_optimal_block_size(V)
# non reduction loss
loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
dX = torch.empty_like(_input)
if has_label:
n_non_ignore = (shift_labels != ignore_index).sum().item()
else:
n_non_ignore = BT
# Use NPU core count for grid size
num_cores = get_npu_core_count()
grid_size = min(num_cores, n_rows)
_jsd_kernel[(grid_size,)](
X_ptr=_input,
X_stride=_input.stride(-2),
Y_ptr=target,
Y_stride=target.stride(-2),
loss_ptr=loss,
loss_stride=loss.stride(-2),
dX_ptr=dX,
dX_stride=dX.stride(-2),
label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)),
beta=beta,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
n_rows=n_rows,
n_cols=V,
BLOCK_SIZE=BLOCK_SIZE,
HAS_LABEL=has_label,
)
loss = torch.sum(loss)
return loss.to(_input.dtype), dX
def jsd_backward(dX, grad_output):
# If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
return dX
else:
return grad_output * dX
class LigerJSDFunction(torch.autograd.Function):
r"""
This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
.. math::
JSD(\beta)(P || Q)
= \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
.. note::
As all the other losses in PyTorch, this function expects the first argument,
:attr:`_input`, to be the predictions, the output of the student model, in log-space
and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
This differs from the standard mathematical notation :math:`JSD(P || Q)` where
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
"""
@staticmethod
@ensure_contiguous
def forward(
ctx,
_input: torch.Tensor,
target: torch.Tensor,
shift_labels: Optional[torch.Tensor] = None,
beta: float = 0.5,
ignore_index: int = -100,
) -> torch.Tensor:
"""
Args:
_input (torch.Tensor): predict values with shape (BT, V) in logspace
target (torch.Tensor): ground truth values with shape (BT, V) in logspace
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
ignore_index (int): the index to ignore. Default: -100
Returns:
loss (torch.Tensor): generalized JSD
"""
has_label = False
if shift_labels is not None:
assert shift_labels.shape == (_input.shape[0],), (
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
)
shift_labels = shift_labels.contiguous()
has_label = True
loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
ctx.save_for_backward(dX)
return loss
@staticmethod
@ensure_contiguous
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
(dX,) = ctx.saved_tensors
dX = jsd_backward(dX, grad_output)
return (
dX,
None,
None,
None,
None,
)
from typing import Literal
import torch
import triton
import triton.language as tl
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
_REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0)
_REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1)
_REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2)
_REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3)
_str_to_reduction_mode = {
"none": _REDUCTION_MODE_NONE.value,
"sum": _REDUCTION_MODE_SUM.value,
"mean": _REDUCTION_MODE_MEAN.value,
"batchmean": _REDUCTION_MODE_BATCHMEAN.value,
}
# -----------------------------------------------------------------------------
# Kernels (2D Tiling + Persistent Programs)
# -----------------------------------------------------------------------------
@triton.jit
def _kldiv_kernel_forward(
y_ptr, # [B, S], prediction ptr, the kernel expects the prediction in log-space
gt_ptr, # [B, S], ground truth ptr
loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr
n_rows, # int, number of rows in the input tensor
n_cols, # int, number of columns in the input tensor
eps,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
log_target: tl.constexpr = False,
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
):
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
grid_m = tl.cdiv(n_rows, BLOCK_SIZE_M)
grid_n = tl.cdiv(n_cols, BLOCK_SIZE_N)
total_2d_blocks = grid_m * grid_n
# Persistent-program loop over logical 2D blocks.
for block_idx in tl.range(pid, total_2d_blocks, num_progs):
block_m = block_idx // grid_n
block_n = block_idx % grid_n
offset_m = tl.arange(0, BLOCK_SIZE_M) + block_m * BLOCK_SIZE_M
offset_n = tl.arange(0, BLOCK_SIZE_N) + block_n * BLOCK_SIZE_N
mask_m = offset_m < n_rows
mask_n = offset_n < n_cols
offset = offset_m[:, None] * n_cols + offset_n[None, :]
mask = mask_m[:, None] & mask_n[None, :]
y = tl.load(y_ptr + offset, mask=mask, other=0.0)
y_true = tl.load(gt_ptr + offset, mask=mask, other=0.0)
# KL(y_true || y_pred) with y_pred provided in log-space.
# - log_target=False: y_true is probability space; clamp with eps before log.
# - log_target=True : y_true is log-probability space.
if log_target:
loss = tl.exp(y_true) * (y_true - y)
else:
loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y)
if reduction == _REDUCTION_MODE_NONE:
tl.store(loss_ptr + offset, loss, mask=mask)
else:
# Multiple block_n tiles may update the same row, so atomic_add is required.
loss_sum = tl.sum(loss, axis=1)
tl.atomic_add(loss_ptr + offset_m, loss_sum, mask=mask_m)
@triton.jit
def _kldiv_kernel_backward(
target_ptr,
new_grads_ptr,
grad_output_ptr,
n_rows,
n_cols,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
log_target: tl.constexpr = False,
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
):
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
grid_m = tl.cdiv(n_rows, BLOCK_SIZE_M)
grid_n = tl.cdiv(n_cols, BLOCK_SIZE_N)
total_2d_blocks = grid_m * grid_n
# For reduced losses, grad_output is a scalar. Load it once per program.
if reduction != _REDUCTION_MODE_NONE:
grad_output_scalar = tl.load(grad_output_ptr)
# Persistent-program loop over logical 2D blocks.
for block_idx in tl.range(pid, total_2d_blocks, num_progs):
block_m = block_idx // grid_n
block_n = block_idx % grid_n
offset_m = tl.arange(0, BLOCK_SIZE_M) + block_m * BLOCK_SIZE_M
offset_n = tl.arange(0, BLOCK_SIZE_N) + block_n * BLOCK_SIZE_N
mask_m = offset_m < n_rows
mask_n = offset_n < n_cols
offset = offset_m[:, None] * n_cols + offset_n[None, :]
mask = mask_m[:, None] & mask_n[None, :]
y_true = tl.load(target_ptr + offset, mask=mask, other=0.0)
if log_target:
res = -tl.exp(y_true)
else:
res = y_true * -1
if reduction != _REDUCTION_MODE_NONE:
res = res * grad_output_scalar
else:
grad_output = tl.load(grad_output_ptr + offset, mask=mask, other=0.0)
res = res * grad_output
if reduction == _REDUCTION_MODE_BATCHMEAN:
res = res / n_rows
elif reduction == _REDUCTION_MODE_MEAN:
res = res / (n_rows * n_cols)
tl.store(new_grads_ptr + offset, res, mask=mask)
# -----------------------------------------------------------------------------
# Helper: Call compute_default_tiling_strategy
# -----------------------------------------------------------------------------
def get_optimal_block_size(
n_rows,
dtype_size,
BLOCK_SIZE_N: tl.constexpr,
log_target: bool = False,
is_backward: bool = False,
is_scalar_grad_output: bool = True,
):
"""
Calculate optimal BLOCK_SIZE_M using compute_default_tiling_strategy.
"""
# 1) Set memory multiplier
# Backward is lighter than forward in this op, so we typically use a smaller multiplier.
# If backward also needs to stream a full grad_output tile (i.e., grad_output is not a scalar),
# its memory footprint becomes closer to forward, so we bump the multiplier.
if is_backward:
multiplier = 2.5 if is_scalar_grad_output else 3.0
else:
multiplier = 3.0 if log_target else 6.0
# For bf16/fp16 (dtype_size < 4), compile-time UB overflow was observed on some shapes.
# Clamp to fp32 size for a conservative tiling estimate; this can be refined later.
dtype_size = max(dtype_size, 4)
# 2) Call tiling strategy (tile only dim 0 / rows)
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9,
dtype_size=dtype_size,
memory_multiplier=multiplier,
shapes=((n_rows, BLOCK_SIZE_N),),
tiling_dims=(0,),
)
# 3) Parse result
if tile_shapes and len(tile_shapes) > 0:
block_size = tile_shapes[0][0]
return block_size
else:
return triton.next_power_of_2(min(128, n_rows))
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
BT, V = y_pred.shape
reduction = _str_to_reduction_mode[reduction]
out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32)
BLOCK_SIZE_N = triton.next_power_of_2(min(128, V))
BLOCK_SIZE_M = get_optimal_block_size(BT, y_pred.element_size(), BLOCK_SIZE_N, log_target=log_target)
num_cores = get_npu_core_count()
total_blocks = triton.cdiv(BT, BLOCK_SIZE_M) * triton.cdiv(V, BLOCK_SIZE_N)
grid = min(num_cores, total_blocks)
_kldiv_kernel_forward[(grid,)](
y_pred,
y_true,
output_tensor,
BT,
V,
eps=eps,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
log_target=log_target,
reduction=reduction,
)
# Final reduction follows PyTorch KLDivLoss semantics.
# Note: In newer PyTorch versions, `mean` is planned to match `batchmean`.
# See: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
return output_tensor.sum() / BT
elif reduction == _REDUCTION_MODE_SUM.value:
return output_tensor.sum(dim=0)
elif reduction == _REDUCTION_MODE_MEAN.value:
return output_tensor.sum() / (BT * V)
else:
return output_tensor
def kldiv_backward_triton(target, grad_output, new_grads, log_target, reduction):
BT, V = target.shape
reduction = _str_to_reduction_mode[reduction]
BLOCK_SIZE_N = triton.next_power_of_2(min(128, V))
# grad_output handling:
# - numel() == 1: use scalar grad_output path in kernel.
# - numel() != 1: stream per-element grad_output tile in kernel.
is_scalar_grad_output = grad_output.numel() == 1
BLOCK_SIZE_M = get_optimal_block_size(
BT,
target.element_size(),
BLOCK_SIZE_N,
log_target=log_target,
is_backward=True,
is_scalar_grad_output=is_scalar_grad_output,
)
num_cores = get_npu_core_count()
total_blocks = triton.cdiv(BT, BLOCK_SIZE_M) * triton.cdiv(V, BLOCK_SIZE_N)
grid = min(num_cores, total_blocks)
_kldiv_kernel_backward[(grid,)](
target,
new_grads,
grad_output,
BT,
V,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
log_target=log_target,
reduction=reduction,
)
return new_grads
class LigerKLDivLossFunction(torch.autograd.Function):
"""
Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
```python
if log_target:
loss = target.exp() * (target - input)
else:
loss = target * (target.log() - input)
```,
then the loss is reduced according to the `reduction` parameter.
as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
"""
@staticmethod
@ensure_contiguous
def forward(
ctx,
y_pred: torch.Tensor,
y_true: torch.Tensor,
reduction: REDUCTION_LITERAL = "batchmean",
log_target: bool = False,
eps: float = 1e-10,
) -> torch.Tensor:
"""A forward pass for the KL Divergence Loss.
Args:
ctx: Torch autograd context
y_pred (torch.Tensor): A tensor of shape (BT, V) containing the predicted values, expected to be log-probabilities.
y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`.
reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean".
log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False.
eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10.
Returns:
torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar.
"""
ctx.save_for_backward(y_true)
ctx.reduction = reduction
ctx.log_target = log_target
return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
@staticmethod
@ensure_contiguous
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
"""A backward pass for the KL Divergence Loss.
Args:
ctx: Torch autograd context
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
Returns:
tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
"""
(y_true,) = ctx.saved_tensors
new_grads = torch.empty_like(y_true)
derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target, ctx.reduction)
return (
derivative,
None,
None,
None,
None,
)
import torch
import triton
import triton.language as tl
from triton.language.math import rsqrt
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
# -----------------------------------------------------------------------------
# Optimized Forward Kernel - No Tiling (for n_cols <= 2048)
# -----------------------------------------------------------------------------
@triton.jit
def _layer_norm_forward_kernel_no_tiling(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
W_ptr,
B_ptr,
Mean_ptr,
Mean_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
eps: tl.constexpr,
n_cols_inv: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
"""
OPTIMIZED NPU layer_norm forward kernel for small n_cols (<= 2048).
Key optimizations:
1. Pre-compute n_cols_inv to avoid repeated scalar division
2. Hoist W and B loads outside the loop (already done)
3. Minimize per-iteration scalar operations
4. Use vectorized operations for mask handling
5. Optimize cache hints for memory access patterns
6. Reduce type conversions by keeping intermediate results in float32
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
# Pre-compute grid stride constants (done once, not per iteration)
grid_stride = num_progs * BLOCK_SIZE_M
num_iterations = tl.cdiv(n_rows, grid_stride)
col_offsets = tl.arange(0, BLOCK_SIZE_N)
col_mask = col_offsets < n_cols
row_offsets = tl.arange(0, BLOCK_SIZE_M)
# Load W and B once (already optimized - kept outside loop)
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0).to(tl.float32)
B_row = tl.load(B_ptr + col_offsets, mask=col_mask, other=0.0).to(tl.float32)
base_row_idx = pid * BLOCK_SIZE_M
# Grid-stride loop over row blocks
for i in range(num_iterations):
row_idx = i * grid_stride + base_row_idx + row_offsets
row_mask = row_idx < n_rows
block_mask = row_mask[:, None] & col_mask[None, :]
X_block_ptr = X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :]
X_rows = tl.load(
X_block_ptr,
mask=block_mask,
other=0.0,
cache_modifier=".cg",
).to(tl.float32)
# Compute mean with vectorized operations
row_sum = tl.sum(X_rows, axis=1)
mean_rows = row_sum * n_cols_inv # Multiplication is faster than division
# Center the data (vectorized operation)
X_centered = X_rows - mean_rows[:, None]
X_centered_masked = tl.where(block_mask, X_centered, 0.0)
var_rows = tl.sum(X_centered_masked * X_centered_masked, axis=1) * n_cols_inv
rstd_rows = rsqrt(var_rows + eps)
Mean_ptr_offset = Mean_ptr + row_idx * Mean_row_stride
RSTD_ptr_offset = RSTD_ptr + row_idx * RSTD_row_stride
tl.store(Mean_ptr_offset, mean_rows, mask=row_mask)
tl.store(RSTD_ptr_offset, rstd_rows, mask=row_mask)
Y_f32 = X_centered * rstd_rows[:, None] * W_row[None, :] + B_row[None, :]
# Store output with coalesced memory access
Y_block_ptr = Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :]
tl.store(Y_block_ptr, Y_f32, mask=block_mask)
# -----------------------------------------------------------------------------
# Forward Kernel - With Tiling (for n_cols > 2048)
# -----------------------------------------------------------------------------
@triton.jit
def _layer_norm_forward_kernel_npu(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
W_ptr,
B_ptr,
Mean_ptr,
Mean_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""NPU-optimized layer_norm forward kernel with column blocking."""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
num_col_blocks = tl.cdiv(n_cols, BLOCK_SIZE)
offsets = tl.arange(0, BLOCK_SIZE)
n_cols_inv = 1.0 / n_cols
for row_idx in range(pid, n_rows, num_progs):
Y_row_ptr = Y_ptr + row_idx * Y_row_stride
X_row_ptr = X_ptr + row_idx * X_row_stride
Mean_row_ptr = Mean_ptr + row_idx * Mean_row_stride
RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride
row_sum = 0.0
for col_block_idx in range(num_col_blocks):
col_start = col_block_idx * BLOCK_SIZE
col_offsets = col_start + offsets
mask = col_offsets < n_cols
X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32)
row_sum += tl.sum(X_block)
mean = row_sum * n_cols_inv
var_sum = 0.0
for col_block_idx in range(num_col_blocks):
col_start = col_block_idx * BLOCK_SIZE
col_offsets = col_start + offsets
mask = col_offsets < n_cols
X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32)
X_centered = X_block - mean
var_sum += tl.sum(tl.where(mask, X_centered * X_centered, 0.0))
var = var_sum * n_cols_inv
rstd = rsqrt(var + eps)
tl.store(Mean_row_ptr, mean)
tl.store(RSTD_row_ptr, rstd)
for col_block_idx in range(num_col_blocks):
col_start = col_block_idx * BLOCK_SIZE
col_offsets = col_start + offsets
mask = col_offsets < n_cols
X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".ca").to(tl.float32)
W_block = tl.load(W_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
B_block = tl.load(B_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
X_centered = X_block - mean
Y_f32 = X_centered * rstd * W_block + B_block
tl.store(Y_row_ptr + col_offsets, Y_f32.to(X_block.dtype), mask=mask)
# -----------------------------------------------------------------------------
# Optimized Backward Kernel - No Tiling (for n_cols <= 2048)
# -----------------------------------------------------------------------------
@triton.jit
def _layer_norm_backward_kernel_no_tiling(
X_ptr,
X_row_stride,
W_ptr,
Mean_ptr,
Mean_row_stride,
RSTD_ptr,
RSTD_row_stride,
DX_ptr,
DX_row_stride,
DW_scratch_ptr,
DW_scratch_stride,
DB_scratch_ptr,
DB_scratch_stride,
DY_ptr,
DY_row_stride,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
n_cols_inv: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
"""
OPTIMIZED NPU layer_norm backward kernel for small n_cols (<= 2048).
Key optimizations:
1. Pre-compute n_cols_inv to avoid repeated division
2. Minimize scalar operations in the hot path
3. Reduce redundant mask computations
4. Optimize memory access patterns with better cache hints
5. Keep intermediate results in float32 to reduce conversions
6. Use vectorized operations throughout
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
grid_stride = num_progs * BLOCK_SIZE_M
num_iterations = tl.cdiv(n_rows, grid_stride)
col_offsets = tl.arange(0, BLOCK_SIZE_N)
col_mask = col_offsets < n_cols
row_offsets = tl.arange(0, BLOCK_SIZE_M)
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0).to(tl.float32)
# Per-program accumulators for dW/dB
dW_acc = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
dB_acc = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
base_row_idx = pid * BLOCK_SIZE_M
# Grid-stride loop over row blocks
for i in range(num_iterations):
row_idx = i * grid_stride + base_row_idx + row_offsets
row_mask = row_idx < n_rows
# Pre-compute block mask once
block_mask = row_mask[:, None] & col_mask[None, :]
X_block_ptr = X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :]
DY_block_ptr = DY_ptr + row_idx[:, None] * DY_row_stride + col_offsets[None, :]
Mean_row_ptr = Mean_ptr + row_idx * Mean_row_stride
RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride
# Load all required data with appropriate cache hints
# .cg = cache global (read once, don't pollute cache)
X_rows = tl.load(X_block_ptr, mask=block_mask, other=0.0, cache_modifier=".cg").to(tl.float32)
DY_rows = tl.load(DY_block_ptr, mask=block_mask, other=0.0, cache_modifier=".cg").to(tl.float32)
mean_rows = tl.load(Mean_row_ptr, mask=row_mask, other=0.0).to(tl.float32)
rstd_rows = tl.load(RSTD_row_ptr, mask=row_mask, other=0.0).to(tl.float32)
x_hat = (X_rows - mean_rows[:, None]) * rstd_rows[:, None]
wdy = W_row[None, :] * DY_rows
x_hat_wdy_masked = tl.where(block_mask, x_hat * wdy, 0.0)
wdy_masked = tl.where(block_mask, wdy, 0.0)
c1 = tl.sum(x_hat_wdy_masked, axis=1) * n_cols_inv
c2 = tl.sum(wdy_masked, axis=1) * n_cols_inv
DX_f32 = (wdy - (x_hat * c1[:, None] + c2[:, None])) * rstd_rows[:, None]
# Store dX with coalesced memory access
DX_block_ptr = DX_ptr + row_idx[:, None] * DX_row_stride + col_offsets[None, :]
tl.store(DX_block_ptr, DX_f32.to(X_ptr.dtype.element_ty), mask=block_mask)
dW_acc += tl.sum(tl.where(block_mask, DY_rows * x_hat, 0.0), axis=0)
dB_acc += tl.sum(tl.where(block_mask, DY_rows, 0.0), axis=0)
# Write accumulated gradients to scratch buffers
DW_scratch_offset = DW_scratch_ptr + pid * DW_scratch_stride + col_offsets
DB_scratch_offset = DB_scratch_ptr + pid * DB_scratch_stride + col_offsets
tl.store(DW_scratch_offset, dW_acc, mask=col_mask)
tl.store(DB_scratch_offset, dB_acc, mask=col_mask)
# -----------------------------------------------------------------------------
# Backward Kernel - With Tiling (for n_cols > 2048)
# -----------------------------------------------------------------------------
@triton.jit
def _layer_norm_backward_kernel_npu(
X_ptr,
X_row_stride,
W_ptr,
Mean_ptr,
Mean_row_stride,
RSTD_ptr,
RSTD_row_stride,
DX_ptr,
DX_row_stride,
DW_ptr,
DB_ptr,
DY_ptr,
DY_row_stride,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""NPU-optimized layer_norm backward kernel with column blocking."""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
num_col_blocks = tl.cdiv(n_cols, BLOCK_SIZE)
offsets = tl.arange(0, BLOCK_SIZE)
n_cols_inv = 1.0 / n_cols
for row_idx in range(pid, n_rows, num_progs):
X_row_ptr = X_ptr + row_idx * X_row_stride
DY_row_ptr = DY_ptr + row_idx * DY_row_stride
DX_row_ptr = DX_ptr + row_idx * DX_row_stride
Mean_row_ptr = Mean_ptr + row_idx * Mean_row_stride
RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride
mean = tl.load(Mean_row_ptr).to(tl.float32)
rstd = tl.load(RSTD_row_ptr).to(tl.float32)
sum_x_hat_wdy = 0.0
sum_wdy = 0.0
for col_block_idx in range(num_col_blocks):
col_start = col_block_idx * BLOCK_SIZE
col_offsets = col_start + offsets
mask = col_offsets < n_cols
X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
DY_block = tl.load(DY_row_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
W_block = tl.load(W_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
x_hat = (X_block - mean) * rstd
wdy = W_block * DY_block
sum_x_hat_wdy += tl.sum(tl.where(mask, x_hat * wdy, 0.0))
sum_wdy += tl.sum(tl.where(mask, wdy, 0.0))
c1 = sum_x_hat_wdy * n_cols_inv
c2 = sum_wdy * n_cols_inv
for col_block_idx in range(num_col_blocks):
col_start = col_block_idx * BLOCK_SIZE
col_offsets = col_start + offsets
mask = col_offsets < n_cols
X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
DY_block = tl.load(DY_row_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
W_block = tl.load(W_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
x_hat = (X_block - mean) * rstd
wdy = W_block * DY_block
DX_block = (wdy - (x_hat * c1 + c2)) * rstd
tl.store(DX_row_ptr + col_offsets, DX_block.to(X_ptr.dtype.element_ty), mask=mask)
dW_block = DY_block * x_hat
dB_block = DY_block
tl.atomic_add(DW_ptr + col_offsets, dW_block, mask=mask)
tl.atomic_add(DB_ptr + col_offsets, dB_block, mask=mask)
# -----------------------------------------------------------------------------
# Helper Functions
# -----------------------------------------------------------------------------
def get_optimal_block_size(n_cols, is_forward: bool):
"""
Calculate optimal block size using compute_default_tiling_strategy.
Memory analysis for forward pass (per row):
- Load: X_block, W_block, B_block (3 blocks)
- Store: Y_block, Mean, RSTD (3 blocks)
- Compute: X_centered, Y intermediate (2 blocks)
- Total: conservative estimate 10 blocks of memory
Memory analysis for backward pass (per row):
- Load: X_block, DY_block, W_block, Mean, RSTD, existing_DW, existing_DB (7 blocks)
- Store: DX_block, new_DW, new_DB (3 blocks)
- Compute: x_hat, wdy, DX intermediate, dW_block, dB_block (5 blocks)
- Total: conservative estimate 15 blocks of memory
Args:
n_cols: Number of columns in the tensor
is_forward: Whether this is for forward pass (True) or backward pass (False)
Returns:
Optimal block size
"""
if n_cols <= 2048:
return triton.next_power_of_2(n_cols)
memory_multiplier = 10.0 if is_forward else 15.0
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9,
dtype_size=4,
memory_multiplier=memory_multiplier,
shapes=((n_cols,),),
tiling_dims=(0,),
)
if tile_shapes and len(tile_shapes) > 0:
block_size = tile_shapes[0][0]
return max(2048, block_size)
else:
return 2048
def _compute_grid_size(n_rows: int, block_size_m: int, num_cores: int) -> int:
"""
Compute the effective grid size for no-tiling kernels.
OPTIMIZATION: Balances parallelism with overhead
- Ensures enough work per program to amortize launch costs
- Avoids launching idle programs
- Caps at 2x core count for hardware concurrency
"""
num_row_blocks = triton.cdiv(n_rows, block_size_m)
return min(num_cores * 2, num_row_blocks)
# -----------------------------------------------------------------------------
# Forward and Backward Functions
# -----------------------------------------------------------------------------
def layer_norm_forward(X, W, B, eps):
"""
NPU-optimized forward pass for LayerNorm.
Args:
X: Input tensor of shape (..., hidden_size)
W: Weight tensor of shape (hidden_size,)
B: Bias tensor of shape (hidden_size,)
eps: Small constant for numerical stability
Returns:
Tuple of (output, input, mean, rstd)
"""
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
n_rows, n_cols = X.shape
if X.shape[1] != W.shape[0]:
raise ValueError(
f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
f"must match weight size (W.shape[0]={W.shape[0]})"
)
# Get optimal block sizes
BLOCK_SIZE = get_optimal_block_size(n_cols, True)
BLOCK_SIZE_M = 2048 // BLOCK_SIZE
# Allocate output tensors
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
num_cores = get_npu_core_count()
# Choose kernel
if n_cols <= 2048:
grid_size = _compute_grid_size(n_rows, BLOCK_SIZE_M, num_cores)
n_cols_inv = 1.0 / float(n_cols)
_layer_norm_forward_kernel_no_tiling[(grid_size,)](
Y,
Y.stride(0),
X,
X.stride(0),
W,
B,
Mean,
Mean.stride(0),
RSTD,
RSTD.stride(0),
n_rows,
n_cols,
eps,
n_cols_inv,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE,
)
else:
grid_size = min(num_cores, n_rows)
_layer_norm_forward_kernel_npu[(grid_size,)](
Y,
Y.stride(0),
X,
X.stride(0),
W,
B,
Mean,
Mean.stride(0),
RSTD,
RSTD.stride(0),
n_rows,
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
)
return Y.view(*shape), X, Mean, RSTD
def layer_norm_backward(dY, X, W, B, Mean, RSTD):
"""
NPU-optimized backward pass for LayerNorm.
Args:
dY: Gradient of output
X: Input tensor
W: Weight tensor
B: Bias tensor
Mean: Pre-computed mean
RSTD: Pre-computed reciprocal standard deviation
Returns:
Tuple of (input_grad, weight_grad, bias_grad)
"""
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
n_rows, n_cols = dY.shape
# Get optimal block sizes
BLOCK_SIZE = get_optimal_block_size(n_cols, False)
BLOCK_SIZE_M = 2048 // BLOCK_SIZE
num_cores = get_npu_core_count()
# Allocate gradient tensors
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
# Choose kernel
if n_cols <= 2048:
grid_size = _compute_grid_size(n_rows, BLOCK_SIZE_M, num_cores)
DW_scratch = torch.empty((grid_size, n_cols), dtype=torch.float32, device=W.device)
DB_scratch = torch.empty((grid_size, n_cols), dtype=torch.float32, device=W.device)
n_cols_inv = 1.0 / float(n_cols)
_layer_norm_backward_kernel_no_tiling[(grid_size,)](
X,
X.stride(0),
W,
Mean,
Mean.stride(0),
RSTD,
RSTD.stride(0),
DX,
DX.stride(0),
DW_scratch,
DW_scratch.stride(0),
DB_scratch,
DB_scratch.stride(0),
dY,
dY.stride(0),
n_rows,
n_cols,
n_cols_inv,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE,
)
DW = DW_scratch.sum(dim=0)
DB = DB_scratch.sum(dim=0)
else:
grid_size = min(num_cores, n_rows)
DW = torch.zeros(n_cols, dtype=torch.float32, device=W.device)
DB = torch.zeros(n_cols, dtype=torch.float32, device=W.device)
_layer_norm_backward_kernel_npu[(grid_size,)](
X,
X.stride(0),
W,
Mean,
Mean.stride(0),
RSTD,
RSTD.stride(0),
DX,
DX.stride(0),
DW,
DB,
dY,
dY.stride(0),
n_rows,
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
)
return DX.view(*shape), DW.to(W.dtype), DB.to(B.dtype)
# -----------------------------------------------------------------------------
# Autograd Function
# -----------------------------------------------------------------------------
class LigerLayerNormFunction(torch.autograd.Function):
"""
OPTIMIZED NPU LayerNorm operation.
Key optimizations for no-tiling kernels:
1. Pre-compute 1/n_cols to avoid scalar division (40.6% → <30% target)
2. Minimize per-iteration scalar operations in grid-stride loops
3. Hoist constant computations outside loops
4. Use vectorized operations throughout
5. Optimize memory access patterns with better cache hints
6. Reduce type conversions by keeping intermediates in float32
7. Improve grid sizing for better work distribution
"""
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, B, eps):
Y, X, Mean, RSTD = layer_norm_forward(X, W, B, eps)
ctx.save_for_backward(X, W, B, Mean, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, W, B, Mean, RSTD = ctx.saved_tensors
DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD)
return DX, DW, DB, None
import torch
import triton
import triton.language as tl
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
def _cast_and_contiguous(q, k, freqs_complex):
# Align dtype: fp32 only when q is fp32; otherwise keep q dtype for perf
compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype
if k.dtype != q.dtype:
k = k.to(q.dtype)
q = q.to(compute_dtype).contiguous()
k = k.to(compute_dtype).contiguous()
freqs_complex = freqs_complex.contiguous()
return q, k, freqs_complex, compute_dtype
@triton.jit
def _triton_llama4_rope_npu(
q_ptr,
k_ptr,
freqs_complex_ptr,
q_row_stride,
k_row_stride,
q_head_stride,
k_head_stride,
freqs_row_stride,
sl,
bs: tl.constexpr,
n_qh: tl.constexpr,
n_kh: tl.constexpr,
hd: tl.constexpr,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
imag_sign: tl.constexpr,
):
"""
Llama4 RoPE on Ascend NPU for interleaved complex layout:
- q/k shape: (bs, sl, n_heads, hd)
- freqs_complex_ptr: (sl, hd//2, 2)
"""
pid = tl.program_id(0).to(tl.int64)
batch_idx = pid // sl
seq_idx = pid % sl
if batch_idx >= bs:
return
q_base = q_ptr + pid * q_row_stride
k_base = k_ptr + pid * k_row_stride
freq_base = seq_idx * freqs_row_stride
hd_idx = tl.arange(0, hd)
hd_mask = hd_idx < (hd)
freq_idx = tl.arange(0, hd)
freq_mask = freq_idx < (hd)
freqs_complex = tl.load(freqs_complex_ptr + freq_base + freq_idx, mask=freq_mask, other=0.0)
freqs_complex = freqs_complex.reshape(hd // 2, 2, can_reorder=True)
freqs_real, freqs_imag = tl.split(freqs_complex)
freqs_imag = freqs_imag * imag_sign
# Q heads (chunked for UB)
for qh_block in range(0, n_qh, BLOCK_Q):
qh_idx = tl.arange(0, BLOCK_Q) + qh_block
qh_mask = qh_idx < n_qh
block_mask = qh_mask[:, None] & hd_mask[None, :]
head_ptr = q_base + qh_idx[:, None] * q_head_stride
q_pair = tl.load(
head_ptr + hd_idx[None, :],
mask=block_mask,
other=0.0,
)
q_pair = q_pair.reshape(BLOCK_Q, hd // 2, 2, can_reorder=True)
q_real, q_imag = tl.split(q_pair)
new_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag))
new_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real)
pair_idx = tl.arange(0, hd // 2)
real_idx = pair_idx * 2
imag_idx = pair_idx * 2 + 1
pair_mask = pair_idx < (hd // 2)
real_mask = qh_mask[:, None] & pair_mask[None, :]
imag_mask = qh_mask[:, None] & pair_mask[None, :]
# store real
tl.store(
head_ptr + real_idx[None, :],
new_real,
mask=real_mask,
)
# store imag
tl.store(
head_ptr + imag_idx[None, :],
new_imag,
mask=imag_mask,
)
# K heads (chunked for UB)
for kh_block in range(0, n_kh, BLOCK_K):
kh_idx = tl.arange(0, BLOCK_K) + kh_block
kh_mask = kh_idx < n_kh
block_mask = kh_mask[:, None] & hd_mask[None, :]
head_ptr = k_base + kh_idx[:, None] * k_head_stride
k_pair = tl.load(
head_ptr + hd_idx[None, :],
mask=block_mask,
other=0.0,
)
k_pair = k_pair.reshape(BLOCK_K, hd // 2, 2, can_reorder=True)
k_real, k_imag = tl.split(k_pair)
new_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag))
new_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real)
pair_idx = tl.arange(0, hd // 2)
real_idx = pair_idx * 2
imag_idx = pair_idx * 2 + 1
pair_mask = pair_idx < (hd // 2)
real_mask = kh_mask[:, None] & pair_mask[None, :]
imag_mask = kh_mask[:, None] & pair_mask[None, :]
# store real
tl.store(
head_ptr + real_idx[None, :],
new_real,
mask=real_mask,
)
# store imag
tl.store(
head_ptr + imag_idx[None, :],
new_imag,
mask=imag_mask,
)
def llama4_rope_forward(q, k, freqs_cis):
"""
Ascend NPU implementation of Llama4 RoPE.
q/k: (bs, sl, n_heads, hd) with interleaved complex last-dim layout.
freqs_cis: complex (..., hd//2) OR packed (..., 2*(hd//2)).
"""
original_dtype = q.dtype
bs, sl, n_qh, hd = q.shape
_, _, n_kh, _ = k.shape
if hd % 2 != 0:
raise ValueError(f"head_dim must be even for interleaved complex layout, got {hd}")
if freqs_cis.is_complex():
freqs_cis = freqs_cis.reshape(-1, freqs_cis.shape[-1])
if freqs_cis.shape[0] > sl:
freqs_cis = freqs_cis[:sl]
freqs_cis = torch.view_as_real(freqs_cis)
q, k, freqs_cis, compute_dtype = _cast_and_contiguous(q, k, freqs_cis)
# UB tiling strategy: tile heads dimension only
dtype_size = q.element_size()
shapes = ((n_qh, hd), (n_kh, hd))
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.90,
dtype_size=dtype_size,
memory_multiplier=20.0,
shapes=shapes,
tiling_dims=(0, 0),
)
if tile_shapes is not None and len(tile_shapes) == len(shapes):
q_tile_shape, k_tile_shape = tile_shapes
BLOCK_Q, _ = q_tile_shape
BLOCK_K, _ = k_tile_shape
BLOCK_Q = max(BLOCK_Q, 2)
BLOCK_K = max(BLOCK_K, 2)
else:
BLOCK_Q = triton.next_power_of_2(n_qh)
BLOCK_K = triton.next_power_of_2(n_kh)
n_row = bs * sl
_triton_llama4_rope_npu[(n_row,)](
q,
k,
freqs_cis,
q.stride(1),
k.stride(1),
q.stride(2),
k.stride(2),
freqs_cis.stride(0),
sl,
bs,
n_qh,
n_kh,
hd,
BLOCK_Q,
BLOCK_K,
imag_sign=1.0,
)
if compute_dtype != original_dtype:
q = q.to(original_dtype)
k = k.to(original_dtype)
return q, k
def llama4_rope_backward(dq, dk, freqs_cis):
"""
Ascend NPU implementation of Llama4 RoPE.
q/k: (bs, sl, n_heads, hd) with interleaved complex last-dim layout.
freqs_cis: complex (..., hd//2) OR packed (..., 2*(hd//2)).
"""
original_dtype = dq.dtype
bs, sl, n_qh, hd = dq.shape
_, _, n_kh, _ = dk.shape
if hd % 2 != 0:
raise ValueError(f"head_dim must be even for interleaved complex layout, got {hd}")
if freqs_cis.is_complex():
freqs_cis = freqs_cis.reshape(-1, freqs_cis.shape[-1])
if freqs_cis.shape[0] > sl:
freqs_cis = freqs_cis[:sl]
freqs_cis = torch.view_as_real(freqs_cis)
dq, dk, freqs_cis, compute_dtype = _cast_and_contiguous(dq, dk, freqs_cis)
# UB tiling strategy: tile heads dimension only
dtype_size = dq.element_size()
shapes = ((n_qh, hd), (n_kh, hd))
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.90,
dtype_size=dtype_size,
memory_multiplier=20.0,
shapes=shapes,
tiling_dims=(0, 0),
)
if tile_shapes is not None and len(tile_shapes) == len(shapes):
q_tile_shape, k_tile_shape = tile_shapes
BLOCK_Q, _ = q_tile_shape
BLOCK_K, _ = k_tile_shape
BLOCK_Q = max(BLOCK_Q, 2)
BLOCK_K = max(BLOCK_K, 2)
else:
BLOCK_Q = triton.next_power_of_2(n_qh)
BLOCK_K = triton.next_power_of_2(n_kh)
n_row = bs * sl
_triton_llama4_rope_npu[(n_row,)](
dq,
dk,
freqs_cis,
dq.stride(1),
dk.stride(1),
dq.stride(2),
dk.stride(2),
freqs_cis.stride(0),
sl,
bs,
n_qh,
n_kh,
hd,
BLOCK_Q,
BLOCK_K,
imag_sign=-1.0,
)
if compute_dtype != original_dtype:
dq = dq.to(original_dtype)
dk = dk.to(original_dtype)
return dq, dk
class LigerLlama4RopeFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None):
# BLOCK_SIZE is ignored for Ascend (we auto-tile heads by UB), kept for API compatibility
q_out, k_out = llama4_rope_forward(q, k, freqs_cis)
ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis)
return q_out, k_out
@staticmethod
def backward(ctx, dq, dk):
(freqs_cis,) = ctx.saved_tensors
dq_out, dk_out = llama4_rope_backward(dq, dk, freqs_cis)
return dq_out, dk_out, None, None
import torch
import triton
import triton.language as tl
from triton.language.math import rsqrt
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
# -----------------------------------------------------------------------------
# Forward Kernel - No Tiling (for n_cols <= 2048)
# -----------------------------------------------------------------------------
@triton.jit
def _poly_norm_forward_kernel_no_tiling(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
W_ptr, # weight: [3] for [w0, w1, w2]
B_ptr, # bias: scalar
RSTD_ptr, # cache rstd for backward: shape (n_rows, 3)
RSTD_row_stride,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
eps,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
"""
NPU-optimized PolyNorm forward kernel for small n_cols (<= 2048).
PolyNorm formula:
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
where norm(u) = u / sqrt(mean(u²) + ε)
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
# Grid-stride loop setup
grid_stride = num_progs * BLOCK_SIZE_M
num_iterations = tl.cdiv(n_rows, grid_stride)
col_offsets = tl.arange(0, BLOCK_SIZE_N)
col_mask = col_offsets < n_cols
row_offsets = tl.arange(0, BLOCK_SIZE_M)
# Load weights and bias
w0 = tl.load(W_ptr + 0)
w1 = tl.load(W_ptr + 1)
w2 = tl.load(W_ptr + 2)
b = tl.load(B_ptr)
# Grid-stride loop over row blocks
for i in range(num_iterations):
row_idx = i * grid_stride + pid * BLOCK_SIZE_M + row_offsets
row_mask = row_idx < n_rows
block_mask = row_mask[:, None] & col_mask[None, :]
# Load input rows
X_rows = tl.load(
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
mask=block_mask,
other=0.0,
cache_modifier=".cg",
)
X_f32 = X_rows.to(tl.float32)
# Compute x³, x², x
X_pow3 = X_f32 * X_f32 * X_f32
X_pow2 = X_f32 * X_f32
X_pow1 = X_f32
# Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps)
# Mask out out-of-bounds positions to prevent contaminating the sum
mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=1) / n_cols
rstd_3 = rsqrt(mean_square_3 + eps)
norm_x3 = X_pow3 * rstd_3[:, None]
# Compute norm(x²)
mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=1) / n_cols
rstd_2 = rsqrt(mean_square_2 + eps)
norm_x2 = X_pow2 * rstd_2[:, None]
# Compute norm(x)
mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=1) / n_cols
rstd_1 = rsqrt(mean_square_1 + eps)
norm_x1 = X_pow1 * rstd_1[:, None]
# Cache rstd values for backward (store 3 values per row)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride + 0, rstd_3.to(X_rows.dtype), mask=row_mask)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride + 1, rstd_2.to(X_rows.dtype), mask=row_mask)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride + 2, rstd_1.to(X_rows.dtype), mask=row_mask)
# Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
Y_f32 = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b
# Store output
tl.store(
Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :],
Y_f32.to(X_rows.dtype),
mask=block_mask,
)
# -----------------------------------------------------------------------------
# Forward Kernel - With Tiling (for n_cols > 2048)
# -----------------------------------------------------------------------------
@triton.jit
def _poly_norm_forward_kernel_npu(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
W_ptr, # weight: [3] for [w0, w1, w2]
B_ptr, # bias: scalar
RSTD_ptr, # cache rstd for backward: shape (n_rows, 3)
RSTD_row_stride,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""
NPU-optimized PolyNorm forward kernel with column blocking.
This kernel processes rows using a grid-stride loop pattern:
1. Each program handles multiple rows
2. For each row, we process it in column chunks of BLOCK_SIZE
3. Grid size is limited to NPU core count to avoid resource overflow
Three-pass algorithm per row:
- First pass: compute mean_square and rstd for x³, x², x across all column blocks
- Second pass: apply normalization and affine transformation
PolyNorm formula:
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
where norm(u) = u / sqrt(mean(u²) + ε)
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
num_col_blocks = tl.cdiv(n_cols, BLOCK_SIZE)
offsets = tl.arange(0, BLOCK_SIZE)
# Load weights and bias
w0 = tl.load(W_ptr + 0)
w1 = tl.load(W_ptr + 1)
w2 = tl.load(W_ptr + 2)
b = tl.load(B_ptr)
# Grid-stride loop over rows
for row_idx in range(pid, n_rows, num_progs):
Y_row_ptr = Y_ptr + row_idx * Y_row_stride
X_row_ptr = X_ptr + row_idx * X_row_stride
RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride
# First pass: compute mean_square for x³, x², x
sum_square_3 = 0.0
sum_square_2 = 0.0
sum_square_1 = 0.0
for col_block_idx in range(num_col_blocks):
col_start = col_block_idx * BLOCK_SIZE
col_offsets = col_start + offsets
mask = col_offsets < n_cols
X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32)
# Compute powers
X_pow3 = X_block * X_block * X_block
X_pow2 = X_block * X_block
X_pow1 = X_block
sum_square_3 += tl.sum(X_pow3 * X_pow3)
sum_square_2 += tl.sum(X_pow2 * X_pow2)
sum_square_1 += tl.sum(X_pow1 * X_pow1)
# Compute rstd values
mean_square_3 = sum_square_3 / n_cols
mean_square_2 = sum_square_2 / n_cols
mean_square_1 = sum_square_1 / n_cols
rstd_3 = rsqrt(mean_square_3 + eps)
rstd_2 = rsqrt(mean_square_2 + eps)
rstd_1 = rsqrt(mean_square_1 + eps)
# Store rstd values
tl.store(RSTD_row_ptr + 0, rstd_3)
tl.store(RSTD_row_ptr + 1, rstd_2)
tl.store(RSTD_row_ptr + 2, rstd_1)
# Second pass: normalize and apply affine transformation
for col_block_idx in range(num_col_blocks):
col_start = col_block_idx * BLOCK_SIZE
col_offsets = col_start + offsets
mask = col_offsets < n_cols
# Load input
X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".ca").to(tl.float32)
# Compute powers
X_pow3 = X_block * X_block * X_block
X_pow2 = X_block * X_block
X_pow1 = X_block
# Apply normalization
norm_x3 = X_pow3 * rstd_3
norm_x2 = X_pow2 * rstd_2
norm_x1 = X_pow1 * rstd_1
# Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
Y_f32 = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b
# Store result
tl.store(Y_row_ptr + col_offsets, Y_f32.to(X_block.dtype), mask=mask)
# -----------------------------------------------------------------------------
# Backward Kernel - No Tiling (for n_cols <= 2048)
# -----------------------------------------------------------------------------
@triton.jit
def _poly_norm_backward_kernel_no_tiling(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
W_ptr,
RSTD_ptr,
RSTD_row_stride,
dW_scratch_ptr, # shape: (n_programs, 3)
dW_scratch_stride,
dB_scratch_ptr, # shape: (n_programs,)
n_rows: tl.constexpr,
n_cols: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
"""
NPU-optimized PolyNorm backward kernel for small n_cols (<= 2048).
Backward pass equations:
∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
where:
- D_p = RMS(x^p) = 1/rstd_p
- S_p = sum(grad * x^p) over the row
- d = n_cols
- p ∈ {3, 2, 1}
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
# Grid-stride loop setup
grid_stride = num_progs * BLOCK_SIZE_M
num_iterations = tl.cdiv(n_rows, grid_stride)
col_offsets = tl.arange(0, BLOCK_SIZE_N)
col_mask = col_offsets < n_cols
row_offsets = tl.arange(0, BLOCK_SIZE_M)
# Load weights
w0 = tl.load(W_ptr + 0).to(tl.float32)
w1 = tl.load(W_ptr + 1).to(tl.float32)
w2 = tl.load(W_ptr + 2).to(tl.float32)
# Each program accumulates its own dW/dB contribution to avoid atomic contention
dW0_acc = 0.0
dW1_acc = 0.0
dW2_acc = 0.0
dB_acc = 0.0
# Grid-stride loop over row blocks
for i in range(num_iterations):
row_idx = i * grid_stride + pid * BLOCK_SIZE_M + row_offsets
row_mask = row_idx < n_rows
block_mask = row_mask[:, None] & col_mask[None, :]
# Load input and gradient data
X_rows = tl.load(
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
mask=block_mask,
other=0.0,
cache_modifier=".cg",
)
dY_rows = tl.load(
dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
mask=block_mask,
other=0.0,
cache_modifier=".cg",
)
# Load cached rstd values (3 values per row)
rstd_3 = tl.load(RSTD_ptr + row_idx * RSTD_row_stride + 0, mask=row_mask, other=0.0).to(tl.float32)
rstd_2 = tl.load(RSTD_ptr + row_idx * RSTD_row_stride + 1, mask=row_mask, other=0.0).to(tl.float32)
rstd_1 = tl.load(RSTD_ptr + row_idx * RSTD_row_stride + 2, mask=row_mask, other=0.0).to(tl.float32)
X_f32 = X_rows.to(tl.float32)
dY_f32 = dY_rows.to(tl.float32)
# Compute powers
X_pow3 = X_f32 * X_f32 * X_f32
X_pow2 = X_f32 * X_f32
X_pow1 = X_f32
# Accumulate bias gradient: dB = sum(dY)
dB_acc += tl.sum(dY_f32)
# Compute gradient w.r.t. input using closed-form formula
# For p=3: ∂L/∂x from w0 * norm(x³)
S_3 = tl.sum(dY_f32 * X_pow3, axis=1) # sum over columns for each row
grad_x_3 = w0 * (
3.0 * X_pow2 * rstd_3[:, None] * dY_f32
- (3.0 / n_cols) * X_pow2 * X_pow3 * (rstd_3[:, None] * rstd_3[:, None] * rstd_3[:, None]) * S_3[:, None]
)
# For p=2: ∂L/∂x from w1 * norm(x²)
S_2 = tl.sum(dY_f32 * X_pow2, axis=1)
grad_x_2 = w1 * (
2.0 * X_pow1 * rstd_2[:, None] * dY_f32
- (2.0 / n_cols) * X_pow1 * X_pow2 * (rstd_2[:, None] * rstd_2[:, None] * rstd_2[:, None]) * S_2[:, None]
)
# For p=1: ∂L/∂x from w2 * norm(x)
S_1 = tl.sum(dY_f32 * X_pow1, axis=1)
grad_x_1 = w2 * (
1.0 * rstd_1[:, None] * dY_f32
- (1.0 / n_cols) * X_pow1 * (rstd_1[:, None] * rstd_1[:, None] * rstd_1[:, None]) * S_1[:, None]
)
# Total gradient
dX_f32 = grad_x_3 + grad_x_2 + grad_x_1
# Store dX
tl.store(
dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
dX_f32.to(X_ptr.dtype.element_ty),
mask=block_mask,
)
# Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p
dW0_acc += tl.sum(rstd_3 * S_3)
dW1_acc += tl.sum(rstd_2 * S_2)
dW2_acc += tl.sum(rstd_1 * S_1)
# Write this program's accumulated dW/dB to its dedicated scratch row
tl.store(dW_scratch_ptr + pid * dW_scratch_stride + 0, dW0_acc)
tl.store(dW_scratch_ptr + pid * dW_scratch_stride + 1, dW1_acc)
tl.store(dW_scratch_ptr + pid * dW_scratch_stride + 2, dW2_acc)
tl.store(dB_scratch_ptr + pid, dB_acc)
# -----------------------------------------------------------------------------
# Backward Kernel - With Tiling (for n_cols > 2048)
# -----------------------------------------------------------------------------
@triton.jit
def _poly_norm_backward_kernel_npu(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
W_ptr,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dB_ptr,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
NPU-optimized PolyNorm backward kernel with column blocking.
Each program processes multiple rows using grid-stride loop.
For each row, we process columns in blocks to avoid UB overflow.
Two-pass algorithm:
- First pass: compute S_p = sum(grad * x^p) for p ∈ {3, 2, 1}
- Second pass: compute gradients dX, dW, dB
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
num_col_blocks = tl.cdiv(n_cols, BLOCK_SIZE)
offsets = tl.arange(0, BLOCK_SIZE)
# Load weights
w0 = tl.load(W_ptr + 0).to(tl.float32)
w1 = tl.load(W_ptr + 1).to(tl.float32)
w2 = tl.load(W_ptr + 2).to(tl.float32)
dw0_acc = 0.0
dw1_acc = 0.0
dw2_acc = 0.0
db_acc = 0.0
# Grid-stride loop over rows
for row_idx in range(pid, n_rows, num_progs):
dY_row_ptr = dY_ptr + row_idx * dY_row_stride
X_row_ptr = X_ptr + row_idx * X_row_stride
dX_row_ptr = dX_ptr + row_idx * dX_row_stride
RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride
# Load cached rstd values
rstd_3 = tl.load(RSTD_row_ptr + 0).to(tl.float32)
rstd_2 = tl.load(RSTD_row_ptr + 1).to(tl.float32)
rstd_1 = tl.load(RSTD_row_ptr + 2).to(tl.float32)
# First pass: compute S_p = sum(grad * x^p)
S_3 = 0.0
S_2 = 0.0
S_1 = 0.0
for col_block_idx in range(num_col_blocks):
col_start = col_block_idx * BLOCK_SIZE
col_offsets = col_start + offsets
mask = col_offsets < n_cols
X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
dY_block = tl.load(dY_row_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
# Compute powers
X_pow3 = X_block * X_block * X_block
X_pow2 = X_block * X_block
X_pow1 = X_block
S_3 += tl.sum(dY_block * X_pow3)
S_2 += tl.sum(dY_block * X_pow2)
S_1 += tl.sum(dY_block * X_pow1)
# Second pass: compute gradients
for col_block_idx in range(num_col_blocks):
col_start = col_block_idx * BLOCK_SIZE
col_offsets = col_start + offsets
mask = col_offsets < n_cols
X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
dY_block = tl.load(dY_row_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
# Compute powers
X_pow3 = X_block * X_block * X_block
X_pow2 = X_block * X_block
X_pow1 = X_block
# Compute gradient w.r.t. input using closed-form formula
# For p=3: ∂L/∂x from w0 * norm(x³)
grad_x_3 = w0 * (
3.0 * X_pow2 * rstd_3 * dY_block - (3.0 / n_cols) * X_pow2 * X_pow3 * (rstd_3 * rstd_3 * rstd_3) * S_3
)
# For p=2: ∂L/∂x from w1 * norm(x²)
grad_x_2 = w1 * (
2.0 * X_pow1 * rstd_2 * dY_block - (2.0 / n_cols) * X_pow1 * X_pow2 * (rstd_2 * rstd_2 * rstd_2) * S_2
)
# For p=1: ∂L/∂x from w2 * norm(x)
grad_x_1 = w2 * (1.0 * rstd_1 * dY_block - (1.0 / n_cols) * X_pow1 * (rstd_1 * rstd_1 * rstd_1) * S_1)
# Total gradient
dX_block = grad_x_3 + grad_x_2 + grad_x_1
# Store dX
tl.store(dX_row_ptr + col_offsets, dX_block.to(X_ptr.dtype.element_ty), mask=mask)
dw0_acc += tl.sum(rstd_3 * dY_block * X_pow3)
dw1_acc += tl.sum(rstd_2 * dY_block * X_pow2)
dw2_acc += tl.sum(rstd_1 * dY_block * X_pow1)
db_acc += tl.sum(dY_block)
tl.store(dW_ptr + 0, dw0_acc)
tl.store(dW_ptr + 1, dw1_acc)
tl.store(dW_ptr + 2, dw2_acc)
tl.store(dB_ptr, db_acc)
# -----------------------------------------------------------------------------
# Helper Functions
# -----------------------------------------------------------------------------
def get_optimal_block_size(n_cols, is_forward: bool):
"""
Calculate optimal block size using compute_default_tiling_strategy.
Memory analysis for forward pass (per row):
- Load: X_block (1 block)
- Compute: X_pow3, X_pow2, X_pow1, norm_x3, norm_x2, norm_x1 (6 blocks)
- Total: conservative estimate 8 blocks of memory
Memory analysis for backward pass (per row):
- Load: X_block, dY_block, RSTD (3 blocks)
- Compute: X_pow3, X_pow2, X_pow1, grad_x_3, grad_x_2, grad_x_1 (6 blocks)
- Total: conservative estimate 10 blocks of memory
Args:
n_cols: Number of columns in the tensor
is_forward: Whether this is for forward pass (True) or backward pass (False)
Returns:
Optimal block size
"""
if n_cols <= 2048:
return triton.next_power_of_2(n_cols)
memory_multiplier = 8.0 if is_forward else 10.0
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.8,
dtype_size=4,
memory_multiplier=memory_multiplier,
shapes=((n_cols,),),
tiling_dims=(0,),
)
if tile_shapes and len(tile_shapes) > 0:
block_size = tile_shapes[0][0]
return max(2048, block_size)
else:
return 2048
def _compute_grid_size(n_rows: int, block_size_m: int, num_cores: int) -> int:
"""
Compute the effective grid size for no-tiling kernels.
Limits the grid to the minimum of:
- The number of row blocks actually needed (ceil(n_rows / BLOCK_SIZE_M)), which
prevents launching idle programs that would waste core cycles
- NPU core count, which is the hardware concurrency upper bound
Args:
n_rows: Total number of rows to process
block_size_m: Number of rows each program handles per iteration
num_cores: Number of available NPU cores
Returns:
Effective grid size
"""
num_row_blocks = triton.cdiv(n_rows, block_size_m)
return min(num_cores, num_row_blocks)
# -----------------------------------------------------------------------------
# Forward and Backward Functions
# -----------------------------------------------------------------------------
def poly_norm_forward(X, W, B, eps=1e-6):
"""
PolyNorm Forward Pass
Args:
X: input tensor of shape (*, H) where H is hidden dimension
W: weight tensor of shape (3,) for [w0, w1, w2]
B: bias scalar tensor
eps: epsilon for numerical stability
Returns:
Y: output tensor of same shape as X
X: reshaped input (for backward)
RSTD: cached rstd values (for backward)
BLOCK_SIZE: block size used
"""
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
n_rows, n_cols = X.shape
# Check constraints
assert W.shape[0] == 3, "Weight tensor must have shape (3,)"
assert B.numel() == 1, "Bias must be a scalar"
# Get optimal block sizes
BLOCK_SIZE = get_optimal_block_size(n_cols, True)
BLOCK_SIZE_M = 2048 // BLOCK_SIZE
# RSTD is to cache rstd for each row (3 values per row)
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device)
# Grid size
num_cores = get_npu_core_count()
# Choose kernel based on n_cols
if n_cols <= 2048:
# Small kernel: use 2D tensor loading
grid_size = _compute_grid_size(n_rows, BLOCK_SIZE_M, num_cores)
_poly_norm_forward_kernel_no_tiling[(grid_size,)](
Y,
Y.stride(0),
X,
X.stride(0),
W,
B,
RSTD,
RSTD.stride(0),
n_rows,
n_cols,
eps,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE,
)
else:
# Large kernel: use column blocking
grid_size = min(num_cores, n_rows)
_poly_norm_forward_kernel_npu[(grid_size,)](
Y,
Y.stride(0),
X,
X.stride(0),
W,
B,
RSTD,
RSTD.stride(0),
n_rows,
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
)
return Y.view(*shape), X, RSTD
def poly_norm_backward(dY, X, W, RSTD, in_place):
"""
PolyNorm Backward Pass
Args:
dY: gradient of output
X: input tensor (already reshaped to 2D)
W: weight tensor
RSTD: cached rstd values from forward
BLOCK_SIZE: block size from forward
in_place: whether to in-place modify dY to store dX (saves memory)
Returns:
dX: gradient w.r.t. input
dW: gradient w.r.t. weight
dB: gradient w.r.t. bias
"""
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
n_rows, n_cols = dY.shape
# Get optimal block sizes
BLOCK_SIZE_BACKWARD = get_optimal_block_size(n_cols, False)
BLOCK_SIZE_M = 2048 // BLOCK_SIZE_BACKWARD
# Grid size
num_cores = get_npu_core_count()
# Allocate or reuse gradients
if in_place is True:
dX = dY
else:
dX = torch.zeros_like(dY)
# Choose kernel based on n_cols
if n_cols <= 2048:
# Small kernel: use 2D tensor loading with scratch buffers
grid_size = _compute_grid_size(n_rows, BLOCK_SIZE_M, num_cores)
# Allocate per-program scratch buffers for dW and dB
dW_scratch = torch.empty((grid_size, 3), dtype=torch.float32, device=W.device)
dB_scratch = torch.empty((grid_size,), dtype=torch.float32, device=W.device)
_poly_norm_backward_kernel_no_tiling[(grid_size,)](
dY,
dY.stride(0),
dX,
dX.stride(0),
X,
X.stride(0),
W,
RSTD,
RSTD.stride(0),
dW_scratch,
dW_scratch.stride(0),
dB_scratch,
n_rows,
n_cols,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_BACKWARD,
)
dW = dW_scratch.sum(dim=0).to(W.dtype)
dB = dB_scratch.sum().to(W.dtype)
else:
# Large kernel: use column blocking with atomic operations
grid_size = min(num_cores, n_rows)
dW = torch.zeros(3, dtype=torch.float32, device=W.device)
dB = torch.zeros(1, dtype=torch.float32, device=W.device)
_poly_norm_backward_kernel_npu[(grid_size,)](
dY,
dY.stride(0),
dX,
dX.stride(0),
X,
X.stride(0),
W,
RSTD,
RSTD.stride(0),
dW,
dB,
n_rows,
n_cols,
BLOCK_SIZE=BLOCK_SIZE_BACKWARD,
)
dW = dW.to(W.dtype)
dB = dB.squeeze().to(W.dtype)
# Reshape dX back to original shape
dX = dX.view(*shape)
return dX, dW, dB
# -----------------------------------------------------------------------------
# Autograd Function
# -----------------------------------------------------------------------------
class LigerPolyNormFunction(torch.autograd.Function):
"""
PolyNorm Function with forward and backward pass
PolyNorm formula:
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
where norm(u) = u / sqrt(mean(u²) + ε)
Backward uses closed-form gradient:
∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
"""
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, B, eps=1e-6, in_place=True):
"""
Args:
X: input tensor of shape (B, T, H) or (BxT, H)
W: weight tensor of shape (3,) for [w0, w1, w2]
B: bias scalar
eps: epsilon for numerical stability
in_place: whether to in-place modify grad_output in backward (saves memory)
Returns:
Y: output tensor of same shape as X
"""
Y, X, RSTD = poly_norm_forward(X, W, B, eps)
ctx.in_place = in_place
ctx.save_for_backward(X, W, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, grad_output):
"""
Args:
grad_output: gradient of output
Returns:
dX, dW, dB: gradients w.r.t. X, W, B
"""
X, W, RSTD = ctx.saved_tensors
dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.in_place)
return dX, dW, dB, None, None
import torch
import triton
import triton.language as tl
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
from liger_kernel.ops.utils import get_npu_core_count
@triton.jit
def _triton_qwen2vl_mrope_npu(
q_ptr,
q_row_stride,
k_ptr,
k_row_stride,
cos,
sin,
sl,
bs: tl.constexpr,
total_rows: tl.constexpr,
n_qh: tl.constexpr,
n_kh: tl.constexpr,
hd: tl.constexpr,
mrope_section_t: tl.constexpr,
mrope_section_h: tl.constexpr,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
BACKWARD_PASS: tl.constexpr = False,
):
program_id = tl.program_id(0)
num_programs = tl.num_programs(0)
rows_per_program = (total_rows + num_programs - 1) // num_programs
start_row = program_id * rows_per_program
actual_rows = tl.minimum(rows_per_program, total_rows - start_row)
for row_offset in tl.range(0, actual_rows):
pid = start_row + row_offset
t_end = mrope_section_t
h_end = t_end + mrope_section_h
t_cos = cos + pid * hd
h_cos = t_cos + bs * sl * hd
w_cos = h_cos + bs * sl * hd
t_sin = sin + pid * hd
h_sin = t_sin + bs * sl * hd
w_sin = h_sin + bs * sl * hd
q_base = q_ptr + pid * q_row_stride
k_base = k_ptr + pid * k_row_stride
d_idx = tl.arange(0, hd // 2)
d_mask = d_idx < (hd // 2)
pos_mask_t = d_idx < t_end
pos_mask_h = (d_idx >= t_end) & (d_idx < h_end)
text_cos_vals = tl.load(t_cos + d_idx, mask=d_mask, other=0)
text_sin_vals = tl.load(t_sin + d_idx, mask=d_mask, other=0)
height_cos_vals = tl.load(h_cos + d_idx, mask=d_mask, other=0)
height_sin_vals = tl.load(h_sin + d_idx, mask=d_mask, other=0)
width_cos_vals = tl.load(w_cos + d_idx, mask=d_mask, other=0)
width_sin_vals = tl.load(w_sin + d_idx, mask=d_mask, other=0)
cos_vals = tl.where(pos_mask_t, text_cos_vals, tl.where(pos_mask_h, height_cos_vals, width_cos_vals))
sin_vals = tl.where(pos_mask_t, text_sin_vals, tl.where(pos_mask_h, height_sin_vals, width_sin_vals))
# Process q heads in chunks to prevent UB overflow
for qh_block in range(0, n_qh, BLOCK_Q):
qh_idx = tl.arange(0, BLOCK_Q) + qh_block
qh_mask = qh_idx < n_qh
block_mask = qh_mask[:, None] & d_mask[None, :]
offsets = qh_idx[:, None] * hd + d_idx[None, :]
q_left = tl.load(q_base + offsets, mask=block_mask, other=0)
q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0)
if not BACKWARD_PASS:
new_left = q_left * cos_vals - q_right * sin_vals
new_right = q_right * cos_vals + q_left * sin_vals
else:
new_left = q_left * cos_vals + q_right * sin_vals
new_right = q_right * cos_vals - q_left * sin_vals
tl.store(q_base + offsets, new_left, mask=block_mask)
tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask)
# Process k heads in chunks to prevent UB overflow
for kh_block in range(0, n_kh, BLOCK_K):
kh_idx = tl.arange(0, BLOCK_K) + kh_block
kh_mask = kh_idx < n_kh
block_mask = kh_mask[:, None] & d_mask[None, :]
offsets = kh_idx[:, None] * hd + d_idx[None, :]
k_left = tl.load(k_base + offsets, mask=block_mask, other=0)
k_right = tl.load(k_base + offsets + (hd // 2), mask=block_mask, other=0)
if not BACKWARD_PASS:
new_left = k_left * cos_vals - k_right * sin_vals
new_right = k_right * cos_vals + k_left * sin_vals
else:
new_left = k_left * cos_vals + k_right * sin_vals
new_right = k_right * cos_vals - k_left * sin_vals
tl.store(k_base + offsets, new_left, mask=block_mask)
tl.store(k_base + offsets + (hd // 2), new_right, mask=block_mask)
def get_optimal_block_size_mrope(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size):
# MROPE forward tiling strategy:
# - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 6 = 3 * pad_hd elements each
# - In q heads loop (peak memory):
# * q_left: BLOCK_Q * (pad_hd // 2) elements
# * q_right: BLOCK_Q * (pad_hd // 2) elements
# * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
# * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
# * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
# - In k heads loop (peak memory):
# * k_left: BLOCK_K * (pad_hd // 2) elements
# * k_right: BLOCK_K * (pad_hd // 2) elements
# * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
# * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
# * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
# - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
# - Plus shared cos/sin: 6 * (pad_hd // 2) = 3 * pad_hd elements
# - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + 3 * pad_hd) * dtype_size * 8 bits
# - Simplified: (2 * BLOCK_SIZE + 3) * pad_hd * dtype_size * 8 bits
# - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
# - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
# - tiling_dims: (0, 0) means first dimension of each shape can be tiled
# - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.90,
dtype_size=dtype_size,
memory_multiplier=3.0,
shapes=shapes,
tiling_dims=(0, 0),
)
if tile_shapes is not None and len(tile_shapes) == len(shapes):
# Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
q_tile_shape, k_tile_shape = tile_shapes
BLOCK_Q, _ = q_tile_shape
BLOCK_K, _ = k_tile_shape
else:
# Fallback to conservative defaults
BLOCK_Q = 2048
BLOCK_K = 2048
return BLOCK_Q, BLOCK_K
def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
# transpose it back to the physical shape because Triton looks at the physical storage
q = q.transpose(1, 2)
k = k.transpose(1, 2)
batch_size, seq_len, n_q_head, head_dim = q.shape
n_kv_head = k.shape[2]
pad_hd = triton.next_power_of_2(head_dim)
pad_n_q_head = triton.next_power_of_2(n_q_head)
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
n_row = batch_size * seq_len
# ensure tensors passed into the kernel are contiguous
q = q.contiguous()
k = k.contiguous()
cos = cos.contiguous()
sin = sin.contiguous()
dtype_size = q.element_size()
BLOCK_Q, BLOCK_K = get_optimal_block_size_mrope(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size)
num_cores = get_npu_core_count()
grid_size = min(num_cores, n_row)
_triton_qwen2vl_mrope_npu[(grid_size,)](
q,
q.stride(1),
k,
k.stride(1),
cos,
sin,
seq_len,
batch_size,
n_row,
n_q_head,
n_kv_head,
head_dim,
mrope_section[0],
mrope_section[1],
BLOCK_Q,
BLOCK_K,
BACKWARD_PASS=False,
)
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
dq = dq.transpose(1, 2)
dk = dk.transpose(1, 2)
batch_size, seq_len, n_q_head, head_dim = dq.shape
n_kv_head = dk.shape[2]
pad_hd = triton.next_power_of_2(head_dim)
pad_n_q_head = triton.next_power_of_2(n_q_head)
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
n_row = batch_size * seq_len
# ensure dq and dk are contiguous
dq = dq.contiguous()
dk = dk.contiguous()
dtype_size = dq.element_size()
BLOCK_Q, BLOCK_K = get_optimal_block_size_mrope(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size)
num_cores = get_npu_core_count()
grid_size = min(num_cores, n_row)
_triton_qwen2vl_mrope_npu[(grid_size,)](
dq,
dq.stride(1),
dk,
dk.stride(1),
cos,
sin,
seq_len,
batch_size,
n_row,
n_q_head,
n_kv_head,
head_dim,
mrope_section[0],
mrope_section[1],
BLOCK_Q,
BLOCK_K,
BACKWARD_PASS=True,
)
return dq.transpose(1, 2), dk.transpose(1, 2)
class LigerQwen2VLMRopeFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""
q size: (bsz, n_q_head, seq_len, head_dim)
k size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (3, bsz, seq_len, head_dim)
sin size: (3, bsz, seq_len, head_dim)
"""
q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
ctx.save_for_backward(cos, sin)
ctx.mrope_section = mrope_section
return q, k
@staticmethod
def backward(ctx, dq, dk):
"""
dq size: (bsz, n_q_head, seq_len, head_dim)
dk size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (3, bsz, seq_len, head_dim)
sin size: (3, bsz, seq_len, head_dim)
"""
cos, sin = ctx.saved_tensors
mrope_section = ctx.mrope_section
dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
return dq, dk, None, None, None, None
import torch
import triton
import triton.language as tl
from triton.language.math import rsqrt
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
from liger_kernel.ops.utils import torch_to_triton_dtype
_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
def torch_dtype_to_triton(dtype):
mapping = {
torch.float32: tl.float32,
torch.bfloat16: tl.bfloat16,
}
return mapping.get(dtype, tl.float32)
# -----------------------------------------------------------------------------
# Forward Kernel - No Tiling (for n_cols <= 2048)
# -----------------------------------------------------------------------------
@triton.jit
def _rms_norm_forward_kernel_no_tiling(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
W_ptr,
RSTD_ptr,
RSTD_row_stride,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
eps,
offset,
casting_mode: tl.constexpr,
elementwise_affine: tl.constexpr,
X_DTYPE: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
"""
NPU-optimized rms_norm forward kernel for small n_cols (< 2048).
Performance optimizations:
1. Use 2D vector loading to maximize UB utilization (e.g., (1,2048), (2,1024), (4,512))
2. Process multiple rows at once using 2D indexing
3. Keep data in registers, minimize conversions
4. Use optimal cache policies
Used when n_cols < 2048 to avoid the overhead of column blocking.
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
if casting_mode == _CASTING_MODE_NONE:
eps = eps.to(X_DTYPE)
offset = offset.to(X_DTYPE)
# Grid-stride loop setup for 2D blocks
grid_stride = num_progs * BLOCK_SIZE_M
num_iterations = tl.cdiv(n_rows, grid_stride)
col_offsets = tl.arange(0, BLOCK_SIZE_N)
col_mask = col_offsets < n_cols
row_offsets = tl.arange(0, BLOCK_SIZE_M)
if elementwise_affine:
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
# Grid-stride loop over row blocks
for i in range(num_iterations):
row_idx = i * grid_stride + pid * BLOCK_SIZE_M + row_offsets
row_mask = row_idx < n_rows
block_mask = row_mask[:, None] & col_mask[None, :]
# Load multiple rows at once using 2D indexing
X_rows = tl.load(
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
mask=block_mask,
other=0.0,
eviction_policy="evict_first",
)
# Compute sum_square for all rows
if casting_mode == _CASTING_MODE_LLAMA or casting_mode == _CASTING_MODE_GEMMA:
X_rows = X_rows.to(tl.float32)
sum_squares = tl.sum(tl.where(block_mask, X_rows * X_rows, 0.0), axis=1)
# Compute rstd for all rows
mean_squares = sum_squares / n_cols
rstd_rows = rsqrt(mean_squares + eps)
# Store rstd_rows
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd_rows, mask=row_mask)
# Apply casting based on mode
if casting_mode == _CASTING_MODE_GEMMA:
X_rows = X_rows.to(tl.float32)
if elementwise_affine:
W_row_fp32 = W_row.to(tl.float32)
elif casting_mode == _CASTING_MODE_LLAMA:
X_rows = X_rows.to(tl.float32)
# Normalize
X_rows = X_rows * rstd_rows[:, None]
# Cast back for Llama mode before weight multiplication
if casting_mode == _CASTING_MODE_LLAMA:
X_rows = X_rows.to(X_DTYPE)
# Apply weight
if elementwise_affine:
if casting_mode == _CASTING_MODE_GEMMA:
Y_rows = X_rows * (offset + W_row_fp32[None, :])
else:
Y_rows = X_rows * (offset + W_row[None, :])
else:
Y_rows = X_rows
# Cast back for Gemma mode
if casting_mode == _CASTING_MODE_GEMMA:
Y_rows = Y_rows.to(X_DTYPE)
# Store results
tl.store(Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :], Y_rows, mask=block_mask)
# -----------------------------------------------------------------------------
# Forward Kernel - With Tiling (for n_cols > 2048)
# -----------------------------------------------------------------------------
@triton.jit
def _rms_norm_forward_kernel_tiled(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
W_ptr,
RSTD_ptr,
RSTD_row_stride,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
eps,
offset,
casting_mode: tl.constexpr,
elementwise_affine: tl.constexpr,
X_DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
NPU-optimized rms_norm forward kernel for large n_cols (>= 2048).
This kernel processes rows using a grid-stride loop pattern:
1. Each program handles multiple rows
2. For each row, we process it in column chunks of BLOCK_SIZE
3. Grid size is limited to NPU core count to avoid resource overflow
This solves two problems:
1. UB overflow when n_cols is too large (original kernel used n_cols as BLOCK_SIZE)
2. Efficient multi-row processing within a single kernel launch
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
num_col_blocks = tl.cdiv(n_cols, BLOCK_SIZE)
if casting_mode == _CASTING_MODE_NONE:
eps = eps.to(X_DTYPE)
offset = offset.to(X_DTYPE)
offsets = tl.arange(0, BLOCK_SIZE)
# Grid-stride loop over rows
for row_idx in tl.range(pid, n_rows, num_progs):
Y_row_ptr = Y_ptr + row_idx * Y_row_stride
X_row_ptr = X_ptr + row_idx * X_row_stride
RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride
# Accumulator for mean_square computation across all column blocks
sum_square = 0.0
# First pass: accumulate sum of squares
for col_block_idx in range(num_col_blocks):
col_start = col_block_idx * BLOCK_SIZE
col_offsets = col_start + offsets
mask = col_offsets < n_cols
X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, eviction_policy="evict_first")
if casting_mode == _CASTING_MODE_LLAMA or casting_mode == _CASTING_MODE_GEMMA:
X_block = X_block.to(tl.float32)
# Accumulate sum of squares (only for valid elements)
sum_square += tl.sum(X_block * X_block)
# Compute rstd for this row
mean_square = sum_square / n_cols
rstd = rsqrt(mean_square + eps)
# Store rstd
tl.store(RSTD_row_ptr, rstd)
# Second pass: normalize and multiply by weight
for col_block_idx in range(num_col_blocks):
col_start = col_block_idx * BLOCK_SIZE
col_offsets = col_start + offsets
mask = col_offsets < n_cols
# Load X_block
X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".ca")
if elementwise_affine:
W_block = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
# Apply casting based on mode
if casting_mode == _CASTING_MODE_GEMMA:
X_block = X_block.to(tl.float32)
if elementwise_affine:
W_block = W_block.to(tl.float32)
elif casting_mode == _CASTING_MODE_LLAMA:
X_block = X_block.to(tl.float32)
# Normalize
X_block = X_block * rstd
# Cast back for Llama mode before weight multiplication
if casting_mode == _CASTING_MODE_LLAMA:
X_block = X_block.to(X_DTYPE)
# Apply weight
if elementwise_affine:
Y_block = X_block * (offset + W_block)
else:
Y_block = X_block
# Cast back for Gemma mode
if casting_mode == _CASTING_MODE_GEMMA:
Y_block = Y_block.to(X_DTYPE)
# Store result
tl.store(Y_row_ptr + col_offsets, Y_block, mask=mask)
# -----------------------------------------------------------------------------
# Backward Kernel - No Tiling (for n_cols <= 2048)
# -----------------------------------------------------------------------------
@triton.jit
def _rms_norm_backward_kernel_no_tiling(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
W_ptr,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
offset,
casting_mode: tl.constexpr,
elementwise_affine: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
"""
NPU-optimized rms_norm backward kernel for small n_cols (< 2048).
Performance optimizations:
1. Keep all data in registers, minimize conversions
2. Reuse X_normalized (X * rstd) for both dX and dW
3. Optimize computation order to reduce register pressure
4. Combine operations where possible
5. Use 2D vector loading to maximize UB utilization (e.g., (1,2048), (2,1024), (4,512))
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
# Grid-stride loop setup for 2D blocks
grid_stride = num_progs * BLOCK_SIZE_M
num_iterations = tl.cdiv(n_rows, grid_stride)
col_offsets = tl.arange(0, BLOCK_SIZE_N)
col_mask = col_offsets < n_cols
row_offsets = tl.arange(0, BLOCK_SIZE_M)
# Load W once for all iterations
if elementwise_affine:
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
W_offset = W_row + offset
# Grid-stride loop over row blocks
for i in range(num_iterations):
row_idx = i * grid_stride + pid * BLOCK_SIZE_M + row_offsets
row_mask = row_idx < n_rows
block_mask = row_mask[:, None] & col_mask[None, :]
dY_rows = tl.load(
dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
mask=block_mask,
other=0.0,
eviction_policy="evict_first",
)
X_rows = tl.load(
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
mask=block_mask,
other=0.0,
eviction_policy="evict_first",
)
# Load rstd for all rows in the block
rstd_rows = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, mask=row_mask, other=0.0)
# Convert X to fp32 once
X_rows = X_rows.to(tl.float32)
# Compute X_normalized (reused in dX and dW)
X_normalized = X_rows * rstd_rows[:, None]
# Compute m based on casting mode and elementwise_affine
if elementwise_affine:
if casting_mode == _CASTING_MODE_LLAMA:
m_rows = (dY_rows * W_offset[None, :]).to(tl.float32)
# For dW in Llama mode, we need X_normalized in original dtype
X_normalized = X_normalized.to(X_dtype)
elif casting_mode == _CASTING_MODE_GEMMA:
m_rows = dY_rows.to(tl.float32) * W_offset[None, :]
else:
m_rows = dY_rows * W_offset[None, :]
else:
if casting_mode == _CASTING_MODE_LLAMA or casting_mode == _CASTING_MODE_GEMMA:
m_rows = dY_rows.to(tl.float32)
else:
m_rows = dY_rows
# Compute sum(m * X) for correction factor
sum_m_X = tl.sum(tl.where(block_mask, m_rows * X_rows, 0.0), axis=1)
# Compute correction factor
correction_factors = -(1.0 / n_cols) * rstd_rows * rstd_rows * sum_m_X
# Compute dX = rstd * m + rstd * correction_factor * X
dX_rows = rstd_rows[:, None] * m_rows + rstd_rows[:, None] * correction_factors[:, None] * X_rows
# Store dX
tl.store(dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :], dX_rows.to(X_dtype), mask=block_mask)
if elementwise_affine:
# Compute dW contribution: dY * X_normalized
dW_rows = (dY_rows * X_normalized).to(tl.float32)
# Accumulate to per-program dW buffer
dW_row_ptr = dW_ptr + pid * dW_row_stride
existing_dW = tl.load(dW_row_ptr + col_offsets, mask=col_mask, other=0.0)
new_dW = existing_dW + tl.sum(tl.where(block_mask, dW_rows, 0.0), axis=0)
tl.store(dW_row_ptr + col_offsets, new_dW, mask=col_mask)
# -----------------------------------------------------------------------------
# Backward Kernel - With Tiling (for n_cols > 2048)
# -----------------------------------------------------------------------------
@triton.jit
def _rms_norm_backward_kernel_tiled(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
W_ptr,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
offset,
casting_mode: tl.constexpr,
elementwise_affine: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
NPU-optimized rms_norm backward kernel for large n_cols (>= 2048).
Each program processes multiple rows using grid-stride loop.
For each row, we process columns in blocks to avoid UB overflow.
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
# Initialize dW accumulator (per-program, will be reduced later)
num_col_blocks = tl.cdiv(n_cols, BLOCK_SIZE)
offsets = tl.arange(0, BLOCK_SIZE)
# Grid-stride loop over rows
for row_idx in tl.range(pid, n_rows, num_progs):
# Base pointers for this row
dY_row_ptr = dY_ptr + row_idx * dY_row_stride
dX_row_ptr = dX_ptr + row_idx * dX_row_stride
X_row_ptr = X_ptr + row_idx * X_row_stride
RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride
# Load rstd for this row
rstd = tl.load(RSTD_row_ptr)
# First pass: compute sum(m * X) for the correction term
sum_m_X = 0.0
for col_block_idx in range(num_col_blocks):
col_start = col_block_idx * BLOCK_SIZE
col_offsets = col_start + offsets
mask = col_offsets < n_cols
dY_block = tl.load(dY_row_ptr + col_offsets, mask=mask, other=0.0, eviction_policy="evict_first")
X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, eviction_policy="evict_first")
# Convert to fp32 for computation
X_block = X_block.to(tl.float32)
if elementwise_affine:
W_block = tl.load(W_ptr + col_offsets, mask=mask, other=0.0, eviction_policy="evict_first")
W_offset = W_block + offset
# Compute m based on casting mode
if casting_mode == _CASTING_MODE_LLAMA:
m = (dY_block * W_offset).to(tl.float32)
elif casting_mode == _CASTING_MODE_GEMMA:
dY_block = dY_block.to(tl.float32)
m = dY_block * W_offset
else:
m = dY_block * W_offset
else:
# Compute m based on casting mode
if casting_mode == _CASTING_MODE_LLAMA:
m = dY_block.to(tl.float32)
elif casting_mode == _CASTING_MODE_GEMMA:
m = dY_block.to(tl.float32)
else:
m = dY_block
# Accumulate sum(m * X)
sum_m_X += tl.sum(m * X_block)
# Compute the correction factor
correction_factor = -(1.0 / n_cols) * rstd * rstd * sum_m_X
# Second pass: compute gradients
for col_block_idx in range(num_col_blocks):
col_start = col_block_idx * BLOCK_SIZE
col_offsets = col_start + offsets
mask = col_offsets < n_cols
dY_block = tl.load(dY_row_ptr + col_offsets, mask=mask, other=0.0)
X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0)
X_block = X_block.to(tl.float32)
if elementwise_affine:
W_block = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
W_offset = W_block + offset
# Compute m based on casting mode
if casting_mode == _CASTING_MODE_LLAMA:
m = (dY_block * W_offset).to(tl.float32)
elif casting_mode == _CASTING_MODE_GEMMA:
dY_block = dY_block.to(tl.float32)
m = dY_block * W_offset
else:
m = dY_block * W_offset
else:
# Compute m based on casting mode
if casting_mode == _CASTING_MODE_LLAMA:
m = dY_block.to(tl.float32)
elif casting_mode == _CASTING_MODE_GEMMA:
m = dY_block.to(tl.float32)
else:
m = dY_block
# Compute dX
dX_block = rstd * m + rstd * correction_factor * X_block
# Store dX
tl.store(dX_row_ptr + col_offsets, dX_block.to(X_dtype), mask=mask)
if elementwise_affine:
# Compute dW contribution (accumulate per program)
if casting_mode == _CASTING_MODE_LLAMA:
dW_block = dY_block * (X_block * rstd).to(X_dtype)
else:
dW_block = dY_block * (X_block * rstd)
# Atomic add to dW_ptr (each program writes to its own row)
dW_row_ptr = dW_ptr + pid * dW_row_stride
# Load existing dW, add contribution, store back
existing_dW = tl.load(dW_row_ptr + col_offsets, mask=mask, other=0.0)
new_dW = existing_dW + dW_block.to(tl.float32)
tl.store(dW_row_ptr + col_offsets, new_dW, mask=mask)
# -----------------------------------------------------------------------------
# Helper Functions
# -----------------------------------------------------------------------------
def get_optimal_block_size(n_cols, is_forward: bool):
"""
Calculate optimal block size for forward pass using compute_default_tiling_strategy.
Memory analysis for forward pass (per row):
- Load: X_block, W_block (2 blocks)
- Compute: X_block (fp32), Y_block (1-2 blocks)
- Total: conservative estimate 6 blocks of memory
Memory analysis for backward pass (per row):
- Load: dY_block, X_block, W_block (3 blocks)
- Compute: m, dX_block, dW_block (3 blocks)
- Store: dX_block, accumulated dW (2 blocks)
- Total: conservative estimate 8 blocks of memory
Args:
n_cols: Number of columns in the tensor
is_forward: Whether this is for forward pass
Returns:
Optimal block size
"""
if n_cols <= 2048:
return triton.next_power_of_2(n_cols)
memory_multiplier = 6.0 if is_forward else 8.0
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9,
dtype_size=4,
memory_multiplier=memory_multiplier,
shapes=((n_cols,),),
tiling_dims=(0,),
)
if tile_shapes and len(tile_shapes) > 0:
block_size = tile_shapes[0][0]
return max(2048, block_size)
else:
return 2048
# -----------------------------------------------------------------------------
# Forward and Backward Functions
# -----------------------------------------------------------------------------
_str_to_casting_mode = {
"llama": _CASTING_MODE_LLAMA.value,
"gemma": _CASTING_MODE_GEMMA.value,
"none": _CASTING_MODE_NONE.value,
}
def rms_norm_forward(X, W, eps, offset, casting_mode):
if not isinstance(casting_mode, int):
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
casting_mode = _str_to_casting_mode[casting_mode]
else:
assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
n_rows, n_cols = X.shape
X_DTYPE = torch_dtype_to_triton(X.dtype)
# Get optimal block size for column processing
BLOCK_SIZE = get_optimal_block_size(n_cols, True)
BLOCK_SIZE_M = 2048 // BLOCK_SIZE
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
# RSTD is always fp32 for Llama/Gemma modes
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
if W is not None:
# Check constraints
assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension"
elementwise_affine = True
else:
elementwise_affine = False
# Grid size limited to NPU core count
num_cores = get_npu_core_count()
grid_size = min(num_cores * 2, n_rows)
# Choose kernel based on n_cols
if n_cols <= 2048:
# Use no-tiling kernel for small n_cols
_rms_norm_forward_kernel_no_tiling[(grid_size,)](
Y,
Y.stride(0),
X,
X.stride(0),
W,
RSTD,
RSTD.stride(0),
n_rows,
n_cols,
eps,
offset,
casting_mode,
elementwise_affine,
X_DTYPE,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE,
)
else:
# Use tiled kernel for large n_cols
_rms_norm_forward_kernel_tiled[(grid_size,)](
Y,
Y.stride(0),
X,
X.stride(0),
W,
RSTD,
RSTD.stride(0),
n_rows,
n_cols,
eps,
offset,
casting_mode,
elementwise_affine,
X_DTYPE,
BLOCK_SIZE=BLOCK_SIZE,
)
return Y.view(*shape), X, RSTD, casting_mode
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, in_place):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
n_rows, n_cols = dY.shape
# Get NPU core count for grid size
num_cores = get_npu_core_count()
grid_size = min(num_cores * 2, n_rows)
# Get optimal block size for backward pass
BLOCK_SIZE = get_optimal_block_size(n_cols, False)
BLOCK_SIZE_M = 2048 // BLOCK_SIZE
if W is not None:
# fp32 for numerical stability
_dW = torch.zeros((grid_size, n_cols), dtype=torch.float32, device=W.device)
elementwise_affine = True
else:
_dW = None
elementwise_affine = False
if in_place:
dX = dY
else:
dX = torch.empty_like(dY)
# Choose kernel based on n_cols
if n_cols <= 2048:
# Use no-tiling kernel for small n_cols
_rms_norm_backward_kernel_no_tiling[(grid_size,)](
dY,
dY.stride(0),
dX,
dX.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
W,
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0) if elementwise_affine else 0,
n_rows,
n_cols,
offset,
casting_mode,
elementwise_affine,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE,
)
else:
# Use tiled kernel for large n_cols
_rms_norm_backward_kernel_tiled[(grid_size,)](
dY,
dY.stride(0),
dX,
dX.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
W,
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0) if elementwise_affine else 0,
n_rows,
n_cols,
offset,
casting_mode,
elementwise_affine,
BLOCK_SIZE=BLOCK_SIZE,
)
dX = dX.view(*shape)
if elementwise_affine:
dW = _dW.sum(dim=0).to(W.dtype)
else:
dW = None
return dX, dW
class LigerRMSNormFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None):
"""
X: (B, T, H) or (BxT, H)
W: (H,)
"""
if isinstance(X, torch.distributed.tensor.DTensor):
# Input tensor is output of a tensor parallel module and
# needs to be gathered to a local tensor to compute
# RMSE layer norm on each TP worker.
# TODO: support CP.
X = X.full_tensor()
Y, X, RSTD, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
ctx.offset = offset
ctx.casting_mode = casting_mode
ctx.in_place = in_place
ctx.elementwise_affine = W is not None
if W is not None:
ctx.save_for_backward(X, W, RSTD)
else:
ctx.save_for_backward(X, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
"""
Y: (B, T, H) or (BxT, H)
"""
if ctx.elementwise_affine:
X, W, RSTD = ctx.saved_tensors
else:
X, RSTD = ctx.saved_tensors
W = None
if isinstance(dY, torch.distributed.tensor.DTensor):
# Gradients are output of a tensor parallel module and
# needs to be gathered to a local tensor for computing RMSE layer.
# TODO: support CP.
dY = dY.full_tensor()
dX, dW = rms_norm_backward(dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.in_place)
return dX, dW, None, None, None, None, None
import torch
import triton
import triton.language as tl
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
from liger_kernel.ops.utils import get_npu_core_count
@triton.jit
def _triton_rope_npu(
q_ptr,
q_row_stride,
k_ptr,
k_row_stride,
cos,
cos_row_stride,
sin,
sin_row_stride,
sl,
total_rows: tl.constexpr,
cos_bs: tl.constexpr,
n_qh: tl.constexpr,
n_kh: tl.constexpr,
hd: tl.constexpr,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
BACKWARD_PASS: tl.constexpr = False,
):
program_id = tl.program_id(0)
num_programs = tl.num_programs(0)
rows_per_program = (total_rows + num_programs - 1) // num_programs
start_row = program_id * rows_per_program
actual_rows = tl.minimum(rows_per_program, total_rows - start_row)
for row_offset in tl.range(0, actual_rows):
pid = start_row + row_offset
row_idx = pid % sl
cos_ptr = cos + tl.where(cos_bs == 1, row_idx * cos_row_stride, pid * cos_row_stride)
sin_ptr = sin + tl.where(cos_bs == 1, row_idx * sin_row_stride, pid * sin_row_stride)
# Pre-compute d_idx and cos/sin values outside loops (they don't depend on heads)
d_idx = tl.arange(0, hd // 2)
d_mask = d_idx < (hd // 2) # Always True, but kept for clarity
cos_vals = tl.load(cos_ptr + d_idx, mask=d_mask, other=0)
sin_vals = tl.load(sin_ptr + d_idx, mask=d_mask, other=0)
# Process q heads in chunks to prevent UB overflow
for qh_block in range(0, n_qh, BLOCK_Q):
qh_idx = tl.arange(0, BLOCK_Q) + qh_block
qh_mask = qh_idx < n_qh
# block_mask: qh_mask broadcasted over d_idx dimension
block_mask = qh_mask[:, None]
offsets = qh_idx[:, None] * hd + d_idx[None, :]
q_base = q_ptr + pid * q_row_stride
q_left = tl.load(q_base + offsets, mask=block_mask, other=0)
q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0)
if not BACKWARD_PASS:
new_left = q_left * cos_vals - q_right * sin_vals
new_right = q_right * cos_vals + q_left * sin_vals
else:
new_left = q_left * cos_vals + q_right * sin_vals
new_right = q_right * cos_vals - q_left * sin_vals
tl.store(q_base + offsets, new_left, mask=block_mask)
tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask)
# Process k heads in chunks to prevent UB overflow
for kh_block in range(0, n_kh, BLOCK_K):
kh_idx = tl.arange(0, BLOCK_K) + kh_block
kh_mask = kh_idx < n_kh
# block_mask: kh_mask broadcasted over d_idx dimension
block_mask = kh_mask[:, None]
offsets = kh_idx[:, None] * hd + d_idx[None, :]
k_base = k_ptr + pid * k_row_stride
k_left = tl.load(k_base + offsets, mask=block_mask, other=0)
k_right = tl.load(k_base + offsets + (hd // 2), mask=block_mask, other=0)
if not BACKWARD_PASS:
new_left = k_left * cos_vals - k_right * sin_vals
new_right = k_right * cos_vals + k_left * sin_vals
else:
new_left = k_left * cos_vals + k_right * sin_vals
new_right = k_right * cos_vals - k_left * sin_vals
tl.store(k_base + offsets, new_left, mask=block_mask)
tl.store(k_base + offsets + (hd // 2), new_right, mask=block_mask)
def get_optimal_block_size(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size):
# Compute tiling strategy based on UB capacity
# ROPE forward tiling strategy (based on optimized ROPE kernel):
# - cos_vals and sin_vals are loaded once outside loops (shared): pad_hd // 2 elements each
# - In q heads loop (peak memory):
# * q_left: BLOCK_Q * (pad_hd // 2) elements
# * q_right: BLOCK_Q * (pad_hd // 2) elements
# * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
# * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
# * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
# - In k heads loop (peak memory):
# * k_left: BLOCK_K * (pad_hd // 2) elements
# * k_right: BLOCK_K * (pad_hd // 2) elements
# * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
# * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
# * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
# - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
# - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
# - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
# - Simplified: (2 * BLOCK_SIZE + 1) * pad_hd * dtype_size * 8 bits
# - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
# - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
# - tiling_dims: (0, 0) means first dimension of each shape can be tiled
# - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.90,
dtype_size=dtype_size,
memory_multiplier=3.0,
shapes=shapes,
tiling_dims=(0, 0),
)
if tile_shapes is not None and len(tile_shapes) == len(shapes):
# Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
q_tile_shape, k_tile_shape = tile_shapes
BLOCK_Q, _ = q_tile_shape
BLOCK_K, _ = k_tile_shape
else:
# Fallback to conservative defaults
BLOCK_Q = 2048
BLOCK_K = 2048
return BLOCK_Q, BLOCK_K
def rope_forward(q, k, cos, sin):
# transpose it back to the physical shape because Triton looks at the physical storage
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
q = q.transpose(1, 2)
k = k.transpose(1, 2)
batch_size, seq_len, n_q_head, head_dim = q.shape
n_kv_head = k.shape[2]
pad_hd = triton.next_power_of_2(head_dim)
pad_n_q_head = triton.next_power_of_2(n_q_head)
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
n_row = batch_size * seq_len
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
q = q.contiguous()
k = k.contiguous()
cos = cos.contiguous()
sin = sin.contiguous()
cos_batch_size = cos.shape[0]
dtype_size = q.element_size()
BLOCK_Q, BLOCK_K = get_optimal_block_size(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size)
num_cores = get_npu_core_count()
grid_size = min(num_cores, n_row)
_triton_rope_npu[(grid_size,)](
q,
q.stride(1),
k,
k.stride(1),
cos,
cos.stride(-2),
sin,
sin.stride(-2),
seq_len,
n_row,
cos_batch_size,
n_q_head,
n_kv_head,
head_dim,
BLOCK_Q,
BLOCK_K,
BACKWARD_PASS=False,
)
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
def rope_backward(dq, dk, cos, sin):
dq = dq.transpose(1, 2)
dk = dk.transpose(1, 2)
batch_size, seq_len, n_q_head, head_dim = dq.shape
cos_batch_size = cos.shape[0]
n_kv_head = dk.shape[2]
pad_hd = triton.next_power_of_2(head_dim)
pad_n_q_head = triton.next_power_of_2(n_q_head)
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
n_row = batch_size * seq_len
# ensure dq and dk are contiguous
dq = dq.contiguous()
dk = dk.contiguous()
dtype_size = dq.element_size()
BLOCK_Q, BLOCK_K = get_optimal_block_size(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size)
num_cores = get_npu_core_count()
grid_size = min(num_cores, n_row)
_triton_rope_npu[(grid_size,)](
dq,
dq.stride(1),
dk,
dk.stride(1),
cos,
cos.stride(-2),
sin,
sin.stride(-2),
seq_len,
n_row,
cos_batch_size,
n_q_head,
n_kv_head,
head_dim,
BLOCK_Q,
BLOCK_K,
BACKWARD_PASS=True,
)
return dq.transpose(1, 2), dk.transpose(1, 2)
class LigerRopeFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""
q size: (bsz, n_q_head, seq_len, head_dim)
k size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
"""
q, k, cos, sin = rope_forward(q, k, cos, sin)
ctx.save_for_backward(cos, sin)
return q, k
@staticmethod
def backward(ctx, dq, dk):
"""
dq size: (bsz, n_q_head, seq_len, head_dim)
dk size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
"""
cos, sin = ctx.saved_tensors
dq, dk = rope_backward(dq, dk, cos, sin)
return dq, dk, None, None, None, None
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
@triton.jit
def _softmax_single_block_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
ROWS_PER_BLOCK: tl.constexpr,
):
"""
Single-block softmax forward kernel for small column sizes.
Processes entire row in one block when n_cols <= BLOCK_SIZE.
Uses 2D tensor to process multiple rows simultaneously for better UB utilization.
Args:
Y_ptr: Output tensor pointer
Y_row_stride: Stride for output rows
X_ptr: Input tensor pointer
X_row_stride: Stride for input rows
n_rows: Number of rows to process
n_cols: Number of columns per row
BLOCK_SIZE: Block size for column processing
ROWS_PER_BLOCK: Number of rows to process simultaneously
"""
row_block_start = tl.program_id(0) * ROWS_PER_BLOCK
row_block_step = tl.num_programs(0) * ROWS_PER_BLOCK
row_offsets = tl.arange(0, ROWS_PER_BLOCK)
col_offsets = tl.arange(0, BLOCK_SIZE)
for row_block_idx in tl.range(row_block_start, n_rows, row_block_step):
row_idx = row_block_idx + row_offsets
row_mask = row_idx < n_rows
col_mask = col_offsets < n_cols
# 2D mask: [ROWS_PER_BLOCK, BLOCK_SIZE]
mask = row_mask[:, None] & col_mask[None, :]
# Load 2D block: [ROWS_PER_BLOCK, BLOCK_SIZE]
offsets = row_idx[:, None] * X_row_stride + col_offsets[None, :]
x = tl.load(X_ptr + offsets, mask=mask, other=float("-inf"))
# Compute softmax per row (axis=1)
m = tl.max(x, axis=1)
e = tl.exp(x - m[:, None])
d = tl.sum(e, axis=1)
y = e / d[:, None]
# Store 2D block
offsets = row_idx[:, None] * Y_row_stride + col_offsets[None, :]
tl.store(Y_ptr + offsets, y, mask=mask)
@triton.jit
def _softmax_multi_block_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Multi-block softmax forward kernel using two-pass algorithm.
First pass computes max and sum for numerical stability.
Second pass normalizes and writes output.
Args:
Y_ptr: Output tensor pointer
Y_row_stride: Stride for output rows
X_ptr: Input tensor pointer
X_row_stride: Stride for input rows
n_rows: Number of rows to process
n_cols: Number of columns per row
BLOCK_SIZE: Block size for column processing
"""
row_start = tl.program_id(0)
num_prog = tl.num_programs(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
for row_idx in tl.range(row_start, n_rows, num_prog):
row_start_ptr = X_ptr + row_idx * X_row_stride
m = tl.float32(float("-inf"))
d = tl.float32(0.0)
for start in tl.range(0, n_cols, BLOCK_SIZE):
idx = start + col_offsets
mask = idx < n_cols
xblk = tl.load(
row_start_ptr + idx, mask=mask, other=float("-inf"), eviction_policy="evict_first", cache_modifier=".ca"
)
blk_max = tl.max(xblk, axis=0)
new_m = tl.maximum(m, blk_max)
d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0)
m = new_m
for start in tl.range(0, n_cols, BLOCK_SIZE):
idx = start + col_offsets
mask = idx < n_cols
xblk = tl.load(
row_start_ptr + idx, mask=mask, other=float("-inf"), eviction_policy="evict_first", cache_modifier=".ca"
)
yblk = tl.exp(xblk - m) / d
tl.store(Y_ptr + row_idx * Y_row_stride + idx, yblk, mask=mask, cache_modifier=".cs")
@triton.jit
def _softmax_single_block_backward_kernel(
dy_ptr,
dy_stride,
y_ptr,
y_stride,
dx_ptr,
dx_stride,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
ROWS_PER_BLOCK: tl.constexpr,
):
"""
Single-block softmax backward kernel for small column sizes.
Computes gradient: dx = y * (dy - sum(dy * y))
Uses 2D tensor to process multiple rows simultaneously for better UB utilization.
Args:
dy_ptr: Gradient output pointer
dy_stride: Stride for gradient output rows
y_ptr: Forward output pointer
y_stride: Stride for forward output rows
dx_ptr: Gradient input pointer
dx_stride: Stride for gradient input rows
n_rows: Number of rows to process
n_cols: Number of columns per row
BLOCK_SIZE: Block size for column processing
ROWS_PER_BLOCK: Number of rows to process simultaneously
"""
row_block_start = tl.program_id(0) * ROWS_PER_BLOCK
row_block_step = tl.num_programs(0) * ROWS_PER_BLOCK
row_offsets = tl.arange(0, ROWS_PER_BLOCK)
col_offsets = tl.arange(0, BLOCK_SIZE)
for row_block_idx in tl.range(row_block_start, n_rows, row_block_step):
row_idx = row_block_idx + row_offsets
row_mask = row_idx < n_rows
col_mask = col_offsets < n_cols
# 2D mask: [ROWS_PER_BLOCK, BLOCK_SIZE]
mask = row_mask[:, None] & col_mask[None, :]
# Load 2D blocks: [ROWS_PER_BLOCK, BLOCK_SIZE]
dy_offsets = row_idx[:, None] * dy_stride + col_offsets[None, :]
y_offsets = row_idx[:, None] * y_stride + col_offsets[None, :]
dy = tl.load(dy_ptr + dy_offsets, mask=mask, other=0.0)
y = tl.load(y_ptr + y_offsets, mask=mask, other=0.0)
# Compute dot product per row (axis=1)
dot = tl.sum(dy * y, axis=1)
dx = y * (dy - dot[:, None])
# Store 2D block
dx_offsets = row_idx[:, None] * dx_stride + col_offsets[None, :]
tl.store(dx_ptr + dx_offsets, dx, mask=mask)
@triton.jit
def _softmax_multi_block_backward_kernel(
dy_ptr,
dy_stride,
y_ptr,
y_stride,
dx_ptr,
dx_stride,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Multi-block softmax backward kernel using two-pass algorithm.
Computes gradient: dx = y * (dy - sum(dy * y))
Args:
dy_ptr: Gradient output pointer
dy_stride: Stride for gradient output rows
y_ptr: Forward output pointer
y_stride: Stride for forward output rows
dx_ptr: Gradient input pointer
dx_stride: Stride for gradient input rows
n_rows: Number of rows to process
n_cols: Number of columns per row
BLOCK_SIZE: Block size for column processing
"""
row_start = tl.program_id(0)
num_prog = tl.num_programs(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
for row_idx in tl.range(row_start, n_rows, num_prog):
dy_start_ptr = dy_ptr + row_idx * dy_stride
y_start_ptr = y_ptr + row_idx * y_stride
acc = 0.0
for start in tl.range(0, n_cols, BLOCK_SIZE):
idx = start + col_offsets
mask = idx < n_cols
dy_blk = tl.load(dy_start_ptr + idx, mask=mask, other=0.0, eviction_policy="evict_first")
y_blk = tl.load(
y_start_ptr + idx, mask=mask, other=0.0, eviction_policy="evict_first", cache_modifier=".ca"
)
acc += tl.sum(dy_blk * y_blk, axis=0)
for start in tl.range(0, n_cols, BLOCK_SIZE):
idx = start + col_offsets
mask = idx < n_cols
dy_blk = tl.load(dy_start_ptr + idx, mask=mask, other=0.0)
y_blk = tl.load(y_start_ptr + idx, mask=mask, other=0.0, cache_modifier=".ca")
dx_blk = y_blk * (dy_blk - acc)
tl.store(dx_ptr + row_idx * dx_stride + idx, dx_blk, mask=mask, cache_modifier=".wb")
def _softmax_forward(x):
*batch, n_cols = x.shape
x2d = x.contiguous().view(-1, n_cols)
n_rows = x2d.shape[0]
MAX_FUSED_BLOCK_SIZE = 8192
BLOCK_SIZE = triton.next_power_of_2(n_cols)
BLOCK_SIZE = min(BLOCK_SIZE, MAX_FUSED_BLOCK_SIZE)
y2d = torch.empty_like(x2d)
num_cores = get_npu_core_count()
if n_cols <= BLOCK_SIZE:
# Calculate optimal ROWS_PER_BLOCK to utilize UB efficiently
# Target: ROWS_PER_BLOCK * BLOCK_SIZE <= MAX_FUSED_BLOCK_SIZE
ROWS_PER_BLOCK = min(MAX_FUSED_BLOCK_SIZE // BLOCK_SIZE, 32)
ROWS_PER_BLOCK = triton.next_power_of_2(ROWS_PER_BLOCK)
# Calculate number of programs needed
num_row_blocks = (n_rows + ROWS_PER_BLOCK - 1) // ROWS_PER_BLOCK
num_programs = min(num_cores, num_row_blocks)
_softmax_single_block_forward_kernel[(num_programs,)](
y2d, y2d.stride(0), x2d, x2d.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, ROWS_PER_BLOCK=ROWS_PER_BLOCK
)
multi_block_launch = False
else:
num_programs = min(num_cores, n_rows)
ROWS_PER_BLOCK = 1 # Not used in multi-block
_softmax_multi_block_forward_kernel[(num_programs,)](
y2d, y2d.stride(0), x2d, x2d.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE
)
multi_block_launch = True
return y2d.view(*batch, n_cols), BLOCK_SIZE, ROWS_PER_BLOCK, multi_block_launch
def _softmax_backward(
dy: torch.Tensor,
y: torch.Tensor,
BLOCK_SIZE: int,
ROWS_PER_BLOCK: int,
multi_block_launch: bool,
) -> torch.Tensor:
*batch, n_cols = dy.shape
dy2d = dy.contiguous().view(-1, n_cols)
y2d = y.contiguous().view(-1, n_cols)
n_rows = dy2d.shape[0]
dx2d = torch.empty_like(dy2d)
num_cores = get_npu_core_count()
if not multi_block_launch and n_cols <= BLOCK_SIZE:
num_row_blocks = (n_rows + ROWS_PER_BLOCK - 1) // ROWS_PER_BLOCK
num_programs = min(num_cores, num_row_blocks)
_softmax_single_block_backward_kernel[(num_programs,)](
dy2d,
dy2d.stride(0),
y2d,
y2d.stride(0),
dx2d,
dx2d.stride(0),
n_rows,
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
ROWS_PER_BLOCK=ROWS_PER_BLOCK,
)
else:
num_programs = min(num_cores, n_rows)
_softmax_multi_block_backward_kernel[(num_programs,)](
dy2d,
dy2d.stride(0),
y2d,
y2d.stride(0),
dx2d,
dx2d.stride(0),
n_rows,
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
)
return dx2d.view(*batch, n_cols)
class LigerSoftmaxFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, input_: torch.Tensor):
y, BLOCK_SIZE, ROWS_PER_BLOCK, multi_block_launch = _softmax_forward(input_)
ctx.save_for_backward(y)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.ROWS_PER_BLOCK = ROWS_PER_BLOCK
ctx.multi_block_launch = multi_block_launch
return y
@staticmethod
@ensure_contiguous
def backward(ctx, grad_output):
(y,) = ctx.saved_tensors
dx = _softmax_backward(
grad_output,
y,
ctx.BLOCK_SIZE,
ctx.ROWS_PER_BLOCK,
ctx.multi_block_launch,
)
return dx
import torch
import triton
import triton.language as tl
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
@triton.jit
def _sparsemax_forward_kernel(
x_ptr,
x_stride_row,
sorted_x_ptr,
sorted_x_stride_row,
o_ptr,
o_stride_row,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Sparsemax forward kernel for rows where n_cols <= BLOCK_SIZE.
Args:
x_ptr: pointer to input tensor [n_rows, n_cols], fp32.
x_stride_row: row stride of x.
sorted_x_ptr: pointer to x sorted descending along last dim, fp32.
sorted_x_stride_row: row stride of sorted_x.
o_ptr: pointer to output tensor [n_rows, n_cols].
o_stride_row: row stride of o.
n_rows: number of rows (constexpr).
n_cols: number of columns (constexpr).
BLOCK_SIZE: tile size >= n_cols (constexpr).
"""
pid_row = tl.program_id(0)
num_progs = tl.num_programs(0)
for row in tl.range(pid_row, n_rows, num_progs):
ptr_x_data_row = x_ptr + row * x_stride_row
ptr_sorted_x_data_row = sorted_x_ptr + row * sorted_x_stride_row
ptr_output_row = o_ptr + row * o_stride_row
offs = tl.arange(0, BLOCK_SIZE)
mask = offs < n_cols
z_sorted_block = tl.load(
ptr_sorted_x_data_row + offs,
mask=mask,
other=-float("inf"),
cache_modifier=".cg",
).to(tl.float32)
z_valid = tl.where(mask, z_sorted_block, 0.0)
cssv = tl.cumsum(z_valid, 0)
r = (offs + 1).to(tl.float32)
t_vec = (cssv - 1.0) / r
support = (z_sorted_block > t_vec) & mask
k_int = tl.sum(support.to(tl.int32), 0)
k_clamped_int = tl.maximum(k_int, 1)
k = k_clamped_int.to(tl.float32)
s = tl.sum(tl.where(support, z_sorted_block, 0.0), 0)
tau = (s - 1.0) / k
x_block = tl.load(
ptr_x_data_row + offs,
mask=mask,
other=0.0,
cache_modifier=".cg",
).to(tl.float32)
y = tl.maximum(x_block - tau, 0.0)
tl.store(
ptr_output_row + offs,
y.to(ptr_output_row.dtype.element_ty),
mask=mask,
cache_modifier=".cs",
)
@triton.jit
def _sparsemax_forward_tiled_kernel(
x_ptr,
x_stride_row,
sorted_x_ptr,
sorted_x_stride_row,
o_ptr,
o_stride_row,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Sparsemax forward kernel for rows where n_cols > BLOCK_SIZE (tiled).
Args:
x_ptr: pointer to input tensor [n_rows, n_cols], fp32.
x_stride_row: row stride of x.
sorted_x_ptr: pointer to x sorted descending along last dim, fp32.
sorted_x_stride_row: row stride of sorted_x.
o_ptr: pointer to output tensor [n_rows, n_cols].
o_stride_row: row stride of o.
n_rows: number of rows (constexpr).
n_cols: number of columns (constexpr).
BLOCK_SIZE: tile size < n_cols (constexpr).
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
for row in tl.range(pid, n_rows, num_progs):
sorted_row_ptr = sorted_x_ptr + row * sorted_x_stride_row
x_row_ptr = x_ptr + row * x_stride_row
out_row_ptr = o_ptr + row * o_stride_row
offs = tl.arange(0, BLOCK_SIZE)
# ------------------------------------------------------------------
# Pass 1: find tau from sorted data
# Since data is sorted descending, support is a contiguous prefix,
# so k = sum(support) — no need for max(support_r), saves one reduction.
# ------------------------------------------------------------------
running_sum = tl.zeros((), tl.float32)
k = tl.zeros((), tl.int32)
sum_support = tl.zeros((), tl.float32)
for tile in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
idx = tile * BLOCK_SIZE + offs
mask = idx < n_cols
z = tl.load(sorted_row_ptr + idx, mask=mask, other=0.0, cache_modifier=".ca").to(tl.float32)
cssv = tl.cumsum(z, axis=0) + running_sum
r = (idx + 1).to(tl.float32)
t = (cssv - 1.0) / r
support = (z > t) & mask
k += tl.sum(support.to(tl.int32), axis=0)
sum_support += tl.sum(tl.where(support, z, 0.0), axis=0)
running_sum += tl.sum(z, axis=0)
tau = (sum_support - 1.0) / tl.maximum(k, 1).to(tl.float32)
# ------------------------------------------------------------------
# Pass 2: write output y = max(x - tau, 0)
# ------------------------------------------------------------------
for tile in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
idx = tile * BLOCK_SIZE + offs
mask = idx < n_cols
x = tl.load(x_row_ptr + idx, mask=mask, other=0.0, cache_modifier=".ca").to(tl.float32)
y = tl.maximum(x - tau, 0.0)
tl.store(out_row_ptr + idx, y.to(out_row_ptr.dtype.element_ty), mask=mask, cache_modifier=".cs")
@triton.jit
def _sparsemax_backward_kernel(
o_ptr,
go_ptr,
gi_ptr,
stride,
n_rows: tl.constexpr,
n_cols: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Sparsemax backward kernel for rows where n_cols <= BLOCK_SIZE.
Args:
o_ptr: pointer to forward output [n_rows, n_cols], fp32.
go_ptr: pointer to upstream gradient [n_rows, n_cols].
gi_ptr: pointer to input gradient output [n_rows, n_cols].
stride: common row stride for o, go, gi.
n_rows: number of rows (constexpr).
n_cols: number of columns (constexpr).
BLOCK_SIZE: tile size >= n_cols (constexpr).
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
for row in tl.range(pid, n_rows, num_progs):
o_row = o_ptr + row * stride
go_row = go_ptr + row * stride
gi_row = gi_ptr + row * stride
offs = tl.arange(0, BLOCK_SIZE)
mask = offs < n_cols
o_val = tl.load(o_row + offs, mask=mask, other=0.0).to(tl.float32)
go_val = tl.load(go_row + offs, mask=mask, other=0.0).to(tl.float32)
supp = (o_val > 0.0) & mask
go_sum = tl.sum(tl.where(supp, go_val, 0.0), axis=0)
supp_cnt = tl.sum(supp.to(tl.float32), axis=0)
gi_val = tl.where(
supp,
go_val - go_sum / tl.maximum(supp_cnt, 1.0),
0.0,
)
tl.store(gi_row + offs, gi_val.to(gi_row.dtype.element_ty), mask=mask)
@triton.jit
def _sparsemax_backward_tiled_kernel(
o_ptr, go_ptr, gi_ptr, stride, n_rows: tl.constexpr, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
):
"""Sparsemax backward kernel for rows where n_cols > BLOCK_SIZE (tiled).
Args:
o_ptr: pointer to forward output [n_rows, n_cols], fp32.
go_ptr: pointer to upstream gradient [n_rows, n_cols].
gi_ptr: pointer to input gradient output [n_rows, n_cols].
stride: common row stride for o, go, gi.
n_rows: number of rows (constexpr).
n_cols: number of columns (constexpr).
BLOCK_SIZE: tile size < n_cols (constexpr).
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
for row in tl.range(pid, n_rows, num_progs):
o_row = o_ptr + row * stride
go_row = go_ptr + row * stride
gi_row = gi_ptr + row * stride
offs = tl.arange(0, BLOCK_SIZE)
supp_cnt = tl.zeros((), tl.float32)
go_sum = tl.zeros((), tl.float32)
for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
offs_iter = i * BLOCK_SIZE + offs
mask_iter = offs_iter < n_cols
o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
supp = o_val > 0
go_sum += tl.sum(tl.where(supp, go_val, 0.0))
supp_cnt += tl.sum(supp.to(tl.float32))
for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
offs_iter = i * BLOCK_SIZE + offs
mask_iter = offs_iter < n_cols
o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
supp = o_val > 0
gi_val = tl.where(
supp,
go_val - tl.cast(go_sum / tl.maximum(supp_cnt, 1e-6), gi_row.dtype.element_ty).to(tl.float32),
0.0,
)
tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".cs")
def sparsemax_forward(x, dim):
if dim < 0:
dim += x.dim()
x_sw = x.transpose(dim, -1).contiguous()
n_cols = x_sw.size(-1)
n_rows = x_sw.numel() // n_cols
x_flat = x_sw.view(n_rows, n_cols)
x_flat_fp32 = x_flat if x_flat.dtype == torch.float32 else x_flat.float()
x_sorted_flat = torch.sort(x_flat_fp32, dim=-1, descending=True).values
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9,
dtype_size=4,
memory_multiplier=12.0,
shapes=((n_cols,),),
tiling_dims=(0,),
)
if tile_shapes and len(tile_shapes) > 0:
BLOCK_SIZE = tile_shapes[0][0]
else:
BLOCK_SIZE = 2048
out_flat = torch.empty_like(x_flat_fp32)
grid = (min(n_rows, get_npu_core_count()),)
if n_cols <= BLOCK_SIZE:
# non-tiled kernel: single load covers whole row
_sparsemax_forward_kernel[grid](
x_flat_fp32,
x_flat_fp32.stride(0),
x_sorted_flat,
x_sorted_flat.stride(0),
out_flat,
out_flat.stride(0),
n_rows,
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
)
else:
# tiled kernel: compute tau and write output in one fused kernel
_sparsemax_forward_tiled_kernel[grid](
x_flat_fp32,
x_flat_fp32.stride(0),
x_sorted_flat,
x_sorted_flat.stride(0),
out_flat,
out_flat.stride(0),
n_rows,
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
)
y = out_flat.view(x_sw.shape).transpose(dim, -1)
return y, out_flat
def sparsemax_backward(
grad_out: torch.Tensor,
out_flat: torch.Tensor,
dim: int,
) -> torch.Tensor:
if dim < 0:
dim += grad_out.dim()
grad_sw = grad_out.transpose(dim, -1).contiguous()
n_cols = grad_sw.size(-1)
n_rows = grad_sw.numel() // n_cols
go_flat = grad_sw.view(n_rows, n_cols)
dx_flat = torch.empty_like(go_flat).contiguous()
grid = (min(n_rows, get_npu_core_count()),)
# use single-pass kernel when feasible
if n_cols <= 4096:
BLOCK_SIZE = triton.next_power_of_2(n_cols)
_sparsemax_backward_kernel[grid](
out_flat,
go_flat,
dx_flat,
out_flat.stride(0),
n_rows,
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
)
else:
# use tiling strategy for very large n_cols: ~10 live buffers at peak = 10.0 multiplier
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9,
dtype_size=4,
memory_multiplier=8.0,
shapes=((n_cols,),),
tiling_dims=(0,),
)
if tile_shapes and len(tile_shapes) > 0:
BLOCK_SIZE = tile_shapes[0][0]
else:
BLOCK_SIZE = 2048
_sparsemax_backward_tiled_kernel[grid](
out_flat,
go_flat,
dx_flat,
out_flat.stride(0),
n_rows,
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
)
dx = dx_flat.view_as(grad_sw).transpose(dim, -1)
return dx
class LigerSparsemaxFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, x: torch.Tensor, dim: int):
y, out_flat = sparsemax_forward(x, dim)
ctx.save_for_backward(out_flat)
ctx.dim = dim
return y
@staticmethod
@ensure_contiguous
def backward(ctx, grad_out: torch.Tensor):
(out_flat,) = ctx.saved_tensors
dx = sparsemax_backward(grad_out, out_flat, ctx.dim)
return dx, None
import torch
import triton
import triton.language as tl
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
from liger_kernel.ops.utils import get_npu_core_count
# -----------------------------------------------------------------------------
# Kernels (High-performance 1D Flatten Implementation)
# -----------------------------------------------------------------------------
@triton.jit
def _swiglu_forward_kernel_flat(a_ptr, b_ptr, c_ptr, total_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
# Grid-Stride Loop
start_idx = pid * BLOCK_SIZE
stride = num_progs * BLOCK_SIZE
for idx in tl.range(start_idx, total_elements, stride):
offsets = idx + tl.arange(0, BLOCK_SIZE)
mask = offsets < total_elements
a_val = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
b_val = tl.load(b_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
res = (a_val * tl.sigmoid(a_val)) * b_val
tl.store(c_ptr + offsets, res, mask=mask)
@triton.jit
def _swiglu_backward_kernel_flat(dc_ptr, a_ptr, b_ptr, da_ptr, db_ptr, total_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
start_idx = pid * BLOCK_SIZE
stride = num_progs * BLOCK_SIZE
for idx in tl.range(start_idx, total_elements, stride):
offsets = idx + tl.arange(0, BLOCK_SIZE)
mask = offsets < total_elements
dc = tl.load(dc_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
a = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
b = tl.load(b_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
sig_a = tl.sigmoid(a)
silu_a = a * sig_a
term1 = silu_a * (1.0 - sig_a) + sig_a
db = dc * silu_a
da = dc * b * term1
tl.store(da_ptr + offsets, da, mask=mask)
tl.store(db_ptr + offsets, db, mask=mask)
# -----------------------------------------------------------------------------
# Helper: Call compute_default_tiling_strategy
# -----------------------------------------------------------------------------
def get_optimal_block_size(total_elements, is_backward=False):
"""
Calculate optimal Block Size using compute_default_tiling_strategy
"""
# 1. Set Memory Multiplier
# Forward is lighter, Backward requires more memory for intermediate variables
# 8.0 and 12.0 are empirical values based on Atlas 800I A2 UB (192KB)
multiplier = 12.0 if is_backward else 8.0
# 2. Call calculation function
# Treat input as 1D (total_elements,), only tiling on dim 0
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((total_elements,),), tiling_dims=(0,)
)
# 3. Parse result
if tile_shapes and len(tile_shapes) > 0:
block_size = tile_shapes[0][0]
return max(256, block_size)
else:
return 2048
def swiglu_forward(a, b):
if not a.is_contiguous():
a = a.contiguous()
if not b.is_contiguous():
b = b.contiguous()
total_elements = a.numel()
c = torch.empty_like(a)
block_size = get_optimal_block_size(total_elements, is_backward=False)
num_cores = get_npu_core_count()
grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
_swiglu_forward_kernel_flat[(grid_size,)](a, b, c, total_elements, BLOCK_SIZE=block_size)
return c
def swiglu_backward(a, b, dc):
if not dc.is_contiguous():
dc = dc.contiguous()
if not a.is_contiguous():
a = a.contiguous()
if not b.is_contiguous():
b = b.contiguous()
total_elements = dc.numel()
grad_a = torch.empty_like(a)
grad_b = torch.empty_like(b)
block_size = get_optimal_block_size(total_elements, is_backward=True)
num_cores = get_npu_core_count()
grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
_swiglu_backward_kernel_flat[(grid_size,)](dc, a, b, grad_a, grad_b, total_elements, BLOCK_SIZE=block_size)
return grad_a, grad_b
class LigerSiLUMulFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
c = swiglu_forward(a, b)
ctx.save_for_backward(a, b)
return c
@staticmethod
def backward(ctx, dc):
a, b = ctx.saved_tensors
grad_a, grad_b = swiglu_backward(a, b, dc)
return grad_a, grad_b
from typing import Literal
from typing import Optional
import torch
import triton
import triton.language as tl
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
MAX_FUSED_SIZE = 65536 // 4
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
@triton.jit
def _tv_distance_kernel(
p_ptr,
p_stride,
q_ptr,
q_stride,
loss_ptr,
loss_stride,
grads_ptr,
grads_stride,
label_ptr,
ignore_index: tl.constexpr,
n_cols, # V
total_rows: tl.constexpr, # BT
BLOCK_SIZE: tl.constexpr,
HAS_LABEL: tl.constexpr,
reduction: tl.constexpr = "batchmean",
):
thread_id = tl.program_id(0)
num_threads = tl.num_programs(0)
for pid in tl.range(thread_id, total_rows, num_threads):
p_row_ptr = p_ptr + pid * p_stride
q_row_ptr = q_ptr + pid * q_stride
loss_row_ptr = loss_ptr + pid * loss_stride
grads_row_ptr = grads_ptr + pid * grads_stride
label_row_ptr = label_ptr + pid
base_offsets = tl.arange(0, BLOCK_SIZE)
should_skip = False
if HAS_LABEL:
label = tl.load(label_row_ptr)
if label == ignore_index:
should_skip = True
if should_skip:
for i in range(0, n_cols, BLOCK_SIZE):
offsets = i + base_offsets
mask = offsets < n_cols
tl.store(grads_row_ptr + offsets, 0.0, mask=mask)
if reduction == "none":
tl.store(loss_row_ptr + offsets, 0.0, mask=mask)
else:
loss_sum = 0.0
for i in range(0, n_cols, BLOCK_SIZE):
offsets = i + base_offsets
mask = offsets < n_cols
p = tl.load(p_row_ptr + offsets, mask=mask, other=0.0)
q = tl.load(q_row_ptr + offsets, mask=mask, other=0.0)
# TVD(P || Q) = 0.5 * |P - Q|
tv_loss = 0.5 * tl.abs(p - q)
grad_res = tl.where(p > q, 0.5, -0.5)
tl.store(grads_row_ptr + offsets, grad_res, mask=mask)
if reduction == "none":
tl.store(loss_row_ptr + offsets, tv_loss, mask=mask)
else:
loss_sum += tl.sum(tv_loss, axis=0)
if reduction != "none":
tl.store(loss_row_ptr, loss_sum)
def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
BT, V = p.shape
# TVD forward tiling strategy
# - In main loop (calculate loss and grad):
# * p: BLOCK_Q elements
# * q: BLOCK_Q elements
# * tv_loss: BLOCK_Q elements
# * grad_res: BLOCK_Q elements
# * loss_sum: BLOCK_Q elements (when reduction != "none")
# * Total: 4 * BLOCK_Q elements or 5 * BLOCK_Q elements when reduction != "none"
# - Since loss_sum is not necessarily used in every calculation,
# - and considering the consumption of other shared memory and the potential memory consumption of the HAS_LABEL loop.
# - Conservative estimate: 5 * BLOCK_Q * dtype_size * 8 bits
# - For safety, use: memory_multiplier=5.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
# - shapes: ((V,),)
# - tiling_dims: (0,) means first dimension of each shape can be tiled
# - Returns: ((block_size,),
shapes = ((V,),)
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.80,
# In the TVD calculation, many data are implicitly converted to f32, so the size of f32 can be directly used.
dtype_size=4,
memory_multiplier=5.0,
shapes=shapes,
tiling_dims=(0,),
)
if tile_shapes is not None and len(tile_shapes) > 0 and len(tile_shapes[0]) > 0:
# Strategy returns ((block_size,),)
BLOCK_SIZE = tile_shapes[0][0]
else:
# Fallback to desired block size if no best practice found (no tiling needed)
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
num_cores = get_npu_core_count()
grid = (min(num_cores, BT),)
out_size = (BT, V) if reduction == "none" else (BT,)
# The loss and grid accumulation on BF16 platform of NPU will have precision errors.
output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
grads = torch.empty_like(p, dtype=torch.float32)
n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
_tv_distance_kernel[grid](
p,
p.stride(0),
q,
q.stride(0),
output_tensor,
output_tensor.stride(0),
grads,
grads.stride(0),
shift_labels if has_label else torch.empty(1, device=p.device),
ignore_index,
V,
BT,
BLOCK_SIZE=BLOCK_SIZE,
HAS_LABEL=has_label,
reduction=reduction,
)
if reduction == "batchmean":
return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
elif reduction == "sum":
return output_tensor.sum(dim=0), grads
elif reduction == "mean":
return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
else:
return output_tensor, grads
def tvd_backward_triton(grad_output, grads):
# If this is the last layer, grad_output is 1.0. Skip the mul then.
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
return grads
return grads * grad_output
class LigerTVDLossFunction(torch.autograd.Function):
"""
Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
"""
@staticmethod
@ensure_contiguous
def forward(
ctx,
p: torch.Tensor,
q: torch.Tensor,
shift_labels: Optional[torch.Tensor] = None,
reduction: REDUCTION_LITERAL = "batchmean",
ignore_index: int = -100,
) -> torch.Tensor:
"""A forward pass for the Total Variation Distance Loss.
Args:
ctx: Torch autograd context
p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
Returns:
torch.Tensor: The computed Total Variation Distance Loss.
"""
has_label = False
if shift_labels is not None:
assert shift_labels.shape == (p.shape[0],), (
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
)
shift_labels = shift_labels.contiguous()
has_label = True
loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
ctx.save_for_backward(grads)
return loss
@staticmethod
@ensure_contiguous
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
"""A backward pass for the Total Variation Distance Loss.
Args:
ctx: Torch autograd context
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
Returns:
tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
"""
(grads,) = ctx.saved_tensors
grads = tvd_backward_triton(grad_output, grads)
return grads, None, None, None, None
"""
Unified Buffer (UB) Manager for Ascend NPU.
This module provides UB capacity detection and tiling strategy computation
for running Triton kernels on Ascend NPU. It automatically calculates
optimal block sizes based on UB capacity constraints to prevent UB overflow.
"""
import os
from typing import Optional
from typing import Tuple
from typing import Union
import torch
import triton
from liger_kernel.utils import is_npu_available
def _normalize_tiling_dims(tiling_dim: Union[int, Tuple[int, ...]]) -> set:
"""
Normalize tiling dimension specification to a set of dimension indices.
Args:
tiling_dim: Either an int (single dimension) or tuple of ints (multiple dimensions).
Returns:
Set of dimension indices that can be tiled.
"""
if isinstance(tiling_dim, int):
return {tiling_dim}
elif isinstance(tiling_dim, tuple):
return set(tiling_dim)
else:
return set()
def _default_strategy(
ub_capacity_bits: int,
safety_margin: float,
dtype_size: int,
memory_multiplier: float,
shapes: Tuple[Tuple[int, ...], ...],
tiling_dims: Tuple[Union[int, Tuple[int, ...]], ...],
) -> Tuple[int, ...]:
"""
Default tiling strategy: calculate maximum safe block size based on UB capacity.
This is a unified strategy function that works for all kernels by abstracting
the memory calculation as: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
Args:
ub_capacity_bits: UB capacity in bits
safety_margin: Safety margin as a float (e.g., 0.80 for 80%)
dtype_size: Size of data type in bytes (e.g., 2 for float16, 4 for float32)
memory_multiplier: Memory multiplier for estimating peak memory usage
shapes: Tuple of full shapes. Each shape is a tuple of dimension sizes.
- For ROPE: ((n_q_head, hd), (n_kv_head, hd))
- For GEGLU: ((n_cols,),)
tiling_dims: Tuple specifying which dimensions can be tiled for each shape.
Each element can be:
- int: single dimension index (e.g., 0 for first dimension)
- tuple of ints: multiple dimensions that can be tiled together
- For ROPE: (0, 0) means first dimension of each shape can be tiled
- For GEGLU: (0,) means first dimension of the shape can be tiled
Length must match len(shapes).
Returns:
Tuple of maximum safe block sizes, one for each shape.
Each element is a power of 2.
Note:
For each shape, fixed dimensions (non-tiling) are multiplied together to get unit_param.
The final block size is computed in compute_default_tiling_strategy by taking
min(desired_block_size, max_safe_block_size) where desired_block_size = triton.next_power_of_2(original_dim).
"""
if not shapes or not tiling_dims:
return ()
# Calculate max_safe_block_size for each tiling dimension
max_safe_sizes = []
for shape, tiling_dim in zip(shapes, tiling_dims):
# Normalize tiling_dim to a set of dimension indices
tiling_dim_set = _normalize_tiling_dims(tiling_dim)
# Validate tiling dimensions are within shape bounds
if not tiling_dim_set:
raise ValueError(
f"Invalid tiling_dim: {tiling_dim}. tiling_dim must be an int or a non-empty tuple of ints."
)
if any(dim_idx < 0 or dim_idx >= len(shape) for dim_idx in tiling_dim_set):
raise ValueError(
f"Invalid tiling_dim: {tiling_dim} for shape {shape}. "
f"All dimension indices must be in range [0, {len(shape)})."
)
# Calculate unit_param: product of fixed (non-tiling) dimensions
unit_param = 1.0
for dim_idx, dim_size in enumerate(shape):
if dim_idx not in tiling_dim_set:
if dim_size <= 0:
# Invalid dimension size, use conservative default
unit_param = 1.0
break
unit_param *= float(dim_size)
# Ensure unit_param is at least 1.0
if unit_param <= 0:
unit_param = 1.0
# Calculate maximum safe block size based on UB capacity
# Memory: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
SAFE_UB_CAPACITY_BITS = int(ub_capacity_bits * safety_margin)
# Solve: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 <= SAFE_UB_CAPACITY_BITS
# BLOCK_SIZE <= SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8)
max_block_size = int(SAFE_UB_CAPACITY_BITS // (memory_multiplier * unit_param * dtype_size * 8))
max_block_size = max(1, max_block_size)
# Find largest power of 2 <= max_block_size
# Use triton.next_power_of_2(max_block_size + 1) // 2 to get the largest power of 2 <= max_block_size
safe_block_size = triton.next_power_of_2(max_block_size + 1) // 2
max_safe_sizes.append(safe_block_size)
return tuple(max_safe_sizes)
class UBManager:
"""
Unified Buffer Manager for Ascend NPU.
Provides UB capacity detection and management for Ascend NPU devices.
The UB capacity is used by tiling strategy functions to calculate optimal block sizes.
"""
def __init__(self, ub_capacity_bits: Optional[int] = None):
"""
Initialize UB Manager.
Args:
ub_capacity_bits: UB capacity in bits. If None, will be detected automatically.
"""
self._npu_model = self._detect_npu_model()
self._ub_capacity_bits = ub_capacity_bits or self._detect_ub_capacity()
@property
def ub_capacity_bits(self) -> int:
"""Get UB capacity in bits."""
return self._ub_capacity_bits
@property
def ub_capacity_bytes(self) -> int:
"""Get UB capacity in bytes."""
return self._ub_capacity_bits // 8
@property
def npu_model(self) -> str:
"""Get detected NPU model name."""
return self._npu_model
def _detect_npu_model(self) -> str:
"""Detect NPU model from device properties."""
if not is_npu_available():
return "unknown"
try:
dev_props = torch.npu.get_device_properties(0)
# Try to get model name from device properties
return dev_props.name
except Exception:
pass
return "default"
def _detect_ub_capacity(self) -> int:
"""
Detect UB capacity from environment variable or get_soc_spec.
Returns:
UB capacity in bits.
Raises:
RuntimeError: If UB capacity cannot be detected and no environment variable is set.
"""
# Check environment variable first (in bits)
env_capacity = os.getenv("ASCEND_UB_CAPACITY_BITS")
if env_capacity is not None:
try:
capacity_bits = int(env_capacity)
if capacity_bits > 0:
return capacity_bits
except ValueError:
pass
# Try to get from get_soc_spec (returns bytes, convert to bits)
if is_npu_available():
try:
from tbe.common.platform import get_soc_spec
from tbe.common.platform import set_current_compile_soc_info
# Set current SOC info for get_soc_spec to work correctly
device = getattr(torch, "npu")
soc_info = device.get_device_name(device.current_device())
set_current_compile_soc_info(soc_info)
# Query UB size (get_soc_spec returns size in bytes)
ub_size_bytes = get_soc_spec("UB_SIZE")
if ub_size_bytes is None or ub_size_bytes <= 0:
raise ValueError(f"Invalid UB_SIZE from get_soc_spec: {ub_size_bytes}")
# Convert bytes to bits
ub_capacity_bits = ub_size_bytes * 8
return ub_capacity_bits
except ImportError:
raise RuntimeError(
"Cannot import tbe.common.platform.get_soc_spec. "
"Please ensure CANN environment variables are sourced "
"(e.g., source /usr/local/Ascend/ascend-toolkit/set_env.sh)"
)
except Exception as e:
raise RuntimeError(
f"Failed to detect UB capacity from get_soc_spec: {e}. "
"Please set ASCEND_UB_CAPACITY_BITS environment variable as fallback."
)
# If NPU is not available, raise error
raise RuntimeError(
"NPU is not available and UB capacity cannot be detected. "
"Please set ASCEND_UB_CAPACITY_BITS environment variable."
)
# Global singleton instance
_ub_manager: Optional[UBManager] = None
def get_ub_manager() -> UBManager:
"""Get global UB manager instance."""
global _ub_manager
if _ub_manager is None:
_ub_manager = UBManager()
return _ub_manager
def compute_default_tiling_strategy(
safety_margin: float = 0.80,
dtype_size: Optional[int] = None,
memory_multiplier: Optional[float] = None,
shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
tiling_dims: Optional[Tuple[Union[int, Tuple[int, ...]], ...]] = None,
) -> Optional[Tuple[Tuple[int, ...], ...]]:
"""
Compute tiling strategy using the default strategy function.
This function directly calls the default strategy and computes the final
tiling result. All kernels use the same unified strategy function, so
there's no need for kernel_name-based lookup.
Args:
safety_margin: Safety margin as a float (e.g., 0.80 for 80%). Default is 0.80.
dtype_size: Size of data type in bytes (e.g., 2 for float16, 4 for float32).
Must be provided. If None or <= 0, defaults to 4 (float32).
memory_multiplier: Memory multiplier for estimating peak memory usage.
- For GEGLU: typically 10.0 for backward, 4.0 for forward
- For ROPE: typically 3.0
If None, defaults to 10.0 (conservative estimate).
shapes: Tuple of full shapes. Each shape is a tuple of dimension sizes.
- For ROPE: ((n_q_head, hd), (n_kv_head, hd))
- For GEGLU: ((n_cols,),)
Can pass original shapes (will handle padding internally) or padded shapes.
tiling_dims: Tuple specifying which dimensions can be tiled for each shape.
Each element can be:
- int: single dimension index (e.g., 0 for first dimension)
- tuple of ints: multiple dimensions that can be tiled together
- For ROPE: (0, 0) means first dimension of each shape can be tiled
- For GEGLU: (0,) means first dimension of the shape can be tiled
Length must match len(shapes). Cannot be empty.
Returns:
Tuple of tiled shapes with same structure as input shapes.
Tiling dimensions are replaced with computed block sizes (power of 2),
while non-tiling dimensions are padded to next power of 2.
- For ROPE: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
- For GEGLU: ((block_size,),)
Returns None if shapes or tiling_dims is None or empty.
Examples:
>>> # ROPE forward
>>> strategy = compute_default_tiling_strategy(
... safety_margin=0.90,
... dtype_size=4,
... memory_multiplier=3.0,
... shapes=((32, 128), (32, 128)),
... tiling_dims=(0, 0)
... )
>>> # Returns: ((block_size_q, 128), (block_size_kv, 128))
>>> # GEGLU forward
>>> strategy = compute_default_tiling_strategy(
... safety_margin=0.80,
... dtype_size=2,
... memory_multiplier=7.0,
... shapes=((4096,),),
... tiling_dims=(0,)
... )
>>> # Returns: ((block_size,),)
"""
ub_manager = get_ub_manager()
if shapes is None or not shapes or tiling_dims is None or not tiling_dims:
return None
if len(shapes) != len(tiling_dims):
return None
if dtype_size is None or dtype_size <= 0:
dtype_size = 4 # Default to float32
if memory_multiplier is None or memory_multiplier <= 0:
memory_multiplier = 10.0 # Default conservative estimate
# Call strategy to get max_safe_block_size for each shape
max_supported = _default_strategy(
ub_manager.ub_capacity_bits,
safety_margin,
dtype_size,
memory_multiplier,
shapes,
tiling_dims,
)
if not max_supported or len(max_supported) != len(shapes):
return None
# Build result: same structure as shapes, with tiling dims replaced by computed block sizes
result = []
for shape, tiling_dim, max_safe in zip(shapes, tiling_dims, max_supported):
result_shape = list(shape)
# Normalize tiling_dim to a set of dimension indices
tiling_dim_set = _normalize_tiling_dims(tiling_dim)
# Validate tiling dimensions are within shape bounds
if not tiling_dim_set:
raise ValueError(
f"Invalid tiling_dim: {tiling_dim}. tiling_dim must be an int or a non-empty tuple of ints."
)
if any(dim_idx < 0 or dim_idx >= len(result_shape) for dim_idx in tiling_dim_set):
raise ValueError(
f"Invalid tiling_dim: {tiling_dim} for shape {shape}. "
f"All dimension indices must be in range [0, {len(result_shape)})."
)
# Replace tiling dimensions with computed block sizes
# For each tiling dimension, compute: min(desired, max_safe)
for dim_idx in tiling_dim_set:
original_dim = result_shape[dim_idx]
desired = triton.next_power_of_2(original_dim)
final_val = min(desired, max_safe)
final_val = max(1, final_val) # Ensure at least 1
result_shape[dim_idx] = final_val
# Pad non-tiling dimensions to next power of 2
for dim_idx, dim_size in enumerate(result_shape):
if dim_idx not in tiling_dim_set:
result_shape[dim_idx] = triton.next_power_of_2(dim_size)
result.append(tuple(result_shape))
return tuple(result)
"""
Vendor registry for Liger-Kernel multi-backend support.
This module defines VendorInfo and the registry for vendor registration.
Each vendor registers itself by calling register_vendor() in its __init__.py.
"""
from dataclasses import dataclass
from typing import Optional
# Dynamically get backends package path to avoid hardcoding
_BACKENDS_PACKAGE = __name__.rsplit(".", 1)[0] # "liger_kernel.ops.backends"
@dataclass
class VendorInfo:
"""
Information about a chip vendor and its supported device.
Attributes:
vendor: Vendor name (e.g., "ascend", "intel", "nvidia")
device: Device type this vendor supports (e.g., "npu", "xpu")
"""
vendor: str
device: str
@property
def module_path(self) -> str:
"""Auto-generated module path based on vendor name."""
return f"{_BACKENDS_PACKAGE}._{self.vendor}.ops"
# Registry mapping device types to their vendor info
# Vendors register themselves via register_vendor()
VENDOR_REGISTRY: dict[str, VendorInfo] = {}
def register_vendor(vendor_info: VendorInfo) -> None:
"""
Register a vendor's info in the global registry.
This should be called in each vendor's __init__.py to register itself.
Args:
vendor_info: VendorInfo instance to register
"""
VENDOR_REGISTRY[vendor_info.device] = vendor_info
def get_vendor_for_device(device: str) -> Optional[VendorInfo]:
"""
Get the VendorInfo for a given device type.
Args:
device: Device type (e.g., "npu", "xpu")
Returns:
VendorInfo if found, None otherwise
"""
return VENDOR_REGISTRY.get(device)
import operator
from typing import Optional
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import compare_version
from liger_kernel.ops.utils import element_mul_kernel
from liger_kernel.ops.utils import is_hip
from liger_kernel.utils import infer_device
from liger_kernel.utils import is_npu_available
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import tanh
except ModuleNotFoundError:
# for working with NGC containers
from triton.language.extra.cuda.libdevice import tanh
else:
from triton.language.math import tanh
@triton.jit
def liger_cross_entropy_kernel(
X_ptr,
X_stride,
Y_ptr,
Y_stride,
weight_ptr,
loss_ptr,
z_loss_ptr,
loss_stride,
token_accuracy_ptr,
token_accuracy_stride,
predicted_tokens_ptr,
predicted_tokens_stride,
n_cols,
n_non_ignore,
sum_non_ignore_weight,
weight_sum,
ignore_index,
lse_square_scale: tl.constexpr,
label_smoothing: tl.constexpr,
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
softcap,
RETURN_Z_LOSS: tl.constexpr,
RETURN_TOKEN_ACCURACY: tl.constexpr,
RETURN_PREDICTED_TOKENS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_SOFTCAPPING: tl.constexpr,
HAS_GRADIENTS: tl.constexpr,
):
"""
This kernel computes both cross entropy loss and the gradient of the input.
We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
Parameters:
X_ptr: Pointer to input tensor.
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
weight_ptr: Pointer to weight tensor.
loss_ptr: Pointer to tensor to store the loss.
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
loss_stride (int): The stride of the loss tensor.
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
token_accuracy_stride (int): The stride of the token accuracy tensor.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (float): The number of non-ignored elements in the batch.
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
weight_sum (float): The sum of weight tensor.
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
reduction (str): The string for the reduction to apply
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
BLOCK_SIZE (int): The block size for Triton operations.
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
"""
# https://github.com/triton-lang/triton/issues/1058
# If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64
program_id = tl.program_id(0).to(tl.int64)
# 1. Load Y_ptr first because if the target is ignore_index, we can return right away
Y_ptr += program_id * Y_stride
y = tl.load(Y_ptr)
# 2. locate the start index
X_ptr += program_id * X_stride
if y == ignore_index:
# set all X_ptr as 0
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
# For ignored tokens, set token accuracy to 0
if RETURN_TOKEN_ACCURACY:
token_accuracy_ptr += program_id * token_accuracy_stride
tl.store(token_accuracy_ptr, 0.0)
if RETURN_PREDICTED_TOKENS:
predicted_tokens_ptr += program_id * predicted_tokens_stride
tl.store(predicted_tokens_ptr, -1)
return
loss_ptr += program_id * loss_stride
if RETURN_Z_LOSS:
z_loss_ptr += program_id * loss_stride
if RETURN_TOKEN_ACCURACY:
token_accuracy_ptr += program_id * token_accuracy_stride
if RETURN_PREDICTED_TOKENS:
predicted_tokens_ptr += program_id * predicted_tokens_stride
if HAS_WEIGHT:
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
# 3. [Online softmax] first pass: find max + sum
m = float("-inf") # m is the max value. use the notation from the paper
d = 0.0 # d is the sum. use the notation from the paper
argmax_idx = 0 # Track the index of the maximum value for token accuracy / predicted tokens computation
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
if HAS_SOFTCAPPING:
ori_X_y = softcap * tanh(ori_X_y / softcap)
# Label smoothing is a general case of normal cross entropy
# See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
scaled_x_sum = 0.0
eps = label_smoothing / n_cols
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets,
mask=X_offsets < n_cols,
other=float("-inf"),
# Ensure float32 precision for softmax calculation
).cast(tl.float32)
if HAS_SOFTCAPPING:
X_block = softcap * tanh(X_block / softcap)
block_max = tl.max(X_block)
# Track argmax for accuracy / predicted tokens computation
if RETURN_TOKEN_ACCURACY or RETURN_PREDICTED_TOKENS:
# Find the index of the maximum value in this block
is_max_mask = X_block == block_max
# Mask out invalid indices with a value larger than n_cols
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
# Get the first (smallest) index where max occurs
current_block_argmax_idx = tl.min(masked_offsets)
is_new_max = block_max > m
argmax_idx = tl.where(is_new_max, current_block_argmax_idx, argmax_idx)
if label_smoothing > 0:
# scale X beforehand to avoid overflow
if HAS_WEIGHT:
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0))
else:
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
m_new = tl.maximum(m, block_max)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
m = m_new
# log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X)))))
# = log (e^(max(X)) * sum(e ^ (X_i - max(X))))
# = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
lse = m + tl.log(d)
# 4. [Online Softmax] Second pass: compute gradients
# For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
# dx_y = (softmax(x_y) - 1) / N
# dx_i = softmax(x_i) / N, i != y
# For label smoothing:
# dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
# = dx_i - (1 - label_smoothing) / N
# With Z loss:
# dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y
# dx_y = dx_i - (1 - label_smoothing) / N
# For 'sum' reduction, no normalization is applied:
# dx_y = softmax(x_y) - 1
# dx_i = softmax(x_i), for i ≠ y
if HAS_GRADIENTS:
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets,
mask=X_offsets < n_cols,
other=float("-inf"),
# Ensure float32 precision for softmax calculation
).cast(tl.float32)
if HAS_SOFTCAPPING:
intermediate = tanh(X_block / softcap)
X_block = softcap * intermediate
if not HAS_WEIGHT:
# softmax(x_i)
X_block = tl.exp(X_block - m) / d
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
X_block += 2 * lse_square_scale * lse * X_block
# smoothing term
X_block += -eps
# special handle dx_y
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
# reduction scale
if reduction == "mean":
X_block = X_block / n_non_ignore
else:
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
softmax_X = tl.exp(X_block - m) / d
# derivative of original_loss
dloss_ori = (1 - label_smoothing) * softmax_X
# specially handle dx_y
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
dloss_ori = dloss_ori * weight_y
# derivative of smooth_loss
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
# derivative of z-loss
dz_loss = 2 * lse_square_scale * lse * softmax_X
# reduction scale
if reduction == "mean":
dloss_ori = dloss_ori / sum_non_ignore_weight
dloss_smooth = dloss_smooth / sum_non_ignore_weight
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
dz_loss = dz_loss / n_non_ignore
# derivative of total_loss
X_block = dloss_ori + dloss_smooth + dz_loss
# chain rule softcapping
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
if HAS_SOFTCAPPING:
X_block = X_block * (1 - intermediate * intermediate)
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
tl.debug_barrier()
# 5. Calculate the loss
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
# = X_y - m - log d = X_y - lse
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
# So we can safely calculate log (softmax(X_y)) without overflow
loss = lse - ori_X_y
if HAS_WEIGHT:
loss = weight_y * loss
# Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
# = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
if label_smoothing > 0:
if HAS_WEIGHT:
smooth_loss = scaled_x_sum + eps * lse * weight_sum
else:
smooth_loss = scaled_x_sum + label_smoothing * lse
loss = loss * (1 - label_smoothing) + smooth_loss
# An auxiliary loss, z_loss
# Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
z_loss = lse_square_scale * lse * lse
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
if reduction == "mean":
if HAS_WEIGHT:
loss = loss / sum_non_ignore_weight
else:
loss = loss / n_non_ignore
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
z_loss = z_loss / n_non_ignore
loss += z_loss
tl.store(loss_ptr, loss)
if RETURN_Z_LOSS:
tl.store(z_loss_ptr, z_loss)
if RETURN_TOKEN_ACCURACY:
# Store 1.0 if prediction is correct, 0.0 otherwise
is_correct = 1.0 if argmax_idx == y else 0.0
tl.store(token_accuracy_ptr, is_correct)
if RETURN_PREDICTED_TOKENS:
tl.store(predicted_tokens_ptr, argmax_idx)
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
# the best size we found by manually tuning on xpu and npu.
if infer_device() == "xpu":
MAX_FUSED_SIZE = 4096
elif infer_device() == "npu":
MAX_FUSED_SIZE = 2048
else:
MAX_FUSED_SIZE = 65536 // 2
def cross_entropy_forward(
_input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
reduction,
softcap,
return_z_loss,
return_token_accuracy=False,
return_predicted_tokens=False,
):
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
assert isinstance(return_token_accuracy, bool), (
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
)
assert isinstance(return_predicted_tokens, bool), (
f"return_predicted_tokens must be True or False. Got: {return_predicted_tokens}"
)
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
# unreduced loss
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
token_accuracy_1d = (
torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
)
predicted_tokens_1d = (
torch.full((n_rows,), -1, dtype=torch.int64, device=_input.device) if return_predicted_tokens else None
)
target_mask = target != ignore_index
n_non_ignore = target_mask.sum().item()
assert (target * target_mask).max() < _input.shape[-1], (
f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}"
)
assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0"
sum_non_ignore_weight = n_non_ignore
weight_sum = 0.0
if weight is not None:
assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
assert torch.is_floating_point(weight), (
f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
)
sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
weight_sum = weight.sum().item()
# ensure weight is contiguous
if weight.stride(-1) != 1:
weight = weight.contiguous()
# ensure _input and target are contiguous in the last dimension
if _input.stride(-1) != 1:
_input = _input.contiguous()
if target.stride(-1) != 1:
target = target.contiguous()
# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
liger_cross_entropy_kernel[(n_rows,)](
X_ptr=_input,
X_stride=_input.stride(-2),
Y_ptr=target,
Y_stride=target.stride(-1), # always 1
weight_ptr=weight, # dummy if None
loss_ptr=loss_1d,
z_loss_ptr=z_loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
token_accuracy_ptr=token_accuracy_1d,
token_accuracy_stride=token_accuracy_1d.stride(-1)
if return_token_accuracy
else 0, # always 1 if accuracy is enabled
predicted_tokens_ptr=predicted_tokens_1d,
predicted_tokens_stride=predicted_tokens_1d.stride(-1)
if return_predicted_tokens
else 0, # always 1 if predicted tokens is enabled
n_cols=V,
n_non_ignore=n_non_ignore,
sum_non_ignore_weight=sum_non_ignore_weight,
ignore_index=ignore_index,
weight_sum=weight_sum,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap,
RETURN_Z_LOSS=return_z_loss,
RETURN_TOKEN_ACCURACY=return_token_accuracy,
RETURN_PREDICTED_TOKENS=return_predicted_tokens,
BLOCK_SIZE=BLOCK_SIZE,
HAS_WEIGHT=True if weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
HAS_GRADIENTS=_input.requires_grad,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
num_warps=32 if not is_hip() else 16,
)
if reduction == "none":
loss = loss_1d
z_loss = z_loss_1d if return_z_loss else None
token_accuracy = token_accuracy_1d if return_token_accuracy else None
else:
loss = torch.sum(loss_1d)
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
# For accuracy, we compute the mean across all non-ignored tokens
token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
predicted_tokens = predicted_tokens_1d if return_predicted_tokens else None
return loss, z_loss, token_accuracy, predicted_tokens, _input
def cross_entropy_backward(_input, grad_output):
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
pass
# If reduction is 'none'
elif grad_output.ndim > 0:
_input = _input * grad_output.unsqueeze(dim=1)
# If reduction is ['mean', 'sum'], grad_output is just a scalar
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
else:
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
element_mul_kernel[(n_rows,)](
_input,
_input.stride(-2),
grad_output,
V,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
return _input
class LigerCrossEntropyFunction(torch.autograd.Function):
"""
This class implements a custom autograd function for the Liger Cross Entropy loss.
It overrides the forward and backward methods of the torch.autograd.Function class.
"""
@staticmethod
def forward(
ctx,
_input: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.FloatTensor],
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
reduction: str = "mean",
softcap: Optional[float] = None,
return_z_loss: bool = False,
return_token_accuracy: bool = False,
return_predicted_tokens: bool = False,
):
"""
The forward pass of the Liger Cross Entropy loss.
Parameters:
ctx : The context object.
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
ignore_index (int): The index to ignore in the target.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy, predicted_tokens) instead of (loss, None, None, None). Default: `False`
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
return_predicted_tokens (bool): When `return_predicted_tokens` is `True`, returns per-token predicted class indices (argmax) without materializing logits. Default: `False`
Returns:
tuple: A tuple with the computed losses, accuracy, and predicted tokens: (loss, z_loss, token_accuracy, predicted_tokens). z_loss, token_accuracy, and predicted_tokens are None if not requested.
"""
input_requires_grad = _input.requires_grad
loss, z_loss, token_accuracy, predicted_tokens, _input = cross_entropy_forward(
_input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
reduction,
softcap,
return_z_loss,
return_token_accuracy,
return_predicted_tokens,
)
# TODO: investigation
# If we don't detach the _input tensor, the memory will double
# Not sure why but seems that there will be a time both grad and value exist but in different location
if input_requires_grad:
ctx.save_for_backward(_input.detach())
ctx.return_z_loss = return_z_loss
ctx.return_token_accuracy = return_token_accuracy
ctx.return_predicted_tokens = return_predicted_tokens
return loss, z_loss, token_accuracy, predicted_tokens
@staticmethod
def backward(ctx, grad_output, grad_output2, grad_output3, grad_output4):
"""
The backward pass of the Liger Cross Entropy loss.
Parameters:
ctx : The context object with saved tensors.
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
grad_output4 (tensor): No use. Gradient for predicted_tokens (not used as predicted_tokens is only for metrics).
Returns:
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
if ctx.return_z_loss:
del grad_output2 # z_loss is only for logging
if ctx.return_token_accuracy:
del grad_output3 # token_accuracy is only for metrics
if ctx.return_predicted_tokens:
del grad_output4 # predicted_tokens is only for metrics
(_input,) = ctx.saved_tensors
_input = cross_entropy_backward(_input, grad_output)
return (
_input,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
import operator
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import compare_version
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
from liger_kernel.ops.utils import infer_device
from liger_kernel.utils import is_npu_available
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import tanh
except ModuleNotFoundError:
# for working with NGC containers
from triton.language.extra.cuda.libdevice import tanh
else:
from triton.language.math import tanh
# @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
# for bn in [1024, 2048, 4096]
# for ns in [1,2,4]
# for nw in [4, 8, 16, 32]
# ],
# key=['N'])
@triton.jit
def _dyt_fwd_kernel(X, Y, Alpha, Gamma, Beta, HAVE_BETA: tl.constexpr, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024):
col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
mask = col < N
row_id = tl.cast(tl.program_id(1), tl.int64)
X += row_id * N
Y += row_id * N
alpha = tl.load(Alpha).to(tl.float32)
gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
x = tl.load(X + col, mask=mask, other=0.0).to(tl.float32)
tanh_x = tanh(alpha * x)
y = tanh_x * gamma
if HAVE_BETA:
beta = tl.load(Beta + col, mask=mask, other=0.0).to(tl.float32)
y += beta
tl.store(Y + col, y, mask=mask)
# @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
# for bn in [1024, 2048, 4096]
# for ns in [1,2,4]
# for nw in [4, 8, 16]
# ],
# key=['N'])
@triton.jit
def _dyt_bwd_kernel(
DY, DX, DA, DG, DB, X, Alpha, Gamma, HAVE_BETA: tl.constexpr, M, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024
):
col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
mask = col < N
start_row_id = tl.cast(tl.program_id(1), tl.int64)
alpha = tl.load(Alpha).to(tl.float32)
da = 0.0
gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
dg = tl.zeros((BLOCK_N,), dtype=tl.float32)
if HAVE_BETA:
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
for row_id in range(start_row_id, M, tl.num_programs(1)):
x = tl.load(X + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
dy = tl.load(DY + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
tanh_x = tanh(alpha * x)
if HAVE_BETA:
db += dy
dg += dy * tanh_x
tmp = (1 - tanh_x * tanh_x) * dy * gamma
da += tl.sum(x * tmp, 0)
dx = alpha * tmp
tl.store(DX + row_id * N + col, dx, mask=mask)
tl.store(DG + start_row_id * N + col, dg, mask=mask)
if HAVE_BETA:
tl.store(DB + start_row_id * N + col, db, mask=mask)
tl.store(DA + start_row_id * tl.cdiv(N, 512) + tl.program_id(0), da)
def liger_dyt_fwd(x, alpha, gamma, beta):
assert x.is_contiguous()
HAVE_BETA = True if beta is not None else False
input_shape = x.shape
x = x.view(-1, input_shape[-1])
M, N = x.shape
y = torch.empty_like(x)
if N >= 4096:
kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 2048), "num_warps": 4, "num_stages": 1}
else:
kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 4, "num_stages": 1}
grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), M)
_dyt_fwd_kernel[(grid)](
x,
y,
alpha,
gamma,
beta,
HAVE_BETA,
N,
**kwargs,
)
return y.view(input_shape)
def liger_dyt_bwd(dy, x, alpha, gamma, beta):
assert dy.is_contiguous()
input_shape = x.shape
x = x.view(-1, input_shape[-1])
M, N = x.shape
HAVE_BETA = True if beta is not None else False
device = infer_device()
if device == "cuda":
NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
elif device == "xpu":
NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
elif device == "npu":
NUM_SMS = get_npu_core_count()
da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
dx = torch.empty_like(dy)
kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 8, "num_stages": 2}
grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), NUM_SMS)
_dyt_bwd_kernel[grid](dy, dx, da, dg, db, x, alpha, gamma, HAVE_BETA, M, N, **kwargs)
if HAVE_BETA:
db = db.sum(0).to(x.dtype)
dg = dg.sum(0).to(gamma.dtype)
da = da.sum().to(x.dtype).unsqueeze(0)
return dx.view(input_shape), da, dg, db
class LigerDyTFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, x, alpha, gamma, beta):
y = liger_dyt_fwd(x, alpha, gamma, beta)
ctx.save_for_backward(x, alpha, gamma, beta)
return y
@staticmethod
@ensure_contiguous
def backward(ctx, dy):
x, alpha, gamma, beta = ctx.saved_tensors
dx, dalpha, dgamma, dbeta = liger_dyt_bwd(dy, x, alpha, gamma, beta)
return dx, dalpha, dgamma, dbeta
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import ensure_contiguous
@triton.jit
def embedding_forward_kernel(
embeddings_ptr,
indices_ptr,
output_ptr,
n_elements,
embedding_dim: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < n_elements
indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
mask_n = offsets_n < embedding_dim
embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
embeddings = tl.load(
embeddings_ptr + embedding_offsets,
mask=mask_m[:, None] & mask_n[None, :],
other=0.0,
)
output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
tl.store(output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :])
@triton.jit
def embedding_backward_kernel(
grad_output_ptr,
grad_weight_ptr,
indices_ptr,
n_elements,
embedding_dim: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < n_elements
indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
mask_n = offsets_n < embedding_dim
grad_output = tl.load(
grad_output_ptr + offsets_m[:, None] * embedding_dim + offsets_n[None, :],
mask=mask_m[:, None] & mask_n[None, :],
other=0.0,
)
grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
tl.atomic_add(
grad_weight_ptr + grad_weight_offsets,
grad_output,
mask=mask_m[:, None] & mask_n[None, :],
)
class LigerEmbeddingFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, embeddings: torch.Tensor, indices: torch.Tensor):
ori_shape = indices.shape
indices = indices.view(-1)
output = torch.empty(
indices.shape[0],
embeddings.shape[1],
device=indices.device,
dtype=embeddings.dtype,
)
n_elements = indices.numel()
embedding_dim = embeddings.shape[1]
BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim))
BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
grid = (
triton.cdiv(n_elements, BLOCK_SIZE_M),
triton.cdiv(embedding_dim, BLOCK_SIZE_N),
)
embedding_forward_kernel[grid](
embeddings,
indices,
output,
n_elements,
embedding_dim=embedding_dim,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
)
ctx.save_for_backward(indices, embeddings)
return output.view(*ori_shape, -1)
@staticmethod
@ensure_contiguous
def backward(ctx, grad_output: torch.Tensor):
indices, embedding_table = ctx.saved_tensors
grad_output = grad_output.contiguous().view(-1, embedding_table.shape[1])
grad_weight = torch.zeros_like(embedding_table)
n_elements = indices.numel()
embedding_dim = embedding_table.shape[1]
BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim))
BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
grid = (
triton.cdiv(n_elements, BLOCK_SIZE_M),
triton.cdiv(embedding_dim, BLOCK_SIZE_N),
)
embedding_backward_kernel[grid](
grad_output,
grad_weight,
indices,
n_elements,
embedding_dim=embedding_dim,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
)
return grad_weight, None
import torch
import triton
import triton.language as tl
def unpack_weights(packed: torch.Tensor, bits: int = 2) -> torch.Tensor:
values_per_item = 8 // bits
packed_shape = packed.shape
if len(packed_shape) == 1:
original_row_dim = packed_shape[0] * values_per_item
unpacked_shape = (original_row_dim,)
else:
original_row_dim = packed_shape[0] * values_per_item
unpacked_shape = (original_row_dim, *packed_shape[1:])
unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8)
for i in range(values_per_item):
start = i * packed_shape[0]
end = start + packed_shape[0]
mask = 3 << (2 * i)
unpacked[start:end] = (packed & mask) >> (2 * i)
unpacked = unpacked.to(torch.int32) - 1
return unpacked
def pack_weights(intweights: torch.Tensor, bits: int = 2) -> torch.Tensor:
intweights += 1
original_shape = intweights.shape
values_per_item = 8 // bits
row_dim = (original_shape[0] + values_per_item - 1) // values_per_item
if len(original_shape) == 1:
packed_tensor_shape = (row_dim,)
else:
packed_tensor_shape = (row_dim, *original_shape[1:])
packed = torch.zeros(packed_tensor_shape, device=intweights.device, dtype=torch.uint8)
unpacked = intweights.to(torch.uint8)
def lshift(t: torch.Tensor, bits: int):
return t << bits
it = min(values_per_item, (original_shape[0] // row_dim) + 1)
for i in range(it):
start = i * row_dim
end = min(start + row_dim, original_shape[0])
packed[: (end - start)] |= lshift(unpacked[start:end], bits * i)
return packed
def get_autotune_config():
return [
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 4,
},
num_stages=4,
num_warps=4,
),
]
@triton.autotune(
configs=get_autotune_config(),
key=["M", "N", "K"],
)
@triton.jit
def matmul_kernel(
a_ptr,
b_ptr,
c_ptr,
M,
N,
K: tl.constexpr,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
# We want K / 4 to be divisible by BLOCK_SIZE_K so that the multiplication can be aligned
tl.static_assert(
K % (4 * BLOCK_SIZE_K) == 0,
"K / 4 must be divisible by BLOCK_SIZE_K => K divisible by 4*BLOCK_SIZE_K",
)
# determine the block id in the 1D grid, pid <=> blockId in cuda
pid = tl.program_id(axis=0)
# number of blocks we would need in the M dimension
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
# number of blocks we would need in the N dimension
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# blocks are grouped along the M dimension. num_pid_in_group computes how many blocks are grouped together,
# and group_id calculates the group to which the current block (pid) belongs.
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
# pid of the first block in the group that the current block belongs too
first_pid_m = group_id * GROUP_SIZE_M
# pid_m : pid of the block along the M dimension of the output matrix, and pid_n : pid of the block along the N dimension of the output matrix
# remember that the grid of blocks is 1D, but we calculate pid_m and pid_n to locate the block pid place in the output matrix
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# offs_am represent the indices of elements within the block for matrices A with respect to the M dimension
# offs_bn represent the indices of elements within the block for matrices B with respect to the N dimension
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
"""
This part of the code generates pointers to the specific blocks of matrices A and B that the current thread block will process.
As described in the PyTorch documentation, a stride refers to the step size needed to move from one element to the next along a given dimension:
For matrix A: stride_am = A.stride(0) = K (stride along the rows), and stride_ak = A.stride(1) = 1 (stride along the columns).
For matrix B: stride_bk = B.stride(0) = N (stride along the rows), and stride_bn = B.stride(1) = 1 (stride along the columns).
Now, let's break down the pointer generation:
offs_am[:, None] creates a column of shape [BLOCK_SIZE_M, 1], which represents the row indices of matrix A that this block is processing. It is multiplied by K (the number of columns in matrix A) since A is stored in row-major order. So, the element at position (i, j) in A is located at index i*K + j in memory.
offs_k[None, BLOCK_SIZE_K] creates a row vector representing the column indices of the block, i.e., a range from 0 to BLOCK_SIZE_K. This is used to compute the positions of the columns within the block.
When combined, the result has the shape [BLOCK_SIZE_M, BLOCK_SIZE_K], where each entry (i, j) points to the element in matrix A at position (i, j) for the current block.
The same logic is applied to matrix B, but the resulting shape is [BLOCK_SIZE_K, BLOCK_SIZE_N], representing the block of matrix B that the thread block will work on.
"""
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# An accumulator matrix is initialized with zeros. It stores the intermediate results of the block matrix multiplication.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
"""
We split the loop into two layers. The outer loop runs 4 times, and each iteration focuses on a specific portion of matrix A.
For example, when i = 0, we’re only concerned with the blocks of matrix A that cover the range from 0 to K // (4 * BLOCK_SIZE_K).
Since matrix B is packed, its first dimension is effectively divided by 4. So, while we process the first segment of matrix A,
we still iterate over the entire first dimension of matrix B.
In each of the 4 iterations of the outer loop, we go through the full blocks of matrix B, but what changes is the data we extract.
Matrix B elements contain 4 weights, all packed into an int8 format, and during each iteration of the outer loop,
we extract a different weight by using bitwise shifting operations. This way, we access a unique weight on each pass.
"""
for i in range(4):
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
for j in range(0, tl.cdiv(K // 4, BLOCK_SIZE_K)):
k = i * tl.cdiv(K // 4, BLOCK_SIZE_K) + j
# load the block of matrix A
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0)
# load the block of matrix B
b_uint8 = tl.load(b_ptrs, mask=offs_k[:, None] < K, other=0)
# when i = 0 for example, we only care about the first 2 bits of the elements of the matrix B, so we use the mask 00000011 to mask the other bits
mask = 3 << (2 * i)
# we shift the results after the mask
b = (b_uint8 & mask) >> (2 * i)
# During the packing of the weights, it's easier to pack 0, 1, 2, then -1, 0, 1, so we add 1 to the weight tensor, and we substract it here
tensor_full = tl.full((1,), 1, dtype=tl.int8)
# We accumulate the result of multiplication of the blocks along the K dimension on int32 to avoid any overflows or underflows.
accumulator += tl.dot(a, (b.to(tl.int8) - tensor_full), out_dtype=tl.int32)
# we move the pointers, for the a_ptrs we more in a horizontal way along the second dimension -> we use strid_ak=1
# for b_ptrs we move in a vertical way, along the rows -> we use stride_bk=N
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator
# These lines compute the offsets into matrix C where the result of this block’s computation should be stored.
# stride_cm = N & stride_cn = 1
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
# we do a boundary check to ensure only elements within matrix bounds are stored
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def matmul(a, b):
assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the weight matrix need to be packed"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
_, N = b.shape
# c is in int32 to avoid any overflows or underflows
c = torch.empty((M, N), device=a.device, dtype=torch.int32)
grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
matmul_kernel[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
)
return c
import math
import operator
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import compare_version
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
from liger_kernel.ops.utils import set_large_grf_mode
from liger_kernel.ops.utils import torch_to_triton_dtype
from liger_kernel.utils import is_npu_available
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
# for working with NGC containers
from triton.language.extra.cuda.libdevice import rsqrt
else:
from triton.language.math import rsqrt
_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
@triton.jit
def _fused_add_rms_norm_forward_kernel(
Y_ptr,
Y_row_stride,
S_ptr, # output residual
S_row_stride,
X_ptr,
X_row_stride,
R_ptr, # input residual
R_row_stride,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_cols,
eps,
offset,
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
BLOCK_SIZE: tl.constexpr,
):
"""
This kernel computes the following:
1. hidden_states = residual + hidden_states
2. residual = hidden_states
3. hidden_states = rmsnorm(hidden_states)
This is a commonly used pattern in the decoder layers of LLMs.
Some examples:
1. https://github.com/huggingface/transformers/blob/0dc2df5ddafe3cb5824ad24e85beba13e0aa6726/src/transformers/models/qwen3/modeling_qwen3.py#L271
2. https://github.com/huggingface/transformers/blob/0dc2df5ddafe3cb5824ad24e85beba13e0aa6726/src/transformers/models/llama4/modeling_llama4.py#L393
This kernel is inspired by the rms_norm forward kernel, and is adapted to support the residual addition in the forward pass.
The backward pass is also adapted to support the residual addition in the backward pass.
"""
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
Y_ptr += row_idx * Y_row_stride
S_ptr += row_idx * S_row_stride
X_ptr += row_idx * X_row_stride
R_ptr += row_idx * R_row_stride
RSTD_ptr += row_idx * RSTD_row_stride
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
R_row = tl.load(R_ptr + col_offsets, mask=mask, other=0)
S_row = X_row + R_row
tl.store(S_ptr + col_offsets, S_row, mask=mask)
S_row_dtype = S_row.dtype
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
# On Llama, only rstd is computed on fp32
if casting_mode == _CASTING_MODE_LLAMA:
S_row = S_row.to(tl.float32)
# Gemma computes everything on fp32, and then casts back the output to the original dtype
if casting_mode == _CASTING_MODE_GEMMA:
W_row = W_row.to(tl.float32)
S_row = S_row.to(tl.float32)
if casting_mode == _CASTING_MODE_NONE:
eps = eps.to(S_row_dtype)
offset = offset.to(S_row_dtype)
mean_square = tl.sum(S_row * S_row, axis=0) / n_cols
rstd = rsqrt(mean_square + eps)
# We can save time by caching rms with minimal memory overhead
# because rms is much smaller compared to X_row, as rms is for each row.
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
tl.store(RSTD_ptr, rstd)
S_row = S_row * rstd
# On Llama, the multiplication with the weight is done on the original dtype
if casting_mode == _CASTING_MODE_LLAMA:
S_row = S_row.to(S_row_dtype)
Y_row = S_row * (offset + W_row)
if casting_mode == _CASTING_MODE_GEMMA:
Y_row = Y_row.to(S_row_dtype)
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
@triton.jit
def _fused_add_rms_norm_backward_kernel(
dY_ptr,
dY_row_stride,
dS_out_ptr,
dS_out_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows,
n_cols,
offset,
rows_per_program: tl.constexpr,
casting_mode: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
has_dS_out: tl.constexpr,
):
"""
This kernel is adapted from the rms_norm backward kernel, and is adapted to support the residual
addition in the backward pass. For the following code pattern:
1. hidden_states = residual + hidden_states
2. residual = hidden_states
3. hidden_states = rmsnorm(hidden_states)
The gradient of hidden_states and residual comes out be exactly same. The value of this gradient is
the sum of the gradient of the hidden_states in step 3 and the gradient of the residual in step 2.
The backward pass computation logic is same as the rms_norm backward kernel, except that the gradient
of the hidden_states in step 3 and the gradient of the residual in step 2 are summed up.
"""
row_block_id = tl.program_id(0).to(tl.int64)
row_start = row_block_id * rows_per_program
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
W_row = W_row + offset
for row_idx in range(row_start, row_end):
dy_base = dY_ptr + row_idx * dY_row_stride
dx_base = dX_ptr + row_idx * dX_row_stride
x_base = X_ptr + row_idx * X_row_stride
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
# Get cached rms
rstd_row = tl.load(rstd_base)
X_row = X_row.to(tl.float32)
# Different bacward graphs for different casting modes
if casting_mode == _CASTING_MODE_LLAMA:
m = (dY_row * W_row).to(tl.float32)
elif casting_mode == _CASTING_MODE_GEMMA:
dY_row = dY_row.to(tl.float32)
m = dY_row * W_row
else:
m = dY_row * W_row
dX_row = rstd_row * m
if has_dS_out:
ds_base = dS_out_ptr + row_idx * dS_out_row_stride
dS_out_row = tl.load(ds_base + col_offsets, mask=mask, other=0.0)
dX_row += (rstd_row) * (
-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
) + dS_out_row
else:
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
# calculate the gradient of W
if casting_mode == _CASTING_MODE_LLAMA:
dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
else:
# here X_row is already in fp32 (see previous if block)
dW_row += dY_row * (X_row * rstd_row)
tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
_str_to_casting_mode = {
"llama": _CASTING_MODE_LLAMA.value,
"gemma": _CASTING_MODE_GEMMA.value,
"none": _CASTING_MODE_NONE.value,
}
def fused_add_rms_norm_forward(X, R, W, eps, offset, casting_mode):
if not isinstance(casting_mode, int):
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
casting_mode = _str_to_casting_mode[casting_mode]
else:
assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
R = R.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
S = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
# RSTD is to cache rstd for each row
# RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
# Check constraints.
assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
# XPU-specific optimization
kernel_args = {}
if X.device.type == "xpu":
set_large_grf_mode(kernel_args)
# TODO: add _block_fused_add_rms_norm_forward_kernel
_fused_add_rms_norm_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
S,
S.stride(0),
X,
X.stride(0),
R,
R.stride(0),
W,
W.stride(0),
RSTD,
RSTD.stride(0),
n_cols,
eps,
offset,
casting_mode,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_args, # XPU-specific optimization
)
return Y.view(*shape), S.view(*shape), RSTD, BLOCK_SIZE, num_warps, casting_mode
def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
dS_out = dS_out.view(-1, dim)
S = S.view(-1, dim)
n_rows, n_cols = dY.shape
sm_count = 1
if S.device.type == "cuda":
sm_count = torch.cuda.get_device_properties(S.device).multi_processor_count
elif S.device.type == "xpu":
sm_count = torch.xpu.get_device_properties(S.device).gpu_eu_count
elif S.device.type == "npu":
sm_count = get_npu_core_count()
# fp32 for numerical stability especially.
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
if n_cols > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
rows_per_program = math.ceil(n_rows / sm_count)
grid = (sm_count,)
if in_place is True:
dX = dY
else:
dX = torch.empty_like(dY)
# XPU-specific optimization
kernel_args = {}
if S.device.type == "xpu":
set_large_grf_mode(kernel_args)
# TODO: add _block_fused_add_rms_norm_backward_kernel
_fused_add_rms_norm_backward_kernel[grid](
dY,
dY.stride(0),
dS_out,
dS_out.stride(0),
dX,
dX.stride(0),
S,
S.stride(0),
torch_to_triton_dtype[S.dtype],
W,
W.stride(0),
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0),
n_rows,
n_cols,
offset,
rows_per_program,
casting_mode,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
has_dS_out=dS_out is not None,
**kernel_args, # XPU-specific optimization
)
dX = dX.view(*shape)
dW = _dW.sum(dim=0).to(W.dtype)
return dX, dX, dW # dR is equal to dX
class LigerFusedAddRMSNormFunction(torch.autograd.Function):
"""
Performs a fused operation that first adds a residual tensor to the hidden_states tensor (`X`), then applies RMSNorm (Root Mean Square Normalization) to the result using the weight tensor `W`, with optional offset and casting mode.
This class implements the following sequence, commonly used in transformer decoder layers:
1. hidden_states = residual + hidden_states
2. residual = hidden_states (after addition)
3. hidden_states = rmsnorm(hidden_states)
Both the normalized hidden_states and the updated residual are returned as outputs.
Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma
uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual
`(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function.
In addition, different models cast their inputs at different places during RMSNorm computation. For
example, Gemma casts everything to fp32 before starting the computation, while Llama casts only the
inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently
support the following casting modes (they match HuggingFace Transformers' implementations):
- 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
- 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
- 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
The `in_place` option determines whether to modify dY in-place to store dX. This defaults to `True` to save memory.
"""
@staticmethod
@ensure_contiguous
def forward(ctx, X, R, W, eps, offset=0.0, casting_mode="llama", in_place=False):
"""
X: (B, T, H) or (BxT, H)
W: (H,)
"""
# TODO: add row_mode
Y, S, RSTD, BLOCK_SIZE, num_warps, casting_mode = fused_add_rms_norm_forward(X, R, W, eps, offset, casting_mode)
ctx.offset = offset
ctx.casting_mode = casting_mode
ctx.in_place = in_place
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(S, W, RSTD)
return Y, S
@staticmethod
@ensure_contiguous
def backward(ctx, dY, dS_out):
"""
Y: (B, T, H) or (BxT, H)
"""
S, W, RSTD = ctx.saved_tensors
dX, dR, dW = fused_add_rms_norm_backward(
dY,
dS_out,
S,
W,
RSTD,
ctx.offset,
ctx.casting_mode,
ctx.BLOCK_SIZE,
ctx.num_warps,
ctx.in_place,
)
return dX, dR, dW, None, None, None, None, None
import torch
import triton
from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
from liger_kernel.ops.utils import amp_custom_bwd
from liger_kernel.ops.utils import amp_custom_fwd
from liger_kernel.ops.utils import element_mul_kernel
from liger_kernel.ops.utils import is_hip
from liger_kernel.utils import infer_device
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
MAX_FUSED_SIZE = 2048 if infer_device() == "npu" else 65536 // 2
def fused_linear_cross_entropy_forward(
_input,
weight,
target,
ce_weight=None,
bias=None,
ignore_index=-100,
lse_square_scale=0.0,
label_smoothing=0.0,
reduction="mean",
softcap=None,
return_z_loss=False,
accum_dtype=None,
use_token_scaling=False,
return_token_accuracy=False,
return_predicted_tokens=False,
):
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
assert isinstance(return_token_accuracy, bool), (
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
)
assert isinstance(return_predicted_tokens, bool), (
f"return_predicted_tokens must be True or False. Got: {return_predicted_tokens}"
)
device = _input.device
input_requires_grad = _input.requires_grad
# inputs have shape: BT x H
# materialized activations will have shape: BT x V
# the increase in memory = BT x V
# reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
# for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
# inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
# for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
BT, H = _input.shape
V = weight.shape[0]
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
grad_input = torch.zeros_like(_input, device=device)
# we use fp32 for loss and gradients accumulator
if input_requires_grad:
if accum_dtype is None:
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
else:
grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
else:
grad_weight = None
grad_bias = None
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
predicted_tokens_1d = torch.full((BT,), -1, dtype=torch.int64, device=device) if return_predicted_tokens else None
# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
target_mask = target != ignore_index
total_n_non_ignore = target_mask.sum().item()
total_sum_non_ignore_ce_weight = total_n_non_ignore
ce_weight_sum = 0.0
if ce_weight is not None:
assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
assert torch.is_floating_point(ce_weight), (
f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
)
total_sum_non_ignore_ce_weight = (
torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
)
ce_weight_sum = ce_weight.sum().item()
if ce_weight.stride(-1) != 1:
ce_weight = ce_weight.contiguous()
for chunk_id in range(num_chunks):
start_idx = chunk_id * chunk_size
end_idx = min((chunk_id + 1) * chunk_size, BT)
_input_chunk = _input[start_idx:end_idx] # chunk_size x H
# when doing matmul, use the original precision
logits_chunk = _input_chunk @ weight.t() # chunk_size x V
if bias is not None:
logits_chunk = logits_chunk + bias
target_chunk = target[start_idx:end_idx] # chunk_size,
n_rows = logits_chunk.shape[0]
# Compute predicted probabilities for token scaling if needed
if use_token_scaling:
# Compute softmax probabilities for scaling
# We need to compute this before the cross entropy kernel modifies logits_chunk
logits_for_softmax = logits_chunk.detach().clone() # Detach to avoid gradient flow
if softcap is not None:
logits_for_softmax = softcap * torch.tanh(logits_for_softmax / softcap)
# Compute softmax to get predicted probabilities
probs = torch.softmax(logits_for_softmax, dim=-1)
# Get predicted probabilities for token scaling, handling ignored targets
valid_target_mask = target_chunk != ignore_index
valid_targets = target_chunk[valid_target_mask]
if len(valid_targets) > 0:
# Gather probabilities only for valid targets
valid_probs = probs[valid_target_mask]
pred_probs_valid = torch.gather(valid_probs, -1, valid_targets.unsqueeze(-1)).squeeze(-1)
# Create full tensor with zeros for ignored targets
pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
pred_probs[valid_target_mask] = pred_probs_valid
else:
# All targets are ignored
pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
# Store the scaling factors
scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow
# unreduced loss
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
predicted_tokens_1d_slice = predicted_tokens_1d[start_idx:end_idx] if return_predicted_tokens else None
# ensure _input and target are contiguous
logits_chunk = logits_chunk.contiguous()
target_chunk = target_chunk.contiguous()
# Here we calculate the gradient of logits_chunk in place so we can save memory.
liger_cross_entropy_kernel[(n_rows,)](
X_ptr=logits_chunk,
X_stride=logits_chunk.stride(-2),
Y_ptr=target_chunk,
Y_stride=target_chunk.stride(-1), # always 1
weight_ptr=ce_weight,
loss_ptr=loss_1d_slice,
z_loss_ptr=z_loss_1d_slice,
loss_stride=loss_1d_slice.stride(-1), # always 1
token_accuracy_ptr=token_accuracy_1d_slice,
token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
if return_token_accuracy
else 0, # always 1 if accuracy is enabled
predicted_tokens_ptr=predicted_tokens_1d_slice,
predicted_tokens_stride=predicted_tokens_1d_slice.stride(-1)
if return_predicted_tokens
else 0, # always 1 if predicted tokens is enabled
n_cols=V,
n_non_ignore=total_n_non_ignore,
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
weight_sum=ce_weight_sum,
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap,
RETURN_Z_LOSS=return_z_loss,
RETURN_TOKEN_ACCURACY=return_token_accuracy,
RETURN_PREDICTED_TOKENS=return_predicted_tokens,
HAS_WEIGHT=True if ce_weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
HAS_GRADIENTS=input_requires_grad,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
# Apply token scaling if requested
if use_token_scaling:
loss_1d_slice = loss_1d_slice * scaling_factors
if return_z_loss:
z_loss_1d_slice = z_loss_1d_slice * scaling_factors
loss_1d[start_idx:end_idx] = loss_1d_slice
if return_z_loss:
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
if return_token_accuracy:
token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
if return_predicted_tokens:
predicted_tokens_1d[start_idx:end_idx] = predicted_tokens_1d_slice
grad_logits_chunk = logits_chunk # chunk_size x V
# Apply token scaling to gradients if requested
if use_token_scaling:
# Expand scaling factors to match gradient dimensions
scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
if input_requires_grad:
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
if grad_weight is not None and input_requires_grad:
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
if bias is not None and input_requires_grad:
torch.add(
input=grad_bias,
other=grad_logits_chunk.sum(dim=0),
out=grad_bias,
alpha=1.0,
)
# Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
# if reduction == "none":
# loss = loss_1d
# z_loss = z_loss_1d if return_z_loss else None
if reduction == "none":
# Return per-token losses
loss = loss_1d
z_loss = z_loss_1d if return_z_loss else None
token_accuracy = token_accuracy_1d if return_token_accuracy else None
else:
loss = torch.sum(loss_1d)
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
# For accuracy, we compute the mean across all non-ignored tokens
token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
predicted_tokens = predicted_tokens_1d if return_predicted_tokens else None
# Cast back to original dtype
grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
return loss, z_loss, token_accuracy, predicted_tokens, grad_input, grad_weight, grad_bias
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
BT, H = grad_input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
element_mul_kernel[(n_rows,)](
grad_input,
grad_input.stride(-2),
grad_output,
H,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
# handle grad_weight
if grad_weight is not None:
V, H = grad_weight.shape
n_rows = V
element_mul_kernel[(n_rows,)](
grad_weight,
grad_weight.stride(-2),
grad_output,
H,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
if grad_bias is not None:
V = grad_bias.shape[0]
n_rows = V
element_mul_kernel[(n_rows,)](
grad_bias,
grad_bias.stride(-1),
grad_output,
1,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
return grad_input, grad_weight, grad_bias
class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
@staticmethod
@amp_custom_fwd
def forward(
ctx,
_input,
weight,
target,
bias=None,
ce_weight=None,
ignore_index=-100,
lse_square_scale=0.0,
label_smoothing=0.0,
reduction="mean",
softcap=None,
return_z_loss: bool = False,
accum_dtype=None,
use_token_scaling: bool = False,
return_token_accuracy: bool = False,
return_predicted_tokens: bool = False,
):
"""
Fusing the last linear layer with cross-entropy loss
Reference: https://github.com/mgmalek/efficient_cross_entropy
Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
compute the gradient at the forward pass. By doing so, we don't have to store the _input and target
for the backward pass.
_input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension.
target: (B*T) where each value is in [0, V-1]
weight: (V, H) where V is the number of classes
bias: (V) where V is the number of classes
ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
ignore_index: the index to ignore in the target
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction: reduction to apply
accum_dtype (torch.dtype): the dtype of intermediate result buffers for weight and bias gradient accumulations.
Recommended to set `accum_dtype` to higher precision, e.g. `torch.float32`, if the training is unstable with original dtype. Default: `None`, performing accumulations in original dtype
use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
Default: False.
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
return_predicted_tokens (bool): When `return_predicted_tokens` is `True`, returns per-token predicted class indices (argmax) without materializing logits. Default: `False`
"""
loss, z_loss, token_accuracy, predicted_tokens, grad_input, grad_weight, grad_bias = (
fused_linear_cross_entropy_forward(
_input=_input,
weight=weight,
target=target,
bias=bias,
ce_weight=ce_weight,
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap,
return_z_loss=return_z_loss,
accum_dtype=accum_dtype,
use_token_scaling=use_token_scaling,
return_token_accuracy=return_token_accuracy,
return_predicted_tokens=return_predicted_tokens,
)
)
# downcast to dtype and store for backward
ctx.save_for_backward(
grad_input.detach(),
grad_weight.detach() if grad_weight is not None else None,
grad_bias.detach() if grad_bias is not None else None,
)
ctx.return_z_loss = return_z_loss
ctx.return_token_accuracy = return_token_accuracy
ctx.return_predicted_tokens = return_predicted_tokens
return loss, z_loss, token_accuracy, predicted_tokens
@staticmethod
@amp_custom_bwd
def backward(ctx, grad_output, grad_output2, grad_output3, grad_output4):
if ctx.return_z_loss:
del grad_output2 # z_loss is only for logging
if ctx.return_token_accuracy:
del grad_output3 # token_accuracy is only for metrics
if ctx.return_predicted_tokens:
del grad_output4 # predicted_tokens is only for metrics
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
grad_output, grad_input, grad_weight, grad_bias
)
return (
grad_input,
grad_weight,
None,
grad_bias,
None,
None,
None,
None,
None,
None,
None,
None,
None, # use_token_scaling
None, # return_token_accuracy
None, # return_predicted_tokens
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment