cross_entropy.py 3.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Cross Entropy Loss API"""

import torch

import transformer_engine.pytorch.triton.cross_entropy as triton_cross_entropy

__all__ = [
    "parallel_cross_entropy",
]


class CrossEntropyFunction(torch.autograd.Function):
    """
    This class implements a custom autograd function for the Cross Entropy loss. The input tensor can be in BF16/FP32, the
    loss and gradient calculation happens in FP32 only. The returned loss is always in FP32, the input gradients are upcasted
    to the dataype of the input.
    """

    @staticmethod
    def forward(
25
26
27
28
29
30
31
        ctx,
        _input,
        target,
        label_smoothing=0.0,
        reduce_loss=False,
        dist_process_group=None,
        ignore_idx=-100,
32
        is_cg_capturable=False,
33
34
35
36
37
38
39
40
41
42
43
44
    ):
        """
        The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each
        distributed rank should be (*,V/world_size). Note that each of the ranks should get equal shards along the V dimension.

        Parameters:
        ctx : The context object.
        _input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size.
        target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1].
        label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
        reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension.
        dist_process_group (torch.dist.ProcessGroup): The distributed process group the loss computation is split across, None if on 1 device.
45
        ignore_idx (int): The index for which loss and gradients are made to zero
46
47
48
49
50

        Returns:
        tensor: The computed loss.
        """
        loss, _input = triton_cross_entropy.cross_entropy_forward(
51
52
53
54
55
56
            _input,
            target,
            label_smoothing,
            reduce_loss,
            dist_process_group,
            ignore_idx,
57
58
59
        )

        ctx.save_for_backward(_input.detach())
60
        ctx.is_cg_capturable = is_cg_capturable
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        return loss

    @staticmethod
    def backward(ctx, grad_output):
        """
        The backward pass of the Cross Entropy loss.

        Parameters:
        ctx : The context object with saved tensors.
        grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.

        Returns:
        tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
        """
        (_input,) = ctx.saved_tensors
76
77
78
        _input = triton_cross_entropy.cross_entropy_backward(
            _input, grad_output, ctx.is_cg_capturable
        )
79
80
81
82
83
84
        return (
            _input,
            None,
            None,
            None,
            None,
85
86
            None,
            None,
87
88
89
90
        )


parallel_cross_entropy = CrossEntropyFunction.apply