cross_entropy.py 6.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
# the losses we can get the global loss. There's no need to do it step by step
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
import torch
import torch.nn as nn
import xentropy_cuda_lib

# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
    torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base


class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
20
21
22
23
24
25
26
27
28
    def forward(
        ctx,
        logits,
        labels,
        smoothing=0.0,
        ignored_index=-100,
        inplace_backward=False,
        process_group=None,
    ):
29
30
31
32
33
34
35
36
37
38
39
40
41
        """
        logits: (batch, vocab_size)
        labels: (batch,)
        If process_group is not None, we're doing Tensor Parallel: each process is responsible for
        one part of the vocab. The loss needs to be aggregated across processes.
        """
        batch, vocab_size = logits.shape
        assert labels.shape == (batch,)
        world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
        ctx.total_classes = world_size * vocab_size

        if world_size == 1:
            losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
Tri Dao's avatar
Tri Dao committed
42
            losses.masked_fill_(labels == ignored_index, 0)
43
44
45
46
47
48
49
50
51
52
53
54
55
            labels_local = labels
        else:
            rank = torch.distributed.get_rank(process_group)
            vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size

            # Create a mask of valid vocab ids (1 means it needs to be masked).
            labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
            ignored_mask = labels == ignored_index
            labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)

            # For tensor parallel cross entropy with smoothing, we want to pass in the total number
            # of classes so that smoothing can be applied correctly. If total_classes=-1, use the
            # last dimension of the input tensor.
Tri Dao's avatar
Tri Dao committed
56
57
58
            losses, lse_local = xentropy_cuda_lib.forward(
                logits, labels_local, smoothing, world_size * vocab_size
            )
59
60
61
62
63
64
65
66
67
68
69
            assert lse_local.shape == (batch,)
            assert losses.shape == (batch,)
            losses.masked_fill_(ignored_mask, 0)
            # For labels == ignored_index, the loss is always 0.
            # If there's no smoothing, if labels are in the vocab of this partition, losses contains
            # lse_local - predicted logit, and 0 otherwise.
            # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
            # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
            # For labels not in the vocab of this partition, losses contains
            # 0.1 * (lse_local - sum logit / total_classes).

Tri Dao's avatar
Tri Dao committed
70
71
72
73
74
75
            lse_allgather = torch.empty(
                world_size, batch, dtype=lse_local.dtype, device=lse_local.device
            )
            torch.distributed.all_gather_into_tensor(
                lse_allgather, lse_local.contiguous(), group=process_group
            )
76
77
78
79
80
81
82
83
84
            handle_losses = torch.distributed.all_reduce(
                losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
            )
            lse = torch.logsumexp(lse_allgather, dim=0)
            # If there's no smoothing, the total losses are lse_local - predicted_logit,
            # we just have to subtract the lse_local and add the lse (global).
            # If there's smoothing=0.1, the total losses are
            # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
            # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
Tri Dao's avatar
Tri Dao committed
85
86
87
88
            rank_per_sample = torch.div(labels, vocab_size, rounding_mode="floor")
            lse_local = lse_allgather[
                rank_per_sample, torch.arange(batch, device=lse_allgather.device)
            ]
89
90
91
92
93

            handle_losses.wait()
            if smoothing == 0.0:
                losses += lse - lse_local
            else:
Tri Dao's avatar
Tri Dao committed
94
95
96
                losses += (1 - smoothing) * (lse - lse_local) + smoothing * (
                    lse - lse_allgather.sum(dim=0)
                )
97
98
99
100
101
102
103
104
105
106
107
108
            losses.masked_fill_(ignored_mask, 0)

        ctx.save_for_backward(logits, lse, labels_local)
        ctx.smoothing = smoothing
        ctx.ignored_index = ignored_index
        ctx.inplace_backward = inplace_backward
        return losses

    @staticmethod
    def backward(ctx, grad_loss):
        logits, lse, labels = ctx.saved_tensors
        grad_loss = grad_loss.contiguous()
Tri Dao's avatar
Tri Dao committed
109
110
111
112
        grad_loss.masked_fill_(labels == ctx.ignored_index, 0)
        grad_logits = xentropy_cuda_lib.backward(
            grad_loss, logits, lse, labels, ctx.smoothing, ctx.inplace_backward, ctx.total_classes
        )
113
114
115
116
        return grad_logits, None, None, None, None, None, None


class CrossEntropyLoss(nn.Module):
Tri Dao's avatar
Tri Dao committed
117
118
119
120
121
122
123
124
    def __init__(
        self,
        ignore_index=-100,
        reduction="mean",
        label_smoothing=0.0,
        inplace_backward=False,
        process_group=None,
    ):
125
        super().__init__()
Tri Dao's avatar
Tri Dao committed
126
        if reduction not in ["mean", "none"]:
127
128
129
130
131
            raise NotImplementedError("Only support reduction = 'mean' or 'none'")
        self.ignore_index = ignore_index
        self.reduction = reduction
        self.label_smoothing = label_smoothing
        self.inplace_backward = inplace_backward
132
        self.process_group = process_group
133

134
    def forward(self, input, target):
135
136
137
        assert input.is_cuda and target.is_cuda
        # SoftmaxCrossEntropyLoss implicitly casts to float
        loss = SoftmaxCrossEntropyLossFn.apply(
Tri Dao's avatar
Tri Dao committed
138
139
140
141
142
143
            input,
            target,
            self.label_smoothing,
            self.ignore_index,
            self.inplace_backward,
            self.process_group,
144
        )
Tri Dao's avatar
Tri Dao committed
145
        if self.reduction == "mean":
146
147
148
            return loss.sum() / (target != self.ignore_index).sum()
        else:
            return loss