from dataclasses import dataclass from typing import Optional import torch from liger_kernel.ops import LigerCrossEntropyFunction from liger_kernel.ops import LigerDyTFunction from liger_kernel.ops import LigerFusedAddRMSNormFunction from liger_kernel.ops import LigerFusedLinearCrossEntropyFunction from liger_kernel.ops import LigerFusedLinearJSDFunction from liger_kernel.ops import LigerFusedNeighborhoodAttentionFunction from liger_kernel.ops import LigerGELUMulFunction from liger_kernel.ops import LigerGroupNormFunction from liger_kernel.ops import LigerJSDFunction from liger_kernel.ops import LigerKLDivLossFunction from liger_kernel.ops import LigerLayerNormFunction from liger_kernel.ops import LigerMHCCoeffsFunction from liger_kernel.ops import LigerMHCPostResFunction from liger_kernel.ops import LigerMHCPreFunction from liger_kernel.ops import LigerMultiTokenAttentionFunction from liger_kernel.ops import LigerPolyNormFunction from liger_kernel.ops import LigerQwen2VLMRopeFunction from liger_kernel.ops import LigerRMSNormFunction from liger_kernel.ops import LigerRopeFunction from liger_kernel.ops import LigerSiLUMulFunction from liger_kernel.ops import LigerSoftmaxFunction from liger_kernel.ops import LigerSparsemaxFunction from liger_kernel.ops import LigerTVDLossFunction @dataclass class CrossEntropyOutput: loss: torch.Tensor z_loss: Optional[torch.Tensor] = None token_accuracy: Optional[torch.Tensor] = None predicted_tokens: Optional[torch.Tensor] = None # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html # `weight` and `size_average` are placeholders and not implemented yet def liger_cross_entropy( input, target, weight=None, size_average=None, ignore_index: int = -100, reduce=None, reduction: str = "mean", label_smoothing: float = 0.0, lse_square_scale: float = 0.0, softcap: Optional[float] = None, return_z_loss: bool = False, return_token_accuracy: bool = False, return_predicted_tokens: bool = False, ): loss, z_loss, token_accuracy, predicted_tokens = LigerCrossEntropyFunction.apply( input, target, weight, ignore_index, lse_square_scale, label_smoothing, reduction, softcap, return_z_loss, return_token_accuracy, return_predicted_tokens, ) if not return_z_loss and not return_token_accuracy and not return_predicted_tokens: return loss return CrossEntropyOutput( loss=loss, z_loss=z_loss, token_accuracy=token_accuracy, predicted_tokens=predicted_tokens ) def liger_fused_linear_cross_entropy( input, weight, target, bias=None, ce_weight=None, ignore_index: int = -100, lse_square_scale: float = 0.0, label_smoothing: float = 0.0, reduction: str = "mean", softcap: Optional[float] = None, return_z_loss: bool = False, accum_dtype=None, use_token_scaling: bool = False, return_token_accuracy: bool = False, return_predicted_tokens: bool = False, ): loss, z_loss, token_accuracy, predicted_tokens = LigerFusedLinearCrossEntropyFunction.apply( input, weight, target, bias, ce_weight, ignore_index, lse_square_scale, label_smoothing, reduction, softcap, return_z_loss, accum_dtype, use_token_scaling, return_token_accuracy, return_predicted_tokens, ) if not return_z_loss and not return_token_accuracy and not return_predicted_tokens: return loss return CrossEntropyOutput( loss=loss, z_loss=z_loss, token_accuracy=token_accuracy, predicted_tokens=predicted_tokens ) def liger_fused_linear_jsd( student_input, student_weight, teacher_input, teacher_weight, shift_labels=None, jsd_beta: float = 0.5, ignore_index: int = -100, temperature: float = 1.0, ): return LigerFusedLinearJSDFunction.apply( student_input, student_weight, teacher_input, teacher_weight, shift_labels, jsd_beta, ignore_index, temperature, ) def liger_geglu(a, b): return LigerGELUMulFunction.apply(a, b) def liger_group_norm( X, affine_scaling_weight, affine_shifting_bias, num_channels, num_groups, eps, ): return LigerGroupNormFunction.apply( X, affine_scaling_weight, affine_shifting_bias, num_channels, num_groups, eps, ) def liger_jsd( input, target, shift_labels=None, beta: float = 0.5, ignore_index: int = -100, ): return LigerJSDFunction.apply( input, target, shift_labels, beta, ignore_index, ) # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div # `size_average` and `mean` are being deprecated in torch API and are placeholders here def liger_kl_div( input, target, size_average: bool = True, reduce: bool = True, reduction: str = "mean", log_target: bool = False, eps: float = 1e-10, ): # Note: the default reduction in torch is `mean`, but being `batchmean` in Liger return LigerKLDivLossFunction.apply( input, target, reduction, log_target, eps, ) def liger_sparsemax( input, dim: int = -1, ): return LigerSparsemaxFunction.apply(input, dim) def liger_multi_token_attention( scores, weight, bias=None, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, sparse: bool = False, ): """ Functional interface for multi-token attention. Args: scores: Input tensor of shape (B, C_in, L, L) weight: Convolution weight tensor of shape (C_out, C_in // groups, K, K) bias: Optional bias tensor of shape (C_out,) stride: Stride for the convolution (default: 1) padding: Padding for the convolution (default: 0) dilation: Dilation factor for the convolution (default: 1) groups: Number of groups for the convolution (default: 1) sparse: Specifies if input tensors are expected to be sparse (default: False) Returns: Output tensor after applying multi-token attention. """ return LigerMultiTokenAttentionFunction.apply(scores, weight, bias, stride, padding, dilation, groups, sparse) def liger_fused_neighborhood_attention( query, key, value, kernel_size: int = 7, dilation: int = 1, scale: float = None, ): """ Liger fused neighborhood attention. paper: https://arxiv.org/pdf/2504.16922 Args: query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim] key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim] value: Value tensor of shape [batch_size, num_heads, seq_len, head_dim] kernel_size: Size of the neighborhood window (default: 7) dilation: Dilation factor for the neighborhood (default: 1) scale: Scaling factor for attention scores (default: rsqrt(head_dim)) Returns: Output tensor of shape [batch_size, num_heads, seq_len, head_dim] """ return LigerFusedNeighborhoodAttentionFunction.apply(query, key, value, kernel_size, dilation, scale) def liger_tvd( input, target, shift_labels=None, reduction: str = "mean", ignore_index: int = -100, ): return LigerTVDLossFunction.apply( input, target, shift_labels, reduction, ignore_index, ) def liger_layer_norm(X, W, B, eps): return LigerLayerNormFunction.apply(X, W, B, eps) def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1): return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim) def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True): return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place) def liger_poly_norm(X, W, B, eps=1e-6, in_place=True): return LigerPolyNormFunction.apply(X, W, B, eps, in_place) def liger_fused_add_rms_norm(X, R, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True): return LigerFusedAddRMSNormFunction.apply(X, R, W, eps, offset, casting_mode, in_place) def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim) def liger_swiglu(a, b): return LigerSiLUMulFunction.apply(a, b) def liger_softmax(x): return LigerSoftmaxFunction.apply(x) def liger_dyt(x, alpha, gamma, beta): return LigerDyTFunction.apply(x, alpha, gamma, beta) def liger_mhc_coeffs( x, phi, b, alpha_pre, alpha_post, alpha_res, *, allow_fp32: bool = False, tmax: int = 20, rms_eps: float = 1e-6, pre_eps: float = 0.0, sinkhorn_eps: float = 1e-6, post_mult: float = 2.0, ): # Convert config scalars to Python types so they are not included in the # autograd computation graph (they are not learnable parameters). return LigerMHCCoeffsFunction.apply( x, phi, b, alpha_pre, alpha_post, alpha_res, allow_fp32, int(tmax), float(rms_eps), float(pre_eps), float(sinkhorn_eps), float(post_mult), ) def liger_mhc_pre(x, h_pre): return LigerMHCPreFunction.apply(x, h_pre) def liger_mhc_post_res(x, f_out, h_post, h_res): return LigerMHCPostResFunction.apply(x, f_out, h_post, h_res) def liger_mhc_apply(x, f_out, h_pre, h_post, h_res, *, return_x_in: bool = False): x_in = liger_mhc_pre(x, h_pre) x_out = liger_mhc_post_res(x, f_out, h_post, h_res) if return_x_in: return x_out, x_in return x_out def liger_mhc_forward( x, layer, phi, b, alpha_pre, alpha_post, alpha_res, *, allow_fp32=False, tmax=20, rms_eps=1e-6, pre_eps=0.0, sinkhorn_eps=1e-6, post_mult=2.0, return_coeffs=False, ): """High-level helper: compute coeffs, apply pre, run layer, then apply post+res.""" h_pre, h_post, h_res = liger_mhc_coeffs( x, phi, b, alpha_pre, alpha_post, alpha_res, allow_fp32=allow_fp32, tmax=tmax, rms_eps=rms_eps, pre_eps=pre_eps, sinkhorn_eps=sinkhorn_eps, post_mult=post_mult, ) x_in = liger_mhc_pre(x, h_pre) layer_dtype = x_in.dtype if hasattr(layer, "parameters"): try: layer_dtype = next(layer.parameters()).dtype except StopIteration: layer_dtype = x_in.dtype if x_in.dtype != layer_dtype: x_in = x_in.to(layer_dtype) f_out = layer(x_in) x_out = liger_mhc_post_res(x, f_out, h_post, h_res) if return_coeffs: return x_out, (h_pre, h_post, h_res) return x_out