This directory contains vendor-specific operator implementations that automatically replace the default (CUDA) implementations when running on the corresponding device.
In `backends/_<vendor>/ops/__init__.py`, export your implementations:
```python
"""
<Vendor>-specific operator implementations.
"""
from.<module>import(
LigerGELUMulFunction,
geglu_forward_vendorasgeglu_forward,# Rename to match default API
geglu_backward_vendorasgeglu_backward,
)
# Explicitly declare what to export (recommended)
__all__=[
"LigerGELUMulFunction",
"geglu_forward",
"geglu_backward",
]
```
## Key Points
### Incremental Override
You **don't need to implement all operators**. Only implement the ones that require vendor-specific adaptations. Unimplemented operators will automatically fall back to the default (CUDA) implementation.
### Vendor-Specific Additions
Vendors can also **add new operators** that don't exist in the default implementation. These will be exported to `liger_kernel.ops` namespace for users to import.
### Naming Convention
- Use the **same class/function names** as the default implementations for overrides
- This allows seamless replacement without changing user code
- Use `as` imports to rename if your internal naming differs
## Example: Ascend NPU Backend
See `_ascend/` directory for a complete example of the Ascend NPU backend implementation.
The UB Manager (Unified Buffer Manager) is a core component in **Liger-Kernel** responsible for managing the Unified Buffer (UB) capacity on Ascend NPUs. By automatically detecting UB capacity and providing unified tiling strategy computation, it helps Triton kernels avoid UB overflow errors while maintaining high performance.
- For each shape, identifies which dimensions can be tiled (from `tiling_dims`)
- Calculates `unit_param` as the product of fixed (non-tiling) dimensions
- Calculates the maximum safe block size that fits within UB capacity
- Returns a tuple of max_safe_block_size values (one for each shape)
The `compute_default_tiling_strategy` function:
- Calls `_default_strategy` to get max_safe_block_size for each shape
- For each tiling dimension, computes desired block size using `triton.next_power_of_2(original_dim)`
- Returns the final result with same structure as input shapes: tiling dimensions replaced with computed block sizes, non-tiling dimensions padded to next power of 2
### 3. Parameter Structure
The unified strategy uses the following parameters:
-**`safety_margin`**: Safety margin as a float (e.g., 0.80 for 80%). Default is 0.80.
-**`dtype_size`**: Size of data type in bytes (e.g., 2 for float16, 4 for float32)
-**`memory_multiplier`**: Memory multiplier for estimating peak memory usage
- For GEGLU: typically 10.0 for backward, 7.0 for forward
- For ROPE: typically 3.0
-**`shapes`**: Tuple of full shapes. Each shape is a tuple of dimension sizes.
- For ROPE: `((n_q_head, hd), (n_kv_head, hd))`
- For GEGLU: `((n_cols,),)`
- Can pass original shapes (will handle padding internally) or padded shapes
-**`tiling_dims`**: Tuple specifying which dimensions can be tiled for each shape.
- Each element can be:
-`int`: single dimension index (e.g., `0` for first dimension)
-`tuple of ints`: multiple dimensions that can be tiled together (non-empty)
- For ROPE: `(0, 0)` means first dimension of each shape can be tiled
- For GEGLU: `(0,)` means first dimension of the shape can be tiled
- Length must match `len(shapes)`
- Fixed dimensions (non-tiling) are automatically extracted from shapes and multiplied to get `unit_param`
-**Validation**: Raises `ValueError` if:
- Any `tiling_dim` is empty or invalid (e.g., empty tuple)
- Any dimension index is out of bounds (negative or >= shape length)
### 4. Strategy Computation Flow
```
User calls compute_default_tiling_strategy()
│
▼
Get UB manager instance
│
▼
Validate shapes and tiling_dims (lengths must match)
│
▼
Set defaults for dtype_size (4) and memory_multiplier (10.0)
│
▼
Call _default_strategy() with:
- ub_capacity_bits
- safety_margin
- dtype_size
- memory_multiplier
- shapes
- tiling_dims
│
▼
For each (shape, tiling_dim) pair:
Normalize tiling_dim to set of dimension indices
Validate tiling dimensions are within shape bounds
reduction:tl.constexpr,# set it as constexpr since reduction is always known at compile time
softcap,
RETURN_Z_LOSS:tl.constexpr,
RETURN_TOKEN_ACCURACY:tl.constexpr,
RETURN_PREDICTED_TOKENS:tl.constexpr,
BLOCK_SIZE:tl.constexpr,
HAS_WEIGHT:tl.constexpr,
HAS_SOFTCAPPING:tl.constexpr,
HAS_GRADIENTS:tl.constexpr,
):
"""
This kernel computes both cross entropy loss and the gradient of the input.
We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
Parameters:
X_ptr: Pointer to input tensor.
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
weight_ptr: Pointer to weight tensor.
loss_ptr: Pointer to tensor to store the loss.
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
loss_stride (int): The stride of the loss tensor.
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
token_accuracy_stride (int): The stride of the token accuracy tensor.
n_cols (int): The number of columns in the input tensor.
n_rows (int): The total number of rows to process.
n_non_ignore (float): The number of non-ignored elements in the batch.
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
weight_sum (float): The sum of weight tensor.
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
reduction (str): The string for the reduction to apply
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
BLOCK_SIZE (int): The block size for Triton operations.
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
"""
# Grid-Stride Loop: each program processes multiple rows
This class implements a custom autograd function for the Liger Cross Entropy loss.
It overrides the forward and backward methods of the torch.autograd.Function class.
"""
@staticmethod
defforward(
ctx,
_input:torch.Tensor,
target:torch.Tensor,
weight:Optional[torch.FloatTensor],
ignore_index:int=-100,
lse_square_scale:float=0.0,
label_smoothing:float=0.0,
reduction:str="mean",
softcap:Optional[float]=None,
return_z_loss:bool=False,
return_token_accuracy:bool=False,
return_predicted_tokens:bool=False,
):
"""
The forward pass of the Liger Cross Entropy loss.
Parameters:
ctx : The context object.
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
ignore_index (int): The index to ignore in the target.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy, predicted_tokens) instead of (loss, None, None, None). Default: `False`
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
return_predicted_tokens (bool): When `return_predicted_tokens` is `True`, returns per-token predicted class indices (argmax) without materializing logits. Default: `False`
Returns:
tuple: A tuple with the computed losses, accuracy, and predicted tokens: (loss, z_loss, token_accuracy, predicted_tokens). z_loss, token_accuracy, and predicted_tokens are None if not requested.
Handle the forward and backward pass of the final linear layer via JSD by avoiding
the materialization of the large logits tensor. Since JSD is the last layer, we can
compute the gradient at the forward pass.
"""
@staticmethod
@amp_custom_fwd
defforward(
ctx,
student_input:torch.Tensor,
student_weight:torch.Tensor,
teacher_input:torch.Tensor,
teacher_weight:torch.Tensor,
shift_labels:Optional[torch.Tensor]=None,
jsd_beta:float=0.5,
ignore_index:int=-100,
temperature:float=1.0,
):
"""
Args:
student_input (torch.tensor): input of the last projection layer in student model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
student_weight (torch.tensor): the last projection layer in student model, with shape (V, H), where V is vocab size
teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
ignore_index (int): the index to ignore. Default: -100
temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`