Commit 9b0e3a30 authored by cmx's avatar cmx
Browse files

first commit

parent fe5cd1fc
Pipeline #3450 failed with stages
in 0 seconds
from typing import Optional
import torch
from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
def k3_loss_fn(log_p, log_q):
# computes k3 estimate of KL[q, p]
# ref: http://joschu.net/blog/kl-approx.html
return torch.exp(log_p - log_q) - (log_p - log_q) - 1.0
def sapo_loss_fn(importance_ratio: torch.Tensor, temperature: float) -> torch.Tensor:
"""SAPO (Soft Adaptive Policy Optimization) loss function.
Replaces hard clipping with a smooth, temperature-controlled gate that
adaptively attenuates off-policy updates while preserving useful learning signals.
Reference: https://huggingface.co/papers/2511.20347
TRL implementation: https://github.com/huggingface/trl/blob/1bd2a52ec2d8344050af736d60cdc735181ae4b8/trl/trainer/grpo_trainer.py#L1913
Args:
importance_ratio: The importance sampling ratio (pi_theta / pi_old).
temperature: Temperature parameter controlling the softness of the gate.
Returns:
The SAPO loss value.
"""
if temperature <= 0:
raise ValueError("sapo_temperature must be > 0.")
sigmoid_input = temperature * (importance_ratio - 1)
sigmoid_smoothed_loss = torch.sigmoid(sigmoid_input)
return sigmoid_smoothed_loss * 4 / temperature
def clip_coef_fn(coef, epsilon_low, epsilon_high, loss_type):
if loss_type == "cispo":
# CISPO: clip and detach the importance weights
upper_bound = epsilon_high
lower_bound = None
clipped_coef = torch.clamp(coef, lower_bound, upper_bound).detach()
is_lower_clipped = False
is_upper_clipped = coef > upper_bound
elif loss_type == "sapo":
# SAPO doesn't use clipping metrics
clipped_coef = None
is_lower_clipped = torch.zeros_like(coef, dtype=torch.bool)
is_upper_clipped = torch.zeros_like(coef, dtype=torch.bool)
else:
upper_bound = 1 + epsilon_high
lower_bound = 1 - epsilon_low
clipped_coef = torch.clamp(coef, lower_bound, upper_bound)
is_lower_clipped = coef < lower_bound
is_upper_clipped = coef > upper_bound
return clipped_coef, is_lower_clipped, is_upper_clipped
class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
@staticmethod
def ppo_loss_fn(
log_probs,
selected_token_ids,
attention_mask,
advantages,
full_attention_mask,
ref_per_token_logps=None, # shape: [chunk_size, seq_len]
old_per_token_logps=None,
ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size])
epsilon_low=0.2,
epsilon_high=0.2,
beta=0.04,
loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo"]
max_completion_length=None, # Required for dr_grpo
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
sapo_temperature_pos=1.0, # Temperature for positive advantages in SAPO
sapo_temperature_neg=1.05, # Temperature for negative advantages in SAPO
vllm_is_ratio=None, # vLLM importance sampling ratio (chunk_size, seq_len) or (chunk_size, 1) or None
delta=None, # Upper clamp for two-sided clipping (INTELLECT-2)
use_bias_correction_kl=False, # Importance-sampling-corrected KL (DeepSeek-V3.2)
**kwargs,
):
"""GRPO Loss Function matching GRPOTrainer implementation."""
# Validate sequence-level + loss_type combinations
if importance_sampling_level == "sequence" and loss_type in ("cispo", "sapo"):
raise ValueError(
f"Sequence-level importance sampling is not supported for loss_type='{loss_type}'. "
f"Use importance_sampling_level='token' instead."
)
per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
-1
) # (batch_size, seq_len)
# Get reference model probabilities
if ref_per_token_logps is None:
if ref_log_probs is not None:
with torch.no_grad():
ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
-1
)
else:
ref_per_token_logps = per_token_logps.detach()
# Compute policy gradient loss with importance sampling ratio
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
log_ratio = per_token_logps - old_per_token_logps
if importance_sampling_level == "token":
log_importance_weights = log_ratio
elif importance_sampling_level == "sequence":
log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
log_importance_weights = log_importance_weights.unsqueeze(-1)
else:
raise ValueError(
f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
"and 'sequence'."
)
# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
coef_1 = torch.exp(log_importance_weights)
coef_2, is_lower_clipped, is_upper_clipped = clip_coef_fn(coef_1, epsilon_low, epsilon_high, loss_type)
if loss_type == "cispo":
# CISPO: clip and detach the importance weights, multiply by log probs
# Reference: https://github.com/huggingface/trl/blob/035c3ff151b953ca72cdfe0ee966bc1469a26fde/trl/trainer/grpo_trainer.py#L2030
per_token_loss = -coef_2 * advantages.unsqueeze(1) * per_token_logps
elif loss_type == "sapo":
# SAPO: Soft Adaptive Policy Optimization
# Uses sigmoid-based soft gating instead of hard clipping
# Reference: https://huggingface.co/papers/2511.20347
# TRL implementation: https://github.com/huggingface/trl/blob/1bd2a52ec2d8344050af736d60cdc735181ae4b8/trl/trainer/grpo_trainer.py#L2037-L2046
per_token_loss = torch.empty_like(coef_1)
# Expand advantages to match coef_1 shape for masking
advantages_expanded = advantages.unsqueeze(1).expand_as(coef_1)
positive_advantages_mask = advantages_expanded > 0
# Apply different temperatures based on advantage sign
per_token_loss[positive_advantages_mask] = sapo_loss_fn(
coef_1[positive_advantages_mask], sapo_temperature_pos
)
per_token_loss[~positive_advantages_mask] = sapo_loss_fn(
coef_1[~positive_advantages_mask], sapo_temperature_neg
)
per_token_loss = -per_token_loss * advantages_expanded
else:
# Apply delta (two-sided clipping from INTELLECT-2) to coef_1
if delta is not None:
coef_1 = torch.clamp(coef_1, max=delta)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
# Apply vLLM importance sampling correction BEFORE adding KL penalty
if vllm_is_ratio is not None:
per_token_loss = per_token_loss * vllm_is_ratio
if beta != 0.0:
# Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
if use_bias_correction_kl:
# Importance-sampling-corrected KL (DeepSeek-V3.2): kl *= token-level coef_1
token_coef_1 = torch.exp(per_token_logps - old_per_token_logps)
kl_div = kl_div * token_coef_1
# Combine losses
per_token_loss = per_token_loss + beta * kl_div
# Note: We normalize by the number of tokens in the batch (using full_attention_mask),
# which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
# and TRL GRPO implementation
# (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
if loss_type == "grpo" or loss_type == "sapo":
# Average per-sequence loss (SAPO uses same normalization as GRPO)
loss = (
(per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
).sum() / full_attention_mask.shape[0]
elif loss_type == "bnpo":
# Batch Normalized Per-token loss (original implementation)
loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
elif loss_type == "dr_grpo":
# Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
if max_completion_length is None:
raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
elif loss_type == "dapo" or loss_type == "cispo":
loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
loss = (per_token_loss * attention_mask).sum() / loss_normalizer
elif loss_type == "luspo":
# LUSPO: loss = (per_token_loss * mask.sum(1, keepdim=True)).mean()
# Reformulated as: sum_i(sum_j(per_token_loss_ij) * seq_len_i) / numel
# to avoid (B,T) * (B,1) broadcast which amplifies torch.compile differences.
seq_lens = attention_mask.sum(-1) # (chunk_B,)
per_seq_sum = per_token_loss.sum(-1) # (chunk_B,)
weighted = per_seq_sum * seq_lens # (chunk_B,)
if importance_sampling_level == "sequence" and beta == 0.0:
# per_token_loss stays (B, 1), so .mean() divides by B
loss = weighted.sum() / full_attention_mask.shape[0]
else:
# per_token_loss is (B, T), .mean() divides by B*T
loss = weighted.sum() / (full_attention_mask.shape[0] * full_attention_mask.shape[1])
else:
raise ValueError(f"Unknown loss type: {loss_type}")
# Calculate metrics
metrics = []
if beta != 0.0:
metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
# Adjust clipping metric calculation based on importance sampling level
if importance_sampling_level == "token":
is_clipped = (is_lower_clipped & (advantages.unsqueeze(1) < 0)) | (
is_upper_clipped & (advantages.unsqueeze(1) > 0)
)
else: # sequence level
# For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
is_clipped = (is_lower_clipped & (advantages.unsqueeze(1) < 0)) | (
is_upper_clipped & (advantages.unsqueeze(1) > 0)
)
is_clipped = is_clipped.expand_as(attention_mask)
metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
return loss, metrics
@classmethod
def forward(
cls,
ctx,
_input,
weight,
selected_token_ids,
attention_mask,
advantages,
bias=None,
ref_per_token_logps=None,
old_per_token_logps=None,
ref_input=None,
ref_weight=None,
ref_bias=None,
beta=0.04,
epsilon_low=0.2,
epsilon_high=0.2,
loss_type="dapo",
max_completion_length=None,
importance_sampling_level="token",
sapo_temperature_pos=1.0,
sapo_temperature_neg=1.05,
temperature=1.0,
compiled=True,
use_ref_model=True,
chunk_size=1,
vllm_is_ratio=None,
delta=None,
use_bias_correction_kl=False,
):
"""
Fused linear layer with GRPO loss.
Args:
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
selected_token_ids (torch.Tensor): Selected token ids tensor. Shape: (batch_size, seq_len)
attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
advantages (torch.Tensor): Advantages tensor. Shape: (batch_size,)
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
ref_per_token_logps: Reference model log probs per token tensor. Shape:(batch_size, seq_len)
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
beta (float): Weight for the KL penalty
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo").
Defaults to "dapo".
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
sapo_temperature_pos (float): Temperature for positive advantages in SAPO. Defaults to 1.0.
sapo_temperature_neg (float): Temperature for negative advantages in SAPO. Defaults to 1.05.
temperature (float): Temperature for the logits
compiled (bool): Whether to use torch compile
use_ref_model (bool): Whether to use a reference model
chunk_size (int): Size of chunks for processing.
vllm_is_ratio (torch.Tensor, optional): vLLM importance sampling ratio (batch_size, seq_len) or (batch_size, 1) or None.
Used to correct for distribution mismatch when using vLLM for generation.
Returns:
torch.Tensor: Computed loss
"""
# Validate before entering torch.compile boundary
if importance_sampling_level == "sequence" and loss_type in ("cispo", "sapo"):
raise ValueError(
f"Sequence-level importance sampling is not supported for loss_type='{loss_type}'. "
f"Use importance_sampling_level='token' instead."
)
return super().forward(
cls=cls,
ctx=ctx,
_input=_input,
weight=weight,
selected_token_ids=selected_token_ids,
attention_mask=attention_mask,
advantages=advantages,
bias=bias,
ref_per_token_logps=ref_per_token_logps,
old_per_token_logps=old_per_token_logps,
ref_input=ref_input,
ref_weight=ref_weight,
ref_bias=ref_bias,
beta=beta,
epsilon_low=epsilon_low,
epsilon_high=epsilon_high,
loss_type=loss_type,
max_completion_length=max_completion_length,
temperature=temperature,
compiled=compiled,
use_ref_model=use_ref_model,
chunk_size=chunk_size,
importance_sampling_level=importance_sampling_level,
sapo_temperature_pos=sapo_temperature_pos,
sapo_temperature_neg=sapo_temperature_neg,
vllm_is_ratio=vllm_is_ratio,
delta=delta,
use_bias_correction_kl=use_bias_correction_kl,
)
@staticmethod
def backward(ctx, grad_output, *grad_metrics):
"""Backward pass for GRPO loss.
Args:
grad_output: Gradient of the loss (scalar)
grad_metrics: Gradients of the metrics (not used in backward computation)
"""
grads = LigerFusedLinearPPOBase.backward(ctx, grad_output)
return (
*grads[
:6
], # grad_input, grad_weight, grad_selected_token_ids, grad_attention_mask, grad_advantages, grad_bias
None, # grad_ref_per_token_logps
None, # grad_old_per_token_logps
None, # grad_ref_input
None, # grad_ref_weight
None, # grad_ref_bias
None, # grad_beta
None, # grad_epsilon_low
None, # grad_epsilon_high
None, # grad_loss_type (string, not differentiable)
None, # grad_max_completion_length (int, not differentiable)
None, # grad_importance_sampling_level (string, not differentiable)
None, # grad_sapo_temperature_pos (float, not differentiable)
None, # grad_sapo_temperature_neg (float, not differentiable)
None, # grad_temperature
None, # grad_compiled
None, # grad_use_ref_model
None, # grad_chunk_size
None, # grad_vllm_is_ratio
None, # grad_delta
None, # grad_use_bias_correction_kl
)
class LigerFusedLinearGRPOLoss(torch.nn.Module):
"""Fused linear layer with GRPO loss."""
def __init__(
self,
beta: float = 0.04,
compiled: bool = True,
use_ref_model: bool = True,
chunk_size: int = 1,
epsilon_low: float = 0.2,
epsilon_high: float = 0.2,
loss_type: str = "dapo",
max_completion_length: Optional[int] = None,
importance_sampling_level: str = "token",
sapo_temperature_pos: float = 1.0,
sapo_temperature_neg: float = 1.05,
temperature: float = 1.0,
delta: Optional[float] = None,
use_bias_correction_kl: bool = False,
):
"""
Args:
beta (float): Weight for the KL penalty.
compiled (bool): Whether to use torch compile.
use_ref_model (bool): Whether to use a reference model.
chunk_size (int): Size of chunks for processing.
epsilon_low (float): Lower bound for the importance sampling ratio.
epsilon_high (float): Upper bound for the importance sampling ratio.
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo").
Defaults to "dapo". For "cispo", epsilon_high is typically larger (e.g. 5.0) and
epsilon_low is unused. For "sapo", uses soft gating instead of hard clipping.
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
sapo_temperature_pos (float): Temperature for positive advantages in SAPO. Defaults to 1.0.
sapo_temperature_neg (float): Temperature for negative advantages in SAPO. Defaults to 1.05.
temperature (float): Temperature for the logits.
delta (float, optional): Upper clamp for two-sided clipping (INTELLECT-2). None means disabled.
use_bias_correction_kl (bool): If True, multiply KL by importance sampling ratio (DeepSeek-V3.2).
"""
super().__init__()
# Validate SAPO temperatures to prevent division by zero or numerical instability
if sapo_temperature_pos <= 0:
raise ValueError(f"sapo_temperature_pos must be positive, got {sapo_temperature_pos}")
if sapo_temperature_neg <= 0:
raise ValueError(f"sapo_temperature_neg must be positive, got {sapo_temperature_neg}")
if delta is not None and delta <= 0:
raise ValueError(f"delta must be positive, got {delta}")
self.beta = beta
self.compiled = compiled
self.use_ref_model = use_ref_model
self.chunk_size = chunk_size
self.epsilon_low = epsilon_low
self.epsilon_high = epsilon_high
self.loss_type = loss_type
self.max_completion_length = max_completion_length
self.importance_sampling_level = importance_sampling_level
self.sapo_temperature_pos = sapo_temperature_pos
self.sapo_temperature_neg = sapo_temperature_neg
self.temperature = temperature
self.delta = delta
self.use_bias_correction_kl = use_bias_correction_kl
def forward(
self,
_input,
lin_weight,
selected_token_ids,
attention_mask,
advantages,
bias=None,
ref_per_token_logps=None,
old_per_token_logps=None,
ref_input=None,
ref_weight=None,
ref_bias=None,
vllm_is_ratio=None,
):
return LigerFusedLinearGRPOFunction.apply(
_input,
lin_weight,
selected_token_ids,
attention_mask,
advantages,
bias,
ref_per_token_logps,
old_per_token_logps,
ref_input,
ref_weight,
ref_bias,
self.beta,
self.epsilon_low,
self.epsilon_high,
self.loss_type,
self.max_completion_length,
self.importance_sampling_level,
self.sapo_temperature_pos,
self.sapo_temperature_neg,
self.temperature,
self.compiled,
self.use_ref_model,
self.chunk_size,
vllm_is_ratio,
self.delta,
self.use_bias_correction_kl,
)
import math
from typing import Tuple
from typing import Union
import torch
import torch.nn.functional as F
from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
@staticmethod
def distillation_loss_fn(student_logits, teacher_logits, beta=0.5, target=None, ignore_index=-100):
"""
Compute JSD loss (Jensen-Shannon Divergence Loss).
Args:
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
target (torch.Tensor): Target labels for masking. Shape: (chunk_size,).
ignore_index (int): Index to ignore in loss computation.
Returns:
torch.Tensor: Jensen-Shannon Divergence loss
Note:
- Uses reduction="none" to preserve per-token losses for masking
- KL divergence requires summing over vocab dimension (not mean)
- Masking excludes padding/prompt tokens from loss computation
"""
student_log_probs = F.log_softmax(student_logits, dim=-1)
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
if beta == 0:
jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
elif beta == 1:
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
else:
# Compute probabilities (only required for mean calculation)
log_mean_probs = torch.logsumexp(
torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
)
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="none", log_target=True)
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="none", log_target=True)
# JSD is the weighted average of the KL divergences
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
# Sum over vocab dimension (KL divergence definition)
jsd_loss = jsd_loss.sum(dim=-1) # (chunk_size,)
# Apply ignore_index mask
if target is not None:
mask = target != ignore_index
jsd_loss = jsd_loss.masked_fill(~mask, 0.0)
return jsd_loss.sum()
@classmethod
def forward(
cls,
ctx,
student_input: torch.Tensor,
student_weight: torch.Tensor,
teacher_input: torch.Tensor,
teacher_weight: torch.Tensor,
true_labels: torch.LongTensor,
student_bias: torch.Tensor,
teacher_bias: torch.Tensor,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
compiled: bool = True,
chunk_size: int = 1024,
return_soft_hard_loss: bool = False,
):
"""
Fused linear layer with JSD distillation loss.
Args:
student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, hidden_size_student)
student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, hidden_size_student)
teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, hidden_size_teacher)
teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, hidden_size_teacher)
true_labels (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
weight_hard_loss (float): Weight for hard loss.
weight_soft_loss (float): Weight for soft loss.
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
ignore_index (int): Index to ignore in loss computation
temperature (float): Temperature for softening/sharpening distributions
compiled (bool): Whether to use torch compile
chunk_size (int): Size of chunks for processing.
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
Returns:
torch.Tensor: Computed loss, or tuple (loss, soft_loss, hard_loss) if return_soft_hard_loss=True
"""
return super().forward(
cls=cls,
ctx=ctx,
student_input=student_input,
student_weight=student_weight,
teacher_input=teacher_input,
teacher_weight=teacher_weight,
target=true_labels,
student_bias=student_bias,
teacher_bias=teacher_bias,
chunk_size=chunk_size,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
beta=beta,
ignore_index=ignore_index,
temperature=temperature,
compiled=compiled,
return_soft_hard_loss=return_soft_hard_loss,
)
@staticmethod
def backward(ctx, grad_output, *args):
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
return (
*grads,
None, # teacher_bias
None, # weight_hard_loss
None, # weight_soft_loss
None, # beta
None, # ignore_index
None, # temperature
None, # compiled
None, # chunk_size
None, # return_soft_hard_loss
)
class LigerFusedLinearJSDLoss(torch.nn.Module):
"""
Fused linear layer with JSD distillation loss.
"""
def __init__(
self,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
compiled: bool = True,
chunk_size: int = 1024,
return_soft_hard_loss: bool = False,
):
"""
Args:
weight_hard_loss (float): Weight for hard loss.
weight_soft_loss (float): Weight for soft loss.
ignore_index (int): Index to ignore in the loss
temperature (float): Temperature for softening distributions
compiled (bool): Whether to use torch compile
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
chunk_size (int): Size of chunks for processing.
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
"""
super().__init__()
assert temperature != 0, "Temperature cannot be 0."
self.weight_hard_loss = weight_hard_loss
self.weight_soft_loss = weight_soft_loss
self.ignore_index = ignore_index
self.temperature = temperature
self.compiled = compiled
self.beta = beta
self.chunk_size = chunk_size
self.return_soft_hard_loss = return_soft_hard_loss
def forward(
self,
student_input: torch.Tensor,
student_weight: torch.Tensor,
teacher_input: torch.Tensor,
teacher_weight: torch.Tensor,
true_labels: torch.LongTensor,
student_bias: torch.Tensor = None,
teacher_bias: torch.Tensor = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Compute the JSD distillation loss.
Args:
student_input (torch.Tensor): Student input tensor
student_weight (torch.Tensor): Student weight tensor
teacher_input (torch.Tensor): Teacher input tensor
teacher_weight (torch.Tensor): Teacher weight tensor
true_labels (torch.LongTensor): Target labels tensor
Returns:
torch.Tensor or Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
If return_soft_hard_loss is False: Computed combined loss
If return_soft_hard_loss is True: Tuple of (combined_loss, soft_loss, hard_loss)
"""
return LigerFusedLinearJSDFunction.apply(
student_input,
student_weight,
teacher_input,
teacher_weight,
true_labels,
student_bias,
teacher_bias,
self.weight_hard_loss,
self.weight_soft_loss,
self.beta,
self.ignore_index,
self.temperature,
self.compiled,
self.chunk_size,
self.return_soft_hard_loss,
)
import torch
import torch.nn.functional as F
from liger_kernel.chunked_loss.fused_linear_unpaired_preference import LigerFusedLinearUnpairedPreferenceBase
class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
@staticmethod
def preference_loss_fn(
log_prob_chunk,
preference_labels_chunk,
full_target,
ref_log_prob_chunk=None,
beta=0.1,
kl=None,
):
"""
Implements the Kahneman-Tversky Optimization (KTO) loss function.
Paper: "KTO: Model Alignment as Prospect Theory-Guided Optimization"
https://arxiv.org/abs/2402.01306
KTO loss is inspired by prospect theory (https://en.wikipedia.org/wiki/Prospect_theory)
from behavioral economics, which models how humans make decisions under uncertainty.
The loss function is asymmetric, treating gains and losses differently, similar to
human decision-making patterns.
Formula:
When y is chosen:
L_KTO = 1 - σ(β * (log[π(x)/π₀(x)] - KL(π||π₀)_y))
When y is rejected:
L_KTO = 1 - σ(β * (KL(π||π₀)_y - log[π(x)/π₀(x)]))
Where:
- σ: Sigmoid function
- β: Temperature parameter controlling the strength of the preference signal
- π(x): Policy (current model)
- π₀(x): Reference policy (reference model)
- KL(π||π₀)_y: KL divergence estimated using the rejected response y
The loss encourages the model to:
1. Assign higher probability to chosen responses
2. Assign lower probability to rejected responses
3. Maintain reasonable distance from the reference model
Args:
log_prob_chunk: Log probabilities for the chunk (batch_size,)
preference_labels_chunk: Preference labels for the chunk (batch_size,)
full_target: Non chunked full target tensor
ref_log_prob_chunk: Reference log probs for the chunk (batch_size,)
beta: Weight for the KTO loss
kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
Returns:
- loss: The KTO loss value
"""
if ref_log_prob_chunk is not None:
logratios_chunk = log_prob_chunk - ref_log_prob_chunk
else:
logratios_chunk = log_prob_chunk
multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
if kl is not None:
losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
else:
losses = 1 - F.sigmoid(beta * logratios_chunk * multiplier_chunk)
rewards = beta * logratios_chunk
chosen_rewards_sum = (rewards * preference_labels_chunk.unsqueeze(1)).sum()
rejected_rewards_sum = (rewards * (~preference_labels_chunk).unsqueeze(1)).sum()
return losses.sum() / (full_target.shape[0]), chosen_rewards_sum, rejected_rewards_sum
@classmethod
def forward(
cls,
ctx,
_input,
weight,
target,
preference_labels,
bias=None,
ref_input=None,
ref_weight=None,
ref_bias=None,
kl=None,
ignore_index=-100,
beta=0.1,
compiled=True,
use_ref_model=True,
average_log_prob=False,
chunk_size=1,
):
"""
Fused linear layer with KTO loss.
Args:
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
preference_labels (torch.Tensor): Preference labels tensor. Shape: (batch_size,)
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
kl (torch.Tensor, optional): KL divergence tensor. Shape: (batch_size,)
ignore_index (int): Index to ignore in loss computation
beta (float): Temperature parameter for the KTO loss
compiled (bool): Whether to use torch compile
use_ref_model (bool): Whether to use a reference model
average_log_prob (bool): Whether to average the log probability per non-masked token
chunk_size (int): Size of chunks for processing
Returns:
torch.Tensor: Computed loss
"""
return super().forward(
cls=cls,
ctx=ctx,
_input=_input,
weight=weight,
target=target,
preference_labels=preference_labels,
bias=bias,
ignore_index=ignore_index,
beta=beta,
compiled=compiled,
use_ref_model=use_ref_model,
ref_input=ref_input,
ref_weight=ref_weight,
ref_bias=ref_bias,
average_log_prob=average_log_prob,
kl=kl,
chunk_size=chunk_size,
)
@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearUnpairedPreferenceBase.backward(ctx, grad_output)[:5]
return (
*grads,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class LigerFusedLinearKTOLoss(torch.nn.Module):
"""
Fused linear layer with Kahneman-Tversky Optimization (KTO) loss.
"""
def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
compiled: bool = True,
use_ref_model: bool = False,
average_log_prob: bool = False,
chunk_size: int = 1,
):
"""
Args:
ignore_index (int): Index to ignore in the loss calculation
beta (float): Temperature parameter for the KTO loss
compiled (bool): Whether to use compiled operations
use_ref_model (bool): Whether to use a reference model for the DPO loss.
average_log_prob (bool): Whether to average the log probability per non-masked token
chunk_size (int): Size of chunks for processing
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.compiled = compiled
self.use_ref_model = use_ref_model
self.average_log_prob = average_log_prob
self.chunk_size = chunk_size
def forward(
self,
_input,
lin_weight,
target,
bias=None,
preference_labels=None,
ref_input=None,
ref_weight=None,
ref_bias=None,
kl=None,
):
return LigerFusedLinearKTOFunction.apply(
_input,
lin_weight,
target,
preference_labels,
bias,
ref_input,
ref_weight,
ref_bias,
kl,
self.ignore_index,
self.beta,
self.compiled,
self.use_ref_model,
self.average_log_prob,
self.chunk_size,
)
import torch
import torch.nn.functional as F
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
@staticmethod
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
"""
Paper: https://arxiv.org/pdf/2403.07691
Formula:
Compute odds-ratio loss: L_OR = -log(σ(log(odds_θ(y_w|x) / odds_θ(y_l|x))))
where odds_θ(y|x) = P_θ(y|x) / (1 - P_θ(y|x))
Where:
- P_θ(y|x): Policy (model) probability
- y_w: Chosen sequence
- y_l: Rejected sequence
- σ: Sigmoid function
- β: Weight for the odds ratio loss
- odds_θ: Odds function for the policy
Args:
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
full_target (torch.Tensor): Non chunked full target tensor
beta (float): Weight for the odds ratio loss.
"""
log_odds = (chosen_logps - rejected_logps) - (
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
)
ratio = F.logsigmoid(log_odds)
loss = -beta * ratio.sum() / (full_target.shape[0] // 2)
chosen_rewards = beta * chosen_logps
rejected_rewards = beta * rejected_logps
log_odds_ratio = torch.sum(ratio) / (full_target.shape[0] // 2)
log_odds_chosen = torch.sum(log_odds) / (full_target.shape[0] // 2)
return loss, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen
@classmethod
def forward(
cls,
ctx,
_input,
weight,
target,
bias=None,
ignore_index=-100,
beta=0.1,
compute_nll_loss=True,
nll_target=None,
compiled=True,
chunk_size=1,
):
"""
Fused linear layer with ORPO loss.
Args:
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
ignore_index (int): Index to ignore in loss computation
beta (float): Weight for the odds ratio loss
compute_nll_loss (bool): Whether to compute the NLL loss
nll_target (torch.LongTensor, optional): Target tensor for NLL loss. Shape: (batch_size * seq_len,)
compiled (bool): Whether to use torch compile
chunk_size (int): Size of chunks for processing
Returns:
torch.Tensor: Computed loss
"""
return super().forward(
cls=cls,
ctx=ctx,
_input=_input,
weight=weight,
target=target,
bias=bias,
ignore_index=ignore_index,
beta=beta,
compute_nll_loss=compute_nll_loss,
nll_target=nll_target,
compiled=compiled,
chunk_size=chunk_size,
)
@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None, None
class LigerFusedLinearORPOLoss(torch.nn.Module):
"""
Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
"""
def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
compute_nll_loss: bool = True,
compiled: bool = True,
chunk_size: int = 1,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
compute_nll_loss (bool): Whether to compute the NLL loss.
compiled (bool): Whether to use the torch compiled kernel.
chunk_size (int): Size of chunks for processing.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.chunk_size = chunk_size
def forward(
self,
lin_weight,
_input,
target,
bias=None,
nll_target=None,
):
return LigerFusedLinearORPOFunction.apply(
_input,
lin_weight,
target,
bias,
self.ignore_index,
self.beta,
self.compute_nll_loss,
nll_target,
self.compiled,
self.chunk_size,
)
import torch
import torch.nn.functional as F
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
@staticmethod
def preference_loss_fn(
chosen_logps,
rejected_logps,
full_target,
beta=0.1,
gamma=0.5,
label_smoothing=0.0,
):
"""
Paper: https://arxiv.org/pdf/2405.14734
Formula:
L_SimPO(π_θ) = -E [log σ(β/|y_w| log π_θ(y_w|x) - β/|y_l| log π_θ(y_l|x) - γ)]
Where:
- π_θ(y|x): Policy (model) probability
- y_w: Chosen sequence
- y_l: Rejected sequence
- |y_w|, |y_l|: Sequence lengths
- σ: Sigmoid function
- β: beta weight
- γ: gemma margin term
Args:
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
full_target: Non chunked full target tensor
beta (float): beta weight
gamma (float): gemma margin term
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
"""
logits = beta * (chosen_logps - rejected_logps) - gamma
loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
full_target.shape[0] // 2
)
chosen_rewards = beta * chosen_logps
rejected_rewards = beta * rejected_logps
return loss, chosen_rewards, rejected_rewards
@classmethod
def forward(
cls,
ctx,
_input,
weight,
target,
bias=None,
ignore_index=-100,
beta=0.1,
alpha=1.0,
label_smoothing=0.0,
compute_nll_loss=False,
compiled=True,
gamma=0.5,
chunk_size=1,
):
"""
Fused linear layer with SimPO loss.
Args:
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
ignore_index (int): Index to ignore in loss computation
beta (float): Weight for the odds ratio loss
alpha (float): Weight for the alpha parameter
label_smoothing (float): Label smoothing factor
compute_nll_loss (bool): Whether to compute the NLL loss
compiled (bool): Whether to use torch compile
gamma (float): Weight for the gamma parameter
chunk_size (int): Size of chunks for processing
Returns:
torch.Tensor: Computed loss
"""
return super().forward(
cls=cls,
ctx=ctx,
_input=_input,
weight=weight,
target=target,
bias=bias,
ignore_index=ignore_index,
alpha=alpha,
beta=beta,
label_smoothing=label_smoothing,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
gamma=gamma,
chunk_size=chunk_size,
)
@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None, None, None, None
class LigerFusedLinearSimPOLoss(torch.nn.Module):
"""
Fused linear layer with SimPO loss.
"""
def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
label_smoothing: float = 0.0,
compute_nll_loss: bool = True,
compiled: bool = True,
gamma: float = 0.5,
chunk_size: int = 1,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
alpha (float): Weight for the alpha parameter.
label_smoothing (float): Label smoothing factor.
compute_nll_loss (bool): Whether to compute the NLL loss.
compiled (bool): Whether to use the torch compiled kernel.
gamma (float): Weight for the gamma parameter.
chunk_size (int): Size of chunks for processing.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.alpha = alpha
self.label_smoothing = label_smoothing
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.gamma = gamma
self.chunk_size = chunk_size
def forward(
self,
lin_weight,
_input,
target,
bias=None,
):
return LigerFusedLinearSimPOFunction.apply(
_input,
lin_weight,
target,
bias,
self.ignore_index,
self.beta,
self.alpha,
self.label_smoothing,
self.compute_nll_loss,
self.compiled,
self.gamma,
self.chunk_size,
)
import platform
import sys
from importlib.metadata import version
def print_env_report():
"""
Prints a report of the environment. Useful for debugging and reproducibility.
Usage:
```
python -m liger_kernel.env_report
```
"""
print("Environment Report:")
print("-------------------")
print(f"Operating System: {platform.platform()}")
print(f"Python version: {sys.version.split()[0]}")
try:
print(f"Liger Kernel version: {version('liger-kernel')}")
except ImportError:
print("Liger Kernel: Not installed")
try:
import torch
print(f"PyTorch version: {torch.__version__}")
cuda_version = torch.version.cuda if torch.cuda.is_available() else "Not available"
print(f"CUDA version: {cuda_version}")
hip_version = torch.version.hip if torch.cuda.is_available() and torch.version.hip else "Not available"
print(f"HIP(ROCm) version: {hip_version}")
except ImportError:
print("PyTorch: Not installed")
print("CUDA version: Unable to query")
print("HIP(ROCm) version: Unable to query")
try:
import triton
print(f"Triton version: {triton.__version__}")
except ImportError:
print("Triton: Not installed")
try:
import transformers
print(f"Transformers version: {transformers.__version__}")
except ImportError:
print("Transformers: Not installed")
try:
xpu_version = torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
print(f"XPU version: {xpu_version}")
except ImportError:
print("XPU version: Unable to query")
if __name__ == "__main__":
print_env_report()
"""
Liger-Kernel operators with automatic vendor-specific replacement.
This module provides two ways to import operators:
1. Import from this package (recommended for Function classes):
from liger_kernel.ops import LigerGELUMulFunction
This automatically uses vendor-specific implementation if available.
2. Import from submodules (for kernel functions or specific access):
from liger_kernel.ops.geglu import geglu_forward, geglu_backward
This always uses the default implementation (no auto-replacement).
The replacement mechanism:
1. Default implementations are imported from individual modules (e.g., geglu.py)
2. On module load, device is detected via infer_device()
3. If running on a supported vendor device (npu, xpu, etc.), the default
implementations are replaced with vendor-specific ones
4. All subsequent imports from this package get the replaced versions
Note: Direct imports from submodules (e.g., from liger_kernel.ops.geglu import ...)
are NOT affected by the replacement mechanism.
"""
# =============================================================================
# Import default implementations
# Both Function classes and kernel functions are imported here.
# All of these can be replaced by vendor-specific implementations.
# =============================================================================
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction # noqa: F401
from liger_kernel.ops.cross_entropy import cross_entropy_backward # noqa: F401
from liger_kernel.ops.cross_entropy import cross_entropy_forward # noqa: F401
from liger_kernel.ops.dyt import LigerDyTFunction # noqa: F401
from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction # noqa: F401
from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction # noqa: F401
from liger_kernel.ops.fused_add_rms_norm import fused_add_rms_norm_backward # noqa: F401
from liger_kernel.ops.fused_add_rms_norm import fused_add_rms_norm_forward # noqa: F401
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction # noqa: F401
from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_backward # noqa: F401
from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_forward # noqa: F401
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction # noqa: F401
from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_backward # noqa: F401
from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_forward # noqa: F401
from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction # noqa: F401
from liger_kernel.ops.geglu import LigerGELUMulFunction # noqa: F401
from liger_kernel.ops.geglu import geglu_backward # noqa: F401
from liger_kernel.ops.geglu import geglu_forward # noqa: F401
from liger_kernel.ops.group_norm import LigerGroupNormFunction # noqa: F401
from liger_kernel.ops.group_norm import group_norm_backward # noqa: F401
from liger_kernel.ops.group_norm import group_norm_forward # noqa: F401
from liger_kernel.ops.grpo_loss import GrpoLossFunction # noqa: F401
from liger_kernel.ops.jsd import LigerJSDFunction # noqa: F401
from liger_kernel.ops.jsd import jsd_backward # noqa: F401
from liger_kernel.ops.jsd import jsd_forward # noqa: F401
from liger_kernel.ops.kl_div import LigerKLDivLossFunction # noqa: F401
from liger_kernel.ops.layer_norm import LigerLayerNormFunction # noqa: F401
from liger_kernel.ops.layer_norm import layer_norm_backward # noqa: F401
from liger_kernel.ops.layer_norm import layer_norm_forward # noqa: F401
from liger_kernel.ops.llama4_rope import LigerLlama4RopeFunction # noqa: F401
from liger_kernel.ops.mhc import LigerMHCCoeffsFunction # noqa: F401
from liger_kernel.ops.mhc import LigerMHCPostResFunction # noqa: F401
from liger_kernel.ops.mhc import LigerMHCPreFunction # noqa: F401
from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction # noqa: F401
from liger_kernel.ops.poly_norm import LigerPolyNormFunction # noqa: F401
from liger_kernel.ops.poly_norm import poly_norm_backward # noqa: F401
from liger_kernel.ops.poly_norm import poly_norm_forward # noqa: F401
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction # noqa: F401
from liger_kernel.ops.rms_norm import LigerRMSNormFunction # noqa: F401
from liger_kernel.ops.rms_norm import rms_norm_backward # noqa: F401
from liger_kernel.ops.rms_norm import rms_norm_forward # noqa: F401
from liger_kernel.ops.rope import LigerRopeFunction # noqa: F401
from liger_kernel.ops.rope import rope_backward # noqa: F401
from liger_kernel.ops.rope import rope_forward # noqa: F401
from liger_kernel.ops.softmax import LigerSoftmaxFunction # noqa: F401
from liger_kernel.ops.sparsemax import LigerSparsemaxFunction # noqa: F401
from liger_kernel.ops.swiglu import LigerSiLUMulFunction # noqa: F401
from liger_kernel.ops.swiglu import swiglu_backward # noqa: F401
from liger_kernel.ops.swiglu import swiglu_forward # noqa: F401
from liger_kernel.ops.tiled_mlp import LigerTiledMLPFunction # noqa: F401
from liger_kernel.ops.tiled_mlp import apply_tiled_mlp # noqa: F401
from liger_kernel.ops.tvd import LigerTVDLossFunction # noqa: F401
# NOTE: __all__ is intentionally NOT defined.
# - Import from this package (liger_kernel.ops) -> subject to vendor replacement
# - Import from submodules (liger_kernel.ops.geglu) -> always use default implementation
# =============================================================================
# Vendor-specific replacement logic
# =============================================================================
def _replace_with_vendor_ops():
"""
Replace/add vendor-specific operator implementations.
This function is called automatically on module load. It:
1. Detects the current device (cuda, npu, xpu, etc.)
2. Looks up the vendor for that device via VENDOR_REGISTRY
3. Loads and applies vendor-specific implementations
Vendor implementations should be placed in:
liger_kernel/ops/backends/_<vendor>/ops/
If the vendor module defines __all__, only those symbols are exported.
Otherwise, all public symbols (not starting with _) are auto-discovered.
Note: Vendor can both override existing ops AND add new vendor-specific ops.
"""
from liger_kernel.ops.backends import get_vendor_for_device
from liger_kernel.utils import infer_device
device = infer_device()
# Look up vendor info for this device
vendor_info = get_vendor_for_device(device)
if vendor_info is None:
return
try:
import importlib
vendor_ops = importlib.import_module(vendor_info.module_path)
# Get names to export: use __all__ if defined, otherwise auto-discover
names_to_export = getattr(vendor_ops, "__all__", None)
if names_to_export is None:
# Auto-discover: find all public symbols (classes and functions)
names_to_export = [name for name in dir(vendor_ops) if not name.startswith("_")]
# Replace or add to this module's globals
for name in names_to_export:
globals()[name] = getattr(vendor_ops, name)
except ImportError:
# Vendor module not available, use default implementations
pass
_replace_with_vendor_ops()
# Adding a New Vendor Backend
This directory contains vendor-specific operator implementations that automatically replace the default (CUDA) implementations when running on the corresponding device.
## Concepts
- **Vendor**: Chip manufacturer (e.g., `ascend`, `intel`, `nvidia`)
- **Device**: Device type (e.g., `npu`, `xpu`, `cuda`)
- **VendorInfo**: Defines the mapping between vendor and device
## Directory Structure
```
backends/
├── README.md
├── __init__.py
├── registry.py # VendorInfo, register_vendor(), VENDOR_REGISTRY
├── _ascend/ # Ascend (Huawei) vendor - supports NPU
│ ├── __init__.py # Registers VendorInfo for NPU
│ └── ops/
│ ├── __init__.py # Exports vendor-specific implementations
│ └── geglu.py # NPU-specific GEGLU implementation
└── _<vendor>/ # Your new vendor backend
└── ...
```
## How It Works
1. When `liger_kernel.ops.backends` is imported, it imports all vendor packages (e.g., `_ascend`)
2. Each vendor's `__init__.py` calls `register_vendor()` to register itself
3. When `liger_kernel.ops` is imported, `_replace_with_vendor_ops()` is called
4. It detects the current device via `infer_device()` and looks up the vendor
5. Vendor implementations replace/add to the `liger_kernel.ops` namespace
## Adding a New Vendor
### Step 1: Create Directory Structure
```bash
mkdir -p backends/_<vendor>/ops
touch backends/_<vendor>/__init__.py
touch backends/_<vendor>/ops/__init__.py
```
### Step 2: Register Your Vendor
In `backends/_<vendor>/__init__.py`, register your vendor:
```python
"""
<Vendor> backend for Liger-Kernel.
"""
from liger_kernel.ops.backends.registry import VendorInfo, register_vendor
register_vendor(
VendorInfo(
vendor="<vendor>",
device="<device>",
)
)
```
### Step 3: Ensure Device Detection Works
Make sure `infer_device()` in `liger_kernel/utils.py` can detect your device:
```python
def infer_device():
if torch.cuda.is_available():
return "cuda"
if is_npu_available():
return "npu"
# Add your device detection here
if is_<device>_available():
return "<device>"
return "cpu"
```
### Step 4: Implement Vendor-Specific Operators
Create operator files in `backends/_<vendor>/ops/`. For example, `geglu.py`:
```python
import torch
class LigerGELUMulFunction(torch.autograd.Function):
"""
Vendor-specific LigerGELUMulFunction implementation.
"""
@staticmethod
def forward(ctx, a, b):
# Your vendor-specific forward implementation
...
@staticmethod
def backward(ctx, dc):
# Your vendor-specific backward implementation
...
# Optional: vendor-specific kernel functions
def geglu_forward_vendor(a, b):
...
def geglu_backward_vendor(a, b, dc):
...
```
### Step 5: Export in `ops/__init__.py`
In `backends/_<vendor>/ops/__init__.py`, export your implementations:
```python
"""
<Vendor>-specific operator implementations.
"""
from .<module> import (
LigerGELUMulFunction,
geglu_forward_vendor as geglu_forward, # Rename to match default API
geglu_backward_vendor as geglu_backward,
)
# Explicitly declare what to export (recommended)
__all__ = [
"LigerGELUMulFunction",
"geglu_forward",
"geglu_backward",
]
```
## Key Points
### Incremental Override
You **don't need to implement all operators**. Only implement the ones that require vendor-specific adaptations. Unimplemented operators will automatically fall back to the default (CUDA) implementation.
### Vendor-Specific Additions
Vendors can also **add new operators** that don't exist in the default implementation. These will be exported to `liger_kernel.ops` namespace for users to import.
### Naming Convention
- Use the **same class/function names** as the default implementations for overrides
- This allows seamless replacement without changing user code
- Use `as` imports to rename if your internal naming differs
## Example: Ascend NPU Backend
See `_ascend/` directory for a complete example of the Ascend NPU backend implementation.
import importlib
import pkgutil
from liger_kernel.ops.backends.registry import VENDOR_REGISTRY # noqa: F401
from liger_kernel.ops.backends.registry import VendorInfo # noqa: F401
from liger_kernel.ops.backends.registry import get_vendor_for_device # noqa: F401
from liger_kernel.ops.backends.registry import register_vendor # noqa: F401
# Auto-import all _<vendor> subpackages to trigger registration
# Each vendor's __init__.py calls register_vendor() when imported
for _, modname, ispkg in pkgutil.iter_modules(__path__):
if ispkg and modname.startswith("_"):
importlib.import_module(f"{__name__}.{modname}")
from liger_kernel.ops.backends.registry import VendorInfo
from liger_kernel.ops.backends.registry import register_vendor
# Register Ascend vendor for NPU device
register_vendor(VendorInfo(vendor="ascend", device="npu"))
# Ascend NPU UB Manager Design Document
## Overview
The UB Manager (Unified Buffer Manager) is a core component in **Liger-Kernel** responsible for managing the Unified Buffer (UB) capacity on Ascend NPUs. By automatically detecting UB capacity and providing unified tiling strategy computation, it helps Triton kernels avoid UB overflow errors while maintaining high performance.
## Design Goals
1. **Automated UB Management**: Automatically detect device UB capacity without manual configuration
2. **Unified Strategy System**: Use a single unified strategy function for all kernels, abstracting memory calculations
3. **Flexible Parameters**: Support different memory multipliers and safety margins for different kernels
4. **Easy to Use**: Simple interface that directly computes tiling results
## Architecture Design
### Core Components
```
┌─────────────────────────────────────────────────────────┐
│ UB Manager System │
├─────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────────┐ │
│ │ UBManager │ │ Default Strategy │ │
│ │ (Singleton)│────────▶│ Function │ │
│ └──────────────┘ └──────────────────┘ │
│ │ │ │
│ │ │ │
│ ▼ ▼ │
│ ┌──────────────┐ ┌──────────────────┐ │
│ │ Capacity │ │ compute_default │ │
│ │ Detection │ │ _tiling_strategy│ │
│ └──────────────┘ └──────────────────┘ │
│ │
└─────────────────────────────────────────────────────────┘
│ │
│ │
▼ ▼
┌──────────────┐ ┌──────────────────┐
│ GEGLU │ │ ROPE │
│ Kernel │ │ Kernel │
└──────────────┘ └──────────────────┘
```
### Class Diagram
```
┌──────────────────────────────────────┐
│ UBManager │
├──────────────────────────────────────┤
│ - _npu_model: str │
│ - _ub_capacity_bits: int │
├──────────────────────────────────────┤
│ + ub_capacity_bits: int │
│ + ub_capacity_bytes: int │
│ + npu_model: str │
│ - _detect_npu_model() │
│ - _detect_ub_capacity() │
│ (raises RuntimeError if fails) │
└──────────────────────────────────────┘
┌──────────────────────────────────────┐
│ compute_default_tiling_strategy │
├──────────────────────────────────────┤
│ + safety_margin: float │
│ + dtype_size: int │
│ + memory_multiplier: float │
│ + shapes: Tuple[Tuple[int, ...], ...]│
│ + tiling_dims: Tuple │
├──────────────────────────────────────┤
│ Returns: Tuple[Tuple[int, ...], ...] │
│ (same structure as shapes) │
└──────────────────────────────────────┘
┌──────────────────────────────────────┐
│ _normalize_tiling_dims │
├──────────────────────────────────────┤
│ Helper function to normalize │
│ tiling_dim (int or tuple) to set │
└──────────────────────────────────────┘
```
## Core Functionality
### 1. UB Capacity Detection
The UB Manager detects UB capacity in the following priority order:
1. **Environment Variable**: `ASCEND_UB_CAPACITY_BITS` (in bits)
- If set, this value is used directly
- Must be a positive integer representing UB capacity in bits
2. **get_soc_spec**: Query UB size from CANN's `get_soc_spec("UB_SIZE")`
- Returns UB size in bytes
- Automatically converted to bits (bytes * 8)
- Requires CANN environment to be sourced (e.g., `source /usr/local/Ascend/ascend-toolkit/set_env.sh`)
3. **Error Handling**: If neither method succeeds, raises `RuntimeError` with clear instructions
```python
# Detection flow:
# 1. Check ASCEND_UB_CAPACITY_BITS env var (bits)
# 2. Try get_soc_spec("UB_SIZE") (bytes) -> convert to bits
# 3. Raise RuntimeError if both fail
```
### 2. Unified Strategy System
All kernels use a single unified strategy function `_default_strategy` that abstracts memory calculations:
```
Memory Formula: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
```
Where `unit_param` is automatically calculated as the product of all fixed (non-tiling) dimensions in each shape.
The strategy function:
- Takes UB capacity, safety margin, dtype size, memory multiplier, shapes, and tiling dimension specifications
- For each shape, identifies which dimensions can be tiled (from `tiling_dims`)
- Calculates `unit_param` as the product of fixed (non-tiling) dimensions
- Calculates the maximum safe block size that fits within UB capacity
- Returns a tuple of max_safe_block_size values (one for each shape)
The `compute_default_tiling_strategy` function:
- Calls `_default_strategy` to get max_safe_block_size for each shape
- For each tiling dimension, computes desired block size using `triton.next_power_of_2(original_dim)`
- Returns the final result with same structure as input shapes: tiling dimensions replaced with computed block sizes, non-tiling dimensions padded to next power of 2
### 3. Parameter Structure
The unified strategy uses the following parameters:
- **`safety_margin`**: Safety margin as a float (e.g., 0.80 for 80%). Default is 0.80.
- **`dtype_size`**: Size of data type in bytes (e.g., 2 for float16, 4 for float32)
- **`memory_multiplier`**: Memory multiplier for estimating peak memory usage
- For GEGLU: typically 10.0 for backward, 7.0 for forward
- For ROPE: typically 3.0
- **`shapes`**: Tuple of full shapes. Each shape is a tuple of dimension sizes.
- For ROPE: `((n_q_head, hd), (n_kv_head, hd))`
- For GEGLU: `((n_cols,),)`
- Can pass original shapes (will handle padding internally) or padded shapes
- **`tiling_dims`**: Tuple specifying which dimensions can be tiled for each shape.
- Each element can be:
- `int`: single dimension index (e.g., `0` for first dimension)
- `tuple of ints`: multiple dimensions that can be tiled together (non-empty)
- For ROPE: `(0, 0)` means first dimension of each shape can be tiled
- For GEGLU: `(0,)` means first dimension of the shape can be tiled
- Length must match `len(shapes)`
- Fixed dimensions (non-tiling) are automatically extracted from shapes and multiplied to get `unit_param`
- **Validation**: Raises `ValueError` if:
- Any `tiling_dim` is empty or invalid (e.g., empty tuple)
- Any dimension index is out of bounds (negative or >= shape length)
### 4. Strategy Computation Flow
```
User calls compute_default_tiling_strategy()
Get UB manager instance
Validate shapes and tiling_dims (lengths must match)
Set defaults for dtype_size (4) and memory_multiplier (10.0)
Call _default_strategy() with:
- ub_capacity_bits
- safety_margin
- dtype_size
- memory_multiplier
- shapes
- tiling_dims
For each (shape, tiling_dim) pair:
Normalize tiling_dim to set of dimension indices
Validate tiling dimensions are within shape bounds
(Raises ValueError if invalid)
Calculate unit_param:
unit_param = product of all non-tiling dimensions
Calculate max_block_size:
SAFE_UB_CAPACITY_BITS = ub_capacity_bits * safety_margin
max_block_size = SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8)
Find largest power of 2 <= max_block_size
Return tuple of max_safe_block_size (one per shape)
Build result with same structure as shapes:
For each (shape, tiling_dim, max_safe):
For each tiling dimension:
desired = triton.next_power_of_2(original_dim)
final = min(desired, max_safe)
final = max(1, final)
For each non-tiling dimension:
pad to triton.next_power_of_2(original_dim)
Return tuple of tiled shapes
```
## Usage Examples
### Basic Usage
```python
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
# GEGLU forward
shapes = ((4096,),)
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.80,
dtype_size=2, # float16
memory_multiplier=7.0,
shapes=shapes,
tiling_dims=(0,) # First dimension can be tiled
)
if tile_shapes is not None and len(tile_shapes) > 0:
block_size = tile_shapes[0][0]
# Call kernel with block_size
# ROPE forward
shapes = ((32, 128), (32, 128)) # (n_q_head, hd), (n_kv_head, hd)
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.90,
dtype_size=4, # float32
memory_multiplier=3.0,
shapes=shapes,
tiling_dims=(0, 0) # First dimension of each shape can be tiled
)
if tile_shapes is not None and len(tile_shapes) == len(shapes):
q_tile_shape, k_tile_shape = tile_shapes
BLOCK_Q, _ = q_tile_shape # Tiled dimension
BLOCK_K, _ = k_tile_shape # Tiled dimension
# Call kernel with BLOCK_Q and BLOCK_K
```
## Strategy Function Details
### `_normalize_tiling_dims` Helper Function
A helper function that normalizes tiling dimension specifications:
```python
def _normalize_tiling_dims(tiling_dim: Union[int, Tuple[int, ...]]) -> set:
"""
Normalize tiling dimension specification to a set of dimension indices.
Args:
tiling_dim: Either an int (single dimension) or tuple of ints (multiple dimensions).
Returns:
Set of dimension indices that can be tiled.
"""
```
This function handles the conversion of `tiling_dim` from either an `int` or `tuple` to a `set` for consistent processing.
### `_default_strategy` Function
The core strategy function that calculates maximum safe block size:
```python
def _default_strategy(
ub_capacity_bits: int,
safety_margin: float,
dtype_size: int,
memory_multiplier: float,
shapes: Tuple[Tuple[int, ...], ...],
tiling_dims: Tuple[Union[int, Tuple[int, ...]], ...],
) -> Tuple[int, ...]:
"""
Calculate maximum safe block size based on UB capacity.
Memory formula: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
For each shape, fixed dimensions (non-tiling) are multiplied together to get unit_param.
Returns:
Tuple of max_safe_block_size (power of 2), one for each shape.
Raises:
ValueError: If any tiling_dim is empty or invalid, or if any dimension
index is out of bounds for the corresponding shape.
"""
```
**Key Steps:**
1. For each `(shape, tiling_dim)` pair:
- Normalize `tiling_dim` to a set of dimension indices using `_normalize_tiling_dims`
- Validate tiling dimensions are within shape bounds
- Raises `ValueError` if `tiling_dim` is empty or invalid
- Raises `ValueError` if any dimension index is out of bounds
- Calculate `unit_param` as the product of all non-tiling dimensions
- If all dimensions are tiling, `unit_param = 1.0`
2. Calculate `SAFE_UB_CAPACITY_BITS = ub_capacity_bits * safety_margin`
3. Solve for max_block_size: `SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8)`
4. Find largest power of 2 <= max_block_size
5. Return tuple with one max_safe_block_size per shape
### `compute_default_tiling_strategy` Function
The public interface that computes final tiling results:
```python
def compute_default_tiling_strategy(
safety_margin: float = 0.80,
dtype_size: Optional[int] = None,
memory_multiplier: Optional[float] = None,
shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
tiling_dims: Optional[Tuple[Union[int, Tuple[int, ...]], ...]] = None,
) -> Optional[Tuple[Tuple[int, ...], ...]]:
"""
Compute tiling strategy using the default strategy function.
Returns tuple of tiled shapes with same structure as input shapes.
Tiling dimensions are replaced with computed block sizes (power of 2),
while non-tiling dimensions are padded to next power of 2.
Returns:
Tuple of tiled shapes, or None if shapes/tiling_dims are empty or
lengths don't match.
Raises:
ValueError: If any tiling_dim is empty or invalid, or if any dimension
index is out of bounds for the corresponding shape.
"""
```
**Key Steps:**
1. Get UB manager instance
2. Validate `shapes` and `tiling_dims` (lengths must match, cannot be empty)
- Returns `None` if validation fails (empty or mismatched lengths)
3. Set defaults for `dtype_size` (4) and `memory_multiplier` (10.0) if not provided
4. Call `_default_strategy` to get `max_supported` (tuple of max_safe_block_size, one per shape)
- May raise `ValueError` if `tiling_dims` are invalid (see `_default_strategy` documentation)
5. For each `(shape, tiling_dim, max_safe)`:
- Normalize `tiling_dim` to a set of dimension indices
- Validate tiling dimensions are within shape bounds
- Raises `ValueError` if `tiling_dim` is empty or invalid
- Raises `ValueError` if any dimension index is out of bounds
- For each tiling dimension:
- Compute `desired = triton.next_power_of_2(original_dim)`
- Compute `final = min(desired, max_safe)`
- Ensure `final >= 1`
- Replace dimension with `final`
- For each non-tiling dimension:
- Pad to `triton.next_power_of_2(original_dim)`
6. Return tuple of tiled shapes (same structure as input `shapes`)
## Memory Analysis Examples
### GEGLU Forward
```
Memory analysis:
- Inputs: a, b
- Intermediates: a_cubed, tanh_arg, tanh_result, geglu_a
- Output: c
- Total: ~7x * BLOCK_SIZE * dtype_size
Strategy:
- shapes: ((n_cols,),)
- tiling_dims: (0,) # First dimension can be tiled
- Fixed dimensions: none (all dimensions are tiling)
- unit_param = 1 (product of fixed dimensions)
- memory_multiplier = 7.0
- Formula: 7.0 * BLOCK_SIZE * 1 * dtype_size * 8 bits
- Returns: ((block_size,),)
```
### GEGLU Backward
```
Memory analysis:
- More intermediates for gradient computation
- Total: ~10x * BLOCK_SIZE * dtype_size
Strategy:
- shapes: ((n_cols,),)
- tiling_dims: (0,) # First dimension can be tiled
- Fixed dimensions: none (all dimensions are tiling)
- unit_param = 1 (product of fixed dimensions)
- memory_multiplier = 10.0
- Formula: 10.0 * BLOCK_SIZE * 1 * dtype_size * 8 bits
- Returns: ((block_size,),)
```
### ROPE Forward/Backward
```
Memory analysis (based on optimized ROPE kernel):
- cos_vals and sin_vals: pad_hd // 2 elements each (shared)
- In q heads loop (peak memory):
* q_left, q_right, new_left, new_right: 2 * BLOCK_Q * pad_hd elements
- In k heads loop (peak memory):
* k_left, k_right, new_left, new_right: 2 * BLOCK_K * pad_hd elements
- Plus shared cos/sin: pad_hd elements
- Conservative estimate: 3 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
Strategy:
- shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
- tiling_dims: (0, 0) # First dimension of each shape can be tiled
- Fixed dimensions: pad_hd (second dimension, non-tiling)
- unit_param = pad_hd (product of fixed dimensions)
- memory_multiplier = 3.0
- Formula: 3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
- Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
```
## Extension Guide
### Adding a New Kernel
To add tiling support for a new kernel:
1. **Analyze memory usage**:
- Identify peak memory usage in the kernel
- Determine memory multiplier (e.g., 7.0, 10.0, 3.0)
- Identify which dimensions can be tiled and which are fixed
- Fixed dimensions will be automatically extracted and multiplied to get `unit_param`
2. **Use `compute_default_tiling_strategy`** in your kernel:
```python
def my_kernel_forward(input):
# Prepare parameters
n_cols = input.shape[-1]
dtype_size = input.element_size()
# Compute strategy
# Example 1: Simple case (all dimensions can be tiled)
shapes = ((n_cols,),)
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.80,
dtype_size=dtype_size,
memory_multiplier=7.0, # Based on your memory analysis
shapes=shapes,
tiling_dims=(0,) # First dimension can be tiled
)
if tile_shapes is not None and len(tile_shapes) > 0:
block_size = tile_shapes[0][0]
else:
block_size = triton.next_power_of_2(n_cols) # Fallback
# Example 2: Multiple shapes with fixed dimensions
# shapes = ((M, K), (K, N))
# tiling_dims = (0, 1) # First shape: dim 0 can be tiled, dim 1 is fixed
# # Second shape: dim 0 is fixed, dim 1 can be tiled
# Returns: ((block_M, K), (K, block_N))
# Call kernel
kernel[(grid_size,)](
input,
BLOCK_SIZE=block_size,
)
```
3. **Document memory analysis** in comments:
```python
# My kernel tiling strategy:
# - Memory analysis:
# * Input: input
# * Intermediates: intermediate1, intermediate2
# * Output: output
# * Total: ~7x * BLOCK_SIZE * dtype_size
# - shapes: ((n_cols,),)
# - tiling_dims: (0,) means first dimension can be tiled
# - Fixed dimensions: none (all dimensions are tiling)
# - unit_param = 1 (product of fixed dimensions)
# - Uses memory_multiplier=7.0 * BLOCK_SIZE * dtype_size * 8 bits for safety
# - compute_default_tiling_strategy returns: ((block_size,),)
# where block_size = min(triton.next_power_of_2(n_cols), max_safe_block_size)
```
## Future Improvements
1. **Strategy Variants**: If needed, could add specialized strategy functions for specific kernels while keeping the unified interface
2. **Multi-dimensional Tiling**: Could extend to support more complex tiling patterns if needed
"""
Ascend NPU operator implementations.
This module exports Ascend NPU-optimized implementations that will automatically
replace the default implementations when running on NPU devices.
Both Function classes and kernel functions can be exported here.
To add a new operator:
1. Create the implementation file (e.g., rms_norm.py)
2. Import the Function class and/or kernel functions here
3. Optionally add to __all__ for explicit control
If __all__ is not defined, all public symbols will be auto-discovered.
"""
from liger_kernel.ops.backends._ascend.ops.cross_entropy import LigerCrossEntropyFunction
from liger_kernel.ops.backends._ascend.ops.cross_entropy import cross_entropy_backward
from liger_kernel.ops.backends._ascend.ops.cross_entropy import cross_entropy_forward
from liger_kernel.ops.backends._ascend.ops.dyt import LigerDyTFunction
from liger_kernel.ops.backends._ascend.ops.dyt import liger_dyt_bwd
from liger_kernel.ops.backends._ascend.ops.dyt import liger_dyt_fwd
from liger_kernel.ops.backends._ascend.ops.embedding import LigerEmbeddingFunction
from liger_kernel.ops.backends._ascend.ops.embedding import embedding_backward
from liger_kernel.ops.backends._ascend.ops.embedding import embedding_forward
from liger_kernel.ops.backends._ascend.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction
from liger_kernel.ops.backends._ascend.ops.fused_add_rms_norm import fused_add_rms_norm_backward
from liger_kernel.ops.backends._ascend.ops.fused_add_rms_norm import fused_add_rms_norm_forward
from liger_kernel.ops.backends._ascend.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
from liger_kernel.ops.backends._ascend.ops.fused_linear_jsd import fused_linear_jsd_backward
from liger_kernel.ops.backends._ascend.ops.fused_linear_jsd import fused_linear_jsd_forward
from liger_kernel.ops.backends._ascend.ops.geglu import LigerGELUMulFunction
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_backward
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_forward
from liger_kernel.ops.backends._ascend.ops.group_norm import LigerGroupNormFunction
from liger_kernel.ops.backends._ascend.ops.group_norm import group_norm_backward
from liger_kernel.ops.backends._ascend.ops.group_norm import group_norm_forward
from liger_kernel.ops.backends._ascend.ops.grpo_loss import GrpoLossFunction
from liger_kernel.ops.backends._ascend.ops.grpo_loss import grpo_loss_backward_triton
from liger_kernel.ops.backends._ascend.ops.grpo_loss import grpo_loss_forward_triton
from liger_kernel.ops.backends._ascend.ops.jsd import LigerJSDFunction
from liger_kernel.ops.backends._ascend.ops.jsd import jsd_backward
from liger_kernel.ops.backends._ascend.ops.jsd import jsd_forward
from liger_kernel.ops.backends._ascend.ops.kl_div import LigerKLDivLossFunction
from liger_kernel.ops.backends._ascend.ops.kl_div import kldiv_backward_triton
from liger_kernel.ops.backends._ascend.ops.kl_div import kldiv_forward_triton
from liger_kernel.ops.backends._ascend.ops.layer_norm import LigerLayerNormFunction
from liger_kernel.ops.backends._ascend.ops.layer_norm import layer_norm_backward
from liger_kernel.ops.backends._ascend.ops.layer_norm import layer_norm_forward
from liger_kernel.ops.backends._ascend.ops.llama4_rope import LigerLlama4RopeFunction
from liger_kernel.ops.backends._ascend.ops.llama4_rope import llama4_rope_backward
from liger_kernel.ops.backends._ascend.ops.llama4_rope import llama4_rope_forward
from liger_kernel.ops.backends._ascend.ops.poly_norm import LigerPolyNormFunction
from liger_kernel.ops.backends._ascend.ops.poly_norm import poly_norm_backward
from liger_kernel.ops.backends._ascend.ops.poly_norm import poly_norm_forward
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_backward
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_forward
from liger_kernel.ops.backends._ascend.ops.rms_norm import LigerRMSNormFunction
from liger_kernel.ops.backends._ascend.ops.rms_norm import rms_norm_backward
from liger_kernel.ops.backends._ascend.ops.rms_norm import rms_norm_forward
from liger_kernel.ops.backends._ascend.ops.rope import LigerRopeFunction
from liger_kernel.ops.backends._ascend.ops.rope import rope_backward
from liger_kernel.ops.backends._ascend.ops.rope import rope_forward
from liger_kernel.ops.backends._ascend.ops.softmax import LigerSoftmaxFunction
from liger_kernel.ops.backends._ascend.ops.softmax import _softmax_backward
from liger_kernel.ops.backends._ascend.ops.softmax import _softmax_forward
from liger_kernel.ops.backends._ascend.ops.sparsemax import LigerSparsemaxFunction
from liger_kernel.ops.backends._ascend.ops.sparsemax import sparsemax_backward
from liger_kernel.ops.backends._ascend.ops.sparsemax import sparsemax_forward
from liger_kernel.ops.backends._ascend.ops.swiglu import LigerSiLUMulFunction
from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_backward
from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_forward
from liger_kernel.ops.backends._ascend.ops.tvd import LigerTVDLossFunction
from liger_kernel.ops.backends._ascend.ops.tvd import tv_distance_forward_triton
from liger_kernel.ops.backends._ascend.ops.tvd import tvd_backward_triton
__all__ = [
"LigerEmbeddingFunction",
"embedding_forward",
"embedding_backward",
"LigerFusedAddRMSNormFunction",
"fused_add_rms_norm_forward",
"fused_add_rms_norm_backward",
"LigerGELUMulFunction",
"geglu_forward",
"geglu_backward",
"LigerQwen2VLMRopeFunction",
"qwen2vl_mrope_forward",
"qwen2vl_mrope_backward",
"LigerRMSNormFunction",
"rms_norm_forward",
"rms_norm_backward",
"LigerRopeFunction",
"rope_forward",
"rope_backward",
"LigerSiLUMulFunction",
"swiglu_forward",
"swiglu_backward",
"LigerTVDLossFunction",
"tv_distance_forward_triton",
"tvd_backward_triton",
"LigerLlama4RopeFunction",
"llama4_rope_forward",
"llama4_rope_backward",
"LigerPolyNormFunction",
"poly_norm_forward",
"poly_norm_backward",
"LigerDyTFunction",
"liger_dyt_fwd",
"liger_dyt_bwd",
"LigerKLDivLossFunction",
"kldiv_forward_triton",
"kldiv_backward_triton",
"LigerLayerNormFunction",
"layer_norm_backward",
"layer_norm_forward",
"LigerSoftmaxFunction",
"_softmax_forward",
"_softmax_backward",
"LigerJSDFunction",
"jsd_forward",
"jsd_backward",
"LigerCrossEntropyFunction",
"cross_entropy_backward",
"cross_entropy_forward",
"GrpoLossFunction",
"grpo_loss_forward_triton",
"grpo_loss_backward_triton",
"LigerFusedLinearJSDFunction",
"fused_linear_jsd_forward",
"fused_linear_jsd_backward",
"LigerGroupNormFunction",
"group_norm_forward",
"group_norm_backward",
"LigerSparsemaxFunction",
"sparsemax_forward",
"sparsemax_backward",
]
from typing import Optional
import torch
import triton
import triton.language as tl
from triton.language.math import tanh
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
from liger_kernel.ops.utils import element_mul_kernel
from liger_kernel.ops.utils import get_npu_core_count
@triton.jit
def liger_cross_entropy_kernel(
X_ptr,
X_stride,
Y_ptr,
Y_stride,
weight_ptr,
loss_ptr,
z_loss_ptr,
loss_stride,
token_accuracy_ptr,
token_accuracy_stride,
predicted_tokens_ptr,
predicted_tokens_stride,
n_cols,
n_rows,
n_non_ignore,
sum_non_ignore_weight,
weight_sum,
ignore_index,
lse_square_scale: tl.constexpr,
label_smoothing: tl.constexpr,
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
softcap,
RETURN_Z_LOSS: tl.constexpr,
RETURN_TOKEN_ACCURACY: tl.constexpr,
RETURN_PREDICTED_TOKENS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_SOFTCAPPING: tl.constexpr,
HAS_GRADIENTS: tl.constexpr,
):
"""
This kernel computes both cross entropy loss and the gradient of the input.
We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
Parameters:
X_ptr: Pointer to input tensor.
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
weight_ptr: Pointer to weight tensor.
loss_ptr: Pointer to tensor to store the loss.
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
loss_stride (int): The stride of the loss tensor.
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
token_accuracy_stride (int): The stride of the token accuracy tensor.
n_cols (int): The number of columns in the input tensor.
n_rows (int): The total number of rows to process.
n_non_ignore (float): The number of non-ignored elements in the batch.
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
weight_sum (float): The sum of weight tensor.
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
reduction (str): The string for the reduction to apply
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
BLOCK_SIZE (int): The block size for Triton operations.
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
"""
# Grid-Stride Loop: each program processes multiple rows
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
start_row = pid
stride = num_progs
for row_idx in range(start_row, n_rows, stride):
# https://github.com/triton-lang/triton/issues/1058
# If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64
program_id = row_idx.to(tl.int64)
# 1. Load Y_ptr first because if the target is ignore_index, we can return right away
Y_ptr_offset = program_id * Y_stride
y = tl.load(Y_ptr + Y_ptr_offset)
# 2. locate the start index
X_ptr_offset = program_id * X_stride
is_ignored = y == ignore_index
if is_ignored:
# set all X_ptr as 0
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(X_ptr + X_ptr_offset + X_offsets, 0.0, mask=X_offsets < n_cols)
# For ignored tokens, set token accuracy to 0
if RETURN_TOKEN_ACCURACY:
token_accuracy_ptr_offset = program_id * token_accuracy_stride
tl.store(token_accuracy_ptr + token_accuracy_ptr_offset, 0.0)
if RETURN_PREDICTED_TOKENS:
predicted_tokens_ptr_offset = program_id * predicted_tokens_stride
tl.store(predicted_tokens_ptr + predicted_tokens_ptr_offset, -1)
else:
loss_ptr_offset = program_id * loss_stride
if RETURN_Z_LOSS:
z_loss_ptr_offset = program_id * loss_stride
if RETURN_TOKEN_ACCURACY:
token_accuracy_ptr_offset = program_id * token_accuracy_stride
if RETURN_PREDICTED_TOKENS:
predicted_tokens_ptr_offset = program_id * predicted_tokens_stride
if HAS_WEIGHT:
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
# 3. [Online softmax] first pass: find max + sum
m = float("-inf") # m is the max value. use the notation from the paper
d = 0.0 # d is the sum. use the notation from the paper
argmax_idx = 0 # Track the index of the maximum value for token accuracy / predicted tokens computation
ori_X_y = tl.load(X_ptr + X_ptr_offset + y).cast(
tl.float32
) # we need to store the original value of X_y for the loss calculation
if HAS_SOFTCAPPING:
ori_X_y = softcap * tanh(ori_X_y / softcap)
# Label smoothing is a general case of normal cross entropy
# See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
scaled_x_sum = 0.0
eps = label_smoothing / n_cols
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_ptr_offset + X_offsets,
mask=X_offsets < n_cols,
other=float("-inf"),
# Ensure float32 precision for softmax calculation
).cast(tl.float32)
if HAS_SOFTCAPPING:
X_block = softcap * tanh(X_block / softcap)
block_max = tl.max(X_block)
# Track argmax for accuracy / predicted tokens computation
if RETURN_TOKEN_ACCURACY or RETURN_PREDICTED_TOKENS:
# Find the index of the maximum value in this block
is_max_mask = X_block == block_max
# Mask out invalid indices with a value larger than n_cols
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
# Get the first (smallest) index where max occurs
current_block_argmax_idx = tl.min(masked_offsets)
is_new_max = block_max > m
argmax_idx = tl.where(is_new_max, current_block_argmax_idx, argmax_idx)
if label_smoothing > 0:
# scale X beforehand to avoid overflow
if HAS_WEIGHT:
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0))
else:
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
m_new = tl.maximum(m, block_max)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
m = m_new
# log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X)))))
# = log (e^(max(X)) * sum(e ^ (X_i - max(X))))
# = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
lse = m + tl.log(d)
# 4. [Online Softmax] Second pass: compute gradients
# For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
# dx_y = (softmax(x_y) - 1) / N
# dx_i = softmax(x_i) / N, i != y
# For label smoothing:
# dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
# = dx_i - (1 - label_smoothing) / N
# With Z loss:
# dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y
# dx_y = dx_i - (1 - label_smoothing) / N
# For 'sum' reduction, no normalization is applied:
# dx_y = softmax(x_y) - 1
# dx_i = softmax(x_i), for i ≠ y
if HAS_GRADIENTS:
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_ptr_offset + X_offsets,
mask=X_offsets < n_cols,
other=float("-inf"),
# Ensure float32 precision for softmax calculation
).cast(tl.float32)
if HAS_SOFTCAPPING:
intermediate = tanh(X_block / softcap)
X_block = softcap * intermediate
if not HAS_WEIGHT:
# softmax(x_i)
X_block = tl.exp(X_block - m) / d
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
X_block += 2 * lse_square_scale * lse * X_block
# smoothing term
X_block += -eps
# special handle dx_y
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
# reduction scale
if reduction == "mean":
X_block = X_block / n_non_ignore
else:
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
softmax_X = tl.exp(X_block - m) / d
# derivative of original_loss
dloss_ori = (1 - label_smoothing) * softmax_X
# specially handle dx_y
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
dloss_ori = dloss_ori * weight_y
# derivative of smooth_loss
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
# derivative of z-loss
dz_loss = 2 * lse_square_scale * lse * softmax_X
# reduction scale
if reduction == "mean":
dloss_ori = dloss_ori / sum_non_ignore_weight
dloss_smooth = dloss_smooth / sum_non_ignore_weight
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
dz_loss = dz_loss / n_non_ignore
# derivative of total_loss
X_block = dloss_ori + dloss_smooth + dz_loss
# chain rule softcapping
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
if HAS_SOFTCAPPING:
X_block = X_block * (1 - intermediate * intermediate)
tl.store(X_ptr + X_ptr_offset + X_offsets, X_block, mask=X_offsets < n_cols)
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
tl.debug_barrier()
# 5. Calculate the loss
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
# = X_y - m - log d = X_y - lse
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
# So we can safely calculate log (softmax(X_y)) without overflow
loss = lse - ori_X_y
if HAS_WEIGHT:
loss = weight_y * loss
# Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
# = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
if label_smoothing > 0:
if HAS_WEIGHT:
smooth_loss = scaled_x_sum + eps * lse * weight_sum
else:
smooth_loss = scaled_x_sum + label_smoothing * lse
loss = loss * (1 - label_smoothing) + smooth_loss
# An auxiliary loss, z_loss
# Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
z_loss = lse_square_scale * lse * lse
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
if reduction == "mean":
if HAS_WEIGHT:
loss = loss / sum_non_ignore_weight
else:
loss = loss / n_non_ignore
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
z_loss = z_loss / n_non_ignore
loss += z_loss
tl.store(loss_ptr + loss_ptr_offset, loss)
if RETURN_Z_LOSS:
tl.store(z_loss_ptr + z_loss_ptr_offset, z_loss)
if RETURN_TOKEN_ACCURACY:
# Store 1.0 if prediction is correct, 0.0 otherwise
is_correct = 1.0 if argmax_idx == y else 0.0
tl.store(token_accuracy_ptr + token_accuracy_ptr_offset, is_correct)
if RETURN_PREDICTED_TOKENS:
tl.store(predicted_tokens_ptr + predicted_tokens_ptr_offset, argmax_idx)
def get_optimal_block_size(n_cols, has_gradients=True):
"""
Calculate optimal Block Size using compute_default_tiling_strategy
"""
# Cross entropy is more memory intensive than swiglu because it needs softmax computation
# Forward needs online softmax calculation, backward needs more memory for intermediate variables
# 10.0 and 16.0 are empirical values based on Atlas 800I A2 UB (192KB)
multiplier = 12.0 if has_gradients else 8.0
# Call calculation function
# Treat input as 1D (n_cols,), only tiling on dim 0
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((n_cols,),), tiling_dims=(0,)
)
# Parse result
if tile_shapes and len(tile_shapes) > 0:
block_size = tile_shapes[0][0]
return block_size
else:
return 2048
def cross_entropy_forward(
_input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
reduction,
softcap,
return_z_loss,
return_token_accuracy=False,
return_predicted_tokens=False,
):
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
assert isinstance(return_token_accuracy, bool), (
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
)
assert isinstance(return_predicted_tokens, bool), (
f"return_predicted_tokens must be True or False. Got: {return_predicted_tokens}"
)
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = get_optimal_block_size(V, has_gradients=_input.requires_grad)
# unreduced loss
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
token_accuracy_1d = (
torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
)
predicted_tokens_1d = (
torch.full((n_rows,), -1, dtype=torch.int64, device=_input.device) if return_predicted_tokens else None
)
target_mask = target != ignore_index
n_non_ignore = target_mask.sum().item()
assert (target * target_mask).max() < _input.shape[-1], (
f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}"
)
assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0"
sum_non_ignore_weight = n_non_ignore
weight_sum = 0.0
if weight is not None:
assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
assert torch.is_floating_point(weight), (
f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
)
sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
weight_sum = weight.sum().item()
# ensure weight is contiguous
if weight.stride(-1) != 1:
weight = weight.contiguous()
# ensure _input and target are contiguous in the last dimension
if _input.stride(-1) != 1:
_input = _input.contiguous()
if target.stride(-1) != 1:
target = target.contiguous()
# NPU-optimized grid configuration
num_cores = get_npu_core_count()
grid_size = min(num_cores, n_rows)
# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
liger_cross_entropy_kernel[(grid_size,)](
X_ptr=_input,
X_stride=_input.stride(-2),
Y_ptr=target,
Y_stride=target.stride(-1), # always 1
weight_ptr=weight, # dummy if None
loss_ptr=loss_1d,
z_loss_ptr=z_loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
token_accuracy_ptr=token_accuracy_1d,
token_accuracy_stride=token_accuracy_1d.stride(-1)
if return_token_accuracy
else 0, # always 1 if accuracy is enabled
predicted_tokens_ptr=predicted_tokens_1d,
predicted_tokens_stride=predicted_tokens_1d.stride(-1)
if return_predicted_tokens
else 0, # always 1 if predicted tokens is enabled
n_cols=V,
n_rows=n_rows,
n_non_ignore=n_non_ignore,
sum_non_ignore_weight=sum_non_ignore_weight,
ignore_index=ignore_index,
weight_sum=weight_sum,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap,
RETURN_Z_LOSS=return_z_loss,
RETURN_TOKEN_ACCURACY=return_token_accuracy,
RETURN_PREDICTED_TOKENS=return_predicted_tokens,
BLOCK_SIZE=BLOCK_SIZE,
HAS_WEIGHT=True if weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
HAS_GRADIENTS=_input.requires_grad,
)
if reduction == "none":
loss = loss_1d
z_loss = z_loss_1d if return_z_loss else None
token_accuracy = token_accuracy_1d if return_token_accuracy else None
else:
loss = torch.sum(loss_1d)
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
# For accuracy, we compute the mean across all non-ignored tokens
token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
predicted_tokens = predicted_tokens_1d if return_predicted_tokens else None
return loss, z_loss, token_accuracy, predicted_tokens, _input
def cross_entropy_backward(_input, grad_output):
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
pass
# If reduction is 'none'
elif grad_output.ndim > 0:
_input = _input * grad_output.unsqueeze(dim=1)
# If reduction is ['mean', 'sum'], grad_output is just a scalar
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
else:
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = min(2048, triton.next_power_of_2(V))
element_mul_kernel[(n_rows,)](
_input,
_input.stride(-2),
grad_output,
V,
BLOCK_SIZE=BLOCK_SIZE,
)
return _input
class LigerCrossEntropyFunction(torch.autograd.Function):
"""
This class implements a custom autograd function for the Liger Cross Entropy loss.
It overrides the forward and backward methods of the torch.autograd.Function class.
"""
@staticmethod
def forward(
ctx,
_input: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.FloatTensor],
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
reduction: str = "mean",
softcap: Optional[float] = None,
return_z_loss: bool = False,
return_token_accuracy: bool = False,
return_predicted_tokens: bool = False,
):
"""
The forward pass of the Liger Cross Entropy loss.
Parameters:
ctx : The context object.
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
ignore_index (int): The index to ignore in the target.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy, predicted_tokens) instead of (loss, None, None, None). Default: `False`
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
return_predicted_tokens (bool): When `return_predicted_tokens` is `True`, returns per-token predicted class indices (argmax) without materializing logits. Default: `False`
Returns:
tuple: A tuple with the computed losses, accuracy, and predicted tokens: (loss, z_loss, token_accuracy, predicted_tokens). z_loss, token_accuracy, and predicted_tokens are None if not requested.
"""
input_requires_grad = _input.requires_grad
loss, z_loss, token_accuracy, predicted_tokens, _input = cross_entropy_forward(
_input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
reduction,
softcap,
return_z_loss,
return_token_accuracy,
return_predicted_tokens,
)
if input_requires_grad:
ctx.save_for_backward(_input.detach())
ctx.return_z_loss = return_z_loss
ctx.return_token_accuracy = return_token_accuracy
ctx.return_predicted_tokens = return_predicted_tokens
return loss, z_loss, token_accuracy, predicted_tokens
@staticmethod
def backward(ctx, grad_output, grad_output2, grad_output3, grad_output4):
"""
The backward pass of the Liger Cross Entropy loss.
Parameters:
ctx : The context object with saved tensors.
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
grad_output4 (tensor): No use. Gradient for predicted_tokens (not used as predicted_tokens is only for metrics).
Returns:
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
if ctx.return_z_loss:
del grad_output2 # z_loss is only for logging
if ctx.return_token_accuracy:
del grad_output3 # token_accuracy is only for metrics
if ctx.return_predicted_tokens:
del grad_output4 # predicted_tokens is only for metrics
(_input,) = ctx.saved_tensors
_input = cross_entropy_backward(_input, grad_output)
return (
_input,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
import torch
import triton
import triton.language as tl
from triton.language.math import tanh
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
# -----------------------------------------------------------------------------
# Forward Kernel
# -----------------------------------------------------------------------------
@triton.jit
def _dyt_fwd_kernel(
X,
Y,
Alpha,
Gamma,
Beta,
HAVE_BETA: tl.constexpr,
M: tl.constexpr,
N: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""
Forward kernel for DYT: y = tanh(α·x) · γ + β
Grid: (num_col_blocks, num_row_programs)
Each program processes multiple rows using grid-stride loop
"""
pid_n = tl.program_id(0)
pid_m = tl.program_id(1)
num_row_programs = tl.num_programs(1)
col_start = pid_n * BLOCK_N
col_offsets = col_start + tl.arange(0, BLOCK_N)
col_mask = col_offsets < N
alpha = tl.load(Alpha).to(tl.float32)
gamma = tl.load(Gamma + col_offsets, mask=col_mask, other=0.0).to(tl.float32)
if HAVE_BETA:
beta = tl.load(Beta + col_offsets, mask=col_mask, other=0.0).to(tl.float32)
# Grid-stride loop over rows
for row_idx in range(pid_m, M, num_row_programs):
row_offset = row_idx * N
x = tl.load(X + row_offset + col_offsets, mask=col_mask, other=0.0).to(tl.float32)
# Compute: y = tanh(α·x) · γ + β
tanh_x = tanh(alpha * x)
y = tanh_x * gamma
if HAVE_BETA:
y += beta
tl.store(Y + row_offset + col_offsets, y, mask=col_mask)
# -----------------------------------------------------------------------------
# Backward Kernel
# -----------------------------------------------------------------------------
@triton.jit
def _dyt_bwd_kernel(
DY,
DX,
DA,
DG,
DB,
X,
Alpha,
Gamma,
HAVE_BETA: tl.constexpr,
M: tl.constexpr,
N: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""
Backward kernel for DYT
Grid: (num_col_blocks, num_row_programs)
Each program processes multiple rows using grid-stride loop
Accumulates gradients in local buffers, then stores to global memory
"""
pid_n = tl.program_id(0)
pid_m = tl.program_id(1)
num_row_programs = tl.num_programs(1)
col_start = pid_n * BLOCK_N
col_offsets = col_start + tl.arange(0, BLOCK_N)
col_mask = col_offsets < N
alpha = tl.load(Alpha).to(tl.float32)
gamma = tl.load(Gamma + col_offsets, mask=col_mask, other=0.0).to(tl.float32)
da_vec = tl.zeros((BLOCK_N,), dtype=tl.float32)
dg_acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
if HAVE_BETA:
db_acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
# Grid-stride loop over rows
for row_idx in range(pid_m, M, num_row_programs):
row_offset = row_idx * N
x = tl.load(X + row_offset + col_offsets, mask=col_mask, other=0.0).to(tl.float32)
dy = tl.load(DY + row_offset + col_offsets, mask=col_mask, other=0.0).to(tl.float32)
tanh_x = tanh(alpha * x)
if HAVE_BETA:
db_acc += dy
dg_acc += dy * tanh_x
# Compute intermediate: tmp = (1 - tanh²) · dy · γ
tmp = (1.0 - tanh_x * tanh_x) * dy * gamma
# Accumulate dα = Σ(x · tmp)
da_vec += x * tmp
# Compute dx = α · tmp
dx = alpha * tmp
tl.store(DX + row_offset + col_offsets, dx, mask=col_mask)
da_acc = tl.sum(da_vec, 0)
da_offset = pid_m * triton.cdiv(N, BLOCK_N) + pid_n
tl.store(DA + da_offset, da_acc)
dg_offset = pid_m * N + col_offsets
tl.store(DG + dg_offset, dg_acc, mask=col_mask)
if HAVE_BETA:
db_offset = pid_m * N + col_offsets
tl.store(DB + db_offset, db_acc, mask=col_mask)
def get_optimal_block_size(total_elements, is_backward=False):
"""
Calculate optimal Block Size using compute_default_tiling_strategy
"""
multiplier = 8.0 if is_backward else 4.0
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((total_elements,),), tiling_dims=(0,)
)
if tile_shapes and len(tile_shapes) > 0:
block_size = tile_shapes[0][0]
return block_size
else:
return 2048
def _compute_grid_size(n_cols, n_rows, block_n):
"""
Compute grid size to avoid launching idle programs
Args:
n_cols: Number of columns
n_rows: Number of rows
block_n: Block size for column dimension
Returns:
(num_col_blocks, num_row_programs)
"""
num_cores = get_npu_core_count()
num_col_blocks = triton.cdiv(n_cols, block_n)
num_row_blocks = n_rows
num_row_programs = min(max(1, (num_cores // num_col_blocks)), num_row_blocks)
return num_col_blocks, num_row_programs
# -----------------------------------------------------------------------------
# Python Wrapper Functions
# -----------------------------------------------------------------------------
def liger_dyt_fwd(x, alpha, gamma, beta):
"""
Forward pass of DYT: y = tanh(α·x) · γ + β
Args:
x: Input tensor of shape [..., N]
alpha: Scalar parameter
gamma: Vector parameter of shape [N]
beta: Vector parameter of shape [N] (optional)
Returns:
y: Output tensor of same shape as x
"""
assert x.is_contiguous()
HAVE_BETA = beta is not None
# Flatten to 2D
input_shape = x.shape
x = x.view(-1, input_shape[-1])
M, N = x.shape
# Allocate output
y = torch.empty_like(x)
block_n = get_optimal_block_size(N, is_backward=False)
# Compute grid size
num_col_blocks, num_row_programs = _compute_grid_size(N, M, block_n)
grid = (num_col_blocks, num_row_programs)
# Launch kernel
_dyt_fwd_kernel[grid](x, y, alpha, gamma, beta, HAVE_BETA, M, N, BLOCK_N=block_n)
return y.view(input_shape)
def liger_dyt_bwd(dy, x, alpha, gamma, beta):
"""
Backward pass of DYT
Args:
dy: Upstream gradient of shape [..., N]
x: Input tensor of shape [..., N]
alpha: Scalar parameter
gamma: Vector parameter of shape [N]
beta: Vector parameter of shape [N] (optional)
Returns:
dx: Gradient w.r.t. x
dalpha: Gradient w.r.t. alpha
dgamma: Gradient w.r.t. gamma
dbeta: Gradient w.r.t. beta (or None)
"""
assert dy.is_contiguous()
HAVE_BETA = beta is not None
# Flatten to 2D
input_shape = x.shape
x = x.view(-1, input_shape[-1])
dy = dy.view(-1, input_shape[-1])
M, N = x.shape
block_n = get_optimal_block_size(N, is_backward=True)
# Compute grid size
num_col_blocks, num_row_programs = _compute_grid_size(N, M, block_n)
grid = (num_col_blocks, num_row_programs)
da = torch.zeros(num_row_programs, triton.cdiv(N, block_n), dtype=torch.float32, device=x.device)
dg = torch.empty(num_row_programs, N, dtype=torch.float32, device=x.device)
db = torch.empty(num_row_programs, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
dx = torch.empty_like(dy)
_dyt_bwd_kernel[grid](dy, dx, da, dg, db, x, alpha, gamma, HAVE_BETA, M, N, BLOCK_N=block_n)
da = da.sum().to(x.dtype).unsqueeze(0)
dg = dg.sum(0).to(gamma.dtype)
db = db.sum(0).to(x.dtype) if HAVE_BETA else None
return dx.view(input_shape), da, dg, db
# -----------------------------------------------------------------------------
# Autograd Function
# -----------------------------------------------------------------------------
class LigerDyTFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, x, alpha, gamma, beta):
y = liger_dyt_fwd(x, alpha, gamma, beta)
ctx.save_for_backward(x, alpha, gamma, beta)
return y
@staticmethod
@ensure_contiguous
def backward(ctx, dy):
x, alpha, gamma, beta = ctx.saved_tensors
dx, dalpha, dgamma, dbeta = liger_dyt_bwd(dy, x, alpha, gamma, beta)
return dx, dalpha, dgamma, dbeta
import torch
import triton
import triton.language as tl
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
@triton.jit
def embedding_forward_kernel(
embeddings_ptr,
indices_ptr,
output_ptr,
n_elements,
embedding_dim: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
grid_m = tl.cdiv(n_elements, BLOCK_SIZE_M)
grid_n = tl.cdiv(embedding_dim, BLOCK_SIZE_N)
total_2d_blocks = grid_m * grid_n
for block_idx in tl.range(pid, total_2d_blocks, num_progs):
block_m = block_idx // grid_n
block_n = block_idx % grid_n
start_m = block_m * BLOCK_SIZE_M
start_n = block_n * BLOCK_SIZE_N
offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < n_elements
indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
mask_n = offsets_n < embedding_dim
block_mask = mask_m[:, None] & mask_n[None, :]
embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
embeddings = tl.load(
embeddings_ptr + embedding_offsets,
mask=block_mask,
other=0.0,
)
output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
tl.store(
output_ptr + output_offsets,
embeddings,
mask=block_mask,
)
@triton.jit
def embedding_backward_kernel(
grad_output_ptr,
grad_weight_ptr,
indices_ptr,
n_elements,
embedding_dim: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
grid_m = tl.cdiv(n_elements, BLOCK_SIZE_M)
grid_n = tl.cdiv(embedding_dim, BLOCK_SIZE_N)
total_2d_blocks = grid_m * grid_n
for block_idx in tl.range(pid, total_2d_blocks, num_progs):
block_m = block_idx // grid_n
block_n = block_idx % grid_n
start_m = block_m * BLOCK_SIZE_M
start_n = block_n * BLOCK_SIZE_N
offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < n_elements
indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
mask_n = offsets_n < embedding_dim
block_mask = mask_m[:, None] & mask_n[None, :]
grad_output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
grad_output = tl.load(
grad_output_ptr + grad_output_offsets,
mask=block_mask,
other=0.0,
)
grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
tl.atomic_add(
grad_weight_ptr + grad_weight_offsets,
grad_output,
mask=block_mask,
)
def get_optimal_block_size(total_elements, dtype_size, BLOCK_SIZE_N: tl.constexpr):
# 1. Set Memory Multiplier
# 3.0 are empirical values based on Atlas 800I A2 UB (192KB)
# embedding_offsets, embedding_offsets : BLOCK_SIZE_N * BLOCK_SIZE_M (total 2 * BLOCK_SIZE_N * BLOCK_SIZE_M)
# Reserve a unit of space for the remaining one-dimensional ub to occupy.
# A conservative estimate of the total space occupation is 3 * BLOCK_SIZE_N * BLOCK_SIZE_M
multiplier = 3.0
# 2. Call calculation function
# Treat input as 1D (total_elements,), only tiling on dim 0
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9,
dtype_size=dtype_size,
memory_multiplier=multiplier,
shapes=((total_elements, BLOCK_SIZE_N),),
tiling_dims=(0,),
)
# 3. Parse result
if tile_shapes and len(tile_shapes) > 0:
block_size = tile_shapes[0][0]
return block_size
else:
return triton.next_power_of_2(min(128, total_elements))
def embedding_forward(embeddings, indices):
ori_shape = indices.shape
indices = indices.view(-1)
n_elements = indices.numel()
embedding_dim = embeddings.shape[1]
output = torch.empty(
indices.shape[0],
embeddings.shape[1],
device=indices.device,
dtype=embeddings.dtype,
)
# Due to the involvement of two-dimensional partitioning,
# the sizes of block_m and block_n in the ub space will influence each other.
# Considering that embedding_dim is usually relatively smaller in most cases,
# a value is first assigned to block_n, and then the largest possible block_m is used.
BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
BLOCK_SIZE_M = get_optimal_block_size(n_elements, embeddings.element_size(), BLOCK_SIZE_N)
num_cores = get_npu_core_count()
total_blocks = triton.cdiv(n_elements, BLOCK_SIZE_M) * triton.cdiv(embedding_dim, BLOCK_SIZE_N)
grid = min(num_cores, total_blocks)
embedding_forward_kernel[(grid,)](
embeddings,
indices,
output,
n_elements,
embedding_dim=embedding_dim,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
)
return output.view(*ori_shape, -1)
def embedding_backward(embeddings, indices, grad_output):
grad_output = grad_output.contiguous().view(-1, embeddings.shape[1])
grad_weight = torch.zeros_like(embeddings)
n_elements = indices.numel()
embedding_dim = embeddings.shape[1]
BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
BLOCK_SIZE_M = get_optimal_block_size(n_elements, embeddings.element_size(), BLOCK_SIZE_N)
num_cores = get_npu_core_count()
total_blocks = triton.cdiv(n_elements, BLOCK_SIZE_M) * triton.cdiv(embedding_dim, BLOCK_SIZE_N)
grid = min(num_cores, total_blocks)
embedding_backward_kernel[(grid,)](
grad_output,
grad_weight,
indices,
n_elements,
embedding_dim=embedding_dim,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
)
return grad_weight
class LigerEmbeddingFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, embeddings: torch.Tensor, indices: torch.Tensor):
output = embedding_forward(embeddings, indices)
ctx.save_for_backward(indices, embeddings)
return output
@staticmethod
@ensure_contiguous
def backward(ctx, grad_output: torch.Tensor):
indices, embeddings = ctx.saved_tensors
grad_weight = embedding_backward(embeddings, indices, grad_output)
return grad_weight, None
import torch
import triton
import triton.language as tl
from triton.language.math import rsqrt
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
from liger_kernel.ops.utils import torch_to_triton_dtype
_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
def torch_dtype_to_triton(dtype):
mapping = {
torch.float32: tl.float32,
torch.bfloat16: tl.bfloat16,
}
return mapping.get(dtype, tl.float32)
# -----------------------------------------------------------------------------
# Forward Kernel - No Tiling (for n_cols <= 2048)
# -----------------------------------------------------------------------------
@triton.jit
def _fused_add_rms_norm_forward_kernel_no_tiling(
Y_ptr,
Y_row_stride,
S_ptr, # output residual
S_row_stride,
X_ptr,
X_row_stride,
R_ptr, # input residual
R_row_stride,
W_ptr,
RSTD_ptr,
RSTD_row_stride,
n_rows,
n_cols,
eps,
offset,
casting_mode: tl.constexpr,
X_DTYPE: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
"""
NPU-optimized fused_add_rms_norm forward kernel for small n_cols (< 2048).
Performance optimizations:
1. Keep S_row in registers, avoid reload from memory
2. Minimize type conversions by careful ordering
3. Use optimal cache policies
4. Preload W_row while computing rstd (instruction-level parallelism)
5. Use 2D vector loading to maximize UB utilization (e.g., (1,2048), (2,1024), (4,512))
Used when n_cols < 2048 to avoid the overhead of column blocking.
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
if casting_mode == _CASTING_MODE_NONE:
eps = eps.to(X_DTYPE)
offset = offset.to(X_DTYPE)
# Grid-stride loop setup for 2D blocks
grid_stride = num_progs * BLOCK_SIZE_M
num_iterations = tl.cdiv(n_rows, grid_stride)
col_offsets = tl.arange(0, BLOCK_SIZE_N)
col_mask = col_offsets < n_cols
row_offsets = tl.arange(0, BLOCK_SIZE_M)
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
# Grid-stride loop over row blocks
for i in tl.range(num_iterations):
row_idx = i * grid_stride + pid * BLOCK_SIZE_M + row_offsets
row_mask = row_idx < n_rows
block_mask = row_mask[:, None] & col_mask[None, :]
# Load multiple rows at once using 2D indexing
X_rows = tl.load(
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
mask=block_mask,
other=0.0,
eviction_policy="evict_first",
)
R_rows = tl.load(
R_ptr + row_idx[:, None] * R_row_stride + col_offsets[None, :],
mask=block_mask,
other=0.0,
eviction_policy="evict_first",
)
S_rows = X_rows + R_rows
# Compute sum_square for all rows
if casting_mode == _CASTING_MODE_LLAMA or casting_mode == _CASTING_MODE_GEMMA:
S_rows = S_rows.to(tl.float32)
sum_squares = tl.sum(tl.where(block_mask, S_rows * S_rows, 0.0), axis=1)
# Compute rstd for all rows
mean_squares = sum_squares / n_cols
rstd_rows = rsqrt(mean_squares + eps)
# Store S_rows and rstd_rows
tl.store(
S_ptr + row_idx[:, None] * S_row_stride + col_offsets[None, :],
S_rows,
mask=block_mask,
cache_modifier=".cg",
)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd_rows, mask=row_mask)
# Normalize and apply weight - optimized for each casting mode
if casting_mode == _CASTING_MODE_GEMMA:
Y_rows = ((S_rows * rstd_rows[:, None]) * (offset + W_row[None, :])).to(X_DTYPE)
elif casting_mode == _CASTING_MODE_LLAMA:
S_normalized = (S_rows * rstd_rows[:, None]).to(X_DTYPE)
Y_rows = S_normalized * (offset + W_row[None, :])
else:
Y_rows = (S_rows * rstd_rows[:, None]) * (offset + W_row[None, :])
# Store results
tl.store(Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :], Y_rows, mask=block_mask)
# -----------------------------------------------------------------------------
# Forward Kernel - With Tiling (for n_cols > 2048)
# -----------------------------------------------------------------------------
@triton.jit
def _fused_add_rms_norm_forward_kernel_npu(
Y_ptr,
Y_row_stride,
S_ptr, # output residual
S_row_stride,
X_ptr,
X_row_stride,
R_ptr, # input residual
R_row_stride,
W_ptr,
RSTD_ptr,
RSTD_row_stride,
n_rows,
n_cols,
eps,
offset,
casting_mode: tl.constexpr,
X_DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
NPU-optimized fused_add_rms_norm forward kernel.
This kernel processes rows using a grid-stride loop pattern:
1. Each program handles multiple rows
2. For each row, we process it in column chunks of BLOCK_SIZE_N
3. Grid size is limited to NPU core count to avoid resource overflow
This solves two problems:
1. UB overflow when n_cols is too large (original kernel used n_cols as BLOCK_SIZE_N)
2. Efficient multi-row processing within a single kernel launch
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
num_col_blocks = tl.cdiv(n_cols, BLOCK_SIZE)
if casting_mode == _CASTING_MODE_NONE:
eps = eps.to(X_DTYPE)
offset = offset.to(X_DTYPE)
offsets = tl.arange(0, BLOCK_SIZE)
# Grid-stride loop over rows
for row_idx in tl.range(pid, n_rows, num_progs):
Y_row_ptr = Y_ptr + row_idx * Y_row_stride
S_row_ptr = S_ptr + row_idx * S_row_stride
X_row_ptr = X_ptr + row_idx * X_row_stride
R_row_ptr = R_ptr + row_idx * R_row_stride
RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride
# Accumulator for mean_square computation across all column blocks
sum_square = 0.0
# First pass: compute S_row = X_row + R_row and accumulate sum of squares
for col_block_idx in range(num_col_blocks):
col_start = col_block_idx * BLOCK_SIZE
col_offsets = col_start + offsets
mask = col_offsets < n_cols
X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, eviction_policy="evict_first")
R_block = tl.load(R_row_ptr + col_offsets, mask=mask, other=0.0, eviction_policy="evict_first")
S_block = X_block + R_block
# Store S_row
tl.store(S_row_ptr + col_offsets, S_block, mask=mask, cache_modifier=".cg")
if casting_mode == _CASTING_MODE_LLAMA or casting_mode == _CASTING_MODE_GEMMA:
S_block = S_block.to(tl.float32)
# Accumulate sum of squares (only for valid elements)
sum_square += tl.sum(tl.where(mask, S_block * S_block, 0.0))
# Compute rstd for this row
mean_square = sum_square / n_cols
rstd = rsqrt(mean_square + eps)
# Store rstd
tl.store(RSTD_row_ptr, rstd)
# Second pass: normalize and multiply by weight
for col_block_idx in range(num_col_blocks):
col_start = col_block_idx * BLOCK_SIZE
col_offsets = col_start + offsets
mask = col_offsets < n_cols
# Load S_block (already computed in first pass)
S_block = tl.load(S_row_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".ca")
W_block = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
# Apply casting based on mode
if casting_mode == _CASTING_MODE_GEMMA:
S_block = S_block.to(tl.float32)
W_block = W_block.to(tl.float32)
elif casting_mode == _CASTING_MODE_LLAMA:
S_block = S_block.to(tl.float32)
# Normalize
S_block = S_block * rstd
# Cast back for Llama mode before weight multiplication
if casting_mode == _CASTING_MODE_LLAMA:
S_block = S_block.to(X_DTYPE)
# Apply weight
Y_block = S_block * (offset + W_block)
# Cast back for Gemma mode
if casting_mode == _CASTING_MODE_GEMMA:
Y_block = Y_block.to(X_DTYPE)
# Store result
tl.store(Y_row_ptr + col_offsets, Y_block, mask=mask)
# -----------------------------------------------------------------------------
# Backward Kernel - No Tiling (for n_cols < 2048)
# -----------------------------------------------------------------------------
@triton.jit
def _fused_add_rms_norm_backward_kernel_no_tiling(
dY_ptr,
dY_row_stride,
dS_out_ptr,
dS_out_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
W_ptr,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows,
n_cols,
offset,
casting_mode: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
has_dS_out: tl.constexpr,
):
"""
NPU-optimized fused_add_rms_norm backward kernel for small n_cols (< 2048).
Performance optimizations:
1. Keep all data in registers, minimize conversions
2. Reuse X_normalized (X * rstd) for both dX and dW
3. Optimize computation order to reduce register pressure
4. Combine operations where possible
5. Use 2D vector loading to maximize UB utilization (e.g., (1,2048), (2,1024), (4,512))
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
# Grid-stride loop setup for 2D blocks
grid_stride = num_progs * BLOCK_SIZE_M
num_iterations = tl.cdiv(n_rows, grid_stride)
col_offsets = tl.arange(0, BLOCK_SIZE_N)
col_mask = col_offsets < n_cols
row_offsets = tl.arange(0, BLOCK_SIZE_M)
# Load W once for all iterations
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
W_offset = W_row + offset
# Grid-stride loop over row blocks
for i in tl.range(num_iterations):
row_idx = i * grid_stride + pid * BLOCK_SIZE_M + row_offsets
row_mask = row_idx < n_rows
block_mask = row_mask[:, None] & col_mask[None, :]
dY_rows = tl.load(
dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
mask=block_mask,
other=0.0,
eviction_policy="evict_first",
)
X_rows = tl.load(
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
mask=block_mask,
other=0.0,
eviction_policy="evict_first",
)
# Load rstd for all rows in the block
rstd_rows = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, mask=row_mask, other=0.0)
# Convert X to fp32 once
X_rows = X_rows.to(tl.float32)
# Compute X_normalized (reused in dX and dW)
X_normalized = X_rows * rstd_rows[:, None]
# Compute m based on casting mode (optimized for each mode)
if casting_mode == _CASTING_MODE_LLAMA:
m_rows = (dY_rows * W_offset[None, :]).to(tl.float32)
# For dW in Llama mode, we need X_normalized in original dtype
X_normalized_for_dW = X_normalized.to(X_dtype)
elif casting_mode == _CASTING_MODE_GEMMA:
m_rows = dY_rows.to(tl.float32) * W_offset[None, :]
X_normalized_for_dW = X_normalized
else:
m_rows = dY_rows * W_offset[None, :]
X_normalized_for_dW = X_normalized
# Compute sum(m * X) for correction factor
sum_m_X = tl.sum(tl.where(block_mask, m_rows * X_rows, 0.0), axis=1)
# Compute correction factor
correction_factors = -(1.0 / n_cols) * rstd_rows * rstd_rows * sum_m_X
# Compute dX = rstd * m + rstd * correction_factor * X
dX_rows = rstd_rows[:, None] * m_rows + rstd_rows[:, None] * correction_factors[:, None] * X_rows
# Add dS_out gradient if present
if has_dS_out:
dS_out_rows = tl.load(
dS_out_ptr + row_idx[:, None] * dS_out_row_stride + col_offsets[None, :], mask=block_mask, other=0.0
)
dX_rows += dS_out_rows
# Store dX
tl.store(dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :], dX_rows.to(X_dtype), mask=block_mask)
# Compute dW contribution: dY * X_normalized
dW_rows = (dY_rows * X_normalized_for_dW).to(tl.float32)
# Accumulate to per-program dW buffer
dW_row_ptr = dW_ptr + pid * dW_row_stride
existing_dW = tl.load(dW_row_ptr + col_offsets, mask=col_mask, other=0.0)
new_dW = existing_dW + tl.sum(tl.where(block_mask, dW_rows, 0.0), axis=0)
tl.store(dW_row_ptr + col_offsets, new_dW, mask=col_mask)
# -----------------------------------------------------------------------------
# Backward Kernel - With Tiling (for n_cols > 2048)
# -----------------------------------------------------------------------------
@triton.jit
def _fused_add_rms_norm_backward_kernel_npu(
dY_ptr,
dY_row_stride,
dS_out_ptr,
dS_out_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
W_ptr,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows,
n_cols,
offset,
casting_mode: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
has_dS_out: tl.constexpr,
):
"""
NPU-optimized fused_add_rms_norm backward kernel.
Each program processes multiple rows using grid-stride loop.
For each row, we process columns in blocks to avoid UB overflow.
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
# Initialize dW accumulator (per-program, will be reduced later)
num_col_blocks = tl.cdiv(n_cols, BLOCK_SIZE)
offsets = tl.arange(0, BLOCK_SIZE)
# Grid-stride loop over rows
for row_idx in tl.range(pid, n_rows, num_progs):
# Base pointers for this row
dY_row_ptr = dY_ptr + row_idx * dY_row_stride
dX_row_ptr = dX_ptr + row_idx * dX_row_stride
X_row_ptr = X_ptr + row_idx * X_row_stride
RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride
# Load rstd for this row
rstd = tl.load(RSTD_row_ptr)
# First pass: compute sum(m * X) for the correction term
sum_m_X = 0.0
for col_block_idx in range(num_col_blocks):
col_start = col_block_idx * BLOCK_SIZE
col_offsets = col_start + offsets
mask = col_offsets < n_cols
dY_block = tl.load(dY_row_ptr + col_offsets, mask=mask, other=0.0)
X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0)
W_block = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
# Convert to fp32 for computation
X_block = X_block.to(tl.float32)
# Compute m based on casting mode
W_offset = W_block + offset
if casting_mode == _CASTING_MODE_LLAMA:
m = (dY_block * W_offset).to(tl.float32)
elif casting_mode == _CASTING_MODE_GEMMA:
dY_block = dY_block.to(tl.float32)
m = dY_block * W_offset
else:
m = dY_block * W_offset
# Accumulate sum(m * X)
sum_m_X += tl.sum(tl.where(mask, m * X_block, 0.0))
# Compute the correction factor
correction_factor = -(1.0 / n_cols) * rstd * rstd * sum_m_X
# Second pass: compute gradients
for col_block_idx in range(num_col_blocks):
col_start = col_block_idx * BLOCK_SIZE
col_offsets = col_start + offsets
mask = col_offsets < n_cols
dY_block = tl.load(dY_row_ptr + col_offsets, mask=mask, other=0.0)
X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0)
W_block = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
X_block = X_block.to(tl.float32)
# Compute m based on casting mode
W_offset = W_block + offset
if casting_mode == _CASTING_MODE_LLAMA:
m = (dY_block * W_offset).to(tl.float32)
elif casting_mode == _CASTING_MODE_GEMMA:
dY_block = dY_block.to(tl.float32)
m = dY_block * W_offset
else:
m = dY_block * W_offset
# Compute dX
dX_block = rstd * m + rstd * correction_factor * X_block
# Add dS_out gradient if present
if has_dS_out:
dS_out_row_ptr = dS_out_ptr + row_idx * dS_out_row_stride
dS_out_block = tl.load(dS_out_row_ptr + col_offsets, mask=mask, other=0.0)
dX_block += dS_out_block
# Store dX
tl.store(dX_row_ptr + col_offsets, dX_block.to(X_dtype), mask=mask)
# Compute dW contribution (accumulate per program)
if casting_mode == _CASTING_MODE_LLAMA:
dW_block = dY_block * (X_block * rstd).to(X_dtype)
else:
dW_block = dY_block * (X_block * rstd)
# Atomic add to dW_ptr (each program writes to its own row)
dW_row_ptr = dW_ptr + pid * dW_row_stride
# Load existing dW, add contribution, store back
existing_dW = tl.load(dW_row_ptr + col_offsets, mask=mask, other=0.0)
new_dW = existing_dW + dW_block.to(tl.float32)
tl.store(dW_row_ptr + col_offsets, new_dW, mask=mask)
# -----------------------------------------------------------------------------
# Helper Functions
# -----------------------------------------------------------------------------
def get_optimal_block_size(n_cols, is_forward: bool):
"""
Calculate optimal block size for forward pass using compute_default_tiling_strategy.
Memory analysis for forward pass (per row):
- Load: X_block, R_block, W_block (3 blocks)
- Store: S_block, Y_block (2 blocks)
- Compute: S_block, Y_block intermediate (2 blocks)
- Total: conservative estimate 8 blocks of memory
Memory analysis for backward pass (per row):
- Load: dY_block, X_block, W_block, existing_dW (4 blocks)
- Store: dX_block, new_dW (2 blocks)
- Compute: m, dX_block intermediate, dW_block intermediate (3 blocks)
- Additional: dS_out_block if present (1 block)
- Total: conservative estimate 12 blocks of memory
Args:
n_cols: Number of columns in the tensor
Returns:
Optimal block size
"""
if n_cols <= 2048:
return triton.next_power_of_2(n_cols)
memory_multiplier = 8.0 if is_forward else 12.0
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9,
dtype_size=4,
memory_multiplier=memory_multiplier,
shapes=((n_cols,),),
tiling_dims=(0,),
)
if tile_shapes and len(tile_shapes) > 0:
block_size = tile_shapes[0][0]
return max(2048, block_size)
else:
return 2048
# -----------------------------------------------------------------------------
# Forward and Backward Functions
# -----------------------------------------------------------------------------
_str_to_casting_mode = {
"llama": _CASTING_MODE_LLAMA.value,
"gemma": _CASTING_MODE_GEMMA.value,
"none": _CASTING_MODE_NONE.value,
}
def fused_add_rms_norm_forward(X, R, W, eps, offset, casting_mode):
if not isinstance(casting_mode, int):
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
casting_mode = _str_to_casting_mode[casting_mode]
else:
assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
R = R.view(-1, dim)
n_rows, n_cols = X.shape
X_DTYPE = torch_dtype_to_triton(X.dtype)
# Get optimal block size for column processing
BLOCK_SIZE = get_optimal_block_size(n_cols, True)
BLOCK_SIZE_M = 2048 // BLOCK_SIZE
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
S = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
# RSTD is always fp32 for Llama/Gemma modes
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
# Check constraints
assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension"
# Grid size limited to NPU core count
num_cores = get_npu_core_count()
grid_size = min(num_cores * 2, n_rows)
# Choose kernel based on n_cols
if n_cols <= 2048:
# Use no-tiling kernel for small n_cols
_fused_add_rms_norm_forward_kernel_no_tiling[(grid_size,)](
Y,
Y.stride(0),
S,
S.stride(0),
X,
X.stride(0),
R,
R.stride(0),
W,
RSTD,
RSTD.stride(0),
n_rows,
n_cols,
eps,
offset,
casting_mode,
X_DTYPE,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE,
)
else:
# Use tiled kernel for large n_cols
_fused_add_rms_norm_forward_kernel_npu[(grid_size,)](
Y,
Y.stride(0),
S,
S.stride(0),
X,
X.stride(0),
R,
R.stride(0),
W,
RSTD,
RSTD.stride(0),
n_rows,
n_cols,
eps,
offset,
casting_mode,
X_DTYPE,
BLOCK_SIZE=BLOCK_SIZE,
)
return Y.view(*shape), S.view(*shape), RSTD, casting_mode
def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, in_place):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
if dS_out is not None:
dS_out = dS_out.view(-1, dim)
S = S.view(-1, dim)
n_rows, n_cols = dY.shape
# Get NPU core count for grid size
num_cores = get_npu_core_count()
grid_size = min(num_cores * 2, n_rows)
# Get optimal block size for backward pass
BLOCK_SIZE = get_optimal_block_size(n_cols, False)
BLOCK_SIZE_M = 2048 // BLOCK_SIZE
# fp32 for numerical stability
_dW = torch.zeros((grid_size, n_cols), dtype=torch.float32, device=W.device)
if in_place:
dX = dY
else:
dX = torch.empty_like(dY)
# Choose kernel based on n_cols
if n_cols <= 2048:
# Use no-tiling kernel for small n_cols
_fused_add_rms_norm_backward_kernel_no_tiling[(grid_size,)](
dY,
dY.stride(0),
dS_out if dS_out is not None else dY, # Dummy pointer if dS_out is None
dS_out.stride(0) if dS_out is not None else 0,
dX,
dX.stride(0),
S,
S.stride(0),
torch_to_triton_dtype[S.dtype],
W,
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0),
n_rows,
n_cols,
offset,
casting_mode,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE,
has_dS_out=dS_out is not None,
)
else:
# Use tiled kernel for large n_cols
_fused_add_rms_norm_backward_kernel_npu[(grid_size,)](
dY,
dY.stride(0),
dS_out if dS_out is not None else dY, # Dummy pointer if dS_out is None
dS_out.stride(0) if dS_out is not None else 0,
dX,
dX.stride(0),
S,
S.stride(0),
torch_to_triton_dtype[S.dtype],
W,
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0),
n_rows,
n_cols,
offset,
casting_mode,
BLOCK_SIZE=BLOCK_SIZE,
has_dS_out=dS_out is not None,
)
dX = dX.view(*shape)
dW = _dW.sum(dim=0).to(W.dtype)
return dX, dX, dW # dR is equal to dX
# -----------------------------------------------------------------------------
# Autograd Function
# -----------------------------------------------------------------------------
class LigerFusedAddRMSNormFunction(torch.autograd.Function):
"""
NPU-optimized fused operation for residual addition and RMSNorm.
This implementation solves two key issues:
1. UB overflow when n_cols is too large (by using column-wise blocking)
2. Efficient multi-row processing (by using grid-stride loop with core count limit)
"""
@staticmethod
@ensure_contiguous
def forward(ctx, X, R, W, eps, offset=0.0, casting_mode="llama", in_place=False):
"""
X: (B, T, H) or (BxT, H)
R: (B, T, H) or (BxT, H)
W: (H,)
"""
Y, S, RSTD, casting_mode = fused_add_rms_norm_forward(X, R, W, eps, offset, casting_mode)
ctx.offset = offset
ctx.casting_mode = casting_mode
ctx.in_place = in_place
ctx.save_for_backward(S, W, RSTD)
return Y, S
@staticmethod
@ensure_contiguous
def backward(ctx, dY, dS_out):
"""
dY: (B, T, H) or (BxT, H)
dS_out: (B, T, H) or (BxT, H)
"""
S, W, RSTD = ctx.saved_tensors
dX, dR, dW = fused_add_rms_norm_backward(
dY,
dS_out,
S,
W,
RSTD,
ctx.offset,
ctx.casting_mode,
ctx.in_place,
)
return dX, dR, dW, None, None, None, None
from typing import Optional
import torch
import triton
from liger_kernel.ops.backends._ascend.ops.jsd import _jsd_kernel
from liger_kernel.ops.utils import amp_custom_bwd
from liger_kernel.ops.utils import amp_custom_fwd
from liger_kernel.ops.utils import element_mul_kernel
from liger_kernel.ops.utils import get_npu_core_count
MAX_FUSED_SIZE = 4096
def fused_linear_jsd_forward(
student_input,
student_weight,
teacher_input,
teacher_weight,
shift_labels,
jsd_beta,
ignore_index,
has_label,
temperature,
):
device = student_input.device
dtype = student_input.dtype
# inputs have shape: BT x H
# materialized activations will have shape: BT x V
# the increase in memory = BT x V
# reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
# for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
# inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
# for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
BT, H = student_input.shape
V = student_weight.shape[0]
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
grad_weight = torch.zeros_like(student_weight, device=device) if student_weight.requires_grad else None
grad_input = torch.zeros_like(student_input)
# we use fp32 for loss accumulator
loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
if has_label:
n_non_ignore = (shift_labels != ignore_index).sum().item()
else:
n_non_ignore = BT
num_cores = get_npu_core_count()
for chunk_id in range(num_chunks):
start_idx = chunk_id * chunk_size
end_idx = min((chunk_id + 1) * chunk_size, BT)
# chunk both inputs, shape: chunk_size x H
student_input_chunk = student_input[start_idx:end_idx]
teacher_input_chunk = teacher_input[start_idx:end_idx]
# shape: chunk_size x V
# For anything starting from logits to the final JSD loss, we do computation
# in FP32 to avoid losing numerical stability.
student_logits_chunk = (student_input_chunk @ student_weight.t()).to(torch.float32)
teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(torch.float32)
chunk_n_rows = student_logits_chunk.shape[0]
# unreduced loss
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size
# log-softmax with temperature
student_logits_chunk = student_logits_chunk / temperature
teacher_logits_chunk = teacher_logits_chunk / temperature
student_prob_chunk = torch.log_softmax(student_logits_chunk, dim=-1)
teacher_prob_chunk = torch.log_softmax(teacher_logits_chunk, dim=-1)
# ensure _input and target are contiguous
student_prob_chunk = student_prob_chunk.contiguous()
teacher_prob_chunk = teacher_prob_chunk.contiguous()
# Here we calculate the gradient of prob_chunk in place so we can save memory.
# Grid size is capped at NPU core count; the kernel uses a grid-stride loop
# to process multiple rows per program, consistent with the NPU backend pattern.
grid_size = min(num_cores, chunk_n_rows)
_jsd_kernel[(grid_size,)](
X_ptr=student_prob_chunk,
X_stride=student_prob_chunk.stride(-2),
Y_ptr=teacher_prob_chunk,
Y_stride=teacher_prob_chunk.stride(-2),
loss_ptr=loss_1d_slice,
loss_stride=loss_1d_slice.stride(-2),
dX_ptr=student_prob_chunk,
dX_stride=student_prob_chunk.stride(-2),
label_ptr=(
shift_labels[start_idx:end_idx] if has_label else torch.empty(1, device=device)
), # dummy ptr if no label
beta=jsd_beta,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
n_rows=chunk_n_rows,
n_cols=V,
BLOCK_SIZE=BLOCK_SIZE,
HAS_LABEL=has_label,
)
# gradients of prob_chunk in place, shape: chunk_size x V
# gradients of logits_chunk in place, shape: chunk_size x V
student_logits_chunk = (
student_prob_chunk
- torch.softmax(student_logits_chunk, dim=-1)
* student_prob_chunk.sum(dim=-1, keepdim=True).expand_as(student_prob_chunk)
) / temperature
# now we traverse back to grad w.r.t. input to `lm_head` and grad
# w.r.t. `lm_head` which should be computed in original dtype
student_logits_chunk = student_logits_chunk.to(dtype)
grad_input[start_idx:end_idx] = student_logits_chunk @ student_weight
if grad_weight is not None:
grad_weight.add_(student_logits_chunk.t() @ student_input_chunk)
loss = torch.sum(loss_1d)
return loss, grad_input, grad_weight
def fused_linear_jsd_backward(grad_output, grad_input, grad_weight):
# If JSD is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
BT, H = grad_input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
element_mul_kernel[(n_rows,)](
grad_input,
grad_input.stride(-2),
grad_output,
H,
BLOCK_SIZE=BLOCK_SIZE,
)
# handle grad_weight
if grad_weight is not None:
V, H = grad_weight.shape
n_rows = V
element_mul_kernel[(n_rows,)](
grad_weight,
grad_weight.stride(-2),
grad_output,
H,
BLOCK_SIZE=BLOCK_SIZE,
)
return grad_input, grad_weight
class LigerFusedLinearJSDFunction(torch.autograd.Function):
"""
Fusing the last linear layer with generalized JSD
Handle the forward and backward pass of the final linear layer via JSD by avoiding
the materialization of the large logits tensor. Since JSD is the last layer, we can
compute the gradient at the forward pass.
"""
@staticmethod
@amp_custom_fwd
def forward(
ctx,
student_input: torch.Tensor,
student_weight: torch.Tensor,
teacher_input: torch.Tensor,
teacher_weight: torch.Tensor,
shift_labels: Optional[torch.Tensor] = None,
jsd_beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
):
"""
Args:
student_input (torch.tensor): input of the last projection layer in student model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
student_weight (torch.tensor): the last projection layer in student model, with shape (V, H), where V is vocab size
teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
ignore_index (int): the index to ignore. Default: -100
temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
Returns:
loss (torch.Tensor): generalized JSD
"""
has_label = False
if shift_labels is not None:
assert shift_labels.shape == (teacher_input.shape[0],), (
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
)
shift_labels = shift_labels.contiguous()
has_label = True
loss, grad_input, grad_weight = fused_linear_jsd_forward(
student_input,
student_weight,
teacher_input,
teacher_weight,
shift_labels,
jsd_beta,
ignore_index,
has_label,
temperature,
)
# downcast to dtype and store for backward
ctx.save_for_backward(
grad_input.detach(),
grad_weight.detach() if grad_weight is not None else None,
)
return loss
@staticmethod
@amp_custom_bwd
def backward(ctx, grad_output):
(grad_input, grad_weight) = ctx.saved_tensors
grad_input, grad_weight = fused_linear_jsd_backward(grad_output, grad_input, grad_weight)
return (grad_input, grad_weight, None, None, None, None, None, None)
import torch
import triton
import triton.language as tl
from triton.language.math import tanh
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
@triton.jit
def _geglu_forward_kernel_flat(a_ptr, b_ptr, c_ptr, total_elements, BLOCK_SIZE: tl.constexpr):
"""
High-performance GEGLU forward kernel using flatten 1D approach.
Uses grid-stride loop pattern for optimal performance on NPU.
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
# Grid-Stride Loop
start_idx = pid * BLOCK_SIZE
stride = num_progs * BLOCK_SIZE
# Constants for GELU tanh approximation
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
gelu_coeff = 0.044715
for idx in tl.range(start_idx, total_elements, stride):
offsets = idx + tl.arange(0, BLOCK_SIZE)
mask = offsets < total_elements
a_val = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
b_val = tl.load(b_ptr + offsets, mask=mask, other=0.0)
# tanh approximation form of GELU is computed with:
# 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))
a_cubed = a_val * a_val * a_val
tanh_arg = sqrt_2_over_pi * (a_val + gelu_coeff * a_cubed)
tanh_result = tanh(tanh_arg)
geglu_a = 0.5 * a_val * (1.0 + tanh_result)
c_row = geglu_a.cast(b_val.dtype) * b_val
tl.store(c_ptr + offsets, c_row, mask=mask)
@triton.jit
def _geglu_backward_kernel_flat(dc_ptr, a_ptr, b_ptr, da_ptr, db_ptr, total_elements, BLOCK_SIZE: tl.constexpr):
"""
High-performance GEGLU backward kernel using flatten 1D approach.
Uses grid-stride loop pattern for optimal performance on NPU.
"""
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
start_idx = pid * BLOCK_SIZE
stride = num_progs * BLOCK_SIZE
# Constants for GELU tanh approximation
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
gelu_coeff = 0.044715
for idx in tl.range(start_idx, total_elements, stride):
offsets = idx + tl.arange(0, BLOCK_SIZE)
mask = offsets < total_elements
dc = tl.load(dc_ptr + offsets, mask=mask, other=0.0)
a = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
b = tl.load(b_ptr + offsets, mask=mask, other=0.0)
# recomputation to save memory
a_cubed = a * a * a
tanh_arg = sqrt_2_over_pi * (a + gelu_coeff * a_cubed)
tanh_result = tanh(tanh_arg)
geglu_a = 0.5 * a * (1 + tanh_result)
geglu_a = geglu_a.to(dc.dtype).to(tl.float32)
db = dc.cast(tl.float32) * geglu_a
# Gradient w.r.t. a can be computed with:
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
# where z = sqrt(2/pi) * (a + 0.044715 * a^3)
term1 = 0.5 * (1.0 + tanh_result)
tanh_sq = tanh_result * tanh_result
a_sq = a * a
term2 = 0.5 * a * (1.0 - tanh_sq) * (sqrt_2_over_pi * (1.0 + 3.0 * gelu_coeff * a_sq))
da = dc * b * (term1 + term2)
tl.store(da_ptr + offsets, da, mask=mask)
tl.store(db_ptr + offsets, db.to(dc.dtype), mask=mask)
def get_optimal_block_size(total_elements, is_backward=False):
"""
Calculate optimal Block Size using compute_default_tiling_strategy.
Args:
total_elements: Total number of elements to process
is_backward: Whether this is for backward pass (requires more memory)
Returns:
Optimal block size for the kernel
"""
# Memory multiplier based on peak memory usage analysis
if is_backward:
memory_multiplier = 6.0
else:
memory_multiplier = 3.0
# Call calculation function
# Treat input as 1D (total_elements,), only tiling on dim 0
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9,
dtype_size=4,
memory_multiplier=memory_multiplier,
shapes=((total_elements,),),
tiling_dims=(0,),
)
# Parse result
if tile_shapes and len(tile_shapes) > 0:
block_size = tile_shapes[0][0]
return max(256, block_size)
else:
return 2048
def geglu_forward(a, b):
"""
High-performance GEGLU forward pass for NPU using flatten 1D approach.
"""
if not a.is_contiguous():
a = a.contiguous()
if not b.is_contiguous():
b = b.contiguous()
total_elements = a.numel()
c = torch.empty_like(a)
block_size = get_optimal_block_size(total_elements, is_backward=False)
num_cores = get_npu_core_count()
grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
_geglu_forward_kernel_flat[(grid_size,)](a, b, c, total_elements, BLOCK_SIZE=block_size)
return c
def geglu_backward(a, b, dc):
"""
High-performance GEGLU backward pass for NPU using flatten 1D approach.
"""
if not dc.is_contiguous():
dc = dc.contiguous()
if not a.is_contiguous():
a = a.contiguous()
if not b.is_contiguous():
b = b.contiguous()
total_elements = dc.numel()
grad_a = torch.empty_like(a)
grad_b = torch.empty_like(b)
block_size = get_optimal_block_size(total_elements, is_backward=True)
num_cores = get_npu_core_count()
grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
_geglu_backward_kernel_flat[(grid_size,)](dc, a, b, grad_a, grad_b, total_elements, BLOCK_SIZE=block_size)
return grad_a, grad_b
class LigerGELUMulFunction(torch.autograd.Function):
"""High-performance GEGLU function for Ascend NPU."""
@staticmethod
@ensure_contiguous
def forward(ctx, a, b):
c = geglu_forward(a, b)
ctx.save_for_backward(a, b)
return c
@staticmethod
@ensure_contiguous
def backward(ctx, dc):
a, b = ctx.saved_tensors
grad_a, grad_b = geglu_backward(a, b, dc)
return grad_a, grad_b
import torch
import triton
import triton.language as tl
from triton.language.math import rsqrt
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
# -----------------------------------------------------------------------------
# Kernels (2D row/col tiling + persistent programs)
# -----------------------------------------------------------------------------
@triton.jit
def _group_norm_forward_kernel(
Y_ptr, # pointer to output, shape (B, G, hidden_size)
Y_row_stride, # stride of each batch row in Y
Y_col_stride, # stride of each group row in Y
X_ptr, # pointer to input, shape (B, G, hidden_size)
X_row_stride, # stride of each batch row in X
X_col_stride, # stride of each group row in X
Mean_ptr, # pointer to mean output, shape (B, G)
Mean_row_stride, # stride of each batch row in Mean
Mean_col_stride, # stride of each group row in Mean
RSTD_ptr, # pointer to rstd output, shape (B, G)
RSTD_row_stride, # stride of each batch row in RSTD
RSTD_col_stride, # stride of each group row in RSTD
W_ptr, # pointer to affine scale weights, shape (C)
B_ptr, # pointer to affine bias weights, shape (C)
n_rows, # total logical rows = B * G
hidden_size,
channels_per_group,
num_groups,
SINGLE_CHANNEL_TILE: tl.constexpr,
eps,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
grid_m = tl.cdiv(n_rows, BLOCK_SIZE_M)
num_col_blocks = tl.cdiv(hidden_size, BLOCK_SIZE_N)
hidden_size_per_channel = hidden_size // channels_per_group
hidden_size_inv = 1.0 / hidden_size
row_offsets = tl.arange(0, BLOCK_SIZE_M)
col_offsets_base = tl.arange(0, BLOCK_SIZE_N)
# Persistent-program loop over row tiles.
for block_m in tl.range(pid, grid_m, num_progs):
row_idx = block_m * BLOCK_SIZE_M + row_offsets
row_mask = row_idx < n_rows
batch_idx = row_idx // num_groups
group_idx = row_idx % num_groups
row_sum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
row_square_sum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
# Pass 1: accumulate E[x] and E[x^2] for each row tile.
for cb in range(num_col_blocks):
col_offsets = cb * BLOCK_SIZE_N + col_offsets_base
col_mask = col_offsets < hidden_size
mask = row_mask[:, None] & col_mask[None, :]
X_ptrs = (
X_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :]
)
X_block = tl.load(X_ptrs, mask=mask, other=0.0).to(tl.float32)
row_sum += tl.sum(X_block, axis=1)
row_square_sum += tl.sum(X_block * X_block, axis=1)
mean = row_sum * hidden_size_inv
var = row_square_sum * hidden_size_inv - mean * mean
rstd = rsqrt(tl.maximum(var, 0.0) + eps)
mean_ptrs = Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride
rstd_ptrs = RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride
tl.store(mean_ptrs, mean, mask=row_mask)
tl.store(rstd_ptrs, rstd, mask=row_mask)
# Pass 2: normalize + affine transform.
# SINGLE_CHANNEL_TILE indicates the current column tile maps to one channel,
# so W/B can be loaded once per row and broadcast to the tile.
for cb in range(num_col_blocks):
col_offsets = cb * BLOCK_SIZE_N + col_offsets_base
col_mask = col_offsets < hidden_size
mask = row_mask[:, None] & col_mask[None, :]
X_ptrs = (
X_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :]
)
X_block = tl.load(X_ptrs, mask=mask, other=0.0).to(tl.float32)
if SINGLE_CHANNEL_TILE:
local_channel = (cb * BLOCK_SIZE_N) // hidden_size_per_channel
global_channel = group_idx * channels_per_group + local_channel
W_block = tl.load(W_ptr + global_channel, mask=row_mask, other=0.0).to(tl.float32)[:, None]
B_block = tl.load(B_ptr + global_channel, mask=row_mask, other=0.0).to(tl.float32)[:, None]
else:
local_channel = col_offsets // hidden_size_per_channel
global_channel = group_idx[:, None] * channels_per_group + local_channel[None, :]
W_block = tl.load(W_ptr + global_channel, mask=mask, other=0.0).to(tl.float32)
B_block = tl.load(B_ptr + global_channel, mask=mask, other=0.0).to(tl.float32)
Y_block = (X_block - mean[:, None]) * rstd[:, None] * W_block + B_block
Y_ptrs = (
Y_ptr + batch_idx[:, None] * Y_row_stride + group_idx[:, None] * Y_col_stride + col_offsets[None, :]
)
tl.store(Y_ptrs, Y_block, mask=mask)
@triton.jit
def _group_norm_backward_kernel(
X_ptr, # pointer to input, shape (B, G, hidden_size)
X_row_stride, # stride of each batch row in X
X_col_stride, # stride of each group row in X
W_ptr, # pointer to affine scale weights, shape (C)
Mean_ptr, # pointer to saved group mean, shape (B, G)
Mean_row_stride, # stride of each batch row in Mean
Mean_col_stride, # stride of each group row in Mean
RSTD_ptr, # pointer to saved reciprocal std, shape (B, G)
DX_ptr, # pointer to input gradients, shape (B, G, hidden_size)
DW_scratch_ptr, # pointer to scratch buffer for dW partial sums, shape (grid, C)
DW_scratch_stride, # row stride for DW_scratch
DB_scratch_ptr, # pointer to scratch buffer for dB partial sums, shape (grid, C)
DB_scratch_stride, # row stride for DB_scratch
DY_ptr, # pointer to upstream gradients, shape (B, G, hidden_size)
DY_row_stride, # stride of each batch row in DY
DY_col_stride, # stride of each group row in DY
n_rows, # total logical rows = B * G
hidden_size,
channels_per_group,
num_groups,
SINGLE_CHANNEL_TILE: tl.constexpr,
COMPUTE_PARAM_GRAD: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
pid = tl.program_id(0)
num_progs = tl.num_programs(0)
grid_m = tl.cdiv(n_rows, BLOCK_SIZE_M)
num_col_blocks = tl.cdiv(hidden_size, BLOCK_SIZE_N)
hidden_size_per_channel = hidden_size // channels_per_group
N_inv = 1.0 / hidden_size
row_offsets = tl.arange(0, BLOCK_SIZE_M)
col_offsets_base = tl.arange(0, BLOCK_SIZE_N)
if COMPUTE_PARAM_GRAD:
DW_scratch_base = DW_scratch_ptr + pid * DW_scratch_stride
DB_scratch_base = DB_scratch_ptr + pid * DB_scratch_stride
# Persistent-program loop over row tiles.
for block_m in tl.range(pid, grid_m, num_progs):
row_idx = block_m * BLOCK_SIZE_M + row_offsets
row_mask = row_idx < n_rows
batch_idx = row_idx // num_groups
group_idx = row_idx % num_groups
mean = tl.load(
Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride,
mask=row_mask,
other=0.0,
).to(tl.float32)
rstd = tl.load(
RSTD_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride,
mask=row_mask,
other=0.0,
).to(tl.float32)
sum_x_hat_wdy = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
sum_wdy = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
# Pass 1: compute row-wise reduction terms (c1, c2).
for cb in range(num_col_blocks):
col_offsets = cb * BLOCK_SIZE_N + col_offsets_base
col_mask = col_offsets < hidden_size
mask = row_mask[:, None] & col_mask[None, :]
X_ptrs = (
X_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :]
)
DY_ptrs = (
DY_ptr + batch_idx[:, None] * DY_row_stride + group_idx[:, None] * DY_col_stride + col_offsets[None, :]
)
X_block = tl.load(X_ptrs, mask=mask, other=0.0).to(tl.float32)
DY_block = tl.load(DY_ptrs, mask=mask, other=0.0).to(tl.float32)
if SINGLE_CHANNEL_TILE:
local_channel = (cb * BLOCK_SIZE_N) // hidden_size_per_channel
global_channel = group_idx * channels_per_group + local_channel
W_block = tl.load(W_ptr + global_channel, mask=row_mask, other=0.0).to(tl.float32)[:, None]
else:
local_channel = col_offsets // hidden_size_per_channel
global_channel = group_idx[:, None] * channels_per_group + local_channel[None, :]
W_block = tl.load(W_ptr + global_channel, mask=mask, other=0.0).to(tl.float32)
x_hat = (X_block - mean[:, None]) * rstd[:, None]
wdy = W_block * DY_block
sum_x_hat_wdy += tl.sum(tl.where(mask, x_hat * wdy, 0.0), axis=1)
sum_wdy += tl.sum(tl.where(mask, wdy, 0.0), axis=1)
c1 = sum_x_hat_wdy * N_inv
c2 = sum_wdy * N_inv
# Pass 2: compute DX and optionally accumulate DW/DB.
# COMPUTE_PARAM_GRAD=False is used to skip expensive atomics in cases
# where host-side dense reduction is faster/more stable.
for cb in range(num_col_blocks):
col_offsets = cb * BLOCK_SIZE_N + col_offsets_base
col_mask = col_offsets < hidden_size
mask = row_mask[:, None] & col_mask[None, :]
X_ptrs = (
X_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :]
)
DY_ptrs = (
DY_ptr + batch_idx[:, None] * DY_row_stride + group_idx[:, None] * DY_col_stride + col_offsets[None, :]
)
X_block = tl.load(X_ptrs, mask=mask, other=0.0).to(tl.float32)
DY_block = tl.load(DY_ptrs, mask=mask, other=0.0).to(tl.float32)
if SINGLE_CHANNEL_TILE:
local_channel = (cb * BLOCK_SIZE_N) // hidden_size_per_channel
global_channel = group_idx * channels_per_group + local_channel
W_block = tl.load(W_ptr + global_channel, mask=row_mask, other=0.0).to(tl.float32)[:, None]
else:
local_channel = col_offsets // hidden_size_per_channel
global_channel = group_idx[:, None] * channels_per_group + local_channel[None, :]
W_block = tl.load(W_ptr + global_channel, mask=mask, other=0.0).to(tl.float32)
x_hat = (X_block - mean[:, None]) * rstd[:, None]
wdy = W_block * DY_block
DX_block = (wdy - (x_hat * c1[:, None] + c2[:, None])) * rstd[:, None]
DX_ptrs = (
DX_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :]
)
tl.store(DX_ptrs, DX_block.to(X_ptr.dtype.element_ty), mask=mask)
if COMPUTE_PARAM_GRAD:
if SINGLE_CHANNEL_TILE:
dW_partial = tl.sum(tl.where(mask, DY_block * x_hat, 0.0), axis=1)
dB_partial = tl.sum(tl.where(mask, DY_block, 0.0), axis=1)
tl.atomic_add(DW_scratch_base + global_channel, dW_partial, mask=row_mask)
tl.atomic_add(DB_scratch_base + global_channel, dB_partial, mask=row_mask)
# -----------------------------------------------------------------------------
# Helper: call compute_default_tiling_strategy
# -----------------------------------------------------------------------------
def get_optimal_block_size(n_rows, dtype_size, BLOCK_SIZE_N, is_backward: bool = False):
# Backward keeps larger live-state than forward in this kernel.
multiplier = 7.0 if is_backward else 6.0
# Use fp32-size as conservative UB estimate for tiling.
dtype_size = max(dtype_size, 4)
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9,
dtype_size=dtype_size,
memory_multiplier=multiplier,
shapes=((n_rows, BLOCK_SIZE_N),),
tiling_dims=(0,),
)
if tile_shapes and len(tile_shapes) > 0:
return tile_shapes[0][0]
return triton.next_power_of_2(min(128, n_rows))
def group_norm_forward(X, num_channels, num_groups, W, B, eps):
shape = X.shape
batch_size = shape[0]
channels_per_group = num_channels // num_groups
# Reshape X so that the mean / std are computed across each group
X = X.view(batch_size, num_groups, -1).contiguous()
hidden_size = X.shape[-1]
hidden_size_per_channel = hidden_size // channels_per_group
n_rows = batch_size * num_groups
BLOCK_SIZE_N = min(128, triton.next_power_of_2(hidden_size))
BLOCK_SIZE_M = get_optimal_block_size(n_rows, X.element_size(), BLOCK_SIZE_N)
# Fast path condition: each column tile must lie entirely inside one channel
# segment of length `hidden_size_per_channel`.
#
# Layout of a row:
# | channel0 | channel1 | channel2 | ...
# |----Hc----|----Hc----|
# Hc = hidden_size_per_channel
#
# The kernel processes tiles of shape (BLOCK_SIZE_M, BLOCK_SIZE_N).
# Channel boundaries exist only along the column dimension, because
# each row corresponds to a different (batch, group).
#
# Therefore only BLOCK_SIZE_N matters for whether a tile crosses
# channel boundaries; BLOCK_SIZE_M does not affect channel mapping.
#
# If BLOCK_SIZE_N divides Hc and is <= Hc, each column tile belongs
# to exactly one channel. In that case W/B can be loaded once and
# broadcast across the tile (fast path).
#
# Otherwise a tile may span multiple channels, requiring per-element
# channel index computation and parameter loads (slow path).
single_channel_tile = BLOCK_SIZE_N <= hidden_size_per_channel and hidden_size_per_channel % BLOCK_SIZE_N == 0
num_cores = get_npu_core_count()
grid = min(num_cores, triton.cdiv(n_rows, BLOCK_SIZE_M))
Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
Mean = torch.empty((batch_size, num_groups), dtype=X.dtype, device=X.device)
RSTD = torch.empty((batch_size, num_groups), dtype=X.dtype, device=X.device)
_group_norm_forward_kernel[(grid,)](
Y,
Y.stride(0),
Y.stride(1),
X,
X.stride(0),
X.stride(1),
Mean,
Mean.stride(0),
Mean.stride(1),
RSTD,
RSTD.stride(0),
RSTD.stride(1),
W,
B,
n_rows,
hidden_size,
channels_per_group,
num_groups,
SINGLE_CHANNEL_TILE=single_channel_tile,
eps=eps,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
)
return Y.view(*shape), X.view(*shape), Mean, RSTD
def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups):
shape = dY.shape
batch_size = shape[0]
channels_per_group = num_channels // num_groups
X_grouped = X.view(batch_size, num_groups, -1)
dY_grouped = dY.view(batch_size, num_groups, -1)
hidden_size = dY_grouped.shape[-1]
hidden_size_per_channel = hidden_size // channels_per_group
n_rows = batch_size * num_groups
BLOCK_SIZE_N = min(128, triton.next_power_of_2(hidden_size))
BLOCK_SIZE_M = get_optimal_block_size(
n_rows,
X.element_size(),
BLOCK_SIZE_N,
is_backward=True,
)
# Same condition as forward:
# if true, each BLOCK_SIZE_N tile maps cleanly to one channel segment.
single_channel_tile = BLOCK_SIZE_N <= hidden_size_per_channel and hidden_size_per_channel % BLOCK_SIZE_N == 0
num_cores = get_npu_core_count()
grid = min(num_cores, triton.cdiv(n_rows, BLOCK_SIZE_M))
# For non-single-channel tiles, per-element atomic updates are costly.
# In that case, kernel computes DX only and DW/DB are reduced on host side.
compute_param_grad = single_channel_tile
DX = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
if compute_param_grad:
DW_scratch = torch.zeros((grid, num_channels), dtype=torch.float32, device=W.device)
DB_scratch = torch.zeros((grid, num_channels), dtype=torch.float32, device=W.device)
else:
# Not used when COMPUTE_PARAM_GRAD=False.
# Intentionally set to None to enforce fail-fast behavior if accidentally accessed.
DW_scratch = None
DB_scratch = None
_group_norm_backward_kernel[(grid,)](
X_grouped,
X_grouped.stride(0),
X_grouped.stride(1),
W,
Mean,
Mean.stride(0),
Mean.stride(1),
RSTD,
DX,
DW_scratch,
0 if not compute_param_grad else DW_scratch.stride(0),
DB_scratch,
0 if not compute_param_grad else DB_scratch.stride(0),
dY_grouped,
dY_grouped.stride(0),
dY_grouped.stride(1),
n_rows,
hidden_size,
channels_per_group,
num_groups,
SINGLE_CHANNEL_TILE=single_channel_tile,
COMPUTE_PARAM_GRAD=compute_param_grad,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
)
# Precision note:
# - In-kernel atomic_add on floating-point values is order-dependent under parallel
# scheduling (non-associative summation), which can introduce run-to-run numerical
# differences in DW/DB for contention-heavy shapes.
# - Host-side dense reduction provides a more stable accumulation pattern for these
# difficult layouts.
if compute_param_grad:
DW = DW_scratch.sum(dim=0).to(W.dtype)
DB = DB_scratch.sum(dim=0).to(W.dtype)
else:
# Fallback path to avoid severe atomic contention when SINGLE_CHANNEL_TILE=False.
# Layout: [B, G, hidden_size] -> [B, G, C_per_G, hidden_per_channel]
X4 = X_grouped.reshape(batch_size, num_groups, channels_per_group, hidden_size_per_channel).to(torch.float32)
dY4 = dY_grouped.reshape(batch_size, num_groups, channels_per_group, hidden_size_per_channel).to(torch.float32)
mean4 = Mean.reshape(batch_size, num_groups, 1, 1).to(torch.float32)
rstd4 = RSTD.reshape(batch_size, num_groups, 1, 1).to(torch.float32)
x_hat4 = (X4 - mean4) * rstd4
DW = (dY4 * x_hat4).sum(dim=(0, 3)).reshape(-1).to(W.dtype)
DB = dY4.sum(dim=(0, 3)).reshape(-1).to(W.dtype)
return DX.view(*shape), DW, DB
class LigerGroupNormFunction(torch.autograd.Function):
"""
Group Normalization autograd function for Ascend NPU.
Forward computes, for each sample/group:
y = (x - mean) * rstd * weight + bias
where:
mean = E[x], rstd = 1 / sqrt(Var[x] + eps)
The kernel uses row/column tiling with persistent programs. Backward computes
input gradients in Triton and computes parameter gradients either via Triton
atomics (fast path) or host-side dense reduction (fallback path), depending
on the tile/channel layout.
"""
@staticmethod
@ensure_contiguous
def forward(
ctx,
X,
affine_scaling_weight,
affine_shifting_bias,
num_channels,
num_groups,
eps,
):
Y, X, Mean, RSTD = group_norm_forward(
X,
num_channels,
num_groups,
affine_scaling_weight,
affine_shifting_bias,
eps,
)
ctx.num_channels = num_channels
ctx.num_groups = num_groups
ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, W, B, Mean, RSTD = ctx.saved_tensors
DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
return DX, DW, DB, None, None, None
import torch
import triton
import triton.language as tl
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
# Loss type mapping for Triton constexpr branching
# GRPO/DAPO/BNPO/DR_GRPO share identical per-token loss computation (standard PPO clipping)
_TYPE_GRPO: tl.constexpr = tl.constexpr(0)
_TYPE_CISPO: tl.constexpr = tl.constexpr(1)
_TYPE_SAPO: tl.constexpr = tl.constexpr(2)
_str_to_loss_type = {
"grpo": _TYPE_GRPO.value,
"dapo": _TYPE_GRPO.value,
"bnpo": _TYPE_GRPO.value,
"dr_grpo": _TYPE_GRPO.value,
"luspo": _TYPE_GRPO.value,
"cispo": _TYPE_CISPO.value,
"sapo": _TYPE_SAPO.value,
}
def calculate_tile_count_2d(batch_size, seq_len, num_cores):
"""Compute optimal grid configuration for parallel processing."""
grid_batch = batch_size
cores_per_sample = min(seq_len, num_cores // batch_size)
cores_per_sample = max(1, cores_per_sample)
grid_seq = cores_per_sample
total = grid_batch * grid_seq
if total > num_cores:
grid_seq = max(1, num_cores // grid_batch)
return (grid_batch, grid_seq)
def compute_block_size_softmax(seq_vocab_size):
"""Determine optimal block size for selective log-softmax kernel."""
multiplier = 6.0
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((seq_vocab_size,),), tiling_dims=(0,)
)
if tile_shapes and len(tile_shapes) > 0:
return tile_shapes[0][0]
return 2048
def compute_block_size_forward(seq_vocab_size):
"""Determine optimal block size for forward pass kernel."""
multiplier = 10.0
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((seq_vocab_size,),), tiling_dims=(0,)
)
if tile_shapes and len(tile_shapes) > 0:
return tile_shapes[0][0]
return 2048
def compute_block_size_backward(seq_vocab_size):
"""Determine optimal block size for backward pass kernel."""
multiplier = 12.0
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((seq_vocab_size,),), tiling_dims=(0,)
)
if tile_shapes and len(tile_shapes) > 0:
return tile_shapes[0][0]
return 2048
@triton.jit
def _selective_log_softmax_kernel(
LOGITS,
INPUT_IDS,
LOG_P,
MASK,
TEMPERATURE,
stride_input_ids_b,
L: tl.constexpr,
N: tl.constexpr,
BLOCK_N: tl.constexpr = 2048,
):
pid_b = tl.program_id(0)
pid_l = tl.program_id(1)
num_progs_l = tl.num_programs(1)
batch_start = pid_b * L
batch_end = batch_start + L
start_token = batch_start + pid_l
stride = num_progs_l
for token_idx in tl.range(start_token, batch_end, stride):
off_b = token_idx // L
off_l = token_idx % L
should_process = 1
if MASK is not None:
MASK_local = MASK + off_b * stride_input_ids_b + off_l
not_skip = tl.load(MASK_local)
should_process = not_skip
if should_process != 0:
LOGITS_local = LOGITS + off_b * (L + 1) * N + off_l * N
INPUT_IDS_local = INPUT_IDS + off_b * stride_input_ids_b + off_l
LOG_P_local = LOG_P + token_idx
m_i = float("-inf")
l_i = 0.0
for start in range(0, N, BLOCK_N):
cols = start + tl.arange(0, BLOCK_N)
logits = tl.load(LOGITS_local + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
new_m_i = tl.maximum(m_i, tl.max(logits))
alpha = tl.exp(m_i - new_m_i)
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
m_i = new_m_i
lse = m_i + tl.log(l_i)
ids = tl.load(INPUT_IDS_local)
x = tl.load(LOGITS_local + ids).to(tl.float32) / TEMPERATURE
logp = x - lse
tl.store(LOG_P_local, logp)
@triton.jit
def _grpo_loss_fwd_kernel(
LOGITS,
OLD_LOGP,
REF_LOGP,
INPUT_IDS,
COMPLETION_MASK,
ADVANTAGES,
VLLM_IS_RATIO,
VLLM_IS_RATIO_STRIDE,
LOSS,
LSE,
KL,
IS_CLIPPED,
TEMPERATURE,
BETA: tl.constexpr,
EPS_LOW,
EPS_HIGH,
LOSS_TYPE: tl.constexpr,
SAPO_TEMP_POS,
SAPO_TEMP_NEG,
DELTA,
USE_BIAS_CORRECTION_KL: tl.constexpr,
L: tl.constexpr,
N: tl.constexpr,
BLOCK_N: tl.constexpr = 2048,
):
pid_b = tl.program_id(0)
pid_l = tl.program_id(1)
num_progs_l = tl.num_programs(1)
batch_start = pid_b * L
batch_end = batch_start + L
start_token = batch_start + pid_l
stride = num_progs_l
for token_idx in tl.range(start_token, batch_end, stride):
off_b = token_idx // L
off_l = token_idx % L
should_process = 1
if COMPLETION_MASK is not None:
COMPLETION_MASK_local = COMPLETION_MASK + off_b * L + off_l
not_skip = tl.load(COMPLETION_MASK_local)
should_process = not_skip
if should_process != 0:
LOGITS_local = LOGITS + off_b * (L + 1) * N + off_l * N
INPUT_IDS_local = INPUT_IDS + off_b * L + off_l
ADVANTAGES_local = ADVANTAGES + off_b
LOSS_local = LOSS + token_idx
LSE_local = LSE + token_idx
IS_CLIPPED_local = IS_CLIPPED + token_idx
m_i = float("-inf")
l_i = 0.0
for start in range(0, N, BLOCK_N):
cols = start + tl.arange(0, BLOCK_N)
logits = tl.load(LOGITS_local + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
new_m_i = tl.maximum(m_i, tl.max(logits))
alpha = tl.exp(m_i - new_m_i)
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
m_i = new_m_i
lse = m_i + tl.log(l_i)
idx = tl.load(INPUT_IDS_local)
x = tl.load(LOGITS_local + idx).to(tl.float32) / TEMPERATURE
logp = x - lse
if OLD_LOGP is None:
old_logp = logp
else:
OLD_LOGP_local = OLD_LOGP + token_idx
old_logp = tl.load(OLD_LOGP_local).to(tl.float32)
coef_1 = tl.exp(logp - old_logp)
advantage = tl.load(ADVANTAGES_local).to(tl.float32)
if LOSS_TYPE == 0: # GRPO/DAPO/BNPO/DR_GRPO
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
is_clipped = is_low_clipped | is_high_clipped
if DELTA != 0.0:
coef_1 = tl.minimum(coef_1, DELTA)
per_token_loss1 = coef_1 * advantage
per_token_loss2 = coef_2 * advantage
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
elif LOSS_TYPE == 1: # CISPO
coef_2 = tl.minimum(coef_1, EPS_HIGH)
per_token_loss = -coef_2 * advantage * logp
is_clipped = (coef_1 > EPS_HIGH) & (advantage > 0)
elif LOSS_TYPE == 2: # SAPO
temperature = tl.where(advantage > 0, SAPO_TEMP_POS, SAPO_TEMP_NEG)
sigmoid_input = temperature * (coef_1 - 1.0)
sapo_coef = tl.sigmoid(sigmoid_input) * 4.0 / temperature
per_token_loss = -sapo_coef * advantage
is_clipped = 0.0
if VLLM_IS_RATIO is not None:
vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to(
tl.float32
)
per_token_loss = per_token_loss * vllm_is_ratio
if BETA != 0.0:
REF_LOGP_local = REF_LOGP + token_idx
KL_local = KL + token_idx
ref_logp = tl.load(REF_LOGP_local).to(tl.float32)
kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
if USE_BIAS_CORRECTION_KL:
kl = kl * tl.exp(logp - old_logp)
per_token_loss += BETA * kl
tl.store(KL_local, kl)
tl.store(LOSS_local, per_token_loss)
tl.store(LSE_local, lse)
tl.store(IS_CLIPPED_local, is_clipped)
@triton.jit
def _grpo_loss_fwd_kernel_seq(
LOGITS,
OLD_LOGP,
REF_LOGP,
INPUT_IDS,
COMPLETION_MASK,
ADVANTAGES,
COEF_1,
COEF_2,
IS_CLIPPED_SEQ,
VLLM_IS_RATIO,
VLLM_IS_RATIO_STRIDE,
LOSS,
LSE,
KL,
IS_CLIPPED,
TEMPERATURE,
BETA: tl.constexpr,
USE_BIAS_CORRECTION_KL: tl.constexpr,
L: tl.constexpr,
N: tl.constexpr,
BLOCK_N: tl.constexpr = 2048,
):
pid_b = tl.program_id(0)
pid_l = tl.program_id(1)
num_progs_l = tl.num_programs(1)
batch_start = pid_b * L
batch_end = batch_start + L
start_token = batch_start + pid_l
stride = num_progs_l
for token_idx in tl.range(start_token, batch_end, stride):
off_b = token_idx // L
off_l = token_idx % L
should_process = 1
if COMPLETION_MASK is not None:
COMPLETION_MASK_local = COMPLETION_MASK + off_b * L + off_l
not_skip = tl.load(COMPLETION_MASK_local)
should_process = not_skip
if should_process != 0:
LOGITS_local = LOGITS + off_b * (L + 1) * N + off_l * N
INPUT_IDS_local = INPUT_IDS + off_b * L + off_l
ADVANTAGES_local = ADVANTAGES + off_b
COEF_1_local = COEF_1 + off_b
COEF_2_local = COEF_2 + off_b
IS_CLIPPED_SEQ_local = IS_CLIPPED_SEQ + off_b
LOSS_local = LOSS + token_idx
LSE_local = LSE + token_idx
IS_CLIPPED_local = IS_CLIPPED + token_idx
m_i = float("-inf")
l_i = 0.0
for start in range(0, N, BLOCK_N):
cols = start + tl.arange(0, BLOCK_N)
logits = tl.load(LOGITS_local + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
new_m_i = tl.maximum(m_i, tl.max(logits))
alpha = tl.exp(m_i - new_m_i)
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
m_i = new_m_i
lse = m_i + tl.log(l_i)
idx = tl.load(INPUT_IDS_local)
x = tl.load(LOGITS_local + idx).to(tl.float32) / TEMPERATURE
logp = x - lse
coef_1 = tl.load(COEF_1_local).to(tl.float32)
coef_2 = tl.load(COEF_2_local).to(tl.float32)
is_clipped_seq = tl.load(IS_CLIPPED_SEQ_local)
advantage = tl.load(ADVANTAGES_local).to(tl.float32)
per_token_loss1 = coef_1 * advantage
per_token_loss2 = coef_2 * advantage
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
if VLLM_IS_RATIO is not None:
vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to(
tl.float32
)
per_token_loss = per_token_loss * vllm_is_ratio
if BETA != 0.0:
REF_LOGP_local = REF_LOGP + token_idx
KL_local = KL + token_idx
ref_logp = tl.load(REF_LOGP_local).to(tl.float32)
kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
if USE_BIAS_CORRECTION_KL:
if OLD_LOGP is None:
old_logp = logp
else:
old_logp = tl.load(OLD_LOGP + token_idx).to(tl.float32)
kl = kl * tl.exp(logp - old_logp)
per_token_loss += BETA * kl
tl.store(KL_local, kl)
tl.store(LOSS_local, per_token_loss)
tl.store(LSE_local, lse)
tl.store(IS_CLIPPED_local, is_clipped_seq)
@triton.jit
def _grpo_loss_bwd_kernel_seq(
DLOSS,
DLOSS_SUM,
DLOGITS,
LOGITS,
OLD_LOGP,
REF_LOGP,
INPUT_IDS,
ADVANTAGES,
COMPLETION_MASK,
LSE,
COEF_1,
SEQ_LEN,
TEMPERATURE,
BETA: tl.constexpr,
USE_BIAS_CORRECTION_KL: tl.constexpr,
EPS_LOW,
EPS_HIGH,
DELTA,
loss_stride0,
loss_stride1,
L: tl.constexpr,
N: tl.constexpr,
BLOCK_N: tl.constexpr = 2048,
):
pid_b = tl.program_id(0)
pid_l = tl.program_id(1)
num_progs_l = tl.num_programs(1)
batch_start = pid_b * L
batch_end = batch_start + L
start_token = batch_start + pid_l
stride = num_progs_l
for token_idx in tl.range(start_token, batch_end, stride):
off_b = token_idx // L
off_l = token_idx % L
DLOGITS_local = DLOGITS + off_b * (L + 1) * N + off_l * N
should_process = 1
if COMPLETION_MASK is not None:
COMPLETION_MASK_local = COMPLETION_MASK + off_b * L + off_l
not_skip = tl.load(COMPLETION_MASK_local)
should_process = not_skip
if should_process == 0:
for start in range(0, N, BLOCK_N):
cols = tl.arange(0, BLOCK_N) + start
tl.store(DLOGITS_local + cols, 0.0, mask=cols < N)
else:
LOGITS_local = LOGITS + off_b * (L + 1) * N + off_l * N
DLOSS_local = DLOSS + off_b * loss_stride0 + off_l * loss_stride1
DLOSS_SUM_local = DLOSS_SUM + off_b
INPUT_IDS_local = INPUT_IDS + off_b * L + off_l
ADVANTAGES_local = ADVANTAGES + off_b
LSE_local = LSE + token_idx
COEF_1_local = COEF_1 + off_b
SEQ_LEN_local = SEQ_LEN + off_b
dloss = tl.load(DLOSS_local).to(tl.float32)
dloss_sum = tl.load(DLOSS_SUM_local).to(tl.float32)
lse = tl.load(LSE_local).to(tl.float32)
coef_1 = tl.load(COEF_1_local).to(tl.float32)
seq_len = tl.load(SEQ_LEN_local).to(tl.float32)
idx = tl.load(INPUT_IDS_local)
x = tl.load(LOGITS_local + idx).to(tl.float32) / TEMPERATURE
logp = x - lse
advantage = tl.load(ADVANTAGES_local).to(tl.float32)
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
if DELTA != 0.0:
coef_1_for_loss = tl.minimum(coef_1, DELTA)
else:
coef_1_for_loss = coef_1
per_token_loss1 = coef_1_for_loss * advantage
per_token_loss2 = coef_2 * advantage
is_unclipped = per_token_loss2 >= per_token_loss1
dlogp = -coef_1 * advantage / seq_len * is_unclipped * dloss_sum
if DELTA != 0.0:
dlogp = dlogp * (coef_1 <= DELTA)
if BETA != 0.0:
REF_LOGP_local = REF_LOGP + token_idx
ref_logp = tl.load(REF_LOGP_local).to(tl.float32)
if USE_BIAS_CORRECTION_KL:
if OLD_LOGP is None:
old_logp = logp
else:
old_logp = tl.load(OLD_LOGP + token_idx).to(tl.float32)
token_coef_1 = tl.exp(logp - old_logp)
dlogp += BETA * token_coef_1 * (logp - ref_logp) * dloss
else:
dlogp += BETA * (1 - tl.exp(ref_logp - logp)) * dloss
dlogp = dlogp / TEMPERATURE
for start_n in tl.range(0, N, BLOCK_N):
cols = start_n + tl.arange(0, BLOCK_N)
logits = tl.load(LOGITS_local + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
probs = tl.exp(logits - lse)
dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
tl.store(DLOGITS_local + cols, dlogits, mask=cols < N)
@triton.jit
def _grpo_loss_bwd_kernel(
DLOSS,
DLOGITS,
LOGITS,
OLD_LOGP,
REF_LOGP,
INPUT_IDS,
ADVANTAGES,
COMPLETION_MASK,
LSE,
VLLM_IS_RATIO,
VLLM_IS_RATIO_STRIDE,
TEMPERATURE,
BETA: tl.constexpr,
EPS_LOW,
EPS_HIGH,
LOSS_TYPE: tl.constexpr,
SAPO_TEMP_POS,
SAPO_TEMP_NEG,
DELTA,
USE_BIAS_CORRECTION_KL: tl.constexpr,
loss_stride0,
loss_stride1,
L: tl.constexpr,
N: tl.constexpr,
BLOCK_N: tl.constexpr = 2048,
):
pid_b = tl.program_id(0)
pid_l = tl.program_id(1)
num_progs_l = tl.num_programs(1)
batch_start = pid_b * L
batch_end = batch_start + L
start_token = batch_start + pid_l
stride = num_progs_l
for token_idx in tl.range(start_token, batch_end, stride):
off_b = token_idx // L
off_l = token_idx % L
DLOGITS_local = DLOGITS + off_b * (L + 1) * N + off_l * N
should_process = 1
if COMPLETION_MASK is not None:
COMPLETION_MASK_local = COMPLETION_MASK + off_b * L + off_l
not_skip = tl.load(COMPLETION_MASK_local)
should_process = not_skip
if should_process == 0:
for start in range(0, N, BLOCK_N):
cols = tl.arange(0, BLOCK_N) + start
tl.store(DLOGITS_local + cols, 0.0, mask=cols < N)
else:
LOGITS_local = LOGITS + off_b * (L + 1) * N + off_l * N
DLOSS_local = DLOSS + off_b * loss_stride0 + off_l * loss_stride1
INPUT_IDS_local = INPUT_IDS + off_b * L + off_l
ADVANTAGES_local = ADVANTAGES + off_b
LSE_local = LSE + token_idx
dloss = tl.load(DLOSS_local).to(tl.float32)
lse = tl.load(LSE_local).to(tl.float32)
idx = tl.load(INPUT_IDS_local)
x = tl.load(LOGITS_local + idx).to(tl.float32) / TEMPERATURE
logp = x - lse
if OLD_LOGP is None:
old_logp = logp
else:
OLD_LOGP_local = OLD_LOGP + token_idx
old_logp = tl.load(OLD_LOGP_local).to(tl.float32)
coef_1 = tl.exp(logp - old_logp)
advantage = tl.load(ADVANTAGES_local).to(tl.float32)
if LOSS_TYPE == 0: # GRPO/DAPO/BNPO/DR_GRPO
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
if DELTA != 0.0:
coef_1_for_loss = tl.minimum(coef_1, DELTA)
else:
coef_1_for_loss = coef_1
per_token_loss1 = coef_1_for_loss * advantage
per_token_loss2 = coef_2 * advantage
mask = per_token_loss2 >= per_token_loss1
dlogp = -coef_1 * advantage * mask
if DELTA != 0.0:
dlogp = dlogp * (coef_1 <= DELTA)
elif LOSS_TYPE == 1: # CISPO
coef_2 = tl.minimum(coef_1, EPS_HIGH)
dlogp = -coef_2 * advantage
elif LOSS_TYPE == 2: # SAPO
temperature = tl.where(advantage > 0, SAPO_TEMP_POS, SAPO_TEMP_NEG)
sigmoid_input = temperature * (coef_1 - 1.0)
sigmoid_val = tl.sigmoid(sigmoid_input)
d_sapo_d_coef1 = 4.0 * sigmoid_val * (1.0 - sigmoid_val)
dlogp = -advantage * d_sapo_d_coef1 * coef_1
if VLLM_IS_RATIO is not None:
vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to(
tl.float32
)
dlogp = dlogp * vllm_is_ratio
if BETA != 0.0:
REF_LOGP_local = REF_LOGP + token_idx
ref_logp = tl.load(REF_LOGP_local).to(tl.float32)
if USE_BIAS_CORRECTION_KL:
dlogp += BETA * coef_1 * (logp - ref_logp)
else:
dlogp += BETA * (1 - tl.exp(ref_logp - logp))
dlogp = dlogp * dloss / TEMPERATURE
tl.debug_barrier()
for start_n in tl.range(0, N, BLOCK_N):
cols = start_n + tl.arange(0, BLOCK_N)
logits = tl.load(LOGITS_local + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
probs = tl.exp(logits - lse)
dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
tl.store(DLOGITS_local + cols, dlogits, mask=cols < N)
@torch.no_grad
def fused_selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 0.9, mask=None):
"""Compute log probabilities for specific token IDs with selective masking."""
assert logits.is_contiguous()
B, L_ADD_1, N = logits.shape
L = L_ADD_1 - 1
input_ids = input_ids[:, -L:]
if mask is not None:
mask = mask[:, -L:]
log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device)
block_n = compute_block_size_softmax(N)
num_cores = get_npu_core_count()
grid = calculate_tile_count_2d(B, L, num_cores)
_selective_log_softmax_kernel[grid](
logits,
input_ids,
log_p,
mask,
temperature,
input_ids.stride(0),
L,
N,
BLOCK_N=block_n,
)
return log_p
def compute_distribution_normalizer(completion_mask):
"""Calculate global active token count for distributed loss normalization."""
normalizer = completion_mask.to(torch.float32).sum()
world_size = 1
if torch.distributed.is_available() and torch.distributed.is_initialized():
normalizer = normalizer.clone()
torch.distributed.all_reduce(normalizer, op=torch.distributed.ReduceOp.SUM)
world_size = torch.distributed.get_world_size()
normalizer = normalizer / world_size
return torch.clamp(normalizer, min=1.0)
def reduce_loss(per_token_loss, mask, loss_type, max_completion_length, batch_size, seq_len):
"""Apply reduction strategy based on specified loss type."""
if loss_type == "grpo" or loss_type == "sapo":
return ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean()
elif loss_type == "bnpo":
return (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0)
elif loss_type == "dr_grpo":
max_len = max_completion_length if max_completion_length is not None else seq_len
return (per_token_loss * mask).sum() / (batch_size * max_len)
elif loss_type == "dapo" or loss_type == "cispo":
return (per_token_loss * mask).sum() / compute_distribution_normalizer(mask)
elif loss_type == "luspo":
return (per_token_loss * mask.sum(-1, keepdim=True)).mean()
raise ValueError(f"Unknown loss_type: {loss_type}. Expected one of: grpo, bnpo, dr_grpo, dapo, cispo, sapo, luspo")
def grpo_loss_forward_triton(
ctx,
logits,
old_logp,
ref_logp,
completion_ids,
advantages,
completion_mask,
temperature,
beta,
eps_low,
eps_high,
inplace,
loss_type="grpo",
max_completion_length=None,
reduce=True,
importance_sampling_level="token",
sapo_temperature_pos=1.0,
sapo_temperature_neg=1.05,
vllm_is_ratio=None,
delta=None,
use_bias_correction_kl=False,
):
"""Forward pass computation for GRPO loss."""
assert logits.is_contiguous() and completion_ids.is_contiguous()
assert old_logp is None or old_logp.is_contiguous()
assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
assert importance_sampling_level in ("token", "sequence"), (
f"importance_sampling_level must be 'token' or 'sequence', got {importance_sampling_level}"
)
if loss_type not in _str_to_loss_type:
raise ValueError(f"Unknown loss_type '{loss_type}'. Supported types: {list(_str_to_loss_type.keys())}")
if delta is not None and loss_type in ("cispo", "sapo"):
raise ValueError(f"delta (two-sided clipping) is not supported for loss_type='{loss_type}'.")
delta_val = 0.0 if delta is None else float(delta)
if importance_sampling_level == "sequence" and loss_type in ("cispo", "sapo"):
raise ValueError(
f"Sequence-level importance sampling is not supported for loss_type='{loss_type}'. "
f"Use importance_sampling_level='token' instead."
)
if loss_type == "sapo":
if sapo_temperature_pos <= 0:
raise ValueError(f"sapo_temperature_pos must be positive, got {sapo_temperature_pos}")
if sapo_temperature_neg <= 0:
raise ValueError(f"sapo_temperature_neg must be positive, got {sapo_temperature_neg}")
loss_type_int = _str_to_loss_type[loss_type]
B, L_ADD_1, N = logits.shape
L = L_ADD_1 - 1
if completion_mask is not None:
assert completion_mask.is_contiguous()
mask = completion_mask.float() if completion_mask is not None else torch.ones(B, L, device=logits.device)
vllm_is_ratio_ptr = None
vllm_is_ratio_stride = L
if vllm_is_ratio is not None:
assert vllm_is_ratio.dim() in (1, 2), (
f"vllm_is_ratio must be 1D (B,) or 2D (B, L) / (B, 1), got {vllm_is_ratio.dim()}D"
)
if vllm_is_ratio.dim() == 2:
assert vllm_is_ratio.shape[0] == B and vllm_is_ratio.shape[1] in (1, L), (
f"vllm_is_ratio shape must be ({B}, 1) or ({B}, {L}), got {tuple(vllm_is_ratio.shape)}"
)
else:
assert vllm_is_ratio.shape[0] == B, f"vllm_is_ratio shape must be ({B},), got {tuple(vllm_is_ratio.shape)}"
vllm_is_ratio = vllm_is_ratio.contiguous()
vllm_is_ratio_ptr = vllm_is_ratio
vllm_is_ratio_stride = vllm_is_ratio.shape[1] if vllm_is_ratio.dim() > 1 else 1
loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
lse = torch.zeros_like(loss)
is_clipped = torch.zeros_like(loss)
kl = torch.zeros_like(loss) if beta != 0.0 else None
block_n = compute_block_size_forward(N)
num_cores = get_npu_core_count()
grid = calculate_tile_count_2d(B, L, num_cores)
if importance_sampling_level == "sequence":
per_token_logps = fused_selective_log_softmax(logits, completion_ids, temperature, completion_mask)
if old_logp is None:
log_ratio = torch.zeros_like(per_token_logps)
else:
log_ratio = per_token_logps - old_logp
seq_lens = mask.sum(-1).clamp(min=1.0)
seq_log_importance = (log_ratio * mask).sum(-1) / seq_lens
coef_1 = torch.exp(seq_log_importance)
coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high)
is_clipped_seq = ((coef_1 < 1 - eps_low) & (advantages < 0)) | ((coef_1 > 1 + eps_high) & (advantages > 0))
is_clipped_seq = is_clipped_seq.float()
if delta is not None:
coef_1_for_loss = torch.clamp(coef_1, max=delta)
else:
coef_1_for_loss = coef_1
_grpo_loss_fwd_kernel_seq[grid](
logits,
old_logp,
ref_logp,
completion_ids,
completion_mask,
advantages,
coef_1_for_loss.contiguous(),
coef_2.contiguous(),
is_clipped_seq.contiguous(),
vllm_is_ratio_ptr,
vllm_is_ratio_stride,
loss,
lse,
kl,
is_clipped,
temperature,
beta,
use_bias_correction_kl,
L,
N,
BLOCK_N=block_n,
)
ctx.save_for_backward(
logits,
old_logp,
ref_logp,
completion_ids,
advantages,
completion_mask,
lse,
mask,
coef_1,
seq_lens,
vllm_is_ratio_ptr,
)
else:
_grpo_loss_fwd_kernel[grid](
logits,
old_logp,
ref_logp,
completion_ids,
completion_mask,
advantages,
vllm_is_ratio_ptr,
vllm_is_ratio_stride,
loss,
lse,
kl,
is_clipped,
temperature,
beta,
eps_low,
eps_high,
loss_type_int,
sapo_temperature_pos,
sapo_temperature_neg,
delta_val,
use_bias_correction_kl,
L,
N,
BLOCK_N=block_n,
)
ctx.save_for_backward(
logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse, mask, vllm_is_ratio_ptr
)
ctx.infos = (
temperature,
beta,
eps_low,
eps_high,
inplace,
loss_type,
loss_type_int,
sapo_temperature_pos,
sapo_temperature_neg,
max_completion_length,
B,
L,
importance_sampling_level,
vllm_is_ratio_stride,
reduce,
delta_val,
use_bias_correction_kl,
)
mask_sum = mask.sum().clamp(min=1.0)
kl_mean = (kl * mask).sum() / mask_sum if kl is not None else None
clip_ratio = (is_clipped.float() * mask).sum() / mask_sum
if not reduce:
loss_out = loss * mask
kl_out = kl * mask if kl is not None else None
is_clipped_out = is_clipped * mask
return loss_out, kl_out, is_clipped_out
reduced_loss = reduce_loss(loss, mask, loss_type, max_completion_length, B, L)
return reduced_loss, kl_mean, clip_ratio
def grpo_loss_backward_triton(ctx, *args):
"""Backward pass computation for GRPO loss."""
dloss_input = args[0]
saved_tensors = ctx.saved_tensors
(
temperature,
beta,
eps_low,
eps_high,
inplace,
loss_type,
loss_type_int,
sapo_temperature_pos,
sapo_temperature_neg,
max_completion_length,
B,
L,
importance_sampling_level,
vllm_is_ratio_stride,
reduce,
delta_val,
use_bias_correction_kl,
) = ctx.infos
if importance_sampling_level == "sequence":
(
logits,
old_logp,
ref_logp,
completion_ids,
advantages,
completion_mask,
lse,
mask,
coef_1,
seq_lens,
vllm_is_ratio,
) = saved_tensors
else:
(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse, mask, vllm_is_ratio) = (
saved_tensors
)
_, L_ADD_1, N = logits.shape
if not reduce:
dloss = dloss_input
elif loss_type == "grpo" or loss_type == "sapo":
seq_lens_bwd = mask.sum(-1, keepdim=True).clamp(min=1.0)
dloss = dloss_input * mask / (seq_lens_bwd * B)
elif loss_type == "bnpo":
dloss = dloss_input * mask / mask.sum().clamp(min=1.0)
elif loss_type == "dr_grpo":
max_len = max_completion_length if max_completion_length is not None else L
dloss = dloss_input * mask / (B * max_len)
elif loss_type == "dapo" or loss_type == "cispo":
dloss = dloss_input * mask / compute_distribution_normalizer(mask)
elif loss_type == "luspo":
seq_lens_bwd = mask.sum(-1, keepdim=True).clamp(min=1.0)
dloss = dloss_input * seq_lens_bwd / (B * L)
else:
raise ValueError(f"Unknown loss_type: {loss_type}")
dlogits = logits.data if inplace else torch.empty_like(logits)
block_n = compute_block_size_backward(N)
num_cores = get_npu_core_count()
grid = calculate_tile_count_2d(B, L, num_cores)
if importance_sampling_level == "sequence":
if vllm_is_ratio is None:
dloss_sum = dloss.sum(-1).contiguous()
else:
if vllm_is_ratio.dim() == 1:
ratio = vllm_is_ratio.unsqueeze(-1)
else:
ratio = vllm_is_ratio
dloss_sum = (dloss * ratio).sum(-1).contiguous()
_grpo_loss_bwd_kernel_seq[grid](
dloss,
dloss_sum,
dlogits,
logits,
old_logp,
ref_logp,
completion_ids,
advantages,
completion_mask,
lse,
coef_1,
seq_lens,
temperature,
beta,
use_bias_correction_kl,
eps_low,
eps_high,
delta_val,
*dloss.stride(),
L,
N,
BLOCK_N=block_n,
)
else:
_grpo_loss_bwd_kernel[grid](
dloss,
dlogits,
logits,
old_logp,
ref_logp,
completion_ids,
advantages,
completion_mask,
lse,
vllm_is_ratio,
vllm_is_ratio_stride,
temperature,
beta,
eps_low,
eps_high,
loss_type_int,
sapo_temperature_pos,
sapo_temperature_neg,
delta_val,
use_bias_correction_kl,
*dloss.stride(),
L,
N,
BLOCK_N=block_n,
)
dlogits[:, -1, :] = 0
return (
dlogits,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class GrpoLossFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, *args):
return grpo_loss_forward_triton(ctx, *args)
@staticmethod
@ensure_contiguous
def backward(ctx, *args):
return grpo_loss_backward_triton(ctx, *args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment