As all the other losses in PyTorch, this function expects the first argument,
:attr:`_input`, to be the predictions, the output of the student model, in log-space
and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
This differs from the standard mathematical notation :math:`JSD(P || Q)` where
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
"""
@staticmethod
@ensure_contiguous
defforward(
ctx,
_input:torch.Tensor,
target:torch.Tensor,
shift_labels:Optional[torch.Tensor]=None,
beta:float=0.5,
ignore_index:int=-100,
)->torch.Tensor:
"""
Args:
_input (torch.Tensor): predict values with shape (BT, V) in logspace
target (torch.Tensor): ground truth values with shape (BT, V) in logspace
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
ignore_index (int): the index to ignore. Default: -100
Returns:
loss (torch.Tensor): generalized JSD
"""
has_label=False
ifshift_labelsisnotNone:
assertshift_labels.shape==(_input.shape[0],),(
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
```python
if log_target:
loss = target.exp() * (target - input)
else:
loss = target * (target.log() - input)
```,
then the loss is reduced according to the `reduction` parameter.
as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
"""
@staticmethod
@ensure_contiguous
defforward(
ctx,
y_pred:torch.Tensor,
y_true:torch.Tensor,
reduction:REDUCTION_LITERAL="batchmean",
log_target:bool=False,
eps:float=1e-10,
)->torch.Tensor:
"""A forward pass for the KL Divergence Loss.
Args:
ctx: Torch autograd context
y_pred (torch.Tensor): A tensor of shape (BT, V) containing the predicted values, expected to be log-probabilities.
y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`.
reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean".
log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False.
eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10.
Returns:
torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar.
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
Returns:
tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
reduction:tl.constexpr,# set it as constexpr since reduction is always known at compile time
softcap,
RETURN_Z_LOSS:tl.constexpr,
RETURN_TOKEN_ACCURACY:tl.constexpr,
RETURN_PREDICTED_TOKENS:tl.constexpr,
BLOCK_SIZE:tl.constexpr,
HAS_WEIGHT:tl.constexpr,
HAS_SOFTCAPPING:tl.constexpr,
HAS_GRADIENTS:tl.constexpr,
):
"""
This kernel computes both cross entropy loss and the gradient of the input.
We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
Parameters:
X_ptr: Pointer to input tensor.
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
weight_ptr: Pointer to weight tensor.
loss_ptr: Pointer to tensor to store the loss.
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
loss_stride (int): The stride of the loss tensor.
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
token_accuracy_stride (int): The stride of the token accuracy tensor.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (float): The number of non-ignored elements in the batch.
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
weight_sum (float): The sum of weight tensor.
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
reduction (str): The string for the reduction to apply
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
BLOCK_SIZE (int): The block size for Triton operations.
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
iflabel_smoothing>0:
ifHAS_WEIGHT:
smooth_loss=scaled_x_sum+eps*lse*weight_sum
else:
smooth_loss=scaled_x_sum+label_smoothing*lse
loss=loss*(1-label_smoothing)+smooth_loss
# An auxiliary loss, z_loss
# Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
z_loss=lse_square_scale*lse*lse
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
ifreduction=="mean":
ifHAS_WEIGHT:
loss=loss/sum_non_ignore_weight
else:
loss=loss/n_non_ignore
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
z_loss=z_loss/n_non_ignore
loss+=z_loss
tl.store(loss_ptr,loss)
ifRETURN_Z_LOSS:
tl.store(z_loss_ptr,z_loss)
ifRETURN_TOKEN_ACCURACY:
# Store 1.0 if prediction is correct, 0.0 otherwise
is_correct=1.0ifargmax_idx==yelse0.0
tl.store(token_accuracy_ptr,is_correct)
ifRETURN_PREDICTED_TOKENS:
tl.store(predicted_tokens_ptr,argmax_idx)
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
# the best size we found by manually tuning on xpu and npu.
ifinfer_device()=="xpu":
MAX_FUSED_SIZE=4096
elifinfer_device()=="npu":
MAX_FUSED_SIZE=2048
else:
MAX_FUSED_SIZE=65536//2
defcross_entropy_forward(
_input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
reduction,
softcap,
return_z_loss,
return_token_accuracy=False,
return_predicted_tokens=False,
):
assertisinstance(return_z_loss,bool),f"return_z_loss must be True or False. Got: {return_z_loss}"
assertisinstance(return_token_accuracy,bool),(
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
)
assertisinstance(return_predicted_tokens,bool),(
f"return_predicted_tokens must be True or False. Got: {return_predicted_tokens}"
This class implements a custom autograd function for the Liger Cross Entropy loss.
It overrides the forward and backward methods of the torch.autograd.Function class.
"""
@staticmethod
defforward(
ctx,
_input:torch.Tensor,
target:torch.Tensor,
weight:Optional[torch.FloatTensor],
ignore_index:int=-100,
lse_square_scale:float=0.0,
label_smoothing:float=0.0,
reduction:str="mean",
softcap:Optional[float]=None,
return_z_loss:bool=False,
return_token_accuracy:bool=False,
return_predicted_tokens:bool=False,
):
"""
The forward pass of the Liger Cross Entropy loss.
Parameters:
ctx : The context object.
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
ignore_index (int): The index to ignore in the target.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy, predicted_tokens) instead of (loss, None, None, None). Default: `False`
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
return_predicted_tokens (bool): When `return_predicted_tokens` is `True`, returns per-token predicted class indices (argmax) without materializing logits. Default: `False`
Returns:
tuple: A tuple with the computed losses, accuracy, and predicted tokens: (loss, z_loss, token_accuracy, predicted_tokens). z_loss, token_accuracy, and predicted_tokens are None if not requested.
This part of the code generates pointers to the specific blocks of matrices A and B that the current thread block will process.
As described in the PyTorch documentation, a stride refers to the step size needed to move from one element to the next along a given dimension:
For matrix A: stride_am = A.stride(0) = K (stride along the rows), and stride_ak = A.stride(1) = 1 (stride along the columns).
For matrix B: stride_bk = B.stride(0) = N (stride along the rows), and stride_bn = B.stride(1) = 1 (stride along the columns).
Now, let's break down the pointer generation:
offs_am[:, None] creates a column of shape [BLOCK_SIZE_M, 1], which represents the row indices of matrix A that this block is processing. It is multiplied by K (the number of columns in matrix A) since A is stored in row-major order. So, the element at position (i, j) in A is located at index i*K + j in memory.
offs_k[None, BLOCK_SIZE_K] creates a row vector representing the column indices of the block, i.e., a range from 0 to BLOCK_SIZE_K. This is used to compute the positions of the columns within the block.
When combined, the result has the shape [BLOCK_SIZE_M, BLOCK_SIZE_K], where each entry (i, j) points to the element in matrix A at position (i, j) for the current block.
The same logic is applied to matrix B, but the resulting shape is [BLOCK_SIZE_K, BLOCK_SIZE_N], representing the block of matrix B that the thread block will work on.
Performs a fused operation that first adds a residual tensor to the hidden_states tensor (`X`), then applies RMSNorm (Root Mean Square Normalization) to the result using the weight tensor `W`, with optional offset and casting mode.
This class implements the following sequence, commonly used in transformer decoder layers:
1. hidden_states = residual + hidden_states
2. residual = hidden_states (after addition)
3. hidden_states = rmsnorm(hidden_states)
Both the normalized hidden_states and the updated residual are returned as outputs.
Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma
uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual
`(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function.
In addition, different models cast their inputs at different places during RMSNorm computation. For
example, Gemma casts everything to fp32 before starting the computation, while Llama casts only the
inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently
support the following casting modes (they match HuggingFace Transformers' implementations):
- 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
- 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
- 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
The `in_place` option determines whether to modify dY in-place to store dX. This defaults to `True` to save memory.
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
compute the gradient at the forward pass. By doing so, we don't have to store the _input and target
for the backward pass.
_input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension.
target: (B*T) where each value is in [0, V-1]
weight: (V, H) where V is the number of classes
bias: (V) where V is the number of classes
ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
ignore_index: the index to ignore in the target
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction: reduction to apply
accum_dtype (torch.dtype): the dtype of intermediate result buffers for weight and bias gradient accumulations.
Recommended to set `accum_dtype` to higher precision, e.g. `torch.float32`, if the training is unstable with original dtype. Default: `None`, performing accumulations in original dtype
use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
Default: False.
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
return_predicted_tokens (bool): When `return_predicted_tokens` is `True`, returns per-token predicted class indices (argmax) without materializing logits. Default: `False`