Unverified Commit 8ca2caf8 authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Parallel Cross Entropy using online softmax (#1456)



* Added parallel cross entropy loss implementation using online softmax
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added tests
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added reshape of loss output
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added to test list
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>

* Added Triton dependency
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>

* Added copyright
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>

* Fixed lint errors
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update setup.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <anandaraj@wisc.edu>

* Fixed lint and triton failure
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>

* Removed flattening for scalars
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>

* Skip tests on Blackwell due to TE CI caveat
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added reason arg
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Do not register Triton dependency with setuptools
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>
Signed-off-by: default avatarSelvaraj Anandaraj <anandaraj@wisc.edu>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 94c92919
......@@ -54,6 +54,8 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index
.. autoapifunction:: transformer_engine.pytorch.parallel_cross_entropy
.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index_with_probs
.. autoapifunction:: transformer_engine.pytorch.initialize_ub
......
......@@ -22,6 +22,7 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1
exit $FAIL
......@@ -104,6 +104,8 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
install_reqs.extend(["torch"])
# Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision", "prettytable"])
if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"])
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import random
import pytest
import torch
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
class TestParallelCrossEntropy:
def generate_iters(self, iters: int):
self.iters = iters
def generate_infra(self, reduce_loss: bool, label_smoothing: float):
self.test_loss_func = parallel_cross_entropy
self.ref_loss_func = torch.nn.CrossEntropyLoss(
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
)
def generate_input(self, dtype: torch.dtype, swap_dim: bool):
SQ = random.choice([64, 128])
batch = random.choice([1, 2])
vocab = random.choice([64000, 128000])
if swap_dim:
self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda()
self.tar_test = torch.randint(0, vocab, (SQ, batch)).cuda()
else:
self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda()
self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda()
self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab))
self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,))
def one_iteration_test(
self, dtype: torch.dtype, swap_dim: bool, label_smoothing: float, reduce_loss: bool
):
self.generate_input(dtype, swap_dim)
self.input_test.requires_grad_(True)
self.input_ref.requires_grad_(True)
test_loss = self.test_loss_func(
self.input_test, self.tar_test, label_smoothing, reduce_loss, None
)
if reduce_loss:
test_loss.backward()
ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref)
if reduce_loss:
ref_loss.backward()
test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss
torch.testing.assert_close(test_loss, ref_loss, check_dtype=False)
if reduce_loss:
torch.testing.assert_close(
torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad
)
self.input_test = None
self.input_ref = None
self.tar_test = None
self.tar_ref = None
def test_float32_input(self):
self.generate_iters(5)
self.generate_infra(True, 0)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32, swap_dim=False, label_smoothing=0, reduce_loss=True
)
def test_bfloat16_input(self):
self.generate_iters(5)
self.generate_infra(True, 0)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.bfloat16, swap_dim=False, label_smoothing=0, reduce_loss=True
)
def test_swapped_input(self):
self.generate_iters(5)
self.generate_infra(True, 0)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32, swap_dim=True, label_smoothing=0, reduce_loss=True
)
def test_label_smoothing(self):
self.generate_iters(3)
self.generate_infra(True, 0.1)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32, swap_dim=False, label_smoothing=0.1, reduce_loss=True
)
def test_non_reduced_loss(self):
self.generate_iters(1)
self.generate_infra(False, 0)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32, swap_dim=False, label_smoothing=0, reduce_loss=False
)
......@@ -89,6 +89,7 @@ from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context
from transformer_engine.pytorch import ops
from transformer_engine.pytorch import optimizers
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
try:
torch._dynamo.config.error_on_nested_jit_trace = False
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Cross Entropy Loss API"""
import torch
import transformer_engine.pytorch.triton.cross_entropy as triton_cross_entropy
__all__ = [
"parallel_cross_entropy",
]
class CrossEntropyFunction(torch.autograd.Function):
"""
This class implements a custom autograd function for the Cross Entropy loss. The input tensor can be in BF16/FP32, the
loss and gradient calculation happens in FP32 only. The returned loss is always in FP32, the input gradients are upcasted
to the dataype of the input.
"""
@staticmethod
def forward(
ctx, _input, target, label_smoothing=0.0, reduce_loss=False, dist_process_group=None
):
"""
The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each
distributed rank should be (*,V/world_size). Note that each of the ranks should get equal shards along the V dimension.
Parameters:
ctx : The context object.
_input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size.
target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1].
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension.
dist_process_group (torch.dist.ProcessGroup): The distributed process group the loss computation is split across, None if on 1 device.
Returns:
tensor: The computed loss.
"""
loss, _input = triton_cross_entropy.cross_entropy_forward(
_input, target, label_smoothing, reduce_loss, dist_process_group
)
ctx.save_for_backward(_input.detach())
return loss
@staticmethod
def backward(ctx, grad_output):
"""
The backward pass of the Cross Entropy loss.
Parameters:
ctx : The context object with saved tensors.
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
Returns:
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
(_input,) = ctx.saved_tensors
_input = triton_cross_entropy.cross_entropy_backward(_input, grad_output)
return (
_input,
None,
None,
None,
None,
)
parallel_cross_entropy = CrossEntropyFunction.apply
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Efficient Cross Entropy kernels written with OpenAI Triton."""
from typing import Union
from functools import reduce
from operator import mul
import torch
import torch.distributed as dist
import triton
import triton.language as tl
@triton.jit
def online_softmax_kernel(
X_ptr,
X_stride,
Y_ptr,
Y_stride,
m_d_X_y_ptr,
m_d_X_y_stride,
rank,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""
This kernel computes the m/d components on this TP rank for the online softmax.
Parameters:
X_ptr: Pointer to input tensor.
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
m_d_X_y_ptr: Pointer to m/d/X_y tensor.
m_d_X_y_stride (int): The stride of the m/d/X_y tensor.
rank (int): The rank of this device in the TP group.
n_cols (int): The number of columns in the input tensor.
BLOCK_SIZE (int): The block size for Triton operations.
"""
program_id = tl.program_id(0).to(tl.int64)
# locate the start index
X_ptr += program_id * X_stride
# Load Y_ptr
Y_ptr += program_id * Y_stride
y = tl.load(Y_ptr)
vocab_start_idx = rank * n_cols
vocab_end_idx = (rank + 1) * n_cols
if y >= vocab_start_idx:
if y < vocab_end_idx:
X_y = tl.load(X_ptr + y - vocab_start_idx).to(tl.float32)
else:
X_y = float("-inf")
else:
X_y = float("-inf")
m_d_X_y_ptr += program_id * m_d_X_y_stride * 3
# 3. [Online softmax] first pass: find max + sum
m = float("-inf") # m is the max value. use the notation from the paper
d = 0.0 # d is the sum. use the notation from the paper
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")).to(
tl.float32
)
block_max = tl.max(X_block)
m_new = tl.maximum(m, block_max)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
m = m_new
tl.store(m_d_X_y_ptr, m)
tl.store(m_d_X_y_ptr + m_d_X_y_stride, d)
tl.store(m_d_X_y_ptr + (2 * m_d_X_y_stride), X_y)
@triton.jit
def cross_entropy_kernel(
X_ptr,
X_stride,
Y_ptr,
Y_stride,
loss_ptr,
loss_stride,
m_d_X_y_ptr,
m_d_X_y_stride,
rank,
world_size,
n_cols,
n_non_ignore,
label_smoothing: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
This kernel computes both cross entropy loss and the gradient of the input.
Parameters:
X_ptr: Pointer to input tensor.
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
loss_ptr: Pointer to tensor to store the loss.
loss_stride (int): The stride of the loss tensor.
m_d_X_y_ptr: Pointer to m/d/X_y tensor.
m_d_X_y_stride: The stride of m/d/X_y tensor.
rank (int): The rank of this device in the TP group.
world_size (int): The size of world involved in this distributed loss calculation.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
BLOCK_SIZE (int): The block size for Triton operations.
"""
program_id = tl.program_id(0).to(tl.int64)
# locate the start index
X_ptr += program_id * X_stride
# Load Y_ptr
Y_ptr += program_id * Y_stride
y = tl.load(Y_ptr)
loss_ptr += program_id * loss_stride
m_d_X_y_ptr += program_id * 3 * m_d_X_y_stride
# Need to reduce the m/d/X_y values from other TP ranks
m = tl.load(m_d_X_y_ptr)
d = tl.load(m_d_X_y_ptr + m_d_X_y_stride)
ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride))
for i in range(1, world_size):
offset = i * 3 * n_non_ignore * m_d_X_y_stride
access_ptr = m_d_X_y_ptr + offset
m_new = tl.load(access_ptr)
d_new = tl.load(access_ptr + m_d_X_y_stride)
X_y_new = tl.load(access_ptr + (2 * m_d_X_y_stride))
d = d * tl.exp(m - tl.maximum(m, m_new)) + d_new * tl.exp(m_new - tl.maximum(m, m_new))
m = tl.maximum(m, m_new)
ori_X_y = tl.maximum(ori_X_y, X_y_new)
# Label smoothing is a general case of normal cross entropy
scaled_x_sum = 0.0
eps = label_smoothing / (n_cols * world_size)
# 4. [Online softmax] second pass: calculate the gradients
# dx_y = (softmax(x_y) - 1) / N
# dx_i = softmax(x_i) / N, i != y
# N is the number of non ignored elements in the batch
# For label smoothing:
# dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
# = dx_i - (1 - label_smoothing) / N
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf"))
grad_dtype = X_block.dtype
X_block = X_block.to(tl.float32)
if label_smoothing > 0:
# scale X beforehand to avoid overflow
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols)
# We need tl.debug_barrier() to ensure the new result of X_ptr is written
tl.debug_barrier()
# 5. Calculate the loss
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
loss = -(ori_X_y - m - tl.log(d))
# Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
# = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
if label_smoothing > 0:
smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))
loss = loss * (1 - label_smoothing) + smooth_loss
# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
vocab_start_idx = rank * n_cols
vocab_end_idx = (rank + 1) * n_cols
if y >= vocab_start_idx:
if y < vocab_end_idx:
X_y = tl.load(X_ptr + y - vocab_start_idx)
X_y += -(1 - label_smoothing) / (n_non_ignore)
tl.store(X_ptr + y - vocab_start_idx, X_y)
tl.store(loss_ptr, loss)
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
MAX_FUSED_SIZE = 65536 // 2
@triton.jit
def element_mul_kernel(
X_ptr,
X_stride,
grad_output_ptr,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""
This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
The multiplication is performed in-place on the tensor pointed by X_ptr.
Parameters:
X_ptr: Pointer to the input tensor.
X_stride (int): The stride of the input tensor.
grad_output_ptr: Pointer to the gradient output value.
n_cols (int): The number of columns in the input tensor.
BLOCK_SIZE (int): The block size for Triton operations.
"""
# Get the program ID and convert it to int64 to avoid overflow
program_id = tl.program_id(0).to(tl.int64)
# Locate the start index
X_ptr += program_id * X_stride
# Load the gradient output value
grad_output = tl.load(grad_output_ptr)
# Perform the element-wise multiplication
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
def cross_entropy_forward(
_input: torch.Tensor,
target: torch.Tensor,
label_smoothing: float,
reduce_loss: bool,
dist_process_group: Union[dist.ProcessGroup, None],
):
"""Forward implementation of Cross Entropy kernel"""
B, SQ, V = _input.shape
n_rows = B * SQ
assert reduce(mul, list(target.size())) == (B * SQ), "Each token needs a target token ID."
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
# unreduced loss
loss_1d = torch.zeros(n_rows, dtype=torch.float32, device=_input.device)
# tensor to hold this rank's m/d/X_y values
m_d_X_y = torch.zeros(n_rows * 3, dtype=torch.float32, device=_input.device)
# ensure _input and target are contiguous in the last dimension
if _input.stride(-1) != 1:
_input = _input.contiguous()
if target.stride(-1) != 1:
target = target.contiguous()
rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group)
online_softmax_kernel[(n_rows,)](
X_ptr=_input,
X_stride=_input.stride(-2),
Y_ptr=target,
Y_stride=target.stride(-1), # always 1
m_d_X_y_ptr=m_d_X_y,
m_d_X_y_stride=m_d_X_y.stride(-1),
rank=rank,
n_cols=V,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
)
world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group)
if world_size > 1:
m_d_X_y_gathered = torch.zeros(
n_rows * 3 * world_size, dtype=torch.float32, device=_input.device
)
dist.all_gather_into_tensor(m_d_X_y_gathered, m_d_X_y, group=dist_process_group)
else:
m_d_X_y_gathered = m_d_X_y
cross_entropy_kernel[(n_rows,)](
X_ptr=_input,
X_stride=_input.stride(-2),
Y_ptr=target,
Y_stride=target.stride(-1),
loss_ptr=loss_1d,
loss_stride=loss_1d.stride(-1),
m_d_X_y_ptr=m_d_X_y_gathered,
m_d_X_y_stride=m_d_X_y_gathered.stride(-1),
rank=rank,
world_size=world_size,
n_cols=V,
n_non_ignore=n_rows,
label_smoothing=label_smoothing,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
)
loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / n_rows)
return loss, _input
def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor):
"""Backward implementation of cross entropy loss kernel"""
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
pass
else:
B, SQ, V = _input.shape
n_rows = B * SQ
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
element_mul_kernel[(n_rows,)](
_input,
_input.stride(-2),
grad_output,
V,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
)
return _input
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment