Commit 9b0e3a30 authored by cmx's avatar cmx
Browse files

first commit

parent fe5cd1fc
Pipeline #3450 failed with stages
in 0 seconds
from typing import Optional
import torch
import triton
from liger_kernel.ops.jsd import _jsd_kernel
from liger_kernel.ops.utils import amp_custom_bwd
from liger_kernel.ops.utils import amp_custom_fwd
from liger_kernel.ops.utils import element_mul_kernel
from liger_kernel.ops.utils import is_hip
from liger_kernel.utils import infer_device
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2
def fused_linear_jsd_forward(
student_input,
student_weight,
teacher_input,
teacher_weight,
shift_labels,
jsd_beta,
ignore_index,
has_label,
temperature,
):
device = student_input.device
dtype = student_input.dtype
# inputs have shape: BT x H
# materialized activations will have shape: BT x V
# the increase in memory = BT x V
# reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
# for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
# inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
# for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
BT, H = student_input.shape
V = student_weight.shape[0]
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
grad_weight = torch.zeros_like(student_weight, device=device) if student_weight.requires_grad else None
grad_input = torch.zeros_like(student_input)
# we use fp32 for loss accumulator
loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
if has_label:
n_non_ignore = (shift_labels != ignore_index).sum().item()
else:
n_non_ignore = BT
for chunk_id in range(num_chunks):
start_idx = chunk_id * chunk_size
end_idx = min((chunk_id + 1) * chunk_size, BT)
# chunk both inputs, shape: chunk_size x H
student_input_chunk = student_input[start_idx:end_idx]
teacher_input_chunk = teacher_input[start_idx:end_idx]
# shape: chunk_size x V
# For anything starting from logits to the final JSD loss, we do computation
# in FP32 to avoid losing numerical stability.
student_logits_chunk = (student_input_chunk @ student_weight.t()).to(torch.float32)
teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(torch.float32)
chunk_n_rows = student_logits_chunk.shape[0]
# unreduced loss
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size
# log-softmax with temperature
student_logits_chunk = student_logits_chunk / temperature
teacher_logits_chunk = teacher_logits_chunk / temperature
student_prob_chunk = torch.log_softmax(student_logits_chunk, dim=-1)
teacher_prob_chunk = torch.log_softmax(teacher_logits_chunk, dim=-1)
# ensure _input and target are contiguous
student_prob_chunk = student_prob_chunk.contiguous()
teacher_prob_chunk = teacher_prob_chunk.contiguous()
# Here we calculate the gradient of prob_chunk in place so we can save memory.
_jsd_kernel[(chunk_n_rows,)](
X_ptr=student_prob_chunk,
X_stride=student_prob_chunk.stride(-2),
Y_ptr=teacher_prob_chunk,
Y_stride=teacher_prob_chunk.stride(-2),
loss_ptr=loss_1d_slice,
loss_stride=loss_1d_slice.stride(-2),
dX_ptr=student_prob_chunk,
dX_stride=student_prob_chunk.stride(-2),
label_ptr=(
shift_labels[start_idx:end_idx] if has_label else torch.empty(1, device=device)
), # dummy ptr if no label
beta=jsd_beta,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
n_cols=V,
BLOCK_SIZE=BLOCK_SIZE,
HAS_LABEL=has_label,
)
loss_1d[start_idx:end_idx] = loss_1d_slice
# gradients of prob_chunk in place, shape: chunk_size x V
# gradients of logits_chunk in place, shape: chunk_size x V
student_logits_chunk = (
student_prob_chunk
- torch.softmax(student_logits_chunk, dim=-1)
* student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(student_prob_chunk.shape)
) / temperature
# now we traverse back to grad w.r.t. input to `lm_head` and grad
# w.r.t. `lm_head` which should be computed in original dtype
student_logits_chunk = student_logits_chunk.to(dtype)
grad_input[start_idx:end_idx] = student_logits_chunk @ student_weight
if grad_weight is not None:
grad_weight.add_(student_logits_chunk.t() @ student_input_chunk)
loss = torch.sum(loss_1d)
return loss, grad_input, grad_weight
def fused_linear_jsd_backward(grad_output, grad_input, grad_weight):
# If JSD is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
BT, H = grad_input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
element_mul_kernel[(n_rows,)](
grad_input,
grad_input.stride(-2),
grad_output,
H,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
# handle grad_weight
if grad_weight is not None:
V, H = grad_weight.shape
n_rows = V
element_mul_kernel[(n_rows,)](
grad_weight,
grad_weight.stride(-2),
grad_output,
H,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
return grad_input, grad_weight
class LigerFusedLinearJSDFunction(torch.autograd.Function):
"""
Fusing the last linear layer with generalized JSD
Handle the forward and backward pass of the final linear layer via JSD by avoiding
the materialization of the large logits tensor. Since JSD is the last layer, we can
compute the gradient at the forward pass.
"""
@staticmethod
@amp_custom_fwd
def forward(
ctx,
student_input: torch.Tensor,
student_weight: torch.Tensor,
teacher_input: torch.Tensor,
teacher_weight: torch.Tensor,
shift_labels: Optional[torch.Tensor] = None,
jsd_beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
):
"""
Args:
student_input (torch.tensor): input of the last projection layer in student model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
student_weight (torch.tensor): the last projection layer in student model, with shape (V, H), where V is vocab size
teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
ignore_index (int): the index to ignore. Default: -100
temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
Returns:
loss (torch.Tensor): generalized JSD
"""
has_label = False
if shift_labels is not None:
assert shift_labels.shape == (teacher_input.shape[0],), (
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
)
shift_labels = shift_labels.contiguous()
has_label = True
loss, grad_input, grad_weight = fused_linear_jsd_forward(
student_input,
student_weight,
teacher_input,
teacher_weight,
shift_labels,
jsd_beta,
ignore_index,
has_label,
temperature,
)
# downcast to dtype and store for backward
ctx.save_for_backward(
grad_input.detach(),
grad_weight.detach() if grad_weight is not None else None,
)
return loss
@staticmethod
@amp_custom_bwd
def backward(ctx, grad_output):
(grad_input, grad_weight) = ctx.saved_tensors
grad_input, grad_weight = fused_linear_jsd_backward(grad_output, grad_input, grad_weight)
return (grad_input, grad_weight, None, None, None, None, None, None)
import math
import torch
import triton
import triton.language as tl
from liger_kernel.ops.softmax import _softmax_backward
from liger_kernel.ops.softmax import _softmax_forward
from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import ensure_contiguous
@triton.jit
def _neighborhood_mask_kernel(
mask_ptr,
seq_len: tl.constexpr,
kernel_size: tl.constexpr,
dilation: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr,
num_warps: tl.constexpr,
):
"""
Generate a neighborhood attention mask for a given sequence.
This kernel creates a binary mask that defines which positions in a sequence
can attend to each other based on a neighborhood window with optional dilation.
Each row of the mask corresponds to a query position, and each column indicates
whether that key position is within the allowed neighborhood.
The neighborhood is defined as positions within kernel_size//2 * dilation distance
from the center position. When dilation > 1, only positions at multiples of the
dilation factor are included in the neighborhood.
Args:
mask_ptr: Pointer to the output mask tensor [seq_len, seq_len]
seq_len: Length of the input sequence
kernel_size: Size of the neighborhood window (must be odd)
dilation: Dilation factor for the neighborhood pattern
BLOCK_SIZE: Block size for processing (compile-time constant)
num_stages: Number of pipeline stages (compile-time constant)
num_warps: Number of warps (compile-time constant)
Grid: (seq_len,)
Each program processes one row of the mask matrix.
"""
row_id = tl.program_id(0)
center = row_id
half_kernel = kernel_size // 2
start = tl.maximum(0, center - half_kernel * dilation)
end = tl.minimum(seq_len, center + half_kernel * dilation + 1)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < seq_len
valid_neighbors = (col_offsets >= start) & (col_offsets < end)
if dilation > 1:
relative_pos = col_offsets - center
valid_dilation = (relative_pos % dilation) == 0
valid_neighbors = valid_neighbors & valid_dilation
mask_values = tl.where(valid_neighbors & mask, 1.0, 0.0)
base_offset = row_id * seq_len
tl.store(mask_ptr + base_offset + col_offsets, mask_values, mask=mask)
@triton.jit
def _fused_neighborhood_attention_qk_kernel(
Q_ptr,
K_ptr,
QK_ptr,
mask_ptr,
q_batch_stride,
q_head_stride,
q_seq_stride,
q_dim_stride,
k_batch_stride,
k_head_stride,
k_seq_stride,
k_dim_stride,
qk_batch_stride,
qk_head_stride,
qk_seq_stride,
qk_seq2_stride,
batch_size: tl.constexpr,
num_heads: tl.constexpr,
seq_len: tl.constexpr,
head_dim: tl.constexpr,
scale: tl.constexpr,
kernel_size: tl.constexpr,
dilation: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
num_stages: tl.constexpr,
num_warps: tl.constexpr,
):
"""
Compute Q @ K^T with neighborhood masking and scaling.
This kernel performs the first stage of neighborhood attention by computing
the attention scores between queries and keys, applying scaling, and masking
positions outside the neighborhood window. The result is a matrix of attention
scores ready for softmax normalization.
The computation is tiled across sequence dimensions for memory efficiency.
Each tile computes a block of the attention score matrix by iterating over
the head dimension and accumulating dot products.
Args:
Q_ptr: Pointer to query tensor [batch_size, num_heads, seq_len, head_dim]
K_ptr: Pointer to key tensor [batch_size, num_heads, seq_len, head_dim]
QK_ptr: Pointer to output tensor [batch_size, num_heads, seq_len, seq_len]
mask_ptr: Pointer to neighborhood mask [seq_len, seq_len]
q_*_stride: Strides for query tensor
k_*_stride: Strides for key tensor
qk_*_stride: Strides for output tensor
batch_size: Number of batches
num_heads: Number of attention heads
seq_len: Sequence length
head_dim: Dimension of each attention head
scale: Scaling factor for attention scores (typically 1/sqrt(head_dim))
kernel_size: Size of the neighborhood window
dilation: Dilation factor for the neighborhood
BLOCK_SIZE_M: Block size for sequence dimension (rows)
BLOCK_SIZE_N: Block size for sequence dimension (cols)
BLOCK_SIZE_K: Block size for head dimension
num_stages: Number of pipeline stages
num_warps: Number of warps
Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(seq_len, BLOCK_SIZE_N))
Each program computes a tile of the attention score matrix.
"""
batch_head_id = tl.program_id(0)
tile_m = tl.program_id(1)
tile_n = tl.program_id(2)
batch_id = batch_head_id // num_heads
head_id = batch_head_id % num_heads
row_start = tile_m * BLOCK_SIZE_M
col_start = tile_n * BLOCK_SIZE_N
row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_start in range(0, head_dim, BLOCK_SIZE_K):
k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
k_mask = k_offsets < head_dim
q_ptrs = (
Q_ptr
+ batch_id * q_batch_stride
+ head_id * q_head_stride
+ row_offsets[:, None] * q_seq_stride
+ k_offsets[None, :] * q_dim_stride
)
q_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
q_chunk = tl.load(q_ptrs, mask=q_mask, other=0.0)
k_ptrs = (
K_ptr
+ batch_id * k_batch_stride
+ head_id * k_head_stride
+ col_offsets[:, None] * k_seq_stride
+ k_offsets[None, :] * k_dim_stride
)
k_mask = (col_offsets[:, None] < seq_len) & k_mask[None, :]
k_chunk = tl.load(k_ptrs, mask=k_mask, other=0.0)
acc += tl.dot(q_chunk, tl.trans(k_chunk))
acc = acc * scale
mask_ptrs = mask_ptr + row_offsets[:, None] * seq_len + col_offsets[None, :]
valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < seq_len)
neighborhood_mask = tl.load(mask_ptrs, mask=valid_mask, other=0.0)
acc = tl.where(neighborhood_mask > 0.0, acc, float("-inf"))
qk_ptrs = (
QK_ptr
+ batch_id * qk_batch_stride
+ head_id * qk_head_stride
+ row_offsets[:, None] * qk_seq_stride
+ col_offsets[None, :] * qk_seq2_stride
)
tl.store(qk_ptrs, acc, mask=valid_mask)
@triton.jit
def _fused_neighborhood_attention_av_kernel(
Attn_ptr,
V_ptr,
Out_ptr,
attn_batch_stride,
attn_head_stride,
attn_seq_stride,
attn_seq2_stride,
v_batch_stride,
v_head_stride,
v_seq_stride,
v_dim_stride,
out_batch_stride,
out_head_stride,
out_seq_stride,
out_dim_stride,
batch_size: tl.constexpr,
num_heads: tl.constexpr,
seq_len: tl.constexpr,
head_dim: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
num_stages: tl.constexpr,
num_warps: tl.constexpr,
):
"""
Compute Attention @ V to produce the final output.
This kernel performs the second stage of neighborhood attention by multiplying
the normalized attention weights with the value matrix. The computation is
tiled for memory efficiency, with each tile computing a block of the output.
Args:
Attn_ptr: Pointer to attention weights [batch_size, num_heads, seq_len, seq_len]
V_ptr: Pointer to value tensor [batch_size, num_heads, seq_len, head_dim]
Out_ptr: Pointer to output tensor [batch_size, num_heads, seq_len, head_dim]
attn_*_stride: Strides for attention weights tensor
v_*_stride: Strides for value tensor
out_*_stride: Strides for output tensor
batch_size: Number of batches
num_heads: Number of attention heads
seq_len: Sequence length
head_dim: Dimension of each attention head
BLOCK_SIZE_M: Block size for sequence dimension (rows)
BLOCK_SIZE_N: Block size for head dimension (cols)
BLOCK_SIZE_K: Block size for sequence dimension (reduction)
num_stages: Number of pipeline stages
num_warps: Number of warps
Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N))
Each program computes a tile of the output matrix.
"""
batch_head_id = tl.program_id(0)
tile_m = tl.program_id(1)
tile_n = tl.program_id(2)
batch_id = batch_head_id // num_heads
head_id = batch_head_id % num_heads
row_start = tile_m * BLOCK_SIZE_M
col_start = tile_n * BLOCK_SIZE_N
row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_start in range(0, seq_len, BLOCK_SIZE_K):
k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
k_mask = k_offsets < seq_len
attn_ptrs = (
Attn_ptr
+ batch_id * attn_batch_stride
+ head_id * attn_head_stride
+ row_offsets[:, None] * attn_seq_stride
+ k_offsets[None, :] * attn_seq2_stride
)
attn_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
attn_chunk = tl.load(attn_ptrs, mask=attn_mask, other=0.0)
v_ptrs = (
V_ptr
+ batch_id * v_batch_stride
+ head_id * v_head_stride
+ k_offsets[:, None] * v_seq_stride
+ col_offsets[None, :] * v_dim_stride
)
v_mask = k_mask[:, None] & (col_offsets[None, :] < head_dim)
v_chunk = tl.load(v_ptrs, mask=v_mask, other=0.0)
acc += tl.dot(attn_chunk, v_chunk)
out_ptrs = (
Out_ptr
+ batch_id * out_batch_stride
+ head_id * out_head_stride
+ row_offsets[:, None] * out_seq_stride
+ col_offsets[None, :] * out_dim_stride
)
valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim)
tl.store(out_ptrs, acc, mask=valid_mask)
@triton.jit
def _fused_neighborhood_attention_grad_qk_kernel(
grad_attn_ptr,
K_ptr,
grad_Q_ptr,
grad_attn_batch_stride,
grad_attn_head_stride,
grad_attn_seq_stride,
grad_attn_seq2_stride,
k_batch_stride,
k_head_stride,
k_seq_stride,
k_dim_stride,
grad_q_batch_stride,
grad_q_head_stride,
grad_q_seq_stride,
grad_q_dim_stride,
batch_size: tl.constexpr,
num_heads: tl.constexpr,
seq_len: tl.constexpr,
head_dim: tl.constexpr,
scale: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
num_stages: tl.constexpr,
num_warps: tl.constexpr,
):
"""
Compute gradient with respect to queries: grad_Q = grad_attn @ K * scale.
This kernel computes the gradient of the loss with respect to the query tensor
by multiplying the gradient of attention weights with the key tensor. The
computation follows the chain rule for the attention mechanism.
Args:
grad_attn_ptr: Pointer to gradient of attention weights [batch_size, num_heads, seq_len, seq_len]
K_ptr: Pointer to key tensor [batch_size, num_heads, seq_len, head_dim]
grad_Q_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, head_dim]
grad_attn_*_stride: Strides for gradient attention tensor
k_*_stride: Strides for key tensor
grad_q_*_stride: Strides for gradient query tensor
batch_size: Number of batches
num_heads: Number of attention heads
seq_len: Sequence length
head_dim: Dimension of each attention head
scale: Scaling factor applied to attention scores
BLOCK_SIZE_M: Block size for sequence dimension (rows)
BLOCK_SIZE_N: Block size for head dimension (cols)
BLOCK_SIZE_K: Block size for sequence dimension (reduction)
num_stages: Number of pipeline stages
num_warps: Number of warps
Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N))
Each program computes a tile of the query gradient matrix.
"""
batch_head_id = tl.program_id(0)
tile_m = tl.program_id(1)
tile_n = tl.program_id(2)
batch_id = batch_head_id // num_heads
head_id = batch_head_id % num_heads
row_start = tile_m * BLOCK_SIZE_M
col_start = tile_n * BLOCK_SIZE_N
row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_start in range(0, seq_len, BLOCK_SIZE_K):
k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
k_mask = k_offsets < seq_len
grad_attn_ptrs = (
grad_attn_ptr
+ batch_id * grad_attn_batch_stride
+ head_id * grad_attn_head_stride
+ row_offsets[:, None] * grad_attn_seq_stride
+ k_offsets[None, :] * grad_attn_seq2_stride
)
grad_attn_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
grad_attn_chunk = tl.load(grad_attn_ptrs, mask=grad_attn_mask, other=0.0)
k_ptrs = (
K_ptr
+ batch_id * k_batch_stride
+ head_id * k_head_stride
+ k_offsets[:, None] * k_seq_stride
+ col_offsets[None, :] * k_dim_stride
)
k_mask_2d = k_mask[:, None] & (col_offsets[None, :] < head_dim)
k_chunk = tl.load(k_ptrs, mask=k_mask_2d, other=0.0)
acc += tl.dot(grad_attn_chunk, k_chunk)
acc = acc * scale
grad_q_ptrs = (
grad_Q_ptr
+ batch_id * grad_q_batch_stride
+ head_id * grad_q_head_stride
+ row_offsets[:, None] * grad_q_seq_stride
+ col_offsets[None, :] * grad_q_dim_stride
)
valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim)
tl.store(grad_q_ptrs, acc, mask=valid_mask)
@triton.jit
def _fused_neighborhood_attention_grad_k_kernel(
grad_attn_ptr,
Q_ptr,
grad_K_ptr,
grad_attn_batch_stride,
grad_attn_head_stride,
grad_attn_seq_stride,
grad_attn_seq2_stride,
q_batch_stride,
q_head_stride,
q_seq_stride,
q_dim_stride,
grad_k_batch_stride,
grad_k_head_stride,
grad_k_seq_stride,
grad_k_dim_stride,
batch_size: tl.constexpr,
num_heads: tl.constexpr,
seq_len: tl.constexpr,
head_dim: tl.constexpr,
scale: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
num_stages: tl.constexpr,
num_warps: tl.constexpr,
):
"""
Compute gradient with respect to keys: grad_K = grad_attn^T @ Q * scale.
This kernel computes the gradient of the loss with respect to the key tensor
by multiplying the transpose of the gradient of attention weights with the
query tensor. The computation follows the chain rule for the attention mechanism.
Args:
grad_attn_ptr: Pointer to gradient of attention weights [batch_size, num_heads, seq_len, seq_len]
Q_ptr: Pointer to query tensor [batch_size, num_heads, seq_len, head_dim]
grad_K_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, head_dim]
grad_attn_*_stride: Strides for gradient attention tensor
q_*_stride: Strides for query tensor
grad_k_*_stride: Strides for gradient key tensor
batch_size: Number of batches
num_heads: Number of attention heads
seq_len: Sequence length
head_dim: Dimension of each attention head
scale: Scaling factor applied to attention scores
BLOCK_SIZE_M: Block size for sequence dimension (rows)
BLOCK_SIZE_N: Block size for head dimension (cols)
BLOCK_SIZE_K: Block size for sequence dimension (reduction)
num_stages: Number of pipeline stages
num_warps: Number of warps
Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N))
Each program computes a tile of the key gradient matrix.
"""
batch_head_id = tl.program_id(0)
tile_m = tl.program_id(1)
tile_n = tl.program_id(2)
batch_id = batch_head_id // num_heads
head_id = batch_head_id % num_heads
row_start = tile_m * BLOCK_SIZE_M
col_start = tile_n * BLOCK_SIZE_N
row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_start in range(0, seq_len, BLOCK_SIZE_K):
k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
k_mask = k_offsets < seq_len
q_ptrs = (
Q_ptr
+ batch_id * q_batch_stride
+ head_id * q_head_stride
+ k_offsets[:, None] * q_seq_stride
+ col_offsets[None, :] * q_dim_stride
)
q_mask = k_mask[:, None] & (col_offsets[None, :] < head_dim)
q_chunk = tl.load(q_ptrs, mask=q_mask, other=0.0)
grad_attn_T_ptrs = (
grad_attn_ptr
+ batch_id * grad_attn_batch_stride
+ head_id * grad_attn_head_stride
+ row_offsets[:, None] * grad_attn_seq2_stride
+ k_offsets[None, :] * grad_attn_seq_stride
)
grad_attn_T_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
grad_attn_T_chunk = tl.load(grad_attn_T_ptrs, mask=grad_attn_T_mask, other=0.0)
acc += tl.dot(grad_attn_T_chunk, q_chunk)
acc = acc * scale
grad_k_ptrs = (
grad_K_ptr
+ batch_id * grad_k_batch_stride
+ head_id * grad_k_head_stride
+ row_offsets[:, None] * grad_k_seq_stride
+ col_offsets[None, :] * grad_k_dim_stride
)
valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim)
tl.store(grad_k_ptrs, acc, mask=valid_mask)
@triton.jit
def _fused_neighborhood_attention_grad_v_kernel(
Attn_ptr,
grad_output_ptr,
grad_V_ptr,
attn_batch_stride,
attn_head_stride,
attn_seq_stride,
attn_seq2_stride,
grad_out_batch_stride,
grad_out_head_stride,
grad_out_seq_stride,
grad_out_dim_stride,
grad_v_batch_stride,
grad_v_head_stride,
grad_v_seq_stride,
grad_v_dim_stride,
batch_size: tl.constexpr,
num_heads: tl.constexpr,
seq_len: tl.constexpr,
head_dim: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
num_stages: tl.constexpr,
num_warps: tl.constexpr,
):
"""
Compute gradient with respect to values: grad_V = Attn^T @ grad_output.
This kernel computes the gradient of the loss with respect to the value tensor
by multiplying the transpose of the attention weights with the gradient of the
output. The computation follows the chain rule for the attention mechanism.
Args:
Attn_ptr: Pointer to attention weights [batch_size, num_heads, seq_len, seq_len]
grad_output_ptr: Pointer to gradient of output [batch_size, num_heads, seq_len, head_dim]
grad_V_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, head_dim]
attn_*_stride: Strides for attention weights tensor
grad_out_*_stride: Strides for gradient output tensor
grad_v_*_stride: Strides for gradient value tensor
batch_size: Number of batches
num_heads: Number of attention heads
seq_len: Sequence length
head_dim: Dimension of each attention head
BLOCK_SIZE_M: Block size for sequence dimension (rows)
BLOCK_SIZE_N: Block size for head dimension (cols)
BLOCK_SIZE_K: Block size for sequence dimension (reduction)
num_stages: Number of pipeline stages
num_warps: Number of warps
Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N))
Each program computes a tile of the value gradient matrix.
"""
batch_head_id = tl.program_id(0)
tile_m = tl.program_id(1)
tile_n = tl.program_id(2)
batch_id = batch_head_id // num_heads
head_id = batch_head_id % num_heads
row_start = tile_m * BLOCK_SIZE_M
col_start = tile_n * BLOCK_SIZE_N
row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_start in range(0, seq_len, BLOCK_SIZE_K):
k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
k_mask = k_offsets < seq_len
attn_ptrs = (
Attn_ptr
+ batch_id * attn_batch_stride
+ head_id * attn_head_stride
+ k_offsets[:, None] * attn_seq_stride
+ row_offsets[None, :] * attn_seq2_stride
)
attn_mask = k_mask[:, None] & (row_offsets[None, :] < seq_len)
attn_chunk = tl.load(attn_ptrs, mask=attn_mask, other=0.0)
grad_out_ptrs = (
grad_output_ptr
+ batch_id * grad_out_batch_stride
+ head_id * grad_out_head_stride
+ k_offsets[:, None] * grad_out_seq_stride
+ col_offsets[None, :] * grad_out_dim_stride
)
grad_out_mask = k_mask[:, None] & (col_offsets[None, :] < head_dim)
grad_out_chunk = tl.load(grad_out_ptrs, mask=grad_out_mask, other=0.0)
acc += tl.dot(tl.trans(attn_chunk), grad_out_chunk)
grad_v_ptrs = (
grad_V_ptr
+ batch_id * grad_v_batch_stride
+ head_id * grad_v_head_stride
+ row_offsets[:, None] * grad_v_seq_stride
+ col_offsets[None, :] * grad_v_dim_stride
)
valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim)
tl.store(grad_v_ptrs, acc, mask=valid_mask)
@triton.jit
def _fused_neighborhood_attention_grad_attn_kernel(
grad_output_ptr,
V_ptr,
grad_attn_ptr,
grad_out_batch_stride,
grad_out_head_stride,
grad_out_seq_stride,
grad_out_dim_stride,
v_batch_stride,
v_head_stride,
v_seq_stride,
v_dim_stride,
grad_attn_batch_stride,
grad_attn_head_stride,
grad_attn_seq_stride,
grad_attn_seq2_stride,
batch_size: tl.constexpr,
num_heads: tl.constexpr,
seq_len: tl.constexpr,
head_dim: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
num_stages: tl.constexpr,
num_warps: tl.constexpr,
):
"""
Compute gradient with respect to attention weights: grad_attn = grad_output @ V^T.
This kernel computes the gradient of the loss with respect to the attention
weights by multiplying the gradient of the output with the transpose of the
value tensor. This gradient will later be passed through the softmax backward
pass to compute gradients for the attention scores.
Args:
grad_output_ptr: Pointer to gradient of output [batch_size, num_heads, seq_len, head_dim]
V_ptr: Pointer to value tensor [batch_size, num_heads, seq_len, head_dim]
grad_attn_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, seq_len]
grad_out_*_stride: Strides for gradient output tensor
v_*_stride: Strides for value tensor
grad_attn_*_stride: Strides for gradient attention tensor
batch_size: Number of batches
num_heads: Number of attention heads
seq_len: Sequence length
head_dim: Dimension of each attention head
BLOCK_SIZE_M: Block size for sequence dimension (rows)
BLOCK_SIZE_N: Block size for sequence dimension (cols)
BLOCK_SIZE_K: Block size for head dimension (reduction)
num_stages: Number of pipeline stages
num_warps: Number of warps
Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(seq_len, BLOCK_SIZE_N))
Each program computes a tile of the attention gradient matrix.
"""
batch_head_id = tl.program_id(0)
tile_m = tl.program_id(1)
tile_n = tl.program_id(2)
batch_id = batch_head_id // num_heads
head_id = batch_head_id % num_heads
row_start = tile_m * BLOCK_SIZE_M
col_start = tile_n * BLOCK_SIZE_N
row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k_start in range(0, head_dim, BLOCK_SIZE_K):
k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
k_mask = k_offsets < head_dim
grad_out_ptrs = (
grad_output_ptr
+ batch_id * grad_out_batch_stride
+ head_id * grad_out_head_stride
+ row_offsets[:, None] * grad_out_seq_stride
+ k_offsets[None, :] * grad_out_dim_stride
)
grad_out_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :]
grad_out_chunk = tl.load(grad_out_ptrs, mask=grad_out_mask, other=0.0)
v_ptrs = (
V_ptr
+ batch_id * v_batch_stride
+ head_id * v_head_stride
+ col_offsets[None, :] * v_seq_stride
+ k_offsets[:, None] * v_dim_stride
)
v_mask = (col_offsets[None, :] < seq_len) & k_mask[:, None]
v_chunk = tl.load(v_ptrs, mask=v_mask, other=0.0)
acc += tl.dot(grad_out_chunk, v_chunk)
grad_attn_ptrs = (
grad_attn_ptr
+ batch_id * grad_attn_batch_stride
+ head_id * grad_attn_head_stride
+ row_offsets[:, None] * grad_attn_seq_stride
+ col_offsets[None, :] * grad_attn_seq2_stride
)
valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < seq_len)
tl.store(grad_attn_ptrs, acc, mask=valid_mask)
def fused_neighborhood_attention_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kernel_size: int = 7,
dilation: int = 1,
scale: float = None,
return_lse: bool = False,
) -> tuple:
"""
Fused neighborhood attention forward pass.
Args:
query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim]
key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim]
value: Value tensor of shape [batch_size, num_heads, seq_len, head_dim]
kernel_size: Size of the neighborhood window
dilation: Dilation factor for the neighborhood
scale: Scaling factor for attention scores (default: rsqrt(head_dim))
return_lse: Whether to return log-sum-exp values
Returns:
Tuple of (output tensor, softmax parameters for backward)
"""
batch_size, num_heads, seq_len, head_dim = query.shape
if scale is None:
scale = 1.0 / math.sqrt(head_dim)
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
output = torch.empty_like(query)
qk_scores = torch.empty(batch_size, num_heads, seq_len, seq_len, device=query.device, dtype=query.dtype)
mask = torch.zeros(seq_len, seq_len, device=query.device, dtype=torch.float32)
BLOCK_SIZE, num_warps = calculate_settings(seq_len)
BLOCK_SIZE_M = min(64, triton.next_power_of_2(seq_len))
BLOCK_SIZE_N = min(64, triton.next_power_of_2(seq_len))
BLOCK_SIZE_K = max(16, triton.next_power_of_2(head_dim))
num_stages = 4 if seq_len >= 512 else 2
grid_mask = (seq_len,)
_neighborhood_mask_kernel[grid_mask](
mask,
seq_len,
kernel_size,
dilation,
BLOCK_SIZE,
num_stages,
num_warps,
)
grid_qk = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(seq_len, BLOCK_SIZE_N))
_fused_neighborhood_attention_qk_kernel[grid_qk](
query,
key,
qk_scores,
mask,
query.stride(0),
query.stride(1),
query.stride(2),
query.stride(3),
key.stride(0),
key.stride(1),
key.stride(2),
key.stride(3),
qk_scores.stride(0),
qk_scores.stride(1),
qk_scores.stride(2),
qk_scores.stride(3),
batch_size,
num_heads,
seq_len,
head_dim,
scale,
kernel_size,
dilation,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
BLOCK_SIZE_K,
num_stages,
num_warps,
)
qk_reshaped = qk_scores.view(batch_size * num_heads * seq_len, seq_len)
attn_reshaped, BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch = _softmax_forward(qk_reshaped)
attn_weights = attn_reshaped.view(batch_size, num_heads, seq_len, seq_len)
grid_av = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N))
_fused_neighborhood_attention_av_kernel[grid_av](
attn_weights,
value,
output,
attn_weights.stride(0),
attn_weights.stride(1),
attn_weights.stride(2),
attn_weights.stride(3),
value.stride(0),
value.stride(1),
value.stride(2),
value.stride(3),
output.stride(0),
output.stride(1),
output.stride(2),
output.stride(3),
batch_size,
num_heads,
seq_len,
head_dim,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
BLOCK_SIZE_K,
num_stages,
num_warps,
)
if return_lse:
raise NotImplementedError("return_lse=True is not supported yet.")
softmax_params = (BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch)
return output, attn_weights, softmax_params
class LigerFusedNeighborhoodAttentionFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, query, key, value, kernel_size=7, dilation=1, scale=None):
output, attn_weights, softmax_params = fused_neighborhood_attention_forward(
query, key, value, kernel_size, dilation, scale
)
ctx.save_for_backward(query, key, value, attn_weights)
ctx.kernel_size = kernel_size
ctx.dilation = dilation
ctx.scale = scale
ctx.softmax_params = softmax_params
return output
@staticmethod
@ensure_contiguous
def backward(ctx, grad_output):
query, key, value, attn_weights = ctx.saved_tensors
BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch = ctx.softmax_params
batch_size, num_heads, seq_len, head_dim = query.shape
scale = ctx.scale if ctx.scale is not None else 1.0 / math.sqrt(head_dim)
grad_query = torch.zeros_like(query)
grad_key = torch.zeros_like(key)
grad_value = torch.zeros_like(value)
grad_attn_weights = torch.zeros_like(attn_weights)
BLOCK_SIZE_M = min(64, triton.next_power_of_2(seq_len))
BLOCK_SIZE_N = min(64, triton.next_power_of_2(seq_len))
BLOCK_SIZE_K = min(64, triton.next_power_of_2(head_dim))
num_stages = 4 if seq_len >= 512 else 2
_, num_warps = calculate_settings(seq_len)
grid_grad_attn = (
batch_size * num_heads,
triton.cdiv(seq_len, BLOCK_SIZE_M),
triton.cdiv(seq_len, BLOCK_SIZE_N),
)
_fused_neighborhood_attention_grad_attn_kernel[grid_grad_attn](
grad_output,
value,
grad_attn_weights,
grad_output.stride(0),
grad_output.stride(1),
grad_output.stride(2),
grad_output.stride(3),
value.stride(0),
value.stride(1),
value.stride(2),
value.stride(3),
grad_attn_weights.stride(0),
grad_attn_weights.stride(1),
grad_attn_weights.stride(2),
grad_attn_weights.stride(3),
batch_size,
num_heads,
seq_len,
head_dim,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
BLOCK_SIZE_K,
num_stages,
num_warps,
)
grad_attn_reshaped = grad_attn_weights.view(batch_size * num_heads * seq_len, seq_len)
attn_reshaped = attn_weights.view(batch_size * num_heads * seq_len, seq_len)
grad_qk_reshaped = _softmax_backward(
grad_attn_reshaped, attn_reshaped, BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch
)
grad_qk_scores = grad_qk_reshaped.view(batch_size, num_heads, seq_len, seq_len)
grid_grad_q = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N))
_fused_neighborhood_attention_grad_qk_kernel[grid_grad_q](
grad_qk_scores,
key,
grad_query,
grad_qk_scores.stride(0),
grad_qk_scores.stride(1),
grad_qk_scores.stride(2),
grad_qk_scores.stride(3),
key.stride(0),
key.stride(1),
key.stride(2),
key.stride(3),
grad_query.stride(0),
grad_query.stride(1),
grad_query.stride(2),
grad_query.stride(3),
batch_size,
num_heads,
seq_len,
head_dim,
scale,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
BLOCK_SIZE_K,
num_stages,
num_warps,
)
grid_grad_k = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N))
_fused_neighborhood_attention_grad_k_kernel[grid_grad_k](
grad_qk_scores,
query,
grad_key,
grad_qk_scores.stride(0),
grad_qk_scores.stride(1),
grad_qk_scores.stride(2),
grad_qk_scores.stride(3),
query.stride(0),
query.stride(1),
query.stride(2),
query.stride(3),
grad_key.stride(0),
grad_key.stride(1),
grad_key.stride(2),
grad_key.stride(3),
batch_size,
num_heads,
seq_len,
head_dim,
scale,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
BLOCK_SIZE_K,
num_stages,
num_warps,
)
grid_grad_v = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N))
_fused_neighborhood_attention_grad_v_kernel[grid_grad_v](
attn_weights,
grad_output,
grad_value,
attn_weights.stride(0),
attn_weights.stride(1),
attn_weights.stride(2),
attn_weights.stride(3),
grad_output.stride(0),
grad_output.stride(1),
grad_output.stride(2),
grad_output.stride(3),
grad_value.stride(0),
grad_value.stride(1),
grad_value.stride(2),
grad_value.stride(3),
batch_size,
num_heads,
seq_len,
head_dim,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
BLOCK_SIZE_K,
num_stages,
num_warps,
)
return grad_query, grad_key, grad_value, None, None, None
import operator
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import compare_version
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.utils import is_npu_available
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import tanh
except ModuleNotFoundError:
# for working with NGC containers
from triton.language.extra.cuda.libdevice import tanh
else:
from triton.language.math import tanh
@triton.jit
def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
program_id = tl.program_id(0).to(tl.int64)
# locate start index
a += program_id * stride
b += program_id * stride
c += program_id * stride
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
b_row = tl.load(b + col_offsets, mask=mask, other=0)
# tanh approximation form of GELU is computed with:
# 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
a_cubed = a_row * a_row * a_row
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
tanh_result = tanh(tanh_arg)
geglu_a = 0.5 * a_row * (1 + tanh_result)
c_row = geglu_a.cast(b_row.dtype) * b_row
tl.store(c + col_offsets, c_row, mask=mask)
@triton.jit
def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
program_id = tl.program_id(0).to(tl.int64)
# locate start index
dc += program_id * stride
a += program_id * stride
b += program_id * stride
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
dc_row = tl.load(dc + col_offsets, mask=mask, other=0)
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
b_row = tl.load(b + col_offsets, mask=mask, other=0)
# recomputation to save memory
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
a_cubed = a_row * a_row * a_row
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
tanh_result = tanh(tanh_arg)
geglu_a = 0.5 * a_row * (1 + tanh_result)
geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
db_row = dc_row.cast(tl.float32) * geglu_a
# Gradient w.r.t. a can be computed with:
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
# where z = sqrt(2/pi) * (a + 0.044715 * a^3)
term1 = 0.5 * (1 + tanh_result)
tanh_sq = tanh_result * tanh_result
term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
da_row = dc_row * b_row * (term1 + term2)
tl.store(a + col_offsets, da_row, mask=mask)
tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
def geglu_forward(a, b):
ori_shape = a.shape
n_cols = ori_shape[-1]
a = a.view(-1, n_cols)
b = b.view(-1, n_cols)
c = torch.empty_like(a)
n_rows = a.shape[0]
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
_geglu_tanh_forward_kernel[(n_rows,)](
a,
b,
c,
c.stride(-2),
n_cols=n_cols,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return a, b, c.view(*ori_shape)
def geglu_backward(a, b, dc):
ori_shape = dc.shape
n_cols = ori_shape[-1]
dc = dc.view(-1, n_cols)
n_rows = dc.shape[0]
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
_geglu_tanh_backward_kernel[(n_rows,)](
dc,
a,
b,
dc.stride(-2),
n_cols=n_cols,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return a.view(*ori_shape), b.view(*ori_shape)
class LigerGELUMulFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, a, b):
a, b, c = geglu_forward(a, b)
ctx.save_for_backward(a, b)
return c
@staticmethod
@ensure_contiguous
def backward(ctx, dc):
a, b = ctx.saved_tensors
a, b = geglu_backward(a, b, dc)
return a, b
import operator
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import compare_version
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.utils import infer_device
from liger_kernel.utils import is_npu_available
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
# for working with NGC containers
from triton.language.extra.cuda.libdevice import rsqrt
else:
from triton.language.math import rsqrt
if infer_device() == "npu":
MAX_FUSED_SIZE = 16384 # 8192
else:
MAX_FUSED_SIZE = 65536
@triton.jit
def _group_norm_forward_kernel(
Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size)
Y_row_stride, # stride of each row in output
Y_col_stride, # stride of each column in output
X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size)
X_row_stride, # stride of each row in input
X_col_stride, # stride of each column in input
Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
Mean_row_stride, # stride of each row in mean
Mean_col_stride, # stride of each column in mean
RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
RSTD_row_stride, # stride of each row in rstd
RSTD_col_stride, # stride of each column in rstd
W_ptr, # pointer to W
B_ptr, # pointer to B
hidden_size, # hidden size of X
channels_per_group, # the number of channels per group
eps,
BLOCK_SIZE: tl.constexpr,
):
"""
References:
https://nn.labml.ai/normalization/group_norm/index.html
"""
batch_idx = tl.program_id(0)
group_idx = tl.program_id(1)
X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride
Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride
block_range = tl.arange(0, BLOCK_SIZE)
# Compute mean and variance using the online algorithm
s = 0.0
squared_sum = 0.0
for i in tl.range(0, hidden_size, BLOCK_SIZE):
hidden_size_offsets = i + block_range
mask = hidden_size_offsets < hidden_size
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0)
s += tl.sum(X)
# X**2
squared_sum += tl.sum(X * X)
m = s / hidden_size
# variance = E[X**2] - E[X]**2
variance = (squared_sum / hidden_size) - (m * m)
# 1/std
rstd = rsqrt(variance + eps)
# Normalize — flat loop over full hidden_size (not per-channel)
# This avoids the nested channel × per_channel_hidden loop where
# BLOCK_SIZE >> hidden_size_per_channel causes massive padding waste.
hidden_size_per_channel = hidden_size // channels_per_group
for i in tl.range(0, hidden_size, BLOCK_SIZE):
hidden_size_offsets = i + block_range
mask = hidden_size_offsets < hidden_size
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
# Determine which channel each element belongs to, then load W/B
local_channel = hidden_size_offsets // hidden_size_per_channel
global_channel = group_idx * channels_per_group + local_channel
W = tl.load(W_ptr + global_channel, mask=mask)
B = tl.load(B_ptr + global_channel, mask=mask)
Y = (X - m) * rstd * W + B
tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
@triton.jit
def _group_norm_backward_kernel(
X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size)
X_row_stride, # stride of each row in input
X_col_stride, # stride of each column in input
W_ptr, # pointer to weights, shape (n_channels)
Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
Mean_ptr_row_stride, # stride of each column in mean
Mean_ptr_col_stride, # stride of each column in mean
RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size)
DW_ptr, # pointer to weights grad, shape (n_channels)
DB_ptr, # pointer to bias grad, shape (n_channels)
UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size)
hidden_size: tl.constexpr, # hidden size
channels_per_group: tl.constexpr, # number of groups in group norm
BLOCK_SIZE: tl.constexpr,
dtype: tl.constexpr,
):
"""
References:
https://nn.labml.ai/normalization/group_norm/index.html
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
The backprop equations are the same for group_norm and layer_norm
the only difference here is that we load the Mean, Rstd corresponding to the
group we're computing gradients for and the mean and rstd are computed over n-channels
so the total number of elements we compute the mean over is num_channels_per_group * hidden_size
We also need to load the Weights corresponding to the current channel to compute the gradients.
"""
batch_idx = tl.program_id(0)
group_idx = tl.program_id(1)
# Move the pointers to the correct batch
X_ptr += batch_idx * X_row_stride
DX_ptr += batch_idx * X_row_stride
UPSTREAM_ptr += batch_idx * X_row_stride
# Mean and rstd are the same shape so have the same strides
mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
c1 = 0.0
c2 = 0.0
block_range = tl.arange(0, BLOCK_SIZE)
# We need to compute the sum terms of the backprop equations across all channels in the group
for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
dW = 0.0
dB = 0.0
# Move the pointers to the correct channel
W = tl.load(W_ptr + channel_idx)
for i in tl.range(0, hidden_size, BLOCK_SIZE):
hidden_size_offsets = i + block_range
mask = hidden_size_offsets < hidden_size
X = tl.load(
X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
mask=mask,
other=0.0,
)
UPSTREAM_grad = tl.load(
UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
mask=mask,
other=0.0,
)
x_hat = (X - mean) * rstd
dW += tl.sum(UPSTREAM_grad * x_hat)
dB += tl.sum(UPSTREAM_grad)
wdy = W * UPSTREAM_grad
c1 += tl.sum(x_hat * wdy)
c2 += tl.sum(wdy)
# Need to ensure additions to the same channel are atomic
tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype))
tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype))
N = hidden_size * channels_per_group
c1 = c1 / N
c2 = c2 / N
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
# Move the pointers to the correct channel
W = tl.load(W_ptr + channel_idx)
for i in range(0, hidden_size, BLOCK_SIZE):
hidden_size_offsets = i + block_range
mask = hidden_size_offsets < hidden_size
X = tl.load(
X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
mask=mask,
other=0.0,
)
UPSTREAM_grad = tl.load(
UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
mask=mask,
other=0.0,
)
x_hat = (X - mean) * rstd
wdy = W * UPSTREAM_grad
dx = (wdy - (x_hat * c1 + c2)) * rstd
tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask)
def group_norm_forward(X, num_channels, num_groups, W, B, eps):
shape = X.shape
batch_size = shape[0]
channels_per_group = num_channels // num_groups
# Reshape X so that the mean and std are computed across the groups
X = X.view(batch_size, num_groups, -1).contiguous()
hidden_size = X.shape[-1]
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
_group_norm_forward_kernel[(batch_size, num_groups)](
Y,
Y.stride(0),
Y.stride(1),
X,
X.stride(0),
X.stride(1),
Mean,
Mean.stride(0),
Mean.stride(1),
RSTD,
RSTD.stride(0),
RSTD.stride(1),
W,
B,
hidden_size,
channels_per_group,
eps,
BLOCK_SIZE=BLOCK_SIZE,
)
# Return tensors in the original shape
return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE
def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups):
shape = dY.shape
batch_size = shape[0]
hidden_size = dY.shape[-1]
channels_per_group = num_channels // num_groups
dY = dY.view(batch_size, num_groups, -1)
DX = torch.empty(
(batch_size, num_groups, hidden_size * channels_per_group),
dtype=X.dtype,
device=X.device,
)
DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device)
DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device)
triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
_group_norm_backward_kernel[(batch_size, num_groups)](
X,
X.stride(0),
X.stride(1),
W,
Mean,
Mean.stride(0),
Mean.stride(1),
RSTD,
DX,
DW,
DB,
dY,
hidden_size,
channels_per_group,
BLOCK_SIZE=BLOCK_SIZE,
dtype=triton_dtype,
)
# Return tensors in the original shape
return DX.view(*shape), DW, DB
class LigerGroupNormFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(
ctx,
X,
affine_scaling_weight,
affine_shifting_bias,
num_channels,
num_groups,
eps,
):
Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward(
X,
num_channels,
num_groups,
affine_scaling_weight,
affine_shifting_bias,
eps,
)
ctx.num_channels = num_channels
ctx.num_groups = num_groups
ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, W, B, Mean, RSTD = ctx.saved_tensors
DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
return DX, DW, DB, None, None, None
import torch
import triton
import triton.language as tl
# Loss type constants for Triton constexpr branching
# GRPO/DAPO/BNPO/DR_GRPO all use the same per-token loss computation (standard PPO clipping)
_LOSS_TYPE_GRPO: tl.constexpr = tl.constexpr(0)
_LOSS_TYPE_CISPO: tl.constexpr = tl.constexpr(1)
_LOSS_TYPE_SAPO: tl.constexpr = tl.constexpr(2)
_str_to_loss_type = {
"grpo": _LOSS_TYPE_GRPO.value,
"dapo": _LOSS_TYPE_GRPO.value,
"bnpo": _LOSS_TYPE_GRPO.value,
"dr_grpo": _LOSS_TYPE_GRPO.value,
"luspo": _LOSS_TYPE_GRPO.value,
"cispo": _LOSS_TYPE_CISPO.value,
"sapo": _LOSS_TYPE_SAPO.value,
}
@triton.jit
def _selective_log_softmax_kernel(
LOGITS,
INPUT_IDS,
LOG_P,
MASK,
TEMPERATURE,
stride_input_ids_b,
L: tl.constexpr,
N: tl.constexpr,
BLOCK_N: tl.constexpr = 4096,
):
off_b = tl.program_id(0).cast(tl.int64)
off_l = tl.program_id(1).cast(tl.int64)
LOGITS += off_b * (L + 1) * N + off_l * N
INPUT_IDS += off_b * stride_input_ids_b + off_l
LOG_P += off_b * L + off_l
if MASK is not None:
MASK += off_b * stride_input_ids_b + off_l
not_skip = tl.load(MASK)
if not_skip == 0:
return
m_i = float("-inf")
l_i = 0.0
for start in range(0, N, BLOCK_N):
cols = start + tl.arange(0, BLOCK_N)
logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
new_m_i = tl.maximum(m_i, tl.max(logits))
alpha = tl.exp(m_i - new_m_i)
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
m_i = new_m_i
lse = m_i + tl.log(l_i)
ids = tl.load(INPUT_IDS)
x = tl.load(LOGITS + ids).to(tl.float32) / TEMPERATURE
logp = x - lse
tl.store(LOG_P, logp)
# compue old_logp and ref_logp, it reduce 10G peak Memory. it does not requires grad
@torch.no_grad
def fused_selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 0.9, mask=None):
assert logits.is_contiguous()
B, L_ADD_1, N = logits.shape
L = L_ADD_1 - 1
input_ids = input_ids[:, -L:]
if mask is not None:
mask = mask[:, -L:]
log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device)
kwargs = {"BLOCK_N": 2048, "num_stages": 4, "num_warps": 1}
_selective_log_softmax_kernel[(B, L)](
logits, input_ids, log_p, mask, temperature, input_ids.stride(0), L, N, **kwargs
)
return log_p
# @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
# for BLOCK_N in [2048, 4096, 8192]
# for ns in [1, 2, 4]
# for nw in [1, 2, 4, 8, 16]],
# key=['N'])
@triton.jit
def _grpo_loss_fwd_kernel(
LOGITS,
OLD_LOGP,
REF_LOGP,
INPUT_IDS,
COMPLETION_MASK,
ADVANTAGES,
VLLM_IS_RATIO,
VLLM_IS_RATIO_STRIDE,
LOSS,
LSE,
KL,
IS_CLIPPED,
TEMPERATURE,
BETA: tl.constexpr,
EPS_LOW,
EPS_HIGH,
LOSS_TYPE: tl.constexpr,
SAPO_TEMP_POS,
SAPO_TEMP_NEG,
DELTA,
USE_BIAS_CORRECTION_KL: tl.constexpr,
L: tl.constexpr,
N: tl.constexpr,
BLOCK_N: tl.constexpr = 4096,
):
off_b = tl.program_id(0).cast(tl.int64)
off_l = tl.program_id(1).cast(tl.int64)
if COMPLETION_MASK is not None:
COMPLETION_MASK += off_b * L + off_l
not_skip = tl.load(COMPLETION_MASK)
if not_skip == 0:
return
LOGITS += off_b * (L + 1) * N + off_l * N
INPUT_IDS += off_b * L + off_l
ADVANTAGES += off_b
LOSS += off_b * L + off_l
LSE += off_b * L + off_l
IS_CLIPPED += off_b * L + off_l
m_i = float("-inf")
l_i = 0.0
for start in range(0, N, BLOCK_N):
cols = start + tl.arange(0, BLOCK_N)
logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
new_m_i = tl.maximum(m_i, tl.max(logits))
alpha = tl.exp(m_i - new_m_i)
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
m_i = new_m_i
lse = m_i + tl.log(l_i)
idx = tl.load(INPUT_IDS)
x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
logp = x - lse
if OLD_LOGP is None:
old_logp = logp
else:
OLD_LOGP += off_b * L + off_l
old_logp = tl.load(OLD_LOGP).to(tl.float32)
coef_1 = tl.exp(logp - old_logp)
advantage = tl.load(ADVANTAGES).to(tl.float32)
# Branch based on loss type
if LOSS_TYPE == 0: # GRPO/DAPO/BNPO/DR_GRPO: standard PPO clipping
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
is_clipped = is_low_clipped | is_high_clipped
# Apply delta (two-sided clipping from INTELLECT-2) to coef_1
if DELTA != 0.0:
coef_1 = tl.minimum(coef_1, DELTA)
per_token_loss1 = coef_1 * advantage
per_token_loss2 = coef_2 * advantage
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
elif LOSS_TYPE == 1: # CISPO: upper-bound only clipping, detached, multiply by logp
# Reference: MiniMax-M1 technical report
# https://github.com/huggingface/trl/blob/035c3ff151b953ca72cdfe0ee966bc1469a26fde/trl/trainer/grpo_trainer.py#L2030
coef_2 = tl.minimum(coef_1, EPS_HIGH) # upper-bound only (EPS_HIGH is the raw bound for CISPO)
per_token_loss = -coef_2 * advantage * logp # includes logp term
is_clipped = (coef_1 > EPS_HIGH) & (advantage > 0)
elif LOSS_TYPE == 2: # SAPO: soft adaptive policy optimization with sigmoid gating
# Reference: https://huggingface.co/papers/2511.20347
# Formula: sigmoid(τ * (ρ - 1)) * 4 / τ
temperature = tl.where(advantage > 0, SAPO_TEMP_POS, SAPO_TEMP_NEG)
sigmoid_input = temperature * (coef_1 - 1.0)
sapo_coef = tl.sigmoid(sigmoid_input) * 4.0 / temperature
per_token_loss = -sapo_coef * advantage
is_clipped = 0.0 # SAPO has no clipping concept
# Apply vLLM importance sampling correction BEFORE adding KL penalty
if VLLM_IS_RATIO is not None:
# Use modulo to support both (B, L) per-token and (B, 1) per-sequence shapes
vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to(
tl.float32
)
per_token_loss = per_token_loss * vllm_is_ratio
if BETA != 0.0:
REF_LOGP += off_b * L + off_l
KL += off_b * L + off_l
ref_logp = tl.load(REF_LOGP).to(tl.float32)
kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
if USE_BIAS_CORRECTION_KL:
# Importance-sampling-corrected KL (DeepSeek-V3.2): kl *= coef_1
kl = kl * tl.exp(logp - old_logp)
per_token_loss += BETA * kl
tl.store(KL, kl)
tl.store(LOSS, per_token_loss)
tl.store(LSE, lse)
tl.store(IS_CLIPPED, is_clipped)
# Sequence-level forward kernel: uses pre-computed coef_1 per sequence
@triton.jit
def _grpo_loss_fwd_kernel_seq(
LOGITS,
OLD_LOGP,
REF_LOGP,
INPUT_IDS,
COMPLETION_MASK,
ADVANTAGES,
COEF_1, # Pre-computed sequence-level importance weight (B,)
COEF_2, # Pre-computed clipped coef (B,)
IS_CLIPPED_SEQ, # Pre-computed clipping indicator (B,)
VLLM_IS_RATIO, # vLLM importance sampling ratio (B, L) or (B, 1) or None
VLLM_IS_RATIO_STRIDE, # stride for VLLM_IS_RATIO (L for per-token, 1 for per-sequence)
LOSS,
LSE,
KL,
IS_CLIPPED,
TEMPERATURE,
BETA: tl.constexpr,
USE_BIAS_CORRECTION_KL: tl.constexpr,
L: tl.constexpr,
N: tl.constexpr,
BLOCK_N: tl.constexpr = 4096,
):
off_b = tl.program_id(0).cast(tl.int64)
off_l = tl.program_id(1).cast(tl.int64)
if COMPLETION_MASK is not None:
COMPLETION_MASK += off_b * L + off_l
not_skip = tl.load(COMPLETION_MASK)
if not_skip == 0:
return
LOGITS += off_b * (L + 1) * N + off_l * N
INPUT_IDS += off_b * L + off_l
ADVANTAGES += off_b
COEF_1 += off_b
COEF_2 += off_b
IS_CLIPPED_SEQ += off_b
LOSS += off_b * L + off_l
LSE += off_b * L + off_l
IS_CLIPPED += off_b * L + off_l
# Compute log softmax
m_i = float("-inf")
l_i = 0.0
for start in range(0, N, BLOCK_N):
cols = start + tl.arange(0, BLOCK_N)
logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
new_m_i = tl.maximum(m_i, tl.max(logits))
alpha = tl.exp(m_i - new_m_i)
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
m_i = new_m_i
lse = m_i + tl.log(l_i)
idx = tl.load(INPUT_IDS)
x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
logp = x - lse
# Load pre-computed sequence-level coefficients
coef_1 = tl.load(COEF_1).to(tl.float32)
coef_2 = tl.load(COEF_2).to(tl.float32)
is_clipped_seq = tl.load(IS_CLIPPED_SEQ)
advantage = tl.load(ADVANTAGES).to(tl.float32)
per_token_loss1 = coef_1 * advantage
per_token_loss2 = coef_2 * advantage
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
# Apply vLLM importance sampling correction BEFORE adding KL
if VLLM_IS_RATIO is not None:
vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to(
tl.float32
)
per_token_loss = per_token_loss * vllm_is_ratio
if BETA != 0.0:
REF_LOGP += off_b * L + off_l
KL += off_b * L + off_l
ref_logp = tl.load(REF_LOGP).to(tl.float32)
kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
if USE_BIAS_CORRECTION_KL:
# Importance-sampling-corrected KL (DeepSeek-V3.2): kl *= token-level coef_1
if OLD_LOGP is None:
old_logp = logp
else:
old_logp = tl.load(OLD_LOGP + off_b * L + off_l).to(tl.float32)
kl = kl * tl.exp(logp - old_logp)
per_token_loss += BETA * kl
tl.store(KL, kl)
tl.store(LOSS, per_token_loss)
tl.store(LSE, lse)
tl.store(IS_CLIPPED, is_clipped_seq) # Same for all tokens in sequence
# Sequence-level backward kernel
@triton.jit
def _grpo_loss_bwd_kernel_seq(
DLOSS,
DLOSS_SUM,
DLOGITS,
LOGITS,
OLD_LOGP,
REF_LOGP,
INPUT_IDS,
ADVANTAGES,
COMPLETION_MASK,
LSE,
COEF_1, # Pre-computed sequence-level importance weight (B,)
SEQ_LEN, # Number of valid tokens per sequence (B,)
TEMPERATURE,
BETA: tl.constexpr,
USE_BIAS_CORRECTION_KL: tl.constexpr,
EPS_LOW,
EPS_HIGH,
DELTA,
loss_stride0,
loss_stride1,
L: tl.constexpr,
N: tl.constexpr,
BLOCK_N: tl.constexpr = 4096,
):
off_b = tl.program_id(0).cast(tl.int64)
off_l = tl.program_id(1).cast(tl.int64)
DLOGITS += off_b * (L + 1) * N + off_l * N
if COMPLETION_MASK is not None:
COMPLETION_MASK += off_b * L + off_l
not_skip = tl.load(COMPLETION_MASK)
if not_skip == 0:
for start in range(0, N, BLOCK_N):
cols = tl.arange(0, BLOCK_N) + start
tl.store(DLOGITS + cols, 0.0, mask=cols < N)
return
LOGITS += off_b * (L + 1) * N + off_l * N
DLOSS += off_b * loss_stride0 + off_l * loss_stride1
DLOSS_SUM += off_b
INPUT_IDS += off_b * L + off_l
ADVANTAGES += off_b
LSE += off_b * L + off_l
COEF_1 += off_b
SEQ_LEN += off_b
dloss = tl.load(DLOSS).to(tl.float32)
dloss_sum = tl.load(DLOSS_SUM).to(tl.float32)
lse = tl.load(LSE).to(tl.float32)
coef_1 = tl.load(COEF_1).to(tl.float32)
seq_len = tl.load(SEQ_LEN).to(tl.float32)
idx = tl.load(INPUT_IDS)
x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
logp = x - lse
advantage = tl.load(ADVANTAGES).to(tl.float32)
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
if DELTA != 0.0:
coef_1_for_loss = tl.minimum(coef_1, DELTA)
else:
coef_1_for_loss = coef_1
per_token_loss1 = coef_1_for_loss * advantage
per_token_loss2 = coef_2 * advantage
is_unclipped = per_token_loss2 >= per_token_loss1
# For sequence-level: gradient flows through mean, so scale by coef_1/seq_len
# d(loss)/d(logp) = -advantage * coef_1 / seq_len (when unclipped and not delta-clamped)
dlogp = -coef_1 * advantage / seq_len * is_unclipped * dloss_sum
if DELTA != 0.0:
dlogp = dlogp * (coef_1 <= DELTA)
if BETA != 0.0:
REF_LOGP += off_b * L + off_l
ref_logp = tl.load(REF_LOGP).to(tl.float32)
if USE_BIAS_CORRECTION_KL:
# d(kl * coef_1)/d(logp) = coef_1 * (logp - ref_logp), where coef_1 = exp(logp - old_logp)
if OLD_LOGP is None:
old_logp = logp
else:
old_logp = tl.load(OLD_LOGP + off_b * L + off_l).to(tl.float32)
token_coef_1 = tl.exp(logp - old_logp)
dlogp += BETA * token_coef_1 * (logp - ref_logp) * dloss
else:
dlogp += BETA * (1 - tl.exp(ref_logp - logp)) * dloss
dlogp = dlogp / TEMPERATURE
tl.debug_barrier()
for start_n in tl.range(0, N, BLOCK_N):
cols = start_n + tl.arange(0, BLOCK_N)
logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
probs = tl.exp(logits - lse)
dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
tl.store(DLOGITS + cols, dlogits, mask=cols < N)
@triton.jit
def _grpo_loss_bwd_kernel(
DLOSS,
DLOGITS,
LOGITS,
OLD_LOGP,
REF_LOGP,
INPUT_IDS,
ADVANTAGES,
COMPLETION_MASK,
LSE,
VLLM_IS_RATIO,
VLLM_IS_RATIO_STRIDE,
TEMPERATURE,
BETA: tl.constexpr,
EPS_LOW,
EPS_HIGH,
LOSS_TYPE: tl.constexpr,
SAPO_TEMP_POS,
SAPO_TEMP_NEG,
DELTA,
USE_BIAS_CORRECTION_KL: tl.constexpr,
loss_stride0,
loss_stride1,
L: tl.constexpr,
N: tl.constexpr,
BLOCK_N: tl.constexpr = 4096,
):
off_b = tl.program_id(0).cast(tl.int64)
off_l = tl.program_id(1).cast(tl.int64)
DLOGITS += off_b * (L + 1) * N + off_l * N
if COMPLETION_MASK is not None:
COMPLETION_MASK += off_b * L + off_l
not_skip = tl.load(COMPLETION_MASK)
if not_skip == 0:
for start in range(0, N, BLOCK_N):
cols = tl.arange(0, BLOCK_N) + start
tl.store(DLOGITS + cols, 0.0, mask=cols < N)
return
LOGITS += off_b * (L + 1) * N + off_l * N
DLOSS += off_b * loss_stride0 + off_l * loss_stride1
INPUT_IDS += off_b * L + off_l
ADVANTAGES += off_b
LSE += off_b * L + off_l
dloss = tl.load(DLOSS).to(tl.float32)
lse = tl.load(LSE).to(tl.float32)
idx = tl.load(INPUT_IDS)
x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
logp = x - lse
if OLD_LOGP is None:
old_logp = logp
else:
OLD_LOGP += off_b * L + off_l
old_logp = tl.load(OLD_LOGP).to(tl.float32)
coef_1 = tl.exp(logp - old_logp)
advantage = tl.load(ADVANTAGES).to(tl.float32)
# Branch based on loss type for gradient computation
if LOSS_TYPE == 0: # GRPO/DAPO/BNPO/DR_GRPO: standard PPO clipping
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
if DELTA != 0.0:
coef_1_for_loss = tl.minimum(coef_1, DELTA)
else:
coef_1_for_loss = coef_1
per_token_loss1 = coef_1_for_loss * advantage
per_token_loss2 = coef_2 * advantage
mask = per_token_loss2 >= per_token_loss1
# Gradient uses original coef_1; zero when delta-clamped (constant → no gradient)
dlogp = -coef_1 * advantage * mask
if DELTA != 0.0:
dlogp = dlogp * (coef_1 <= DELTA)
elif LOSS_TYPE == 1: # CISPO: coef_2 is DETACHED, so gradient only flows through logp
# loss = -coef_2 * advantage * logp, where coef_2 = clamp(coef_1, max=eps_high).detach()
# d(loss)/d(logp) = -coef_2 * advantage (coef_2 treated as constant due to detach)
coef_2 = tl.minimum(coef_1, EPS_HIGH)
dlogp = -coef_2 * advantage
elif LOSS_TYPE == 2: # SAPO: gradient through sigmoid gating
# loss = -sapo_coef * advantage, where sapo_coef = sigmoid(τ*(ρ-1)) * 4/τ
# d(loss)/d(logp) = -advantage * d(sapo_coef)/d(coef_1) * d(coef_1)/d(logp)
# d(coef_1)/d(logp) = coef_1 (since coef_1 = exp(logp - old_logp))
# d(sapo_coef)/d(coef_1) = d/d(coef_1)[sigmoid(τ*(coef_1-1)) * 4/τ]
# = τ * sigmoid' * 4/τ = 4 * sigmoid * (1 - sigmoid)
# (the τ factors cancel out in the derivative)
temperature = tl.where(advantage > 0, SAPO_TEMP_POS, SAPO_TEMP_NEG)
sigmoid_input = temperature * (coef_1 - 1.0)
sigmoid_val = tl.sigmoid(sigmoid_input)
d_sapo_d_coef1 = 4.0 * sigmoid_val * (1.0 - sigmoid_val)
dlogp = -advantage * d_sapo_d_coef1 * coef_1
# Apply vLLM IS ratio to PPO gradient (before KL gradient)
if VLLM_IS_RATIO is not None:
# Use modulo to support both (B, L) per-token and (B, 1) per-sequence shapes
vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to(
tl.float32
)
dlogp = dlogp * vllm_is_ratio
if BETA != 0.0:
REF_LOGP += off_b * L + off_l
ref_logp = tl.load(REF_LOGP).to(tl.float32)
if USE_BIAS_CORRECTION_KL:
# d(kl * coef_1)/d(logp) = coef_1 * (logp - ref_logp), where coef_1 = exp(logp - old_logp)
dlogp += BETA * coef_1 * (logp - ref_logp)
else:
dlogp += BETA * (1 - tl.exp(ref_logp - logp))
dlogp = dlogp * dloss / TEMPERATURE
tl.debug_barrier()
for start_n in tl.range(0, N, BLOCK_N):
cols = start_n + tl.arange(0, BLOCK_N)
logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
probs = tl.exp(logits - lse)
dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
tl.store(DLOGITS + cols, dlogits, mask=cols < N)
def _compute_dapo_normalizer(completion_mask):
"""Global active tokens averaged per process (for distributed DAPO loss)."""
normalizer = completion_mask.to(torch.float32).sum()
world_size = 1
if torch.distributed.is_available() and torch.distributed.is_initialized():
normalizer = normalizer.clone()
torch.distributed.all_reduce(normalizer, op=torch.distributed.ReduceOp.SUM)
world_size = torch.distributed.get_world_size()
normalizer = normalizer / world_size
return torch.clamp(normalizer, min=1.0)
def _reduce_loss(per_token_loss, mask, loss_type, max_completion_length, B, L):
"""Apply loss reduction based on loss_type."""
if loss_type == "grpo" or loss_type == "sapo":
return ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean()
elif loss_type == "bnpo":
return (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0)
elif loss_type == "dr_grpo":
max_len = max_completion_length if max_completion_length is not None else L
return (per_token_loss * mask).sum() / (B * max_len)
elif loss_type == "dapo" or loss_type == "cispo":
return (per_token_loss * mask).sum() / _compute_dapo_normalizer(mask)
elif loss_type == "luspo":
return (per_token_loss * mask.sum(-1, keepdim=True)).mean()
raise ValueError(f"Unknown loss_type: {loss_type}. Expected one of: grpo, bnpo, dr_grpo, dapo, cispo, sapo, luspo")
class GrpoLossFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
logits,
old_logp,
ref_logp,
completion_ids,
advantages,
completion_mask,
temperature,
beta,
eps_low,
eps_high,
inplace,
loss_type="grpo",
max_completion_length=None,
reduce=True,
importance_sampling_level="token",
sapo_temperature_pos=1.0,
sapo_temperature_neg=1.05,
vllm_is_ratio=None,
delta=None,
use_bias_correction_kl=False,
):
assert logits.is_contiguous() and completion_ids.is_contiguous()
assert old_logp is None or old_logp.is_contiguous()
assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
assert importance_sampling_level in ("token", "sequence"), (
f"importance_sampling_level must be 'token' or 'sequence', got {importance_sampling_level}"
)
# Validate loss_type
if loss_type not in _str_to_loss_type:
raise ValueError(f"Unknown loss_type '{loss_type}'. Supported types: {list(_str_to_loss_type.keys())}")
# Validate delta + loss_type combinations
if delta is not None and loss_type in ("cispo", "sapo"):
raise ValueError(f"delta (two-sided clipping) is not supported for loss_type='{loss_type}'.")
# Map delta to float for Triton (Triton can't handle None)
delta_val = 0.0 if delta is None else float(delta)
# Validate sequence-level + loss_type combinations
if importance_sampling_level == "sequence" and loss_type in ("cispo", "sapo"):
raise ValueError(
f"Sequence-level importance sampling is not supported for loss_type='{loss_type}'. "
f"Use importance_sampling_level='token' instead."
)
# Validate SAPO temperatures to prevent division by zero or numerical instability
if loss_type == "sapo":
if sapo_temperature_pos <= 0:
raise ValueError(f"sapo_temperature_pos must be positive, got {sapo_temperature_pos}")
if sapo_temperature_neg <= 0:
raise ValueError(f"sapo_temperature_neg must be positive, got {sapo_temperature_neg}")
# Convert loss_type string to integer for Triton constexpr
loss_type_int = _str_to_loss_type[loss_type]
B, L_ADD_1, N = logits.shape
L = L_ADD_1 - 1
if completion_mask is not None:
assert completion_mask.is_contiguous()
mask = completion_mask.float() if completion_mask is not None else torch.ones(B, L, device=logits.device)
# Handle vLLM IS ratio
vllm_is_ratio_ptr = None
vllm_is_ratio_stride = L # default to per-token (unused when ptr is None)
if vllm_is_ratio is not None:
assert vllm_is_ratio.dim() in (1, 2), (
f"vllm_is_ratio must be 1D (B,) or 2D (B, L) / (B, 1), got {vllm_is_ratio.dim()}D"
)
if vllm_is_ratio.dim() == 2:
assert vllm_is_ratio.shape[0] == B and vllm_is_ratio.shape[1] in (1, L), (
f"vllm_is_ratio shape must be ({B}, 1) or ({B}, {L}), got {tuple(vllm_is_ratio.shape)}"
)
else:
assert vllm_is_ratio.shape[0] == B, (
f"vllm_is_ratio shape must be ({B},), got {tuple(vllm_is_ratio.shape)}"
)
vllm_is_ratio = vllm_is_ratio.contiguous()
vllm_is_ratio_ptr = vllm_is_ratio
vllm_is_ratio_stride = vllm_is_ratio.shape[1] if vllm_is_ratio.dim() > 1 else 1
# Allocate outputs
loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
lse = torch.zeros_like(loss)
is_clipped = torch.zeros_like(loss)
kl = torch.zeros_like(loss) if beta != 0.0 else None
if importance_sampling_level == "sequence":
# Sequence-level: pre-compute sequence importance weights, then use Triton kernel
# Step 1: Get per-token log probs using existing Triton kernel
per_token_logps = fused_selective_log_softmax(logits, completion_ids, temperature, completion_mask)
# Step 2: Compute sequence-level importance weights
if old_logp is None:
log_ratio = torch.zeros_like(per_token_logps)
else:
log_ratio = per_token_logps - old_logp
seq_lens = mask.sum(-1).clamp(min=1.0) # (B,)
seq_log_importance = (log_ratio * mask).sum(-1) / seq_lens # (B,)
coef_1 = torch.exp(seq_log_importance) # (B,)
coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high) # (B,)
# Compute is_clipped at sequence level (using original coef_1)
is_clipped_seq = ((coef_1 < 1 - eps_low) & (advantages < 0)) | ((coef_1 > 1 + eps_high) & (advantages > 0))
is_clipped_seq = is_clipped_seq.float() # (B,)
# Apply delta clamp for loss computation (keep original coef_1 for backward)
if delta is not None:
coef_1_for_loss = torch.clamp(coef_1, max=delta)
else:
coef_1_for_loss = coef_1
# Step 3: Run Triton kernel with pre-computed coefficients
kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1}
_grpo_loss_fwd_kernel_seq[(B, L)](
logits,
old_logp,
ref_logp,
completion_ids,
completion_mask,
advantages,
coef_1_for_loss.contiguous(),
coef_2.contiguous(),
is_clipped_seq.contiguous(),
vllm_is_ratio_ptr,
vllm_is_ratio_stride,
loss,
lse,
kl,
is_clipped,
temperature,
beta,
use_bias_correction_kl,
L,
N,
**kwargs,
)
# Save extra tensors for backward
ctx.save_for_backward(
logits,
old_logp,
ref_logp,
completion_ids,
advantages,
completion_mask,
lse,
mask,
coef_1,
seq_lens,
vllm_is_ratio_ptr,
)
else:
# Token-level: use optimized Triton kernel with LOSS_TYPE branching
kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1}
_grpo_loss_fwd_kernel[(B, L)](
logits,
old_logp,
ref_logp,
completion_ids,
completion_mask,
advantages,
vllm_is_ratio_ptr,
vllm_is_ratio_stride,
loss,
lse,
kl,
is_clipped,
temperature,
beta,
eps_low,
eps_high,
loss_type_int,
sapo_temperature_pos,
sapo_temperature_neg,
delta_val,
use_bias_correction_kl,
L,
N,
**kwargs,
)
ctx.save_for_backward(
logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse, mask, vllm_is_ratio_ptr
)
ctx.infos = (
temperature,
beta,
eps_low,
eps_high,
inplace,
loss_type,
loss_type_int,
sapo_temperature_pos,
sapo_temperature_neg,
max_completion_length,
B,
L,
importance_sampling_level,
vllm_is_ratio_stride,
reduce,
delta_val,
use_bias_correction_kl,
)
# Compute metrics before reduction
mask_sum = mask.sum().clamp(min=1.0)
kl_mean = (kl * mask).sum() / mask_sum if kl is not None else None
clip_ratio = (is_clipped.float() * mask).sum() / mask_sum
if not reduce:
loss_out = loss * mask
kl_out = kl * mask if kl is not None else None
is_clipped_out = is_clipped * mask
return loss_out, kl_out, is_clipped_out
reduced_loss = _reduce_loss(loss, mask, loss_type, max_completion_length, B, L)
return reduced_loss, kl_mean, clip_ratio
@staticmethod
def backward(ctx, *args):
dloss_input = args[0]
saved_tensors = ctx.saved_tensors
(
temperature,
beta,
eps_low,
eps_high,
inplace,
loss_type,
loss_type_int,
sapo_temperature_pos,
sapo_temperature_neg,
max_completion_length,
B,
L,
importance_sampling_level,
vllm_is_ratio_stride,
reduce,
delta_val,
use_bias_correction_kl,
) = ctx.infos
if importance_sampling_level == "sequence":
(
logits,
old_logp,
ref_logp,
completion_ids,
advantages,
completion_mask,
lse,
mask,
coef_1,
seq_lens,
vllm_is_ratio,
) = saved_tensors
else:
(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse, mask, vllm_is_ratio) = (
saved_tensors
)
_, L_ADD_1, N = logits.shape
# Compute per-token gradient scaling based on loss_type
if not reduce:
dloss = dloss_input
elif loss_type == "grpo" or loss_type == "sapo":
seq_lens_bwd = mask.sum(-1, keepdim=True).clamp(min=1.0)
dloss = dloss_input * mask / (seq_lens_bwd * B)
elif loss_type == "bnpo":
dloss = dloss_input * mask / mask.sum().clamp(min=1.0)
elif loss_type == "dr_grpo":
max_len = max_completion_length if max_completion_length is not None else L
dloss = dloss_input * mask / (B * max_len)
elif loss_type == "dapo" or loss_type == "cispo":
dloss = dloss_input * mask / _compute_dapo_normalizer(mask)
elif loss_type == "luspo":
# loss = mean(per_token_loss * seq_lens), mean divides by B*L
seq_lens_bwd = mask.sum(-1, keepdim=True).clamp(min=1.0)
dloss = dloss_input * seq_lens_bwd / (B * L)
else:
raise ValueError(f"Unknown loss_type: {loss_type}")
dlogits = logits.data if inplace else torch.empty_like(logits)
kwargs = {"BLOCK_N": 4096, "num_stages": 1, "num_warps": 16}
if importance_sampling_level == "sequence":
if vllm_is_ratio is None:
dloss_sum = dloss.sum(-1).contiguous()
else:
if vllm_is_ratio.dim() == 1:
ratio = vllm_is_ratio.unsqueeze(-1)
else:
ratio = vllm_is_ratio
dloss_sum = (dloss * ratio).sum(-1).contiguous()
# Sequence-level backward kernel
_grpo_loss_bwd_kernel_seq[(B, L)](
dloss,
dloss_sum,
dlogits,
logits,
old_logp,
ref_logp,
completion_ids,
advantages,
completion_mask,
lse,
coef_1,
seq_lens,
temperature,
beta,
use_bias_correction_kl,
eps_low,
eps_high,
delta_val,
*dloss.stride(),
L,
N,
**kwargs,
)
else:
# Token-level backward kernel with LOSS_TYPE branching
_grpo_loss_bwd_kernel[(B, L)](
dloss,
dlogits,
logits,
old_logp,
ref_logp,
completion_ids,
advantages,
completion_mask,
lse,
vllm_is_ratio,
vllm_is_ratio_stride,
temperature,
beta,
eps_low,
eps_high,
loss_type_int,
sapo_temperature_pos,
sapo_temperature_neg,
delta_val,
use_bias_correction_kl,
*dloss.stride(),
L,
N,
**kwargs,
)
dlogits[:, -1, :] = 0
# Return gradients for all forward inputs: dlogits + 19 None for non-differentiable params
return (
dlogits,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
from typing import Optional
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.utils import infer_device
@triton.jit
def _jsd_kernel(
X_ptr, # input in logspace, X = log Q
X_stride,
Y_ptr, # ground truth in logspace, Y = log P
Y_stride,
loss_ptr,
loss_stride,
dX_ptr,
dX_stride,
label_ptr,
beta: tl.constexpr,
n_non_ignore: int,
ignore_index: tl.constexpr,
n_cols,
BLOCK_SIZE: tl.constexpr,
HAS_LABEL: tl.constexpr,
):
# JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X)
# = sum(P * log P + Q * log Q - 2 * M * log M) / 2
# = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2
# grad_x_i = 0.5 * Q * (X - log_M)
pid = tl.program_id(0).to(tl.int64)
X_ptr += pid * X_stride
dX_ptr += pid * dX_stride
Y_ptr += pid * Y_stride
loss_ptr += pid * loss_stride
label_ptr += pid
if HAS_LABEL:
label = tl.load(label_ptr)
if label == ignore_index:
for i in range(0, n_cols, BLOCK_SIZE):
offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols)
return
for i in range(0, n_cols, BLOCK_SIZE):
offsets = i + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_cols
X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
if beta == 0.0: # forward KL
Y_max = tl.max(Y, axis=0)
Y_shifted = Y - Y_max
Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
loss = Y_prob * (Y - X)
dX = -Y_prob
elif beta == 1.0: # reverse KL
X_max = tl.max(X, axis=0)
X_shifted = X - X_max
X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
loss = X_prob * (X - Y)
dX = loss + X_prob
else:
max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
X_shifted = X - max_val
Y_shifted = Y - max_val
# Pre-compute exp(max_val) since it's used twice
exp_max = tl.exp(max_val)
# Compute exp terms with compensation
Q = tl.exp(X_shifted) * exp_max # = exp(X)
P = tl.exp(Y_shifted) * exp_max # = exp(Y)
# Pre-compute common terms
beta_P = beta * P
one_minus_beta_Q = (1 - beta) * Q
M = beta_P + one_minus_beta_Q
log_M = tl.log(M) # No need to compensate as M is already in original scale
loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
dX = one_minus_beta_Q * (X - log_M)
# Pre-compute scaling factor
scale = 1.0 / n_non_ignore
loss = loss * scale
dX = dX * scale
tl.store(loss_ptr + offsets, loss, mask=mask)
tl.store(dX_ptr + offsets, dX, mask=mask)
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536
def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
# non reduction loss
loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
dX = torch.empty_like(_input)
if has_label:
n_non_ignore = (shift_labels != ignore_index).sum().item()
else:
n_non_ignore = BT
_jsd_kernel[(n_rows,)](
X_ptr=_input, # input in logspace, X = log Q
X_stride=_input.stride(-2),
Y_ptr=target, # ground truth in logspace, Y = log P
Y_stride=target.stride(-2),
loss_ptr=loss,
loss_stride=loss.stride(-2),
dX_ptr=dX,
dX_stride=dX.stride(-2),
label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
beta=beta,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
n_cols=V,
BLOCK_SIZE=BLOCK_SIZE,
HAS_LABEL=has_label,
)
loss = torch.sum(loss)
return loss.to(_input.dtype), dX
def jsd_backward(dX, grad_output):
# If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
return dX
else:
return grad_output * dX
class LigerJSDFunction(torch.autograd.Function):
r"""
This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
.. math::
JSD(\beta)(P || Q)
= \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
.. note::
As all the other losses in PyTorch, this function expects the first argument,
:attr:`_input`, to be the predictions, the output of the student model, in log-space
and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
This differs from the standard mathematical notation :math:`JSD(P || Q)` where
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
"""
@staticmethod
@ensure_contiguous
def forward(
ctx,
_input: torch.Tensor,
target: torch.Tensor,
shift_labels: Optional[torch.Tensor] = None,
beta: float = 0.5,
ignore_index: int = -100,
) -> torch.Tensor:
"""
Args:
_input (torch.Tensor): predict values with shape (BT, V) in logspace
target (torch.Tensor): ground truth values with shape (BT, V) in logspace
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
ignore_index (int): the index to ignore. Default: -100
Returns:
loss (torch.Tensor): generalized JSD
"""
has_label = False
if shift_labels is not None:
assert shift_labels.shape == (_input.shape[0],), (
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
)
shift_labels = shift_labels.contiguous()
has_label = True
loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
ctx.save_for_backward(dX)
return loss
@staticmethod
@ensure_contiguous
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
(dX,) = ctx.saved_tensors
dX = jsd_backward(dX, grad_output)
return (
dX,
None,
None,
None,
None,
)
from typing import Literal
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import is_hip
from liger_kernel.utils import infer_device
def get_num_warps(BLOCK_SIZE):
num_warps = 4
if BLOCK_SIZE >= 32768:
num_warps = 32 if not is_hip() else 16
elif BLOCK_SIZE >= 8192:
num_warps = 16
elif BLOCK_SIZE >= 2048:
num_warps = 8
return num_warps
if infer_device() == "xpu":
MAX_FUSED_SIZE = 8192
elif infer_device() == "npu":
MAX_FUSED_SIZE = 8192
else:
MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
_REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0)
_REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1)
_REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2)
_REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3)
_str_to_reduction_mode = {
"none": _REDUCTION_MODE_NONE.value,
"sum": _REDUCTION_MODE_SUM.value,
"mean": _REDUCTION_MODE_MEAN.value,
"batchmean": _REDUCTION_MODE_BATCHMEAN.value,
}
@triton.jit
def _kldiv_kernel_forward(
y_ptr, # [B, S], prediction ptr, the kernel expects the prediction in log-space
y_stride, # int, prediction stride
gt_ptr, # [B, S], ground truth ptr
gt_stride, # int, ground truth stride
loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr
loss_stride, # int, output stride
n_cols, # int, number of columns in the input tensor
eps,
BLOCK_SIZE: tl.constexpr,
log_target: tl.constexpr = False,
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
):
pid = tl.program_id(0).to(tl.int64)
y_ptr += pid * y_stride
gt_ptr += pid * gt_stride
loss_ptr += pid * loss_stride
base_offsets = tl.arange(0, BLOCK_SIZE)
loss_sum = 0.0
for i in range(0, n_cols, BLOCK_SIZE):
offsets = i + base_offsets
mask = offsets < n_cols
y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
y_true = tl.load(gt_ptr + offsets, mask=mask, other=0.0)
# KL(y_true || y) = y_true * (log(y_true) - log(y))
# We compute KL(y_true || y) with y in the log-space
if not log_target:
loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y)
else:
loss = tl.exp(y_true) * (y_true - y)
if reduction == _REDUCTION_MODE_NONE:
tl.store(loss_ptr + offsets, loss, mask=mask)
else:
loss_sum += tl.sum(loss, axis=0)
if reduction != _REDUCTION_MODE_NONE:
tl.store(loss_ptr, loss_sum)
@triton.jit
def _kldiv_kernel_backward(
target_ptr,
target_stride,
new_grads_ptr,
new_grads_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
log_target: tl.constexpr = False,
):
pid = tl.program_id(0).to(tl.int64)
target_ptr += pid * target_stride
new_grads_ptr += pid * new_grads_stride
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < n_cols
for i in range(0, n_cols, BLOCK_SIZE):
offsets = i + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_cols
target = tl.load(target_ptr + offsets, mask=mask, other=0.0)
if not log_target:
res = target * -1
else:
res = -tl.exp(target)
tl.store(new_grads_ptr + offsets, res, mask=mask)
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
BT, V = y_pred.shape
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
grid = (BT,)
reduction = _str_to_reduction_mode[reduction]
out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32)
_kldiv_kernel_forward[grid](
y_pred,
y_pred.stride(0),
y_true,
y_true.stride(0),
output_tensor,
output_tensor.stride(0),
V,
eps=eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
log_target=log_target,
reduction=reduction,
)
# calculated according to the reduction mode same as in Pytorch. In the later versions, `mean` will be changed to the same behavior as `batchmean`
# https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
# https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
return output_tensor.sum() / BT
elif reduction == _REDUCTION_MODE_SUM.value:
return output_tensor.sum(dim=0)
elif reduction == _REDUCTION_MODE_MEAN.value:
return output_tensor.sum() / (BT * V)
else:
return output_tensor
def kldiv_backward_triton(target, grad_output, new_grads, log_target):
BT, V = target.shape
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
grid = (BT,)
# We store the gradients in-place in the input tensor
_kldiv_kernel_backward[grid](
target,
target.stride(0),
new_grads,
new_grads.stride(0),
V,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
log_target=log_target,
)
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
return new_grads
return new_grads * grad_output
class LigerKLDivLossFunction(torch.autograd.Function):
"""
Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
```python
if log_target:
loss = target.exp() * (target - input)
else:
loss = target * (target.log() - input)
```,
then the loss is reduced according to the `reduction` parameter.
as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
"""
@staticmethod
@ensure_contiguous
def forward(
ctx,
y_pred: torch.Tensor,
y_true: torch.Tensor,
reduction: REDUCTION_LITERAL = "batchmean",
log_target: bool = False,
eps: float = 1e-10,
) -> torch.Tensor:
"""A forward pass for the KL Divergence Loss.
Args:
ctx: Torch autograd context
y_pred (torch.Tensor): A tensor of shape (BT, V) containing the predicted values, expected to be log-probabilities.
y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`.
reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean".
log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False.
eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10.
Returns:
torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar.
"""
ctx.save_for_backward(y_true)
ctx.reduction = reduction
ctx.log_target = log_target
return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
@staticmethod
@ensure_contiguous
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
"""A backward pass for the KL Divergence Loss.
Args:
ctx: Torch autograd context
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
Returns:
tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
"""
(y_true,) = ctx.saved_tensors
new_grads = torch.empty_like(y_true)
derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target)
if ctx.reduction == "batchmean":
derivative = derivative / y_true.shape[0]
elif ctx.reduction == "sum" or ctx.reduction == "none":
pass
elif ctx.reduction == "mean":
derivative = derivative / (y_true.shape[0] * y_true.shape[1])
return (
derivative,
None,
None,
None,
None,
)
import math
import operator
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import compare_version
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
from liger_kernel.ops.utils import set_large_grf_mode
from liger_kernel.utils import is_npu_available
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
# for working with NGC containers
from triton.language.extra.cuda.libdevice import rsqrt
else:
from triton.language.math import rsqrt
@triton.jit
def _layer_norm_forward_kernel(
Y_ptr, # pointer to output, shape (n_rows, n_cols)
Y_row_stride, # stride of each row in output
X_ptr, # pointer to input, shape (n_rows, n_cols)
X_row_stride, # stride of each row in input
W_ptr, # pointer to weights, shape (n_cols,)
W_row_stride, # stride of each row in weights
B_ptr, # pointer to bias, shape (n_cols,)
B_row_stride, # stride of each row in bias
Mean_ptr, # pointer to mean, shape (n_rows,)
Mean_row_stride, # stride of each row in mean
RSTD_ptr, # pointer to rstd, shape (n_rows,)
RSTD_row_stride, # stride of each row in rstd
n_cols,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""
References:
https://arxiv.org/abs/1607.06450
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
"""
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# Pre-load weights and bias in fp32 to avoid repeated conversions
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0)
W_f32 = W_row.to(tl.float32)
B_f32 = B_row.to(tl.float32)
# Calculate pointers for this row
row_X_ptr = X_ptr + row_idx * X_row_stride
row_Y_ptr = Y_ptr + row_idx * Y_row_stride
row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride
row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride
# Load input data and convert to fp32 for numerical stability
X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0)
X_f32 = X_row.to(tl.float32)
# Compute statistics in fp32 for numerical stability
mean = tl.sum(X_f32, axis=0) / n_cols
X_centered = X_f32 - mean
# Apply mask to variance calculation to exclude contributions from masked elements
X_centered_masked = tl.where(mask, X_centered, 0.0)
var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
rstd = rsqrt(var + eps)
# Store statistics (convert back to original dtype only once)
tl.store(row_Mean_ptr, mean.to(X_row.dtype))
tl.store(row_RSTD_ptr, rstd.to(X_row.dtype))
# Fused normalization and affine transformation
# Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B
Y_f32 = X_centered * rstd * W_f32 + B_f32
# Store output (single conversion back to original dtype)
tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask)
@triton.jit
def _layer_norm_backward_kernel(
X_ptr, # pointer to input, shape (n_rows, n_cols)
stride_x, # stride of each row in input
W_ptr, # pointer to weights, shape (n_cols,)
Mean_ptr, # pointer to mean, shape (n_rows,)
stride_mean, # stride of each row in mean
RSTD_ptr, # pointer to rstd, shape (n_rows,)
stride_rstd, # stride of each row in rstd
DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
stride_dx, # stride of each row in input grad
DW_ptr, # pointer to weights grad, shape (n_cols,)
stride_dw, # stride of each row in weights grad
DB_ptr, # pointer to bias grad, shape (n_cols,)
stride_db, # stride of each row in bias grad
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
stride_dy, # stride of each row in output grad
n_rows,
n_cols,
rows_per_program: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
References:
https://arxiv.org/abs/1607.06450
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
"""
row_block_id = tl.program_id(0).to(tl.int64)
row_start = row_block_id * rows_per_program
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < n_cols
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
# Pre-load weights once (same optimization as forward pass)
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
w_f32 = w.to(tl.float32)
for row_idx in range(row_start, row_end):
# Calculate pointers for this specific row
row_X_ptr = X_ptr + row_idx * stride_x
row_DX_ptr = DX_ptr + row_idx * stride_dx
row_DY_ptr = DY_ptr + row_idx * stride_dy
row_Mean_ptr = Mean_ptr + row_idx * stride_mean
row_RSTD_ptr = RSTD_ptr + row_idx * stride_rstd
# Load data for this row
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
mean = tl.load(row_Mean_ptr)
rstd = tl.load(row_RSTD_ptr)
# Convert to fp32 for numerical stability
x_f32 = x.to(tl.float32)
dy_f32 = dy.to(tl.float32)
mean_f32 = mean.to(tl.float32)
rstd_f32 = rstd.to(tl.float32)
# Compute backward pass for this row
x_hat = (x_f32 - mean_f32) * rstd_f32
wdy = w_f32 * dy_f32
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
c2 = tl.sum(wdy, axis=0) / n_cols
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
# Store input gradient
tl.store(row_DX_ptr + cols, dx, mask=mask)
# Accumulate weight and bias gradients for this thread block's assigned rows
dw = dy_f32 * x_hat
db = dy_f32
dW_row += dw
db_row += db
tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
def layer_norm_forward(X, W, B, eps):
"""
Args:
X: Input tensor of shape (..., hidden_size)
W: Weight tensor of shape (hidden_size,)
B: Bias tensor of shape (hidden_size,)
eps: Small constant for numerical stability
Returns:
Tuple of (output, input, mean, rstd, block_size, num_warps)
"""
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
n_rows, n_cols = X.shape
# Calculate optimal block size and warp configuration
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
# Allocate output tensors
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
# Validate input dimensions
if X.shape[1] != W.shape[0]:
raise ValueError(
f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
f"must match weight size (W.shape[0]={W.shape[0]})"
)
# XPU-specific optimization
kernel_args = {}
if X.device.type == "xpu":
set_large_grf_mode(kernel_args)
# Launch kernel with one thread block per row for optimal performance
grid = (n_rows,)
_layer_norm_forward_kernel[grid](
Y,
Y.stride(0),
X,
X.stride(0),
W,
W.stride(0),
B,
B.stride(0),
Mean,
Mean.stride(0),
RSTD,
RSTD.stride(0),
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_args,
)
return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
def layer_norm_backward(dY, X, W, B, Mean, RSTD):
"""
Args:
dY: Gradient of output
X: Input tensor
W: Weight tensor
B: Bias tensor
Mean: Pre-computed mean
RSTD: Pre-computed reciprocal standard deviation
Returns:
Tuple of (input_grad, weight_grad, bias_grad)
"""
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
n_rows, n_cols = dY.shape
sm_count = 1
if X.device.type == "cuda":
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
elif X.device.type == "xpu":
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
elif X.device.type == "npu":
sm_count = get_npu_core_count()
# fp32 for numerical stability especially.
_DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
_DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
# Calculate optimal block size and warp configuration
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
if n_cols > BLOCK_SIZE:
raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
rows_per_program = math.ceil(n_rows / sm_count)
grid = (sm_count,)
# Allocate gradient tensors
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
kernel_args = {"num_warps": num_warps}
# XPU-specific optimization
if X.device.type == "xpu":
kernel_args.update({"num_warps": 32, "num_stages": 4})
set_large_grf_mode(kernel_args)
# Launch kernel with one thread block per row for optimal performance
_layer_norm_backward_kernel[grid](
X,
X.stride(0),
W,
Mean,
Mean.stride(0),
RSTD,
RSTD.stride(0),
DX,
DX.stride(0),
_DW,
_DW.stride(0),
_DB,
_DB.stride(0),
dY,
dY.stride(0),
n_rows,
n_cols,
rows_per_program=rows_per_program,
BLOCK_SIZE=BLOCK_SIZE,
**kernel_args,
)
DX = DX.view(*shape)
DW = _DW.sum(dim=0).to(W.dtype)
DB = _DB.sum(dim=0).to(B.dtype)
return DX, DW, DB
class LigerLayerNormFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, B, eps):
Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps)
ctx.save_for_backward(X, W, B, Mean, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, W, B, Mean, RSTD = ctx.saved_tensors
DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD)
return DX, DW, DB, None
import torch
import triton
import triton.language as tl
def _cast_and_contiguous(q, k, freqs_complex):
# Align dtype: fp32 only when q is fp32; otherwise keep q dtype for perf
compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype
if k.dtype != q.dtype:
k = k.to(q.dtype)
q = q.to(compute_dtype).contiguous()
k = k.to(compute_dtype).contiguous()
freqs_complex = freqs_complex.contiguous()
return q, k, freqs_complex
@triton.jit
def _llama4_rope_kernel(
q_ptr,
k_ptr,
freqs_complex_ptr,
q_row_stride,
k_row_stride,
q_head_stride,
k_head_stride,
freqs_row_stride,
seq_len,
batch_size,
imag_sign,
head_dim_half: tl.constexpr,
n_q_heads: tl.constexpr,
n_k_heads: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
H100-optimized RoPE kernel with improved parallelization across heads and dimensions.
Grid: (batch*seq, head)
"""
# 2D grid
pid_bs = tl.program_id(0) # over batch*seq
pid_h = tl.program_id(1) # over heads
batch_idx = pid_bs // seq_len
seq_idx = pid_bs % seq_len
# Bounds check
if batch_idx >= batch_size or seq_idx >= seq_len:
return
# Base pointers for this (batch, seq) position
base_offset = batch_idx * seq_len + seq_idx
q_base = q_ptr + base_offset * q_row_stride
k_base = k_ptr + base_offset * k_row_stride
freq_base = seq_idx * freqs_row_stride
# Tiling over dim/2
for d_start in tl.static_range(0, head_dim_half, BLOCK_SIZE):
d_indices = d_start + tl.arange(0, BLOCK_SIZE)
mask_d = d_indices < head_dim_half
# Compute offsets for the block
freq_offsets = d_indices[:, None] * 2 + tl.arange(0, 2)[None, :]
# Load the block
freqs_complex = tl.load(freqs_complex_ptr + freq_base + freq_offsets, mask=mask_d[:, None], other=0.0)
freqs_real, freqs_imag = tl.split(freqs_complex)
freqs_imag = freqs_imag * imag_sign
# Process one query head per program in pid_h
if pid_h < n_q_heads:
q_head_ptr = q_base + pid_h * q_head_stride
q_real = tl.load(q_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
q_imag = tl.load(q_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
# Complex multiply with FMAs: (a+ib)*(c+i d) = (a*c - b*d) + i(a*d + b*c)
new_q_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag))
new_q_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real)
tl.store(q_head_ptr + d_indices * 2, new_q_real, mask=mask_d)
tl.store(q_head_ptr + d_indices * 2 + 1, new_q_imag, mask=mask_d)
# Process one key head per program in pid_h
if pid_h < n_k_heads:
k_head_ptr = k_base + pid_h * k_head_stride
k_real = tl.load(k_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
k_imag = tl.load(k_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
new_k_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag))
new_k_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real)
tl.store(k_head_ptr + d_indices * 2, new_k_real, mask=mask_d)
tl.store(k_head_ptr + d_indices * 2 + 1, new_k_imag, mask=mask_d)
def _select_kernel_meta(head_dim_half: int):
# Heuristic tuning for block size and num_warps
if head_dim_half >= 256:
return 128, 8
if head_dim_half >= 96:
return 128, 4
if head_dim_half >= 48:
return 64, 4
if head_dim_half >= 24:
return 32, 2
return 16, 2
def llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE: int = None, imag_sign: float = 1.0):
# Save original dtype for casting back
original_dtype = q.dtype
batch_size, seq_len, n_q_heads, head_dim = q.shape
_, _, n_k_heads, _ = k.shape
head_dim_half = head_dim // 2
if freqs_cis.is_complex():
freqs_cis = freqs_cis.reshape(-1, freqs_cis.shape[-1])
if freqs_cis.shape[0] > seq_len:
freqs_cis = freqs_cis[:seq_len]
freqs_cis = torch.view_as_real(freqs_cis)
# Cast to appropriate dtype and make contiguous only when needed
q, k, freqs_cis = _cast_and_contiguous(q, k, freqs_cis)
# H100-optimized meta-params
if BLOCK_SIZE is None:
BLOCK_SIZE, num_warps = _select_kernel_meta(head_dim_half)
else:
# Provide a default num_warps if caller pins BLOCK_SIZE
_, num_warps = _select_kernel_meta(head_dim_half)
# 2D grid: one program per (batch, seq, head)
n_heads_max = max(n_q_heads, n_k_heads)
grid = (batch_size * seq_len, n_heads_max)
# Launch kernel
_llama4_rope_kernel[grid](
q,
k,
freqs_cis,
q.stride(1),
k.stride(1),
q.stride(2),
k.stride(2),
freqs_cis.stride(0),
seq_len,
batch_size,
imag_sign,
head_dim_half,
n_q_heads,
n_k_heads,
BLOCK_SIZE,
num_warps=num_warps,
num_stages=2,
)
# Cast back to original dtype only if it differs from compute dtype
if q.dtype != original_dtype:
q = q.to(original_dtype)
if k.dtype != original_dtype:
k = k.to(original_dtype)
return q, k
class LigerLlama4RopeFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None):
q_out, k_out = llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE, imag_sign=1.0)
ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis)
ctx.BLOCK_SIZE = BLOCK_SIZE
return q_out, k_out
@staticmethod
def backward(ctx, dq, dk):
(freqs_cis,) = ctx.saved_tensors
BLOCK_SIZE = getattr(ctx, "BLOCK_SIZE", None)
# Use imag_sign=-1.0 for conjugate without materializing a new tensor
dq_out, dk_out = llama4_rope_forward(dq, dk, freqs_cis, BLOCK_SIZE, imag_sign=-1.0)
return dq_out, dk_out, None
import math
from typing import Any
from typing import Optional
from typing import Tuple
from typing import Union
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import ensure_contiguous
def _post_res_default_meta(c: int) -> Tuple[int, int, int, int]:
"""
Returns default (block_n, block_c, num_warps, num_stages) for post_res kernels.
Tuned for different hidden dimensions on NVIDIA GPUs.
"""
if c >= 4096:
return 32, 128, 8, 3 # (block_n, block_c, num_warps, num_stages)
if c >= 2048:
return 32, 128, 4, 2
if c >= 1024:
return 32, 64, 4, 2
return 32, 64, 2, 2
def _post_res_meta(
c: int,
block_n: Optional[int],
block_c: Optional[int],
num_warps: Optional[int],
num_stages: Optional[int],
) -> Tuple[int, int, int, int]:
bn, bc, nw, ns = _post_res_default_meta(c)
return (
bn if block_n is None else int(block_n),
bc if block_c is None else int(block_c),
nw if num_warps is None else int(num_warps),
ns if num_stages is None else int(num_stages),
)
# -------------------------------------------------------------------------------------------------
# (1) Coefficients: fused matmul + RMS scalar (Eq. 14–15)
# mix = (x @ phi) * rsqrt(mean(x^2) + eps)
#
# We provide two paths:
# - TC path: x BF16/FP16 and phi BF16/FP16 (Tensor Cores)
# - TF32-ish path: x cast to FP32 and phi FP32 (relies on Triton/arch for TF32)
# -------------------------------------------------------------------------------------------------
@triton.jit
def _mhc_mm_norm_fwd_kernel(
x_ptr,
phi_ptr,
mix_ptr,
invr_ptr,
N: tl.constexpr,
K: tl.constexpr,
M: tl.constexpr,
stride_xn: tl.constexpr,
stride_xk: tl.constexpr,
stride_phik: tl.constexpr,
stride_phim: tl.constexpr,
stride_mn: tl.constexpr,
stride_mm: tl.constexpr,
eps: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_M: tl.constexpr,
CAST_FP32: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_m = tl.program_id(1)
n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
m_offs = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
acc = tl.zeros((BLOCK_N, BLOCK_M), tl.float32)
sumsq = tl.zeros((BLOCK_N,), tl.float32)
for k0 in tl.static_range(0, K, BLOCK_K):
k_offs = k0 + tl.arange(0, BLOCK_K)
x = tl.load(
x_ptr + n_offs[:, None] * stride_xn + k_offs[None, :] * stride_xk,
mask=(n_offs[:, None] < N) & (k_offs[None, :] < K),
other=0.0,
)
if CAST_FP32:
x = x.to(tl.float32)
sumsq += tl.sum(x * x, axis=1)
else:
x_f = x.to(tl.float32)
sumsq += tl.sum(x_f * x_f, axis=1)
phi = tl.load(
phi_ptr + k_offs[:, None] * stride_phik + m_offs[None, :] * stride_phim,
mask=(k_offs[:, None] < K) & (m_offs[None, :] < M),
other=0.0,
)
if CAST_FP32:
phi = phi.to(tl.float32)
acc += tl.dot(x, phi)
invr = tl.rsqrt(sumsq / K + eps)
out = acc * invr[:, None]
tl.store(
mix_ptr + n_offs[:, None] * stride_mn + m_offs[None, :] * stride_mm,
out,
mask=(n_offs[:, None] < N) & (m_offs[None, :] < M),
)
if pid_m == 0:
tl.store(invr_ptr + n_offs, invr, mask=n_offs < N)
def mhc_mm_norm_fwd(
x: torch.Tensor,
phi: torch.Tensor,
eps: float,
*,
out_mix: Optional[torch.Tensor] = None,
out_invr: Optional[torch.Tensor] = None,
block_n: int = 32,
block_k: int = 256,
block_m: int = 32,
num_warps: int = 4,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Fused (x @ phi) + invr = rsqrt(mean(x^2)+eps) and returns mix=(x@phi)*invr.
Args:
x: [N, K] contiguous
phi: [K, M] contiguous
eps: float
Returns:
mix: [N, M] float32
invr: [N] float32
"""
assert x.is_contiguous(), "x must be contiguous"
assert phi.is_contiguous(), "phi must be contiguous"
N, K = x.shape
K2, M = phi.shape
assert K2 == K, f"phi.shape[0] must match K: got {K2} vs {K}"
if out_mix is None:
out_mix = torch.empty((N, M), device=x.device, dtype=torch.float32)
if out_invr is None:
out_invr = torch.empty((N,), device=x.device, dtype=torch.float32)
grid = (triton.cdiv(N, block_n), triton.cdiv(M, block_m))
use_tc = (x.dtype == phi.dtype) and (x.dtype in (torch.float16, torch.bfloat16))
_mhc_mm_norm_fwd_kernel[grid](
x,
phi,
out_mix,
out_invr,
N=N,
K=K,
M=M,
stride_xn=x.stride(0),
stride_xk=x.stride(1),
stride_phik=phi.stride(0),
stride_phim=phi.stride(1),
stride_mn=out_mix.stride(0),
stride_mm=out_mix.stride(1),
eps=eps,
BLOCK_N=block_n,
BLOCK_K=block_k,
BLOCK_M=block_m,
CAST_FP32=not use_tc,
num_warps=num_warps,
)
return out_mix, out_invr
# -------------------------------------------------------------------------------------------------
# Backward for fused (x @ phi) + RMS scalar
#
# mix = (x @ phi) * invr
# invr = rsqrt(mean(x^2) + eps)
#
# Given grad_mix, compute:
# grad_z = grad_mix * invr
# g = sum(grad_mix * (mix / invr)) = sum(grad_mix * mix) / invr
# factor = -(g / K) * invr^3
# grad_x = grad_z @ phi^T + factor * x
# grad_phi = x^T @ grad_z
#
# grad_phi is accumulated into FP32 with atomic adds (split over N-chunks).
# -------------------------------------------------------------------------------------------------
@triton.jit
def _mhc_mm_norm_bwd_fused_kernel(
x_ptr,
phi_ptr,
mix_ptr,
invr_ptr,
grad_mix_ptr,
grad_x_ptr,
grad_phi_ptr,
N: tl.constexpr,
K: tl.constexpr,
M: tl.constexpr,
stride_xn: tl.constexpr,
stride_xk: tl.constexpr,
stride_phik: tl.constexpr,
stride_phim: tl.constexpr,
stride_mn: tl.constexpr,
stride_mm: tl.constexpr,
stride_invr: tl.constexpr,
stride_gmn: tl.constexpr,
stride_gmm: tl.constexpr,
stride_gxn: tl.constexpr,
stride_gxk: tl.constexpr,
stride_gpk: tl.constexpr,
stride_gpm: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_M: tl.constexpr,
CAST_FP32: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_k = tl.program_id(1)
n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
k_offs = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
invr = tl.load(invr_ptr + n_offs * stride_invr, mask=n_offs < N, other=0.0).to(tl.float32)
x = tl.load(
x_ptr + n_offs[:, None] * stride_xn + k_offs[None, :] * stride_xk,
mask=(n_offs[:, None] < N) & (k_offs[None, :] < K),
other=0.0,
)
if CAST_FP32:
x = x.to(tl.float32)
x_f = x
else:
x_f = x.to(tl.float32)
acc = tl.zeros((BLOCK_N, BLOCK_K), tl.float32)
g_acc = tl.zeros((BLOCK_N,), tl.float32)
for m0 in tl.static_range(0, M, BLOCK_M):
m_offs = m0 + tl.arange(0, BLOCK_M)
grad_mix = tl.load(
grad_mix_ptr + n_offs[:, None] * stride_gmn + m_offs[None, :] * stride_gmm,
mask=(n_offs[:, None] < N) & (m_offs[None, :] < M),
other=0.0,
).to(tl.float32)
mix = tl.load(
mix_ptr + n_offs[:, None] * stride_mn + m_offs[None, :] * stride_mm,
mask=(n_offs[:, None] < N) & (m_offs[None, :] < M),
other=0.0,
).to(tl.float32)
g_acc += tl.sum(grad_mix * mix, axis=1)
phi = tl.load(
phi_ptr + k_offs[:, None] * stride_phik + m_offs[None, :] * stride_phim,
mask=(k_offs[:, None] < K) & (m_offs[None, :] < M),
other=0.0,
)
if CAST_FP32:
phi = phi.to(tl.float32)
grad_z = grad_mix * invr[:, None]
else:
grad_z = (grad_mix * invr[:, None]).to(phi.dtype)
acc += tl.dot(grad_z, tl.trans(phi))
dphi = tl.dot(tl.trans(x), grad_z)
tl.atomic_add(
grad_phi_ptr + k_offs[:, None] * stride_gpk + m_offs[None, :] * stride_gpm,
dphi,
mask=(k_offs[:, None] < K) & (m_offs[None, :] < M),
)
g = g_acc / invr
invr3 = invr * invr * invr
factor = (-g * invr3) / K
gx = acc + x_f * factor[:, None]
if CAST_FP32:
tl.store(
grad_x_ptr + n_offs[:, None] * stride_gxn + k_offs[None, :] * stride_gxk,
gx,
mask=(n_offs[:, None] < N) & (k_offs[None, :] < K),
)
else:
tl.store(
grad_x_ptr + n_offs[:, None] * stride_gxn + k_offs[None, :] * stride_gxk,
gx.to(x.dtype),
mask=(n_offs[:, None] < N) & (k_offs[None, :] < K),
)
def mhc_mm_norm_bwd(
x: torch.Tensor,
phi: torch.Tensor,
mix: torch.Tensor,
invr: torch.Tensor,
grad_mix: torch.Tensor,
*,
out_grad_x: Optional[torch.Tensor] = None,
out_grad_phi: Optional[torch.Tensor] = None,
block_n: int = 32,
block_k: int = 256,
block_m: int = 32,
num_warps: int = 4,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Triton backward for `mhc_mm_norm_fwd`.
Returns:
grad_x: [N, K] same dtype as x
grad_phi: [K, M] FP32 (safe for atomic adds; cast on return if needed)
Note:
grad_phi is accumulated via atomic_add in FP32. For very large N
(batch * sequence length > 1M), accumulated rounding errors may
become noticeable. This is typically not an issue for standard
training configurations.
"""
assert (
x.is_contiguous()
and phi.is_contiguous()
and mix.is_contiguous()
and invr.is_contiguous()
and grad_mix.is_contiguous()
)
N, K = x.shape
K2, M = phi.shape
assert K2 == K
assert mix.shape == (N, M)
assert grad_mix.shape == (N, M)
assert invr.shape == (N,)
if out_grad_x is None:
out_grad_x = torch.empty_like(x)
if out_grad_phi is None:
out_grad_phi = torch.zeros((K, M), device=x.device, dtype=torch.float32)
use_tc = (x.dtype == phi.dtype) and (x.dtype in (torch.float16, torch.bfloat16))
grid = (triton.cdiv(N, block_n), triton.cdiv(K, block_k))
_mhc_mm_norm_bwd_fused_kernel[grid](
x,
phi,
mix,
invr,
grad_mix,
out_grad_x,
out_grad_phi,
N=N,
K=K,
M=M,
stride_xn=x.stride(0),
stride_xk=x.stride(1),
stride_phik=phi.stride(0),
stride_phim=phi.stride(1),
stride_mn=mix.stride(0),
stride_mm=mix.stride(1),
stride_invr=invr.stride(0),
stride_gmn=grad_mix.stride(0),
stride_gmm=grad_mix.stride(1),
stride_gxn=out_grad_x.stride(0),
stride_gxk=out_grad_x.stride(1),
stride_gpk=out_grad_phi.stride(0),
stride_gpm=out_grad_phi.stride(1),
BLOCK_N=block_n,
BLOCK_K=block_k,
BLOCK_M=block_m,
CAST_FP32=not use_tc,
num_warps=num_warps,
)
if out_grad_phi.dtype != phi.dtype:
out_grad_phi = out_grad_phi.to(phi.dtype)
return out_grad_x, out_grad_phi
# -------------------------------------------------------------------------------------------------
# Sinkhorn-Knopp forward/backward for H_res (Eq. 19)
# -------------------------------------------------------------------------------------------------
@triton.jit
def _mhc_split_sinkhorn_fwd_kernel(
mix_ptr,
b_ptr,
hpre_ptr,
hpost_ptr,
hres_ptr,
hist_ptr,
N: tl.constexpr,
HC: tl.constexpr,
M: tl.constexpr,
stride_mn: tl.constexpr,
stride_mm: tl.constexpr,
stride_hp_n: tl.constexpr,
stride_hp_h: tl.constexpr,
stride_hq_n: tl.constexpr,
stride_hq_h: tl.constexpr,
stride_hr_n: tl.constexpr,
stride_hr_i: tl.constexpr,
stride_hr_j: tl.constexpr,
stride_hn: tl.constexpr,
stride_ht: tl.constexpr,
stride_hi: tl.constexpr,
stride_hj: tl.constexpr,
alpha_pre_ptr,
alpha_post_ptr,
alpha_res_ptr,
pre_eps: tl.constexpr,
sinkhorn_eps: tl.constexpr,
post_mult: tl.constexpr,
TMAX: tl.constexpr,
STORE_HIST: tl.constexpr,
):
pid = tl.program_id(0)
if pid >= N:
return
# Load scalar alpha parameters from GPU memory (avoids CPU sync)
alpha_pre = tl.load(alpha_pre_ptr).to(tl.float32)
alpha_post = tl.load(alpha_post_ptr).to(tl.float32)
alpha_res = tl.load(alpha_res_ptr).to(tl.float32)
# --- Pre/post logits
j = tl.arange(0, HC)
mix_pre = tl.load(mix_ptr + pid * stride_mn + j * stride_mm).to(tl.float32)
mix_post = tl.load(mix_ptr + pid * stride_mn + (HC + j) * stride_mm).to(tl.float32)
b_pre = tl.load(b_ptr + j).to(tl.float32)
b_post = tl.load(b_ptr + (HC + j)).to(tl.float32)
pre_logits = mix_pre * alpha_pre + b_pre
post_logits = mix_post * alpha_post + b_post
pre = tl.sigmoid(pre_logits) + pre_eps
post = tl.sigmoid(post_logits) * post_mult
tl.store(hpre_ptr + pid * stride_hp_n + j * stride_hp_h, pre)
tl.store(hpost_ptr + pid * stride_hq_n + j * stride_hq_h, post)
# --- Residual logits matrix [HC, HC]
rows = tl.arange(0, HC)[:, None]
cols = tl.arange(0, HC)[None, :]
flat = rows * HC + cols # [HC,HC]
mix_res = tl.load(mix_ptr + pid * stride_mn + (2 * HC + flat) * stride_mm).to(tl.float32)
b_res = tl.load(b_ptr + (2 * HC + flat)).to(tl.float32)
logits = mix_res * alpha_res + b_res
# Sinkhorn: initial row-softmax (stable) then alternating row/col norms.
row_max = tl.max(logits, axis=1)
e = tl.exp(logits - row_max[:, None])
row_sum = tl.sum(e, axis=1)
mat = e / row_sum[:, None] + sinkhorn_eps
col_sum = tl.sum(mat, axis=0)
mat = mat / (col_sum[None, :] + sinkhorn_eps)
if STORE_HIST:
tl.store(
hist_ptr + pid * stride_hn + 0 * stride_ht + rows * stride_hi + cols * stride_hj,
mat,
)
for t in tl.static_range(0, TMAX - 1):
row_sum = tl.sum(mat, axis=1)
mat = mat / (row_sum[:, None] + sinkhorn_eps)
col_sum = tl.sum(mat, axis=0)
mat = mat / (col_sum[None, :] + sinkhorn_eps)
if STORE_HIST:
tl.store(
hist_ptr + pid * stride_hn + (t + 1) * stride_ht + rows * stride_hi + cols * stride_hj,
mat,
)
# Store h_res [N, HC, HC] (row-major: out, in)
tl.store(hres_ptr + pid * stride_hr_n + rows * stride_hr_i + cols * stride_hr_j, mat)
@triton.jit
def _mhc_sinkhorn_bwd_kernel(
mix_ptr,
b_ptr,
grad_out_ptr,
grad_logits_ptr,
N: tl.constexpr,
HC: tl.constexpr,
stride_mn: tl.constexpr,
stride_mm: tl.constexpr,
stride_go_n: tl.constexpr,
stride_go_i: tl.constexpr,
stride_go_j: tl.constexpr,
stride_gl_n: tl.constexpr,
stride_gl_i: tl.constexpr,
stride_gl_j: tl.constexpr,
alpha_res_ptr,
sinkhorn_eps: tl.constexpr,
TMAX: tl.constexpr,
):
pid = tl.program_id(0)
if pid >= N:
return
alpha_res = tl.load(alpha_res_ptr).to(tl.float32)
rows = tl.arange(0, HC)[:, None]
cols = tl.arange(0, HC)[None, :]
flat = rows * HC + cols
# Rebuild logits
mix_res = tl.load(mix_ptr + pid * stride_mn + (2 * HC + flat) * stride_mm).to(tl.float32)
b_res = tl.load(b_ptr + (2 * HC + flat)).to(tl.float32)
logits = mix_res * alpha_res + b_res
# Forward recompute (no lists) and backward with recompute per step.
row_max = tl.max(logits, axis=1)
e = tl.exp(logits - row_max[:, None])
row_sum0 = tl.sum(e, axis=1)
p = e / row_sum0[:, None] # softmax, row-wise
p_eps = p + sinkhorn_eps
col_sum0 = tl.sum(p_eps, axis=0)
mat0 = p_eps / (col_sum0[None, :] + sinkhorn_eps)
# Start backward from grad_out
g = tl.load(
grad_out_ptr + pid * stride_go_n + rows * stride_go_i + cols * stride_go_j,
).to(tl.float32)
# Reverse iterations (TMAX-1 .. 1), recomputing mat_t, rs_t, cs_t
for t in tl.static_range(TMAX - 1, 0, -1):
mat = mat0
rs_t = row_sum0
cs_t = col_sum0
mat_t = mat0
for s in tl.static_range(1, TMAX):
rs = tl.sum(mat, axis=1)
mat = mat / (rs[:, None] + sinkhorn_eps)
cs = tl.sum(mat, axis=0)
mat = mat / (cs[None, :] + sinkhorn_eps)
if s == t:
mat_t = mat
rs_t = rs
cs_t = cs
denom_col = cs_t + sinkhorn_eps # [HC]
dot_col = tl.sum(g * mat_t, axis=0) # [HC]
g_row = (g - dot_col[None, :]) / denom_col[None, :]
m_row = mat_t * denom_col[None, :] # invert col norm: m_row = m_out * denom
denom_row = rs_t + sinkhorn_eps
dot_row = tl.sum(g_row * m_row, axis=1)
g = (g_row - dot_row[:, None]) / denom_row[:, None]
# Undo initial col norm (t=0)
denom_col0 = col_sum0 + sinkhorn_eps
dot_col0 = tl.sum(g * mat0, axis=0)
g_p = (g - dot_col0[None, :]) / denom_col0[None, :]
# Softmax backward on rows: p * (g_p - sum(g_p * p))
dot_soft = tl.sum(g_p * p, axis=1)
grad_logits = p * (g_p - dot_soft[:, None])
tl.store(grad_logits_ptr + pid * stride_gl_n + rows * stride_gl_i + cols * stride_gl_j, grad_logits)
@triton.jit
def _mhc_sinkhorn_bwd_hist_kernel(
mix_ptr,
b_ptr,
hist_ptr,
grad_out_ptr,
grad_logits_ptr,
N: tl.constexpr,
HC: tl.constexpr,
stride_mn: tl.constexpr,
stride_mm: tl.constexpr,
stride_hn: tl.constexpr,
stride_ht: tl.constexpr,
stride_hi: tl.constexpr,
stride_hj: tl.constexpr,
stride_go_n: tl.constexpr,
stride_go_i: tl.constexpr,
stride_go_j: tl.constexpr,
stride_gl_n: tl.constexpr,
stride_gl_i: tl.constexpr,
stride_gl_j: tl.constexpr,
alpha_res_ptr,
sinkhorn_eps: tl.constexpr,
TMAX: tl.constexpr,
):
pid = tl.program_id(0)
if pid >= N:
return
alpha_res = tl.load(alpha_res_ptr).to(tl.float32)
rows = tl.arange(0, HC)[:, None]
cols = tl.arange(0, HC)[None, :]
flat = rows * HC + cols
# Rebuild logits
mix_res = tl.load(mix_ptr + pid * stride_mn + (2 * HC + flat) * stride_mm).to(tl.float32)
b_res = tl.load(b_ptr + (2 * HC + flat)).to(tl.float32)
logits = mix_res * alpha_res + b_res
# Initial row-softmax
row_max = tl.max(logits, axis=1)
e = tl.exp(logits - row_max[:, None])
row_sum0 = tl.sum(e, axis=1)
p = e / row_sum0[:, None]
p_eps = p + sinkhorn_eps
col_sum0 = tl.sum(p_eps, axis=0)
mat0 = p_eps / (col_sum0[None, :] + sinkhorn_eps)
# Start backward from grad_out
g = tl.load(
grad_out_ptr + pid * stride_go_n + rows * stride_go_i + cols * stride_go_j,
).to(tl.float32)
# Reverse iterations (TMAX-1 .. 1) using stored mats
for t in tl.static_range(TMAX - 1, 0, -1):
mat_t = tl.load(hist_ptr + pid * stride_hn + t * stride_ht + rows * stride_hi + cols * stride_hj).to(tl.float32)
mat_prev = tl.load(hist_ptr + pid * stride_hn + (t - 1) * stride_ht + rows * stride_hi + cols * stride_hj).to(
tl.float32
)
row_sum = tl.sum(mat_prev, axis=1)
mat_row = mat_prev / (row_sum[:, None] + sinkhorn_eps)
col_sum = tl.sum(mat_row, axis=0)
denom_col = col_sum + sinkhorn_eps
dot_col = tl.sum(g * mat_t, axis=0)
g_row = (g - dot_col[None, :]) / denom_col[None, :]
m_row = mat_t * denom_col[None, :]
denom_row = row_sum + sinkhorn_eps
dot_row = tl.sum(g_row * m_row, axis=1)
g = (g_row - dot_row[:, None]) / denom_row[:, None]
# Undo initial col norm (t=0)
denom_col0 = col_sum0 + sinkhorn_eps
dot_col0 = tl.sum(g * mat0, axis=0)
g_p = (g - dot_col0[None, :]) / denom_col0[None, :]
# Softmax backward on rows: p * (g_p - sum(g_p * p))
dot_soft = tl.sum(g_p * p, axis=1)
grad_logits = p * (g_p - dot_soft[:, None])
tl.store(grad_logits_ptr + pid * stride_gl_n + rows * stride_gl_i + cols * stride_gl_j, grad_logits)
def mhc_split_sinkhorn_fwd(
mix: torch.Tensor,
b: torch.Tensor,
alpha_pre: torch.Tensor,
alpha_post: torch.Tensor,
alpha_res: torch.Tensor,
*,
tmax: int,
pre_eps: float,
sinkhorn_eps: float,
post_mult: float,
out_hpre: Optional[torch.Tensor] = None,
out_hpost: Optional[torch.Tensor] = None,
out_hres: Optional[torch.Tensor] = None,
out_hist: Optional[torch.Tensor] = None,
return_hist: bool = False,
num_warps: int = 1,
) -> Union[
Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
]:
"""
Compute h_pre, h_post, h_res from `mix` (already normalized by RMS scalar).
mix: [N, M] float32 where M = HC*HC + 2*HC
b: [M] float32
"""
assert mix.is_contiguous() and b.is_contiguous()
N, M = mix.shape
assert M == b.numel()
# infer HC from M = HC*HC + 2*HC
# Solve HC^2 + 2HC - M = 0
HC = int((math.isqrt(4 + 4 * M) - 2) // 2)
assert HC * HC + 2 * HC == M, f"Invalid M for mHC: M={M}"
if out_hpre is None:
out_hpre = torch.empty((N, HC), device=mix.device, dtype=torch.float32)
if out_hpost is None:
out_hpost = torch.empty((N, HC), device=mix.device, dtype=torch.float32)
if out_hres is None:
out_hres = torch.empty((N, HC, HC), device=mix.device, dtype=torch.float32)
if return_hist:
if out_hist is None:
out_hist = torch.empty((N, tmax, HC, HC), device=mix.device, dtype=torch.float32)
else:
if out_hist is None:
out_hist = torch.empty((1,), device=mix.device, dtype=torch.float32)
grid = (N,)
_mhc_split_sinkhorn_fwd_kernel[grid](
mix,
b,
out_hpre,
out_hpost,
out_hres,
out_hist,
N=N,
HC=HC,
M=M,
stride_mn=mix.stride(0),
stride_mm=mix.stride(1),
stride_hp_n=out_hpre.stride(0),
stride_hp_h=out_hpre.stride(1),
stride_hq_n=out_hpost.stride(0),
stride_hq_h=out_hpost.stride(1),
stride_hr_n=out_hres.stride(0),
stride_hr_i=out_hres.stride(1),
stride_hr_j=out_hres.stride(2),
stride_hn=out_hist.stride(0) if out_hist.ndim > 1 else 0,
stride_ht=out_hist.stride(1) if out_hist.ndim > 1 else 0,
stride_hi=out_hist.stride(2) if out_hist.ndim > 1 else 0,
stride_hj=out_hist.stride(3) if out_hist.ndim > 1 else 0,
alpha_pre_ptr=alpha_pre.contiguous(),
alpha_post_ptr=alpha_post.contiguous(),
alpha_res_ptr=alpha_res.contiguous(),
pre_eps=pre_eps,
sinkhorn_eps=sinkhorn_eps,
post_mult=post_mult,
TMAX=tmax,
STORE_HIST=return_hist,
num_warps=num_warps,
)
if return_hist:
return out_hpre, out_hpost, out_hres, out_hist
return out_hpre, out_hpost, out_hres
def mhc_sinkhorn_bwd(
mix: torch.Tensor,
b: torch.Tensor,
alpha_res: torch.Tensor,
grad_hres: torch.Tensor,
*,
tmax: int,
sinkhorn_eps: float,
hist: Optional[torch.Tensor] = None,
out_grad_logits: Optional[torch.Tensor] = None,
num_warps: int = 1,
) -> torch.Tensor:
"""
Backward for Sinkhorn: returns grad_logits (same shape as h_res).
mix: [N, M] float32
b: [M] float32
grad_hres: [N, HC, HC] float32
"""
assert mix.is_contiguous() and b.is_contiguous() and grad_hres.is_contiguous()
N, M = mix.shape
HC = grad_hres.shape[1]
assert grad_hres.shape == (N, HC, HC)
assert M == HC * HC + 2 * HC
if out_grad_logits is None:
out_grad_logits = torch.empty((N, HC, HC), device=mix.device, dtype=torch.float32)
grid = (N,)
alpha_res_c = alpha_res.contiguous()
if hist is not None:
assert hist.is_contiguous()
assert hist.shape == (N, tmax, HC, HC)
_mhc_sinkhorn_bwd_hist_kernel[grid](
mix,
b,
hist,
grad_hres,
out_grad_logits,
N=N,
HC=HC,
stride_mn=mix.stride(0),
stride_mm=mix.stride(1),
stride_hn=hist.stride(0),
stride_ht=hist.stride(1),
stride_hi=hist.stride(2),
stride_hj=hist.stride(3),
stride_go_n=grad_hres.stride(0),
stride_go_i=grad_hres.stride(1),
stride_go_j=grad_hres.stride(2),
stride_gl_n=out_grad_logits.stride(0),
stride_gl_i=out_grad_logits.stride(1),
stride_gl_j=out_grad_logits.stride(2),
alpha_res_ptr=alpha_res_c,
sinkhorn_eps=sinkhorn_eps,
TMAX=tmax,
num_warps=num_warps,
)
else:
_mhc_sinkhorn_bwd_kernel[grid](
mix,
b,
grad_hres,
out_grad_logits,
N=N,
HC=HC,
stride_mn=mix.stride(0),
stride_mm=mix.stride(1),
stride_go_n=grad_hres.stride(0),
stride_go_i=grad_hres.stride(1),
stride_go_j=grad_hres.stride(2),
stride_gl_n=out_grad_logits.stride(0),
stride_gl_i=out_grad_logits.stride(1),
stride_gl_j=out_grad_logits.stride(2),
alpha_res_ptr=alpha_res_c,
sinkhorn_eps=sinkhorn_eps,
TMAX=tmax,
num_warps=num_warps,
)
return out_grad_logits
# -------------------------------------------------------------------------------------------------
# Apply kernels: mhc_pre and mhc_post_res (forward + backward)
# -------------------------------------------------------------------------------------------------
@triton.jit
def _mhc_pre_fwd_kernel(
x_ptr,
hpre_ptr,
out_ptr,
N: tl.constexpr,
HC: tl.constexpr,
C: tl.constexpr,
stride_xn: tl.constexpr,
stride_xh: tl.constexpr,
stride_xc: tl.constexpr,
stride_hn: tl.constexpr,
stride_hh: tl.constexpr,
stride_on: tl.constexpr,
stride_oc: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_C: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_c = tl.program_id(1)
n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_offs = pid_c * BLOCK_C + tl.arange(0, BLOCK_C)
acc = tl.zeros((BLOCK_N, BLOCK_C), tl.float32)
for s in tl.static_range(0, HC):
h_s = tl.load(
hpre_ptr + n_offs * stride_hn + s * stride_hh,
mask=(n_offs < N),
other=0.0,
).to(tl.float32)
xs = tl.load(
x_ptr + n_offs[:, None] * stride_xn + s * stride_xh + c_offs[None, :] * stride_xc,
mask=(n_offs[:, None] < N) & (c_offs[None, :] < C),
other=0.0,
).to(tl.float32)
acc += xs * h_s[:, None]
tl.store(
out_ptr + n_offs[:, None] * stride_on + c_offs[None, :] * stride_oc,
acc,
mask=(n_offs[:, None] < N) & (c_offs[None, :] < C),
)
@triton.jit
def _mhc_pre_bwd_kernel(
x_ptr,
hpre_ptr,
grad_out_ptr,
grad_x_ptr,
grad_h_ptr,
N: tl.constexpr,
HC: tl.constexpr,
C: tl.constexpr,
stride_xn: tl.constexpr,
stride_xh: tl.constexpr,
stride_xc: tl.constexpr,
stride_hn: tl.constexpr,
stride_hh: tl.constexpr,
stride_gon: tl.constexpr,
stride_goc: tl.constexpr,
stride_gxn: tl.constexpr,
stride_gxh: tl.constexpr,
stride_gxc: tl.constexpr,
stride_ghn: tl.constexpr,
stride_ghh: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_C: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_c = tl.program_id(1)
n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_offs = pid_c * BLOCK_C + tl.arange(0, BLOCK_C)
go = tl.load(
grad_out_ptr + n_offs[:, None] * stride_gon + c_offs[None, :] * stride_goc,
mask=(n_offs[:, None] < N) & (c_offs[None, :] < C),
other=0.0,
).to(tl.float32)
# grad_x = grad_out * hpre
for s in tl.static_range(0, HC):
h_s = tl.load(
hpre_ptr + n_offs * stride_hn + s * stride_hh,
mask=(n_offs < N),
other=0.0,
).to(tl.float32)
gx = go * h_s[:, None]
tl.store(
grad_x_ptr + n_offs[:, None] * stride_gxn + s * stride_gxh + c_offs[None, :] * stride_gxc,
gx,
mask=(n_offs[:, None] < N) & (c_offs[None, :] < C),
)
# grad_hpre: dot(go, x_s) over C -> atomic add
xs = tl.load(
x_ptr + n_offs[:, None] * stride_xn + s * stride_xh + c_offs[None, :] * stride_xc,
mask=(n_offs[:, None] < N) & (c_offs[None, :] < C),
other=0.0,
).to(tl.float32)
part = tl.sum(go * xs, axis=1)
tl.atomic_add(
grad_h_ptr + n_offs * stride_ghn + s * stride_ghh,
part,
mask=n_offs < N,
)
def mhc_pre_fwd(
x: torch.Tensor,
h_pre: torch.Tensor,
*,
out: Optional[torch.Tensor] = None,
block_n: int = 32,
block_c: int = 128,
num_warps: int = 4,
) -> torch.Tensor:
assert x.is_contiguous() and h_pre.is_contiguous()
N, HC, C = x.shape
assert h_pre.shape == (N, HC)
if out is None:
out = torch.empty((N, C), device=x.device, dtype=torch.float32)
grid = (triton.cdiv(N, block_n), triton.cdiv(C, block_c))
_mhc_pre_fwd_kernel[grid](
x,
h_pre,
out,
N=N,
HC=HC,
C=C,
stride_xn=x.stride(0),
stride_xh=x.stride(1),
stride_xc=x.stride(2),
stride_hn=h_pre.stride(0),
stride_hh=h_pre.stride(1),
stride_on=out.stride(0),
stride_oc=out.stride(1),
BLOCK_N=block_n,
BLOCK_C=block_c,
num_warps=num_warps,
)
return out
def mhc_pre_bwd(
x: torch.Tensor,
h_pre: torch.Tensor,
grad_out: torch.Tensor,
*,
out_grad_x: Optional[torch.Tensor] = None,
out_grad_h: Optional[torch.Tensor] = None,
block_n: int = 32,
block_c: int = 128,
num_warps: int = 4,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.is_contiguous() and h_pre.is_contiguous() and grad_out.is_contiguous()
N, HC, C = x.shape
assert grad_out.shape == (N, C)
if out_grad_x is None:
out_grad_x = torch.empty_like(x, dtype=torch.float32)
if out_grad_h is None:
out_grad_h = torch.zeros((N, HC), device=x.device, dtype=torch.float32)
grid = (triton.cdiv(N, block_n), triton.cdiv(C, block_c))
_mhc_pre_bwd_kernel[grid](
x,
h_pre,
grad_out,
out_grad_x,
out_grad_h,
N=N,
HC=HC,
C=C,
stride_xn=x.stride(0),
stride_xh=x.stride(1),
stride_xc=x.stride(2),
stride_hn=h_pre.stride(0),
stride_hh=h_pre.stride(1),
stride_gon=grad_out.stride(0),
stride_goc=grad_out.stride(1),
stride_gxn=out_grad_x.stride(0),
stride_gxh=out_grad_x.stride(1),
stride_gxc=out_grad_x.stride(2),
stride_ghn=out_grad_h.stride(0),
stride_ghh=out_grad_h.stride(1),
BLOCK_N=block_n,
BLOCK_C=block_c,
num_warps=num_warps,
)
return out_grad_x, out_grad_h
@triton.jit
def _mhc_post_res_fwd_kernel(
x_ptr,
f_ptr,
hpost_ptr,
hres_ptr,
out_ptr,
N: tl.constexpr,
HC: tl.constexpr,
C: tl.constexpr,
stride_xn: tl.constexpr,
stride_xh: tl.constexpr,
stride_xc: tl.constexpr,
stride_fn: tl.constexpr,
stride_fc: tl.constexpr,
stride_hpn: tl.constexpr,
stride_hph: tl.constexpr,
stride_hrn: tl.constexpr,
stride_hri: tl.constexpr,
stride_hrj: tl.constexpr,
stride_on: tl.constexpr,
stride_oh: tl.constexpr,
stride_oc: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_C: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_c = tl.program_id(1)
n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_offs = pid_c * BLOCK_C + tl.arange(0, BLOCK_C)
f = tl.load(
f_ptr + n_offs[:, None] * stride_fn + c_offs[None, :] * stride_fc,
mask=(n_offs[:, None] < N) & (c_offs[None, :] < C),
other=0.0,
).to(tl.float32)
o2 = tl.arange(0, HC)[:, None] # [HC,1]
hpost = tl.load(
hpost_ptr + n_offs[None, :] * stride_hpn + o2 * stride_hph,
mask=(n_offs[None, :] < N),
other=0.0,
).to(tl.float32) # [HC, BN]
acc = f[None, :, :] * hpost[:, :, None] # [HC, BN, BC]
# residual mixing: sum_i hres[o,i] * x_i
for i in tl.static_range(0, HC):
xs = tl.load(
x_ptr + n_offs[:, None] * stride_xn + i * stride_xh + c_offs[None, :] * stride_xc,
mask=(n_offs[:, None] < N) & (c_offs[None, :] < C),
other=0.0,
).to(tl.float32) # [BN, BC]
w = tl.load(
hres_ptr + n_offs[None, :] * stride_hrn + o2 * stride_hri + i * stride_hrj,
mask=(n_offs[None, :] < N),
other=0.0,
).to(tl.float32) # [HC, BN]
acc += xs[None, :, :] * w[:, :, None]
o3 = tl.arange(0, HC)[:, None, None]
n3 = n_offs[None, :, None]
c3 = c_offs[None, None, :]
tl.store(
out_ptr + n3 * stride_on + o3 * stride_oh + c3 * stride_oc,
acc,
mask=(n3 < N) & (c3 < C),
)
@triton.jit
def _mhc_post_res_bwd_kernel(
x_ptr,
f_ptr,
hpost_ptr,
hres_ptr,
grad_out_ptr,
grad_x_ptr,
grad_f_ptr,
grad_hpost_ptr,
grad_hres_ptr,
N: tl.constexpr,
HC: tl.constexpr,
C: tl.constexpr,
stride_xn: tl.constexpr,
stride_xh: tl.constexpr,
stride_xc: tl.constexpr,
stride_fn: tl.constexpr,
stride_fc: tl.constexpr,
stride_hpn: tl.constexpr,
stride_hph: tl.constexpr,
stride_hrn: tl.constexpr,
stride_hri: tl.constexpr,
stride_hrj: tl.constexpr,
stride_gon: tl.constexpr,
stride_goh: tl.constexpr,
stride_goc: tl.constexpr,
stride_gxn: tl.constexpr,
stride_gxh: tl.constexpr,
stride_gxc: tl.constexpr,
stride_gfn: tl.constexpr,
stride_gfc: tl.constexpr,
stride_ghpn: tl.constexpr,
stride_ghph: tl.constexpr,
stride_ghrn: tl.constexpr,
stride_ghri: tl.constexpr,
stride_ghrj: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_C: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_c = tl.program_id(1)
n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_offs = pid_c * BLOCK_C + tl.arange(0, BLOCK_C)
f = tl.load(
f_ptr + n_offs[:, None] * stride_fn + c_offs[None, :] * stride_fc,
mask=(n_offs[:, None] < N) & (c_offs[None, :] < C),
other=0.0,
).to(tl.float32)
o2 = tl.arange(0, HC)[:, None] # [HC,1]
hpost = tl.load(
hpost_ptr + n_offs[None, :] * stride_hpn + o2 * stride_hph,
mask=(n_offs[None, :] < N),
other=0.0,
).to(tl.float32) # [HC, BN]
o3 = tl.arange(0, HC)[:, None, None]
n3 = n_offs[None, :, None]
c3 = c_offs[None, None, :]
go = tl.load(
grad_out_ptr + n3 * stride_gon + o3 * stride_goh + c3 * stride_goc,
mask=(n3 < N) & (c3 < C),
other=0.0,
).to(tl.float32) # [HC, BN, BC]
# grad_f: sum_o go[o] * hpost[o]
gf = tl.sum(go * hpost[:, :, None], axis=0)
tl.store(
grad_f_ptr + n_offs[:, None] * stride_gfn + c_offs[None, :] * stride_gfc,
gf,
mask=(n_offs[:, None] < N) & (c_offs[None, :] < C),
)
# grad_hpost: dot(go[o], f) over C (atomic over C blocks)
part_hpost = tl.sum(go * f[None, :, :], axis=2) # [HC, BN]
tl.atomic_add(
grad_hpost_ptr + n_offs[None, :] * stride_ghpn + o2 * stride_ghph,
part_hpost,
mask=(n_offs[None, :] < N),
)
# grad_x: hres^T @ go (in-stream i gets sum_o hres[o,i] * go[o])
for i in tl.static_range(0, HC):
w = tl.load(
hres_ptr + n_offs[None, :] * stride_hrn + o2 * stride_hri + i * stride_hrj,
mask=(n_offs[None, :] < N),
other=0.0,
).to(tl.float32) # [HC, BN]
gx = tl.sum(go * w[:, :, None], axis=0) # [BN, BC]
tl.store(
grad_x_ptr + n_offs[:, None] * stride_gxn + i * stride_gxh + c_offs[None, :] * stride_gxc,
gx,
mask=(n_offs[:, None] < N) & (c_offs[None, :] < C),
)
# grad_hres[o,i]: dot(go[o], x[i]) over C (atomic)
for i in tl.static_range(0, HC):
xi = tl.load(
x_ptr + n_offs[:, None] * stride_xn + i * stride_xh + c_offs[None, :] * stride_xc,
mask=(n_offs[:, None] < N) & (c_offs[None, :] < C),
other=0.0,
).to(tl.float32)
part_hres = tl.sum(go * xi[None, :, :], axis=2) # [HC, BN]
tl.atomic_add(
grad_hres_ptr + n_offs[None, :] * stride_ghrn + o2 * stride_ghri + i * stride_ghrj,
part_hres,
mask=(n_offs[None, :] < N),
)
def mhc_post_res_fwd(
x: torch.Tensor,
f_out: torch.Tensor,
h_post: torch.Tensor,
h_res: torch.Tensor,
*,
out: Optional[torch.Tensor] = None,
block_n: Optional[int] = None,
block_c: Optional[int] = None,
num_warps: Optional[int] = None,
num_stages: Optional[int] = None,
) -> torch.Tensor:
assert x.is_contiguous() and f_out.is_contiguous() and h_post.is_contiguous() and h_res.is_contiguous()
N, HC, C = x.shape
assert f_out.shape == (N, C)
assert h_post.shape == (N, HC)
assert h_res.shape == (N, HC, HC)
if out is None:
out = torch.empty((N, HC, C), device=x.device, dtype=torch.float32)
block_n, block_c, num_warps, num_stages = _post_res_meta(C, block_n, block_c, num_warps, num_stages)
grid = (triton.cdiv(N, block_n), triton.cdiv(C, block_c))
_mhc_post_res_fwd_kernel[grid](
x,
f_out,
h_post,
h_res,
out,
N=N,
HC=HC,
C=C,
stride_xn=x.stride(0),
stride_xh=x.stride(1),
stride_xc=x.stride(2),
stride_fn=f_out.stride(0),
stride_fc=f_out.stride(1),
stride_hpn=h_post.stride(0),
stride_hph=h_post.stride(1),
stride_hrn=h_res.stride(0),
stride_hri=h_res.stride(1),
stride_hrj=h_res.stride(2),
stride_on=out.stride(0),
stride_oh=out.stride(1),
stride_oc=out.stride(2),
BLOCK_N=block_n,
BLOCK_C=block_c,
num_warps=num_warps,
num_stages=num_stages,
)
return out
def mhc_post_res_bwd(
x: torch.Tensor,
f_out: torch.Tensor,
h_post: torch.Tensor,
h_res: torch.Tensor,
grad_out: torch.Tensor,
*,
out_grad_x: Optional[torch.Tensor] = None,
out_grad_f: Optional[torch.Tensor] = None,
out_grad_hpost: Optional[torch.Tensor] = None,
out_grad_hres: Optional[torch.Tensor] = None,
block_n: Optional[int] = None,
block_c: Optional[int] = None,
num_warps: Optional[int] = None,
num_stages: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
assert (
x.is_contiguous()
and f_out.is_contiguous()
and h_post.is_contiguous()
and h_res.is_contiguous()
and grad_out.is_contiguous()
)
N, HC, C = x.shape
assert grad_out.shape == (N, HC, C)
if out_grad_x is None:
out_grad_x = torch.empty_like(x, dtype=torch.float32)
if out_grad_f is None:
out_grad_f = torch.empty_like(f_out, dtype=torch.float32)
if out_grad_hpost is None:
out_grad_hpost = torch.zeros((N, HC), device=x.device, dtype=torch.float32)
if out_grad_hres is None:
out_grad_hres = torch.zeros((N, HC, HC), device=x.device, dtype=torch.float32)
block_n, block_c, num_warps, num_stages = _post_res_meta(C, block_n, block_c, num_warps, num_stages)
grid = (triton.cdiv(N, block_n), triton.cdiv(C, block_c))
_mhc_post_res_bwd_kernel[grid](
x,
f_out,
h_post,
h_res,
grad_out,
out_grad_x,
out_grad_f,
out_grad_hpost,
out_grad_hres,
N=N,
HC=HC,
C=C,
stride_xn=x.stride(0),
stride_xh=x.stride(1),
stride_xc=x.stride(2),
stride_fn=f_out.stride(0),
stride_fc=f_out.stride(1),
stride_hpn=h_post.stride(0),
stride_hph=h_post.stride(1),
stride_hrn=h_res.stride(0),
stride_hri=h_res.stride(1),
stride_hrj=h_res.stride(2),
stride_gon=grad_out.stride(0),
stride_goh=grad_out.stride(1),
stride_goc=grad_out.stride(2),
stride_gxn=out_grad_x.stride(0),
stride_gxh=out_grad_x.stride(1),
stride_gxc=out_grad_x.stride(2),
stride_gfn=out_grad_f.stride(0),
stride_gfc=out_grad_f.stride(1),
stride_ghpn=out_grad_hpost.stride(0),
stride_ghph=out_grad_hpost.stride(1),
stride_ghrn=out_grad_hres.stride(0),
stride_ghri=out_grad_hres.stride(1),
stride_ghrj=out_grad_hres.stride(2),
BLOCK_N=block_n,
BLOCK_C=block_c,
num_warps=num_warps,
num_stages=num_stages,
)
return out_grad_x, out_grad_f, out_grad_hpost, out_grad_hres
def _flatten_tokens(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Size]:
"""
Flattens leading dimensions so x becomes [N, HC, C].
Returns (x_flat, x_shape) where x_shape is the original shape.
"""
assert x.dim() >= 3, "x must be [..., HC, C]"
return x.contiguous().view(-1, x.shape[-2], x.shape[-1]), x.shape
class LigerMHCCoeffsFunction(torch.autograd.Function):
"""
Autograd function for mHC coefficient computation.
Memory/Compute Trade-off:
When gradients are needed, Sinkhorn iteration history (hist) is saved
during forward to avoid recomputation in backward. This increases
memory usage by O(N * tmax * HC^2) but reduces backward compute.
"""
@staticmethod
@ensure_contiguous
def forward( # type: ignore[override]
ctx: Any,
x: torch.Tensor, # [..., HC, C] bf16/fp16 (or fp32 if allow_fp32)
phi: torch.Tensor, # [HC*C, M]
b: torch.Tensor, # [M]
alpha_pre: torch.Tensor, # scalar
alpha_post: torch.Tensor, # scalar
alpha_res: torch.Tensor, # scalar
allow_fp32: bool,
tmax: int,
rms_eps: float,
pre_eps: float,
sinkhorn_eps: float,
post_mult: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if allow_fp32:
assert x.dtype in (
torch.bfloat16,
torch.float16,
torch.float32,
), "x should be BF16/FP16/FP32 when allow_fp32=True"
else:
assert x.dtype in (torch.bfloat16, torch.float16), "x should be BF16/FP16 (set allow_fp32=True for FP32)"
# Store original shape for restoring at the end
x_shape = x.shape
x_flat, _ = _flatten_tokens(x)
N, HC, C = x_flat.shape
K = HC * C
x_mat = x_flat.view(-1, K)
assert phi.dim() == 2 and phi.shape[0] == K, f"phi must be [HC*C, M], got {tuple(phi.shape)}"
M = int(phi.shape[1])
assert b.shape == (M,), f"b must be [M], got {tuple(b.shape)}"
# (1) fused coeff matmul + norm
mix, invr = mhc_mm_norm_fwd(x_mat, phi, eps=float(rms_eps))
# (2) split + sigmoid + sinkhorn
need_hist = any(ctx.needs_input_grad)
if need_hist:
h_pre, h_post, h_res, hist = mhc_split_sinkhorn_fwd(
mix,
b,
alpha_pre,
alpha_post,
alpha_res,
tmax=int(tmax),
pre_eps=float(pre_eps),
sinkhorn_eps=float(sinkhorn_eps),
post_mult=float(post_mult),
return_hist=True,
)
else:
h_pre, h_post, h_res = mhc_split_sinkhorn_fwd(
mix,
b,
alpha_pre,
alpha_post,
alpha_res,
tmax=int(tmax),
pre_eps=float(pre_eps),
sinkhorn_eps=float(sinkhorn_eps),
post_mult=float(post_mult),
)
hist = None
# Save for backward
if hist is not None:
ctx.save_for_backward(x_mat, phi, b, mix, invr, alpha_pre, alpha_post, alpha_res, hist)
else:
ctx.save_for_backward(x_mat, phi, b, mix, invr, alpha_pre, alpha_post, alpha_res)
ctx.meta = (
x_shape,
HC,
C,
int(tmax),
float(sinkhorn_eps),
float(post_mult),
hist is not None,
)
# Reshape to original leading dims
outer = x_shape[:-2]
return (
h_pre.view(*outer, HC),
h_post.view(*outer, HC),
h_res.view(*outer, HC, HC),
)
@staticmethod
@ensure_contiguous
def backward(
ctx: Any,
grad_h_pre: torch.Tensor | None,
grad_h_post: torch.Tensor | None,
grad_h_res: torch.Tensor | None,
):
saved = ctx.saved_tensors
x_shape, HC, C, tmax, sinkhorn_eps, post_mult, has_hist = ctx.meta
if has_hist:
x_mat, phi, b, mix, invr, alpha_pre, alpha_post, alpha_res, hist = saved
else:
x_mat, phi, b, mix, invr, alpha_pre, alpha_post, alpha_res = saved
hist = None
N = x_mat.shape[0]
M = mix.shape[1]
assert M == HC * HC + 2 * HC
need_pre = grad_h_pre is not None
need_post = grad_h_post is not None
need_res = grad_h_res is not None
# flatten grads (None -> zeros)
if need_pre:
gh_pre = grad_h_pre.view(-1, HC).to(torch.float32)
else:
gh_pre = torch.zeros((N, HC), device=mix.device, dtype=torch.float32)
if need_post:
gh_post = grad_h_post.view(-1, HC).to(torch.float32)
else:
gh_post = torch.zeros((N, HC), device=mix.device, dtype=torch.float32)
if need_res:
gh_res = grad_h_res.view(-1, HC, HC).to(torch.float32)
else:
gh_res = torch.zeros((N, HC, HC), device=mix.device, dtype=torch.float32)
# --- Sinkhorn backward -> grad logits for residual matrix
if need_res:
grad_res_logits = mhc_sinkhorn_bwd(
mix,
b,
alpha_res,
gh_res,
tmax=tmax,
sinkhorn_eps=sinkhorn_eps,
hist=hist,
) # [N, HC, HC] fp32
else:
grad_res_logits = gh_res
# --- Pre/post derivatives (sigmoid)
mix_pre = mix[:, :HC]
mix_post = mix[:, HC : 2 * HC]
mix_res = mix[:, 2 * HC :]
b_pre = b[:HC]
b_post = b[HC : 2 * HC]
if need_pre:
pre_logits = mix_pre * alpha_pre + b_pre
pre_sig = torch.sigmoid(pre_logits)
grad_pre_logits = gh_pre * (pre_sig * (1.0 - pre_sig)) # [N,HC]
else:
grad_pre_logits = gh_pre
if need_post:
post_logits = mix_post * alpha_post + b_post
post_sig = torch.sigmoid(post_logits)
grad_post_logits = gh_post * (post_mult * post_sig * (1.0 - post_sig)) # [N,HC]
else:
grad_post_logits = gh_post
grad_res_logits_flat = grad_res_logits.reshape(N, HC * HC)
# --- Grad w.r.t mix
grad_mix = torch.empty_like(mix)
grad_mix[:, :HC] = grad_pre_logits * alpha_pre
grad_mix[:, HC : 2 * HC] = grad_post_logits * alpha_post
grad_mix[:, 2 * HC :] = grad_res_logits_flat * alpha_res
# --- Grad w.r.t b
grad_b = torch.zeros_like(b, dtype=torch.float32)
if need_pre:
grad_b[:HC] = grad_pre_logits.sum(dim=0)
if need_post:
grad_b[HC : 2 * HC] = grad_post_logits.sum(dim=0)
if need_res:
grad_b[2 * HC :] = grad_res_logits_flat.sum(dim=0)
# --- Grad w.r.t alphas
if need_pre:
grad_alpha_pre = (grad_pre_logits * mix_pre).sum()
else:
grad_alpha_pre = torch.zeros((), device=mix.device, dtype=torch.float32)
if need_post:
grad_alpha_post = (grad_post_logits * mix_post).sum()
else:
grad_alpha_post = torch.zeros((), device=mix.device, dtype=torch.float32)
if need_res:
grad_alpha_res = (grad_res_logits_flat * mix_res).sum()
else:
grad_alpha_res = torch.zeros((), device=mix.device, dtype=torch.float32)
# --- Grad w.r.t x and phi via fused mm+norm backward
grad_x_mat, grad_phi = mhc_mm_norm_bwd(
x_mat,
phi,
mix,
invr,
grad_mix,
)
# Reshape to original shape
grad_x = grad_x_mat.view(x_shape)
# Return grads for each forward input
return (
grad_x, # x
grad_phi, # phi
grad_b, # b
grad_alpha_pre, # alpha_pre
grad_alpha_post, # alpha_post
grad_alpha_res, # alpha_res
None, # allow_fp32
None, # tmax
None, # rms_eps
None, # pre_eps
None, # sinkhorn_eps
None, # post_mult
)
class LigerMHCPreFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx: Any, x: torch.Tensor, h_pre: torch.Tensor) -> torch.Tensor:
x_shape = x.shape
x_flat, _ = _flatten_tokens(x)
h_pre_flat = h_pre.view(-1, x_flat.shape[1]).to(torch.float32)
out = mhc_pre_fwd(x_flat, h_pre_flat) # [N,C] fp32
ctx.save_for_backward(x_flat, h_pre_flat)
ctx.x_shape = x_shape
out = out.to(x_flat.dtype)
return out.view(*x_shape[:-2], out.shape[-1])
@staticmethod
@ensure_contiguous
def backward(ctx: Any, grad_out: torch.Tensor):
x_flat, h_pre_flat = ctx.saved_tensors
x_shape = ctx.x_shape
N, HC, C = x_flat.shape
go = grad_out.view(-1, C).to(torch.float32)
grad_x, grad_h = mhc_pre_bwd(x_flat, h_pre_flat, go)
grad_x = grad_x.to(x_flat.dtype)
return grad_x.view(*x_shape), grad_h.view(*x_shape[:-1])
class LigerMHCPostResFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(
ctx: Any, x: torch.Tensor, f_out: torch.Tensor, h_post: torch.Tensor, h_res: torch.Tensor
) -> torch.Tensor:
x_shape = x.shape
x_flat, _ = _flatten_tokens(x)
N, HC, C = x_flat.shape
f_flat = f_out.view(-1, C)
h_post_flat = h_post.view(-1, HC).to(torch.float32)
h_res_flat = h_res.view(-1, HC, HC).to(torch.float32)
out = mhc_post_res_fwd(x_flat, f_flat, h_post_flat, h_res_flat) # [N,HC,C] fp32
ctx.save_for_backward(x_flat, f_flat, h_post_flat, h_res_flat)
ctx.x_shape = x_shape
out = out.to(x_flat.dtype)
return out.view(*x_shape)
@staticmethod
@ensure_contiguous
def backward(ctx: Any, grad_out: torch.Tensor):
x_flat, f_flat, h_post_flat, h_res_flat = ctx.saved_tensors
x_shape = ctx.x_shape
N, HC, C = x_flat.shape
go = grad_out.view(-1, HC, C).to(torch.float32)
grad_x, grad_f, grad_hpost, grad_hres = mhc_post_res_bwd(x_flat, f_flat, h_post_flat, h_res_flat, go)
outer = x_shape[:-2]
return (
grad_x.to(x_flat.dtype).view(*x_shape),
grad_f.to(f_flat.dtype).view(*outer, C),
grad_hpost.view(*outer, HC),
grad_hres.view(*outer, HC, HC),
)
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from torch.nn.modules.utils import _pair
from liger_kernel.ops.softmax import _softmax_forward
from liger_kernel.ops.sparsemax import _sparsemax_backward
from liger_kernel.ops.sparsemax import _sparsemax_forward
from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import ensure_contiguous
@triton.jit
def _mask_fwd_kernel(
scores_ptr,
out_ptr,
stride_b,
stride_m,
stride_n,
L,
mask_val: tl.constexpr,
BLOCK: tl.constexpr,
num_warps: tl.constexpr,
):
row_block = tl.program_id(0)
col_block = tl.program_id(1)
batch_id = tl.program_id(2)
row_idx = row_block * BLOCK + tl.arange(0, BLOCK)
col_idx = col_block * BLOCK + tl.arange(0, BLOCK)
in_bounds = (row_idx[:, None] < L) & (col_idx[None, :] < L)
base = scores_ptr + batch_id * stride_b
offs = row_idx[:, None] * stride_m + col_idx[None, :] * stride_n
future = col_idx[None, :] > row_idx[:, None]
mask_load = in_bounds & ~future
out = tl.load(base + offs, mask=mask_load, other=mask_val, cache_modifier=".ca")
tl.store(out_ptr + batch_id * stride_b + offs, out, mask=in_bounds, cache_modifier=".cs")
@triton.jit
def _mask_bwd_kernel(
grad_in_ptr, out_ptr, stride_b, stride_m, stride_n, L, BLOCK: tl.constexpr, num_warps: tl.constexpr
):
row_block = tl.program_id(0)
col_block = tl.program_id(1)
batch_id = tl.program_id(2)
row_idx = row_block * BLOCK + tl.arange(0, BLOCK)
col_idx = col_block * BLOCK + tl.arange(0, BLOCK)
in_bounds = (row_idx[:, None] < L) & (col_idx[None, :] < L)
base = grad_in_ptr + batch_id * stride_b
offs = row_idx[:, None] * stride_m + col_idx[None, :] * stride_n
grad_vals = tl.load(base + offs, mask=in_bounds, other=0.0, cache_modifier=".ca")
future = col_idx[None, :] > row_idx[:, None]
zero = tl.zeros(grad_vals.shape, dtype=grad_vals.dtype)
out = tl.where(future, zero, grad_vals)
tl.store(out_ptr + batch_id * stride_b + offs, out, mask=in_bounds, cache_modifier=".wb")
def _mask_inf_forward(scores: torch.Tensor) -> torch.Tensor:
*batch, L, _ = scores.shape
N = int(torch.prod(torch.tensor(batch))) if batch else 1
scores_f = scores.view(N, L, L)
out = torch.empty_like(scores_f)
sb, sm, sn = scores_f.stride(0), scores_f.stride(1), scores_f.stride(2)
BLOCK_SIZE, num_warps = calculate_settings(L)
grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
_mask_fwd_kernel[grid](scores_f, out, sb, sm, sn, L, mask_val=-1e9, BLOCK=BLOCK_SIZE, num_warps=num_warps)
return out.view(*batch, L, L)
def _mask_inf_backward(grad: torch.Tensor) -> torch.Tensor:
*batch, L, _ = grad.shape
N = int(torch.prod(torch.tensor(batch))) if batch else 1
grad_f = grad.view(N, L, L)
out = torch.empty_like(grad_f)
sb, sm, sn = grad_f.stride(0), grad_f.stride(1), grad_f.stride(2)
BLOCK_SIZE, num_warps = calculate_settings(L)
grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
_mask_bwd_kernel[grid](grad_f, out, sb, sm, sn, L, BLOCK=BLOCK_SIZE, num_warps=num_warps)
return out.view(*batch, L, L)
def _mask_zero_forward(scores: torch.Tensor) -> torch.Tensor:
*batch, L, _ = scores.shape
N = int(torch.prod(torch.tensor(batch))) if batch else 1
scores_f = scores.view(N, L, L)
out = torch.empty_like(scores_f)
sb, sm, sn = scores_f.stride(0), scores_f.stride(1), scores_f.stride(2)
BLOCK_SIZE, num_warps = calculate_settings(L)
grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
_mask_fwd_kernel[grid](scores_f, out, sb, sm, sn, L, mask_val=0.0, BLOCK=BLOCK_SIZE, num_warps=num_warps)
return out.view(*batch, L, L)
def _mask_zero_backward(grad: torch.Tensor) -> torch.Tensor:
*batch, L, _ = grad.shape
N = int(torch.prod(torch.tensor(batch))) if batch else 1
grad_f = grad.view(N, L, L)
out = torch.empty_like(grad_f)
sb, sm, sn = grad_f.stride(0), grad_f.stride(1), grad_f.stride(2)
BLOCK_SIZE, num_warps = calculate_settings(L)
grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
_mask_bwd_kernel[grid](grad_f, out, sb, sm, sn, L, BLOCK=BLOCK_SIZE, num_warps=num_warps)
return out.view(*batch, L, L)
class LigerMultiTokenAttentionFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, scores, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, sparse=False):
scores_inf = _mask_inf_forward(scores)
out_flat_sparse = None
activation_output = None
ctx.sparse = sparse
if sparse:
if scores_inf.dtype != torch.float32:
raise RuntimeError("Liger sparse multi-token attention currently only supports fp32 input scores")
probs_sparse, out_flat_sparse = _sparsemax_forward(scores_inf, dim=-1)
activation_output = probs_sparse
ctx.save_for_backward(scores_inf, activation_output, out_flat_sparse, weight, bias)
ctx.out_flat_sparse_saved = True
else:
probs_softmax, _, _, _ = _softmax_forward(scores_inf)
activation_output = probs_softmax
ctx.save_for_backward(scores_inf, activation_output, weight, bias)
ctx.out_flat_sparse_saved = False
out_conv = F.conv2d(
activation_output,
weight,
bias,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out = _mask_zero_forward(out_conv)
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.groups = groups
ctx.dim = -1
return out
@staticmethod
@ensure_contiguous
def backward(ctx, grad_out):
if ctx.out_flat_sparse_saved:
scores_inf, activation_output, out_flat_sparse, weight, bias = ctx.saved_tensors
else:
scores_inf, activation_output, weight, bias = ctx.saved_tensors
out_flat_sparse = None
use_sparsemax = ctx.sparse
dim = ctx.dim
stride, padding, dilation, groups = (ctx.stride, ctx.padding, ctx.dilation, ctx.groups)
grad_conv = _mask_zero_backward(grad_out)
grad_probs = F.conv_transpose2d(
grad_conv, weight, None, stride=stride, padding=padding, dilation=dilation, groups=groups
)
grad_weight = torch.nn.grad.conv2d_weight(
input=activation_output,
weight_size=weight.shape,
grad_output=grad_conv,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
grad_bias = None
if bias is not None:
grad_bias = grad_conv.sum(dim=(0, 2, 3))
grad_scores_inf = None
if use_sparsemax:
if not ctx.out_flat_sparse_saved or out_flat_sparse is None:
raise RuntimeError("Internal error: Sparse flag is set but sparse tensor was not saved.")
grad_scores_inf = _sparsemax_backward(grad_probs, out_flat_sparse, dim=dim)
else:
grad_probs_cont = grad_probs
probs_cont = activation_output
dot = (grad_probs_cont * probs_cont).sum(dim=-1, keepdim=True)
grad_scores_inf = probs_cont * (grad_probs_cont - dot)
grad_scores = _mask_inf_backward(grad_scores_inf)
return (grad_scores, grad_weight, grad_bias, None, None, None, None, None)
import operator
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import compare_version
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
from liger_kernel.ops.utils import set_large_grf_mode
from liger_kernel.utils import is_npu_available
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
from triton.language.extra.cuda.libdevice import rsqrt
else:
from triton.language.math import rsqrt
@triton.jit
def _poly_norm_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
W_ptr, # weight: [3] for [w0, w1, w2]
B_ptr, # bias: scalar
RSTD_ptr, # cache rstd for backward: shape (n_rows, 3)
RSTD_row_stride,
n_cols,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""
PolyNorm formula:
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
where norm(u) = u / sqrt(mean(u²) + ε)
Reference:
1. https://github.com/BryceZhuo/PolyCom/
2. https://arxiv.org/pdf/2411.03884
Cache rstd values for backward pass
"""
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# Load pointers
Y_ptr += row_idx * Y_row_stride
X_ptr += row_idx * X_row_stride
RSTD_ptr += row_idx * RSTD_row_stride
# Load input row
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
# Load weights and bias
w0 = tl.load(W_ptr + 0)
w1 = tl.load(W_ptr + 1)
w2 = tl.load(W_ptr + 2)
b = tl.load(B_ptr)
# Compute x³, x², x
X_pow3 = X_row * X_row * X_row
X_pow2 = X_row * X_row
X_pow1 = X_row
# Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps)
mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=0) / n_cols
rstd_3 = rsqrt(mean_square_3 + eps)
norm_x3 = X_pow3 * rstd_3
# Compute norm(x²)
mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=0) / n_cols
rstd_2 = rsqrt(mean_square_2 + eps)
norm_x2 = X_pow2 * rstd_2
# Compute norm(x)
mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=0) / n_cols
rstd_1 = rsqrt(mean_square_1 + eps)
norm_x1 = X_pow1 * rstd_1
# Cache rstd values for backward
tl.store(RSTD_ptr + 0, rstd_3)
tl.store(RSTD_ptr + 1, rstd_2)
tl.store(RSTD_ptr + 2, rstd_1)
# Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
Y_row = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b
# Store output
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
@triton.jit
def _poly_norm_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
W_ptr,
RSTD_ptr,
RSTD_row_stride,
dW_ptr, # shape: (n_programs, 3)
dW_row_stride,
dB_ptr, # shape: (n_programs,)
n_rows,
n_cols,
rows_per_program: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
PolyNorm Backward Kernel Gradient:
∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
where:
- D_p = RMS(x^p) = 1/rstd_p
- S_p = sum(grad * x^p) over the row
- d = n_cols
- p ∈ {3, 2, 1}
"""
row_block_id = tl.program_id(0).to(tl.int64)
row_start = row_block_id * rows_per_program
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# Initialize accumulators for weight and bias gradients (scalars)
dW0_acc = 0.0
dW1_acc = 0.0
dW2_acc = 0.0
dB_acc = 0.0
# Load weights
w0 = tl.load(W_ptr + 0).to(tl.float32)
w1 = tl.load(W_ptr + 1).to(tl.float32)
w2 = tl.load(W_ptr + 2).to(tl.float32)
for row_idx in range(row_start, row_end):
dy_base = dY_ptr + row_idx * dY_row_stride
x_base = X_ptr + row_idx * X_row_stride
dx_base = dX_ptr + row_idx * dX_row_stride
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0).to(tl.float32)
X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0).to(tl.float32)
# Load cached rstd values
rstd_3 = tl.load(rstd_base + 0).to(tl.float32)
rstd_2 = tl.load(rstd_base + 1).to(tl.float32)
rstd_1 = tl.load(rstd_base + 2).to(tl.float32)
# Compute powers
X_pow3 = X_row * X_row * X_row
X_pow2 = X_row * X_row
X_pow1 = X_row
# Accumulate bias gradient: dB = sum(dY)
dB_acc += tl.sum(dY_row, axis=0)
# Compute gradient w.r.t. input using closed-form formula
# For p=3: ∂L/∂x from w0 * norm(x³)
S_3 = tl.sum(dY_row * X_pow3, axis=0) # scalar
grad_x_3 = w0 * (
3.0 * X_pow2 * rstd_3 * dY_row
- (3.0 / n_cols) * X_row * X_row * X_row * X_row * X_row * (rstd_3 * rstd_3 * rstd_3) * S_3
)
# For p=2: ∂L/∂x from w1 * norm(x²)
S_2 = tl.sum(dY_row * X_pow2, axis=0) # scalar
grad_x_2 = w1 * (
2.0 * X_row * rstd_2 * dY_row - (2.0 / n_cols) * X_row * X_row * X_row * (rstd_2 * rstd_2 * rstd_2) * S_2
)
# For p=1: ∂L/∂x from w2 * norm(x)
S_1 = tl.sum(dY_row * X_pow1, axis=0) # scalar
grad_x_1 = w2 * (1.0 * rstd_1 * dY_row - (1.0 / n_cols) * X_row * (rstd_1 * rstd_1 * rstd_1) * S_1)
# Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p
dW0_acc += rstd_3 * S_3
dW1_acc += rstd_2 * S_2
dW2_acc += rstd_1 * S_1
# Total gradient
dX_row = grad_x_3 + grad_x_2 + grad_x_1
# Store gradient
tl.store(dx_base + col_offsets, dX_row, mask=mask)
# Store accumulated gradients (scalars)
tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc)
tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc)
tl.store(dB_ptr + row_block_id, dB_acc)
def poly_norm_forward(X, W, B, eps=1e-6):
"""
PolyNorm Forward Pass
Args:
X: input tensor of shape (*, H) where H is hidden dimension
W: weight tensor of shape (3,) for [w0, w1, w2]
B: bias scalar tensor
eps: epsilon for numerical stability
Returns:
Y: output tensor of same shape as X
X: reshaped input (for backward)
RSTD: cached rstd values (for backward)
BLOCK_SIZE: block size used
num_warps: number of warps used
"""
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
# RSTD is to cache rstd for each row
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device)
# Check constraints
assert W.shape[0] == 3, "Weight tensor must have shape (3,)"
assert B.numel() == 1, "Bias must be a scalar"
# XPU-specific optimization
kernel_args = {}
if X.device.type == "xpu":
set_large_grf_mode(kernel_args)
# Launch kernel
_poly_norm_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
W,
B,
RSTD,
RSTD.stride(0),
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_args,
)
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps
def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
"""
PolyNorm Backward Pass
Args:
dY: gradient of output
X: input tensor (already reshaped to 2D)
W: weight tensor
RSTD: cached rstd values from forward
BLOCK_SIZE: block size from forward
num_warps: number of warps from forward
in_place: whether to in-place modify dY to store dX (saves memory)
Returns:
dX: gradient w.r.t. input
dW: gradient w.r.t. weight
dB: gradient w.r.t. bias
"""
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
n_rows, n_cols = dY.shape
# Get number of SMs for parallelization
import math
sm_count = 1
if X.device.type == "cuda":
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
elif X.device.type == "xpu":
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
elif X.device.type == "npu":
sm_count = get_npu_core_count()
# Allocate or reuse gradients
if in_place is True:
dX = dY
else:
dX = torch.zeros_like(dY)
_dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device)
_dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device)
rows_per_program = math.ceil(n_rows / sm_count)
grid = (sm_count,)
# XPU-specific optimization
kernel_args = {}
if X.device.type == "xpu":
set_large_grf_mode(kernel_args)
# Launch backward kernel
_poly_norm_backward_kernel[grid](
dY,
dY.stride(0),
dX,
dX.stride(0),
X,
X.stride(0),
W,
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0),
_dB,
n_rows,
n_cols,
rows_per_program,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_args,
)
# Reduce gradients across SMs
dX = dX.view(*shape)
dW = _dW.sum(dim=0).to(W.dtype)
dB = _dB.sum().to(W.dtype)
return dX, dW, dB
class LigerPolyNormFunction(torch.autograd.Function):
"""
PolyNorm Function with forward and backward pass
PolyNorm formula:
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
where norm(u) = u / sqrt(mean(u²) + ε)
Backward uses closed-form gradient:
∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
"""
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, B, eps=1e-6, in_place=True):
"""
Args:
X: input tensor of shape (B, T, H) or (BxT, H)
W: weight tensor of shape (3,) for [w0, w1, w2]
B: bias scalar
eps: epsilon for numerical stability
in_place: whether to in-place modify grad_output in backward (saves memory)
Returns:
Y: output tensor of same shape as X
"""
Y, X, RSTD, BLOCK_SIZE, num_warps = poly_norm_forward(X, W, B, eps)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.in_place = in_place
ctx.save_for_backward(X, W, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, grad_output):
"""
Args:
grad_output: gradient of output
Returns:
dX, dW, dB: gradients w.r.t. X, W, B
"""
X, W, RSTD = ctx.saved_tensors
dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place)
return dX, dW, dB, None, None
import torch
import triton
import triton.language as tl
@triton.jit
def _triton_qwen2vl_mrope(
q_ptr,
k_ptr,
cos,
sin,
sl,
bs: tl.constexpr,
n_qh: tl.constexpr,
n_kh: tl.constexpr,
hd: tl.constexpr,
pad_n_qh: tl.constexpr,
pad_n_kh: tl.constexpr,
pad_hd: tl.constexpr,
mrope_section_t: tl.constexpr,
mrope_section_h: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BACKWARD_PASS: tl.constexpr = False,
):
pid = tl.program_id(0)
# locate start address
q_ptr = q_ptr + pid * (n_qh * hd)
k_ptr = k_ptr + pid * (n_kh * hd)
# ####################################################################
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
# m of this program instance
# ####################################################################
# 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
# effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
# being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
# and pid % sl to get the sequence index.
# 2. We only need the left half of cos and sin matrix because the right half is just
# a clone of the left half.
t_end = mrope_section_t
h_end = t_end + mrope_section_h
t_cos = cos + pid * hd
h_cos = t_cos + bs * sl * hd
w_cos = h_cos + bs * sl * hd
t_sin = sin + pid * hd
h_sin = t_sin + bs * sl * hd
w_sin = h_sin + bs * sl * hd
cos_offsets = tl.arange(0, pad_hd // 2)
t_mask = cos_offsets < t_end
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
w_mask = (h_end <= cos_offsets) & (cos_offsets < hd // 2)
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
cos_row = t_cos_row + h_cos_row + w_cos_row
sin_row = t_sin_row + h_sin_row + w_sin_row
# ####################################################################
# Load the left and right half of q and k for the current
# program instance (i.e. for the current token) separately
# ####################################################################
# left half of the head
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
# right half of the head
second_half_q_offsets = first_half_q_offsets + (hd // 2)
second_half_k_offsets = first_half_k_offsets + (hd // 2)
second_q_mask = first_q_mask
second_k_mask = first_k_mask
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
if not BACKWARD_PASS:
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
else:
# with some math, we can get:
# dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
# transpose it back to the physical shape because Triton looks at the physical storage
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
q = q.transpose(1, 2)
k = k.transpose(1, 2)
batch_size, seq_len, n_q_head, head_dim = q.shape
n_kv_head = k.shape[2]
pad_hd = triton.next_power_of_2(head_dim)
pad_n_q_head = triton.next_power_of_2(n_q_head)
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
n_row = batch_size * seq_len
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
q = q.contiguous()
k = k.contiguous()
cos = cos.contiguous()
sin = sin.contiguous()
_triton_qwen2vl_mrope[(n_row,)](
q,
k,
cos,
sin,
seq_len,
batch_size,
n_q_head,
n_kv_head,
head_dim,
pad_n_q_head,
pad_n_kv_head,
pad_hd,
mrope_section[0],
mrope_section[1],
BLOCK_SIZE=BLOCK_SIZE,
BACKWARD_PASS=False,
)
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
dq = dq.transpose(1, 2)
dk = dk.transpose(1, 2)
batch_size, seq_len, n_q_head, head_dim = dq.shape
n_kv_head = dk.shape[2]
pad_hd = triton.next_power_of_2(head_dim)
pad_n_q_head = triton.next_power_of_2(n_q_head)
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
n_row = batch_size * seq_len
# ensure dq and dk are contiguous
dq = dq.contiguous()
dk = dk.contiguous()
# backward is similar to forward except swapping few ops
_triton_qwen2vl_mrope[(n_row,)](
dq,
dk,
cos,
sin,
seq_len,
batch_size,
n_q_head,
n_kv_head,
head_dim,
pad_n_q_head,
pad_n_kv_head,
pad_hd,
mrope_section[0],
mrope_section[1],
BLOCK_SIZE=BLOCK_SIZE,
BACKWARD_PASS=True,
)
return dq.transpose(1, 2), dk.transpose(1, 2)
class LigerQwen2VLMRopeFunction(torch.autograd.Function):
"""
Triton implementation of the Qwen2VL Multimodal Rotary Positional Embedding (M-RoPE) operation.
Please find the corresponding HuggingFace implementation here:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
"""
@staticmethod
def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""
q size: (bsz, n_q_head, seq_len, head_dim)
k size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (3, bsz, seq_len, head_dim)
sin size: (3, bsz, seq_len, head_dim)
"""
q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
ctx.save_for_backward(cos, sin)
ctx.mrope_section = mrope_section
return q, k
def backward(ctx, dq, dk):
"""
dq size: (bsz, n_q_head, seq_len, head_dim)
dk size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (3, bsz, seq_len, head_dim)
sin size: (3, bsz, seq_len, head_dim)
"""
cos, sin = ctx.saved_tensors
mrope_section = ctx.mrope_section
dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
return dq, dk, None, None, None, None
"""
This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
See the original Unsloth repository at https://github.com/unslothai/unsloth.
The following line
https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/rms_norm.py#L30
is based on code from Unsloth, located at:
https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
Modifications made by Yanning Chen, 2024.
"""
import math
import operator
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import compare_version
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
from liger_kernel.ops.utils import set_large_grf_mode
from liger_kernel.ops.utils import torch_to_triton_dtype
from liger_kernel.utils import is_npu_available
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
# for working with NGC containers
from triton.language.extra.cuda.libdevice import rsqrt
else:
from triton.language.math import rsqrt
_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
@triton.jit
def _rms_norm_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_cols,
eps,
offset,
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
elementwise_affine: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
Reference:
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
3. https://arxiv.org/pdf/1910.07467
"""
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
y_base = Y_ptr + row_idx * Y_row_stride
x_base = X_ptr + row_idx * X_row_stride
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
X_row = tl.load(x_base + col_offsets, mask=mask, other=0)
X_row_dtype = X_row.dtype
if elementwise_affine:
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
# On Llama, only rstd is computed on fp32
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(tl.float32)
# Gemma computes everything on fp32, and then casts back the output to the original dtype
if casting_mode == _CASTING_MODE_GEMMA:
if elementwise_affine:
W_row = W_row.to(tl.float32)
X_row = X_row.to(tl.float32)
if casting_mode == _CASTING_MODE_NONE:
eps = eps.to(X_row_dtype)
offset = offset.to(X_row_dtype)
mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
rstd = rsqrt(mean_square + eps)
# We can save time by caching rms with minimal memory overhead
# because rms is much smaller compared to X_row, as rms is for each row.
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
tl.store(rstd_base, rstd)
X_row = X_row * rstd
# On Llama, the multiplication with the weight is done on the original dtype
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(X_row_dtype)
if elementwise_affine:
Y_row = X_row * (offset + W_row)
else:
Y_row = X_row
if casting_mode == _CASTING_MODE_GEMMA:
Y_row = Y_row.to(X_row_dtype)
tl.store(y_base + col_offsets, Y_row, mask=mask)
@triton.jit
def _rms_norm_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows,
n_cols,
offset,
rows_per_program,
casting_mode: tl.constexpr,
elementwise_affine: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
dw = sum(dy * (x / RMS)). summation over BxT dimension
"""
row_block_id = tl.program_id(0).to(tl.int64)
row_start = row_block_id * rows_per_program
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
if elementwise_affine:
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
if elementwise_affine:
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
W_row = W_row + offset
for row_idx in range(row_start, row_end):
dy_base = dY_ptr + row_idx * dY_row_stride
dx_base = dX_ptr + row_idx * dX_row_stride
x_base = X_ptr + row_idx * X_row_stride
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
# Get cached rms
rstd_row = tl.load(rstd_base)
X_row = X_row.to(tl.float32)
# Different bacward graphs for different casting modes
if casting_mode == _CASTING_MODE_LLAMA:
if elementwise_affine:
m = (dY_row * W_row).to(tl.float32)
else:
m = dY_row.to(tl.float32)
elif casting_mode == _CASTING_MODE_GEMMA:
dY_row = dY_row.to(tl.float32)
if elementwise_affine:
m = dY_row * W_row
else:
m = dY_row
else:
if elementwise_affine:
m = dY_row * W_row
else:
m = dY_row
dX_row = rstd_row * m
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
if elementwise_affine:
# calculate the gradient of W
if casting_mode == _CASTING_MODE_LLAMA:
dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
else:
# here X_row is already in fp32 (see previous if block)
dW_row += dY_row * (X_row * rstd_row)
tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
if elementwise_affine:
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
@triton.jit
def _block_rms_norm_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_rows,
n_cols,
eps,
offset,
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
elementwise_affine: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_ROW: tl.constexpr,
):
"""
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
Reference:
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
3. https://arxiv.org/pdf/1910.07467
"""
row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW)
col_offsets = tl.arange(0, BLOCK_SIZE)
row_mask = row_idx < n_rows
col_mask = col_offsets < n_cols
X_row = tl.load(
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
mask=row_mask[:, None] & col_mask[None, :],
other=0,
)
X_row_dtype = X_row.dtype
if elementwise_affine:
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
# On Llama, only rstd is computed on fp32
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(tl.float32)
# Gemma computes everything on fp32, and then casts back the output to the original dtype
if casting_mode == _CASTING_MODE_GEMMA:
if elementwise_affine:
W_row = W_row.to(tl.float32)
X_row = X_row.to(tl.float32)
if casting_mode == _CASTING_MODE_NONE:
eps = eps.to(X_row_dtype)
offset = offset.to(X_row_dtype)
mean_square = tl.sum(X_row * X_row, axis=1) / n_cols
rstd = rsqrt(mean_square + eps)
# We can save time by caching rms with minimal memory overhead
# because rms is much smaller compared to X_row, as rms is for each row.
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask)
X_row = X_row * rstd[:, None]
# On Llama, the multiplication with the weight is done on the original dtype
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(X_row_dtype)
if elementwise_affine:
Y_row = X_row * (offset + W_row)[None, :]
else:
Y_row = X_row
if casting_mode == _CASTING_MODE_GEMMA:
Y_row = Y_row.to(X_row_dtype)
tl.store(
Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :],
Y_row,
mask=row_mask[:, None] & col_mask[None, :],
)
@triton.jit
def _block_rms_norm_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows,
n_cols,
offset,
casting_mode: tl.constexpr,
elementwise_affine: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_ROW: tl.constexpr,
):
"""
dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
dw = sum(dy * (x / RMS)). summation over BxT dimension
"""
pid = tl.program_id(0).cast(tl.int64)
NUM_SMS = tl.num_programs(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
col_mask = col_offsets < n_cols
if elementwise_affine:
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
W_row = W_row + offset
for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
row_idx = start + tl.arange(0, BLOCK_ROW)
row_mask = row_idx < n_rows
dY_row = tl.load(
dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
mask=row_mask[:, None] & col_mask[None, :],
other=0.0,
)
X_row = tl.load(
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
mask=row_mask[:, None] & col_mask[None, :],
other=0.0,
)
# Get cached rms
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask)
X_row = X_row.to(tl.float32)
# Different bacward graphs for different casting modes
if casting_mode == _CASTING_MODE_LLAMA:
if elementwise_affine:
m = (dY_row * W_row[None, :]).to(tl.float32)
else:
m = dY_row.to(tl.float32)
elif casting_mode == _CASTING_MODE_GEMMA:
dY_row = dY_row.to(tl.float32)
if elementwise_affine:
m = dY_row * W_row[None, :]
else:
m = dY_row
else:
if elementwise_affine:
m = dY_row * W_row[None, :]
else:
m = dY_row
dX_row = rstd_row[:, None] * m
dX_row += (rstd_row[:, None]) * (
-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
)
if elementwise_affine:
if casting_mode == _CASTING_MODE_LLAMA:
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
else:
# here X_row is already in fp32 (see previous if block)
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
tl.store(
dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
dX_row,
mask=row_mask[:, None] & col_mask[None, :],
)
if elementwise_affine:
tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
_str_to_casting_mode = {
"llama": _CASTING_MODE_LLAMA.value,
"gemma": _CASTING_MODE_GEMMA.value,
"none": _CASTING_MODE_NONE.value,
}
def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
if not isinstance(casting_mode, int):
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
casting_mode = _str_to_casting_mode[casting_mode]
else:
assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
# RSTD is to cache rstd for each row
# RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
if W is not None:
# Check constraints.
assert X.shape[1] == W.shape[0], (
"Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
)
elementwise_affine = True
else:
elementwise_affine = False
# XPU-specific optimization
kernel_args = {}
if X.device.type == "xpu":
set_large_grf_mode(kernel_args)
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
_rms_norm_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
W,
W.stride(0) if elementwise_affine else 0,
RSTD,
RSTD.stride(0),
n_cols,
eps,
offset,
casting_mode,
elementwise_affine=elementwise_affine,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_args, # XPU-specific optimization
)
else:
BLOCK_ROW = 16
kernel_args["BLOCK_ROW"] = BLOCK_ROW
_block_rms_norm_forward_kernel[(triton.cdiv(n_rows, BLOCK_ROW),)](
Y,
Y.stride(0),
X,
X.stride(0),
W,
W.stride(0) if elementwise_affine else 0,
RSTD,
RSTD.stride(0),
n_rows,
n_cols,
eps,
offset,
casting_mode,
elementwise_affine=elementwise_affine,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_args, # XPU-specific optimization
)
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
n_rows, n_cols = dY.shape
sm_count = 1
if X.device.type == "cuda":
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
elif X.device.type == "xpu":
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
elif X.device.type == "npu":
sm_count = get_npu_core_count()
if W is not None:
# fp32 for numerical stability especially.
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
elementwise_affine = True
else:
_dW = None
elementwise_affine = False
if n_cols > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
rows_per_program = math.ceil(n_rows / sm_count)
grid = (sm_count,)
if in_place is True:
dX = dY
else:
dX = torch.zeros_like(dY)
# XPU-specific optimization
kernel_args = {}
if X.device.type == "xpu":
set_large_grf_mode(kernel_args)
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
_rms_norm_backward_kernel[grid](
dY,
dY.stride(0),
dX,
dX.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
W,
W.stride(0) if elementwise_affine else 0,
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0) if elementwise_affine else 0,
n_rows,
n_cols,
offset,
rows_per_program,
casting_mode,
elementwise_affine=elementwise_affine,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_args, # XPU-specific optimization
)
else:
BLOCK_ROW = 16
kernel_args["BLOCK_ROW"] = BLOCK_ROW
_block_rms_norm_backward_kernel[grid](
dY,
dY.stride(0),
dX,
dX.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
W,
W.stride(0) if elementwise_affine else 0,
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0) if elementwise_affine else 0,
n_rows,
n_cols,
offset,
casting_mode,
elementwise_affine=elementwise_affine,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_args, # XPU-specific optimization
)
dX = dX.view(*shape)
if elementwise_affine:
dW = _dW.sum(dim=0).to(W.dtype)
else:
dW = None
return dX, dW
class LigerRMSNormFunction(torch.autograd.Function):
"""
Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the
weight tensor `W`, with an optional offset and casting mode.
Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma
uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual
`(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function.
In addition, different models cast their inputs at different places during RMSNorm computation. For
example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the
inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently
support the following casting modes (they match HuggingFace Transformers' implementations):
- 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
- 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
- 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
`in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs.
For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place.
Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False`
"""
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None):
"""
X: (B, T, H) or (BxT, H)
W: (H,)
"""
if isinstance(X, torch.distributed.tensor.DTensor):
# Input tensor is output of a tensor parallel module and
# needs to be gathered to a local tensor to compute
# RMSE layer norm on each TP worker.
# TODO: support CP.
X = X.full_tensor()
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
ctx.offset = offset
ctx.casting_mode = casting_mode
ctx.in_place = in_place
ctx.row_mode = row_mode
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.elementwise_affine = W is not None
if W is not None:
ctx.save_for_backward(X, W, RSTD)
else:
ctx.save_for_backward(X, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
"""
Y: (B, T, H) or (BxT, H)
"""
if ctx.elementwise_affine:
X, W, RSTD = ctx.saved_tensors
else:
X, RSTD = ctx.saved_tensors
W = None
if isinstance(dY, torch.distributed.tensor.DTensor):
# Gradients are output of a tensor parallel module and
# needs to be gathered to a local tensor for computing RMSE layer.
# TODO: support CP.
dY = dY.full_tensor()
dX, dW = rms_norm_backward(
dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
)
return dX, dW, None, None, None, None, None
import torch
import triton
import triton.language as tl
@triton.jit
def _triton_rope(
q_ptr,
q_row_stride,
k_ptr,
k_row_stride,
cos,
cos_row_stride,
sin,
sin_row_stride,
sl,
bs: tl.constexpr,
cos_bs: tl.constexpr,
n_qh: tl.constexpr,
n_kh: tl.constexpr,
hd: tl.constexpr,
pad_n_qh: tl.constexpr,
pad_n_kh: tl.constexpr,
pad_hd: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BACKWARD_PASS: tl.constexpr = False,
):
# q size: (bsz, seq_len, num_q_heads, head_dim)
# q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1)
# k size: (bsz, seq_len, num_kv_heads, head_dim)
# k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
# stride: (seq_len * head_dim, head_dim, 1)
pid = tl.program_id(0).to(tl.int64)
# locate start address
q_ptr = q_ptr + pid * q_row_stride
k_ptr = k_ptr + pid * k_row_stride
# ####################################################################
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
# m of this program instance
# ####################################################################
# 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
# effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
# being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
# and pid % sl to get the sequence index.
# 2. We only need the left half of cos and sin matrix because the right half is just
# a clone of the left half.
batch_idx = pid // sl
cos_row_idx = pid % sl
cos = cos + tl.where(
cos_bs == 1,
cos_row_idx * cos_row_stride,
batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
)
sin = sin + tl.where(
cos_bs == 1,
cos_row_idx * sin_row_stride,
batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
)
cos_offsets = tl.arange(0, pad_hd // 2)
cos_mask = cos_offsets < hd // 2
cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0)
# ####################################################################
# Load the left and right half of q and k for the current
# program instance (i.e. for the current token) separately
# ####################################################################
# left half of the head
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
# right half of the head
second_half_q_offsets = first_half_q_offsets + (hd // 2)
second_half_k_offsets = first_half_k_offsets + (hd // 2)
second_q_mask = first_q_mask
second_k_mask = first_k_mask
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
if not BACKWARD_PASS:
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
else:
# with some math, we can get:
# dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
def rope_forward(q, k, cos, sin):
# transpose it back to the physical shape because Triton looks at the physical storage
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
q = q.transpose(1, 2)
k = k.transpose(1, 2)
batch_size, seq_len, n_q_head, head_dim = q.shape
n_kv_head = k.shape[2]
pad_hd = triton.next_power_of_2(head_dim)
pad_n_q_head = triton.next_power_of_2(n_q_head)
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
n_row = batch_size * seq_len
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
q = q.contiguous()
k = k.contiguous()
cos = cos.contiguous()
sin = sin.contiguous()
cos_batch_size = cos.shape[0]
_triton_rope[(n_row,)](
q,
q.stride(1),
k,
k.stride(1),
cos,
cos.stride(-2),
sin,
sin.stride(-2),
seq_len,
batch_size,
cos_batch_size,
n_q_head,
n_kv_head,
head_dim,
pad_n_q_head,
pad_n_kv_head,
pad_hd,
BLOCK_SIZE=BLOCK_SIZE,
BACKWARD_PASS=False,
)
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
def rope_backward(dq, dk, cos, sin):
dq = dq.transpose(1, 2)
dk = dk.transpose(1, 2)
batch_size, seq_len, n_q_head, head_dim = dq.shape
cos_batch_size = cos.shape[0]
n_kv_head = dk.shape[2]
pad_hd = triton.next_power_of_2(head_dim)
pad_n_q_head = triton.next_power_of_2(n_q_head)
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
n_row = batch_size * seq_len
# ensure dq and dk are contiguous
dq = dq.contiguous()
dk = dk.contiguous()
# backward is similar to forward except swapping few ops
_triton_rope[(n_row,)](
dq,
dq.stride(1),
dk,
dk.stride(1),
cos,
cos.stride(-2),
sin,
sin.stride(-2),
seq_len,
batch_size,
cos_batch_size,
n_q_head,
n_kv_head,
head_dim,
pad_n_q_head,
pad_n_kv_head,
pad_hd,
BLOCK_SIZE=BLOCK_SIZE,
BACKWARD_PASS=True,
)
return dq.transpose(1, 2), dk.transpose(1, 2)
class LigerRopeFunction(torch.autograd.Function):
"""
Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that
this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different
than the original RoPE paper.
Please find the corresponding HuggingFace implementation here:
https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184
For more details about the rotation matrix used here, please refer to:
https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2
"""
@staticmethod
def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""
q size: (bsz, n_q_head, seq_len, head_dim)
k size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
"""
q, k, cos, sin = rope_forward(q, k, cos, sin)
ctx.save_for_backward(cos, sin)
return q, k
def backward(ctx, dq, dk):
"""
dq size: (bsz, n_q_head, seq_len, head_dim)
dk size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
"""
cos, sin = ctx.saved_tensors
dq, dk = rope_backward(dq, dk, cos, sin)
return dq, dk, None, None, None, None
from typing import Tuple
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import ensure_contiguous
@triton.jit
def _softmax_single_block_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
row_id = tl.program_id(0)
offs = tl.arange(0, BLOCK_SIZE)
mask = offs < n_cols
x = tl.load(X_ptr + row_id * X_row_stride + offs, mask=mask, other=-float("inf"), cache_modifier=".ca")
m = tl.max(x, axis=0)
e = tl.exp(x - m)
d = tl.sum(e, axis=0)
y = e / d
tl.store(Y_ptr + row_id * Y_row_stride + offs, y, mask=mask, cache_modifier=".cs")
@triton.jit
def _softmax_multi_block_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
row_id = tl.program_id(0)
offs = tl.arange(0, BLOCK_SIZE)
m = tl.float32(-float("inf"))
d = tl.float32(0.0)
for start in tl.range(0, n_cols, BLOCK_SIZE):
idx = start + offs
mask = idx < n_cols
xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
blk_max = tl.max(xblk, axis=0)
new_m = tl.max(m, blk_max)
d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0)
m = new_m
for start in tl.range(0, n_cols, BLOCK_SIZE):
idx = start + offs
mask = idx < n_cols
xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
yblk = tl.exp(xblk - m) / d
tl.store(Y_ptr + row_id * Y_row_stride + idx, yblk, mask=mask, cache_modifier=".cs")
@triton.jit
def _softmax_single_block_backward_kernel(
dy_ptr,
dy_stride,
y_ptr,
y_stride,
dx_ptr,
dx_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
row_id = tl.program_id(0)
offs = tl.arange(0, BLOCK_SIZE)
mask = offs < n_cols
dy = tl.load(dy_ptr + row_id * dy_stride + offs, mask=mask, other=0.0)
y = tl.load(y_ptr + row_id * y_stride + offs, mask=mask, other=0.0, cache_modifier=".ca")
dot = tl.sum(dy * y, axis=0)
dx = y * (dy - dot)
tl.store(dx_ptr + row_id * dx_stride + offs, dx, mask=mask, cache_modifier=".wb")
@triton.jit
def _softmax_multi_block_backward_kernel(
dy_ptr,
dy_stride,
y_ptr,
y_stride,
dx_ptr,
dx_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
row_id = tl.program_id(0)
offs = tl.arange(0, BLOCK_SIZE)
acc = tl.float32(0.0)
for start in tl.range(0, n_cols, BLOCK_SIZE):
idx = start + offs
mask = idx < n_cols
dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
acc += tl.sum(dy_blk * y_blk, axis=0)
for start in tl.range(0, n_cols, BLOCK_SIZE):
idx = start + offs
mask = idx < n_cols
dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
dx_blk = y_blk * (dy_blk - acc)
tl.store(dx_ptr + row_id * dx_stride + idx, dx_blk, mask=mask, cache_modifier=".wb")
def _softmax_forward(x: torch.Tensor) -> Tuple[torch.Tensor, int, int, bool]:
*batch, n_cols = x.shape
x2d = x.contiguous().view(-1, n_cols)
n_rows = x2d.shape[0]
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
y2d = torch.empty_like(x2d)
if n_cols <= BLOCK_SIZE:
_softmax_single_block_forward_kernel[(n_rows,)](
y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
)
multi_block_launch = False
else:
_softmax_multi_block_forward_kernel[(n_rows,)](
y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
)
multi_block_launch = True
return y2d.view(*batch, n_cols), BLOCK_SIZE, num_warps, multi_block_launch
def _softmax_backward(
dy: torch.Tensor,
y: torch.Tensor,
BLOCK_SIZE: int,
num_warps: int,
multi_block_launch: bool,
) -> torch.Tensor:
*batch, n_cols = dy.shape
dy2d = dy.contiguous().view(-1, n_cols)
y2d = y.contiguous().view(-1, n_cols)
n_rows = dy2d.shape[0]
dx2d = torch.empty_like(dy2d)
if not multi_block_launch and n_cols <= BLOCK_SIZE:
_softmax_single_block_backward_kernel[(n_rows,)](
dy2d,
dy2d.stride(0),
y2d,
y2d.stride(0),
dx2d,
dx2d.stride(0),
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
else:
_softmax_multi_block_backward_kernel[(n_rows,)](
dy2d,
dy2d.stride(0),
y2d,
y2d.stride(0),
dx2d,
dx2d.stride(0),
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return dx2d.view(*batch, n_cols)
class LigerSoftmaxFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, input_: torch.Tensor):
y, BLOCK_SIZE, num_warps, multi_block_launch = _softmax_forward(input_)
ctx.save_for_backward(y)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.multi_block_launch = multi_block_launch
return y
@staticmethod
@ensure_contiguous
def backward(ctx, grad_output):
(y,) = ctx.saved_tensors
dx = _softmax_backward(
grad_output,
y,
ctx.BLOCK_SIZE,
ctx.num_warps,
ctx.multi_block_launch,
)
return dx
from typing import Tuple
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import ensure_contiguous
@triton.jit
def _sparsemax_forward_kernel(
x_ptr,
x_stride_row,
sorted_x_ptr,
sorted_x_stride_row,
o_ptr,
o_stride_row,
n_cols,
BLOCK_SIZE: tl.constexpr,
num_warps: tl.constexpr,
):
pid_row = tl.program_id(0)
ptr_x_data_row = x_ptr + pid_row * x_stride_row
ptr_sorted_x_data_row = sorted_x_ptr + pid_row * sorted_x_stride_row
ptr_output_row = o_ptr + pid_row * o_stride_row
offs = tl.arange(0, BLOCK_SIZE)
mask = offs < n_cols
z_sorted_block = tl.load(
ptr_sorted_x_data_row + offs,
mask=mask,
other=-float("inf"),
cache_modifier=".cg",
).to(tl.float32)
z_valid = tl.where(mask, z_sorted_block, 0.0)
cssv = tl.cumsum(z_valid, 0)
r = (offs + 1).to(tl.float32)
t_vec = (cssv - 1.0) / r
support = (z_sorted_block > t_vec) & mask
k_int = tl.sum(support.to(tl.int32), 0)
k_clamped_int = tl.maximum(k_int, 1)
k = k_clamped_int.to(tl.float32)
s = tl.sum(tl.where(support, z_sorted_block, 0.0), 0)
tau = (s - 1.0) / k
x_block = tl.load(
ptr_x_data_row + offs,
mask=mask,
other=0.0,
cache_modifier=".cg",
).to(tl.float32)
y = tl.maximum(x_block - tau, 0.0)
tl.store(
ptr_output_row + offs,
y.to(ptr_output_row.dtype.element_ty),
mask=mask,
cache_modifier=".cs",
)
@triton.jit
def _sparsemax_backward_kernel(
o_ptr, go_ptr, gi_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr, num_warps: tl.constexpr
):
row = tl.program_id(0)
o_row = o_ptr + row * stride
go_row = go_ptr + row * stride
gi_row = gi_ptr + row * stride
offs = tl.arange(0, BLOCK_SIZE)
supp_cnt = tl.zeros((), tl.float32)
go_sum = tl.zeros((), tl.float32)
for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
offs_iter = i * BLOCK_SIZE + offs
mask_iter = offs_iter < n_cols
o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
supp = o_val > 0.0
go_sum += tl.sum(tl.where(supp, go_val, 0.0))
supp_cnt += tl.sum(supp.to(tl.float32))
for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
offs_iter = i * BLOCK_SIZE + offs
mask_iter = offs_iter < n_cols
o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
supp = o_val > 0.0
gi_val = tl.where(
supp,
go_val - tl.cast(go_sum / tl.maximum(supp_cnt, 1e-6), gi_row.dtype.element_ty).to(tl.float32),
0.0,
)
tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
def _sparsemax_forward(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
if dim < 0:
dim += x.dim()
x_sw = x.transpose(dim, -1).contiguous()
n_cols = x_sw.size(-1)
n_rows = x_sw.numel() // n_cols
x_flat = x_sw.view(n_rows, n_cols)
x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
out_flat = torch.empty_like(x_flat)
grid = (n_rows,)
_sparsemax_forward_kernel[grid](
x_flat,
x_flat.stride(0),
x_sorted_flat,
x_sorted_flat.stride(0),
out_flat,
out_flat.stride(0),
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
y = out_flat.view_as(x_sw).transpose(dim, -1)
return y, out_flat
def _sparsemax_backward(
grad_out: torch.Tensor,
out_flat: torch.Tensor,
dim: int,
) -> torch.Tensor:
grad_sw = grad_out.transpose(dim, -1).contiguous()
n_cols = grad_sw.size(-1)
n_rows = grad_sw.numel() // n_cols
go_flat = grad_sw.view(n_rows, n_cols)
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
dx_flat = torch.empty_like(go_flat)
grid = (n_rows,)
_sparsemax_backward_kernel[grid](
out_flat,
go_flat,
dx_flat,
out_flat.stride(0),
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
dx = dx_flat.view_as(grad_sw).transpose(dim, -1)
return dx
class LigerSparsemaxFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, x: torch.Tensor, dim: int):
y, out_flat = _sparsemax_forward(x, dim)
ctx.save_for_backward(out_flat)
ctx.dim = dim
return y
@staticmethod
@ensure_contiguous
def backward(ctx, grad_out: torch.Tensor):
(out_flat,) = ctx.saved_tensors
dx = _sparsemax_backward(grad_out, out_flat, ctx.dim)
return dx, None
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import ensure_contiguous
@triton.jit
def silu(x):
return x * tl.sigmoid(x)
@triton.jit
def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
program_id = tl.program_id(0).to(tl.int64)
# locate start index
a_ptr += program_id * stride
b_ptr += program_id * stride
c_ptr += program_id * stride
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# sigmoid requires type float32
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
c_row = silu(a_row).cast(b_row.dtype) * b_row
tl.store(c_ptr + col_offsets, c_row, mask=mask)
@triton.jit
def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
program_id = tl.program_id(0).to(tl.int64)
# locate start index
dc_ptr += program_id * stride
a_ptr += program_id * stride
b_ptr += program_id * stride
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)
# sigmoid requires type float32
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
# recomputation to save memory
sig_a = tl.sigmoid(a_row)
silu_a = a_row * sig_a
db_row = dc_row * silu_a
da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row
tl.store(a_ptr + col_offsets, da_row, mask=mask)
tl.store(b_ptr + col_offsets, db_row, mask=mask)
def swiglu_forward(a, b):
ori_shape = a.shape
n_cols = ori_shape[-1]
a = a.view(-1, n_cols)
b = b.view(-1, n_cols)
c = torch.empty_like(a)
n_rows = a.shape[0]
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
_swiglu_forward_kernel[(n_rows,)](
a,
b,
c,
c.stride(-2),
n_cols=n_cols,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return a, b, c.view(*ori_shape)
def swiglu_backward(a, b, dc):
ori_shape = dc.shape
n_cols = ori_shape[-1]
dc = dc.view(-1, n_cols)
n_rows = dc.shape[0]
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
_swiglu_backward_kernel[(n_rows,)](
dc,
a,
b,
dc.stride(-2),
n_cols=n_cols,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return a.view(*ori_shape), b.view(*ori_shape)
class LigerSiLUMulFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, a, b):
if isinstance(a, torch.distributed.tensor.DTensor) or isinstance(b, torch.distributed.tensor.DTensor):
device_mesh, placements = (
(a.device_mesh, a.placements)
if isinstance(a, torch.distributed.tensor.DTensor)
else (b.device_mesh, b.placements)
)
# Assume that full tensors are gathered before and identical across
# the associated process groups.
if not isinstance(a, torch.distributed.tensor.DTensor):
a = torch.distributed.tensor.distribute_tensor(a, device_mesh=device_mesh, placements=placements)
if not isinstance(b, torch.distributed.tensor.DTensor):
b = torch.distributed.tensor.distribute_tensor(b, device_mesh=device_mesh, placements=placements)
a_local, b_local, c_local = swiglu_forward(a.to_local(), b.to_local())
ctx.save_for_backward(a_local, b_local)
ctx.dtensor_metadata = (device_mesh, placements)
return torch.distributed.tensor.DTensor.from_local(c_local, device_mesh, placements)
else:
a, b, c = swiglu_forward(a, b)
ctx.save_for_backward(a, b)
ctx.dtensor_metadata = None
return c
@staticmethod
@ensure_contiguous
def backward(ctx, dc):
a, b = ctx.saved_tensors
if ctx.dtensor_metadata is not None:
device_mesh, placements = ctx.dtensor_metadata
# Assume that full tensors are gathered before and identical across
# the associated process groups.
dc_local = (
dc.to_local()
if isinstance(dc, torch.distributed.tensor.DTensor)
else torch.distributed.tensor.distribute_tensor(dc, device_mesh=device_mesh, placements=placements)
)
a_local, b_local = swiglu_backward(a, b, dc_local)
return (
torch.distributed.tensor.DTensor.from_local(a_local, device_mesh, placements),
torch.distributed.tensor.DTensor.from_local(b_local, device_mesh, placements),
)
a, b = swiglu_backward(a, b, dc)
return a, b
import math
from typing import Callable
from typing import List
from typing import Optional
import torch
from liger_kernel.ops.utils import ensure_contiguous
class LigerTiledMLPFunction(torch.autograd.Function):
"""
Based on DeepSpeed's TiledMLP:
https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838
Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP
when using very long sequence lengths.
This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration.
And if you're using activation checkpointing it then occurs thrice.
Args:
fn: the function to call on sharded inputs (e.g., mlp.forward)
mlp_module: the MLP nn.Module object
x: the input to MLP.forward (hidden_states)
shards: how many shards to use
compute_params: a list of weights engaged in the compute
Returns:
the computed hidden_states
"""
@staticmethod
@ensure_contiguous
def forward(
ctx,
fn: Callable,
mlp_module: torch.nn.Module,
x: torch.Tensor,
shards: int,
compute_params: Optional[List[torch.nn.Parameter]] = None,
) -> torch.Tensor:
ctx.fn = fn
ctx.mlp_module = mlp_module
ctx.shards = shards
ctx.save_for_backward(x)
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
with torch.no_grad():
output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards]
output_unsharded = torch.cat(output_shards, dim=-2)
return output_unsharded
@staticmethod
@ensure_contiguous
def backward(ctx, *grads) -> tuple:
fn = ctx.fn
(x,) = ctx.saved_tensors
mlp_module = ctx.mlp_module
shards = ctx.shards
x_requires_grad = x.requires_grad
x = x.detach()
# detach() unsets x.requires_grad, so restore it
x.requires_grad_(x_requires_grad)
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
hidden_size = x.shape[-1]
x_shape_orig = x.shape
# flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
x = x.view(-1, hidden_size)
incoming_grad = grads[0].view(-1, hidden_size)
x_grad = torch.zeros_like(x)
x_shards = list(torch.chunk(x, chunks=shards, dim=0))
for i, x_shard in enumerate(x_shards):
x_shard.requires_grad_(x_requires_grad)
# if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
shard_step = x_shards[i].shape[0]
shard_offset = i * x_shards[0].shape[0]
x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
with torch.enable_grad():
output = fn(mlp_module, x_shard)
torch.autograd.backward(output, incoming_grad_shard)
# unflatten
x_grad = x_grad.view(x_shape_orig)
return (None, None, x_grad, None, None)
def apply_tiled_mlp(
fn: Callable,
mlp_module: torch.nn.Module,
x: torch.Tensor,
num_shards: Optional[int] = None,
compute_params: Optional[List[torch.nn.Parameter]] = None,
) -> torch.Tensor:
"""
Apply tiled MLP computation for memory efficiency.
Args:
fn: the function to call on sharded inputs (e.g., lambda module, x: module(x))
mlp_module: the MLP nn.Module object
x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size]
num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size)
compute_params: list of parameters for DeepSpeed ZeRO optimization
Returns:
output tensor with the same shape as input
"""
if num_shards is None:
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size]
hidden_size = x.shape[-1]
seqlen = x.shape[-2]
num_shards = math.ceil(seqlen / hidden_size)
# Ensure num_shards is at least 1
num_shards = max(1, num_shards)
return LigerTiledMLPFunction.apply(
fn,
mlp_module,
x,
num_shards,
compute_params,
)
from typing import Literal
from typing import Optional
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import ensure_contiguous
MAX_FUSED_SIZE = 65536 // 4
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
_REDUCTION_MODE_NONE = tl.constexpr(0)
_REDUCTION_MODE_SUM = tl.constexpr(1)
_REDUCTION_MODE_MEAN = tl.constexpr(2)
_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
_str_to_reduction_mode = {
"none": _REDUCTION_MODE_NONE.value,
"sum": _REDUCTION_MODE_SUM.value,
"mean": _REDUCTION_MODE_MEAN.value,
"batchmean": _REDUCTION_MODE_BATCHMEAN.value,
}
def get_num_warps(BLOCK_SIZE):
num_warps = 4
if BLOCK_SIZE >= 32768:
num_warps = 32
elif BLOCK_SIZE >= 8192:
num_warps = 16
elif BLOCK_SIZE >= 2048:
num_warps = 8
return num_warps
@triton.jit
def _tv_distance_kernel(
p_ptr,
p_stride,
q_ptr,
q_stride,
loss_ptr,
loss_stride,
grads_ptr,
grads_stride,
label_ptr,
ignore_index: tl.constexpr,
n_cols,
scale, # pre-computed reduction scale for gradients (fused into kernel)
BLOCK_SIZE: tl.constexpr,
HAS_LABEL: tl.constexpr,
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
):
pid = tl.program_id(0).to(tl.int64)
p_ptr += pid * p_stride
q_ptr += pid * q_stride
loss_ptr += pid * loss_stride
grads_ptr += pid * grads_stride
label_ptr += pid
base_offsets = tl.arange(0, BLOCK_SIZE)
if HAS_LABEL:
label = tl.load(label_ptr)
if label == ignore_index:
for i in range(0, n_cols, BLOCK_SIZE):
offsets = i + base_offsets
mask = offsets < n_cols
tl.store(grads_ptr + offsets, 0.0, mask=mask)
if reduction == _REDUCTION_MODE_NONE:
tl.store(loss_ptr + offsets, 0.0, mask=mask)
return
loss_sum = 0.0
for i in range(0, n_cols, BLOCK_SIZE):
offsets = i + base_offsets
mask = offsets < n_cols
p = tl.load(p_ptr + offsets, mask=mask, other=0.0)
q = tl.load(q_ptr + offsets, mask=mask, other=0.0)
# TVD(P || Q) = 0.5 * |P - Q|
tv_loss = 0.5 * tl.abs(p - q)
# Fuse reduction scaling into gradient computation (eliminates separate Python division)
grad_res = tl.where(p > q, 0.5 * scale, -0.5 * scale)
tl.store(grads_ptr + offsets, grad_res, mask=mask)
if reduction == _REDUCTION_MODE_NONE:
tl.store(loss_ptr + offsets, tv_loss, mask=mask)
else:
loss_sum += tl.sum(tv_loss, axis=0)
if reduction != _REDUCTION_MODE_NONE:
# Fuse reduction scaling into loss (same scale as gradients; avoids Python division)
tl.store(loss_ptr, loss_sum * scale)
def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
BT, V = p.shape
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
num_warps = get_num_warps(BLOCK_SIZE)
grid = (BT,)
reduction = _str_to_reduction_mode[reduction]
out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
grads = torch.empty_like(p)
n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
# Pre-compute gradient scale factor (fused into kernel to avoid separate division)
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
scale = 1.0 / n_non_ignore
elif reduction == _REDUCTION_MODE_MEAN.value:
scale = 1.0 / (n_non_ignore * V)
else:
scale = 1.0
_tv_distance_kernel[grid](
p,
p.stride(0),
q,
q.stride(0),
output_tensor,
output_tensor.stride(0),
grads,
grads.stride(0),
shift_labels if has_label else torch.empty(1, device=p.device),
ignore_index,
V,
scale,
BLOCK_SIZE=BLOCK_SIZE,
HAS_LABEL=has_label,
num_warps=num_warps,
reduction=reduction,
)
# Loss and gradients are already scaled inside the kernel — no separate division needed
if reduction in (_REDUCTION_MODE_BATCHMEAN.value, _REDUCTION_MODE_MEAN.value):
return output_tensor.sum(), grads
elif reduction == _REDUCTION_MODE_SUM.value:
return output_tensor.sum(dim=0), grads
else:
return output_tensor, grads
def tvd_backward_triton(grad_output, grads):
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
return grads
return grads * grad_output
class LigerTVDLossFunction(torch.autograd.Function):
"""
Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
"""
@staticmethod
@ensure_contiguous
def forward(
ctx,
p: torch.Tensor,
q: torch.Tensor,
shift_labels: Optional[torch.Tensor] = None,
reduction: REDUCTION_LITERAL = "batchmean",
ignore_index: int = -100,
) -> torch.Tensor:
"""A forward pass for the Total Variation Distance Loss.
Args:
ctx: Torch autograd context
p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
Returns:
torch.Tensor: The computed Total Variation Distance Loss.
"""
has_label = False
if shift_labels is not None:
assert shift_labels.shape == (p.shape[0],), (
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
)
shift_labels = shift_labels.contiguous()
has_label = True
loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
ctx.save_for_backward(grads)
return loss
@staticmethod
@ensure_contiguous
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
"""A backward pass for the Total Variation Distance Loss.
Args:
ctx: Torch autograd context
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
Returns:
tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
"""
(grads,) = ctx.saved_tensors
grads = tvd_backward_triton(grad_output, grads)
return grads, None, None, None, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment