cross_entropy.py 5.21 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.

"""Cross Entropy Loss API"""

Paweł Gadziński's avatar
Paweł Gadziński committed
7
8
9
from typing import Optional
import warnings

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch

import transformer_engine.pytorch.triton.cross_entropy as triton_cross_entropy

__all__ = [
    "parallel_cross_entropy",
]


class CrossEntropyFunction(torch.autograd.Function):
    """
    This class implements a custom autograd function for the Cross Entropy loss. The input tensor can be in BF16/FP32, the
    loss and gradient calculation happens in FP32 only. The returned loss is always in FP32, the input gradients are upcasted
    to the dataype of the input.
    """

    @staticmethod
    def forward(
28
        ctx,
Paweł Gadziński's avatar
Paweł Gadziński committed
29
        inp,
30
31
32
33
34
        target,
        label_smoothing=0.0,
        reduce_loss=False,
        dist_process_group=None,
        ignore_idx=-100,
35
        is_cg_capturable=False,
36
37
38
39
40
41
42
    ):
        """
        The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each
        distributed rank should be (*,V/world_size). Note that each of the ranks should get equal shards along the V dimension.

        Parameters:
        ctx : The context object.
Paweł Gadziński's avatar
Paweł Gadziński committed
43
        inp (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size.
44
45
46
47
        target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1].
        label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
        reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension.
        dist_process_group (torch.dist.ProcessGroup): The distributed process group the loss computation is split across, None if on 1 device.
48
        ignore_idx (int): The index for which loss and gradients are made to zero
49
50
51
52

        Returns:
        tensor: The computed loss.
        """
Paweł Gadziński's avatar
Paweł Gadziński committed
53
54
        loss, inp = triton_cross_entropy.cross_entropy_forward(
            inp,
55
56
57
58
59
            target,
            label_smoothing,
            reduce_loss,
            dist_process_group,
            ignore_idx,
60
61
        )

Paweł Gadziński's avatar
Paweł Gadziński committed
62
        ctx.save_for_backward(inp.detach())
63
        ctx.is_cg_capturable = is_cg_capturable
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        return loss

    @staticmethod
    def backward(ctx, grad_output):
        """
        The backward pass of the Cross Entropy loss.

        Parameters:
        ctx : The context object with saved tensors.
        grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.

        Returns:
        tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
        """
Paweł Gadziński's avatar
Paweł Gadziński committed
78
79
        (inp,) = ctx.saved_tensors
        inp = triton_cross_entropy.cross_entropy_backward(inp, grad_output, ctx.is_cg_capturable)
80
        return (
Paweł Gadziński's avatar
Paweł Gadziński committed
81
            inp,
82
83
84
85
            None,
            None,
            None,
            None,
86
87
            None,
            None,
88
89
90
        )


Paweł Gadziński's avatar
Paweł Gadziński committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def parallel_cross_entropy(
    inp: torch.Tensor,
    target: torch.Tensor,
    label_smoothing: float = 0.0,
    reduce_loss: bool = False,
    dist_process_group: Optional[torch.distributed.ProcessGroup] = None,
    ignore_idx: int = -100,
    is_cg_capturable: bool = False,
    *,
    _input: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Cross Entropy loss with optional distributed reduction.

    The input tensor can be in BF16/FP32, the loss and gradient calculation happens in
    FP32 only. The returned loss is always in FP32, the input gradients are upcasted
    to the datatype of the input.

    If ``dist_process_group`` is passed for distributed loss calculation, the input to each
    distributed rank should be ``(*, V/world_size)``. Note that each of the ranks should
    get equal shards along the V dimension.

    Parameters
    ----------
    inp : torch.Tensor
        The input tensor of shape ``(B, SQ, V)`` or ``(SQ, B, V)`` where B is batch size,
        SQ is sequence length, V is vocab size.
    target : torch.Tensor
        The target tensor of shape ``(B, SQ)`` or ``(SQ, B)`` where each value is in ``[0, V-1]``.
    label_smoothing : float, default = 0.0
        The amount of smoothing when computing the loss, where 0.0 means no smoothing.
    reduce_loss : bool, default = False
        If True, returns the averaged loss across the B*SQ dimension.
    dist_process_group : torch.distributed.ProcessGroup, default = None
        The distributed process group the loss computation is split across, None if on 1 device.
    ignore_idx : int, default = -100
        The index for which loss and gradients are made to zero.
    is_cg_capturable : bool, default = False
        Whether the operation is CUDA graph capturable.

    Returns
    -------
    torch.Tensor
        The computed loss.
    """
    # Handle backward compatibility with _input parameter
    if _input is not None:
        warnings.warn(
            "The '_input' parameter is deprecated. Please use 'inp' instead.",
            FutureWarning,
        )
        inp = _input

    return CrossEntropyFunction.apply(
        inp,
        target,
        label_smoothing,
        reduce_loss,
        dist_process_group,
        ignore_idx,
        is_cg_capturable,
    )