fused_cross_entropy.py 6.21 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Copyright (c) 2023, Tri Dao.

import torch
import torch.nn as nn
import xentropy_cuda_lib

# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
    torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base


class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        logits,
        labels,
        smoothing=0.0,
        ignored_index=-100,
        inplace_backward=False,
        process_group=None,
    ):
        """
        logits: (batch, vocab_size)
        labels: (batch,)
        If process_group is not None, we're doing Tensor Parallel: each process is responsible for
        one part of the vocab. The loss needs to be aggregated across processes.
        """
        batch, vocab_size = logits.shape
        assert labels.shape == (batch,)
        world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
        ctx.total_classes = world_size * vocab_size

        if world_size == 1:
            losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
            losses.masked_fill_(labels == ignored_index, 0)
            labels_local = labels
        else:
            rank = torch.distributed.get_rank(process_group)
            vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size

            # Create a mask of valid vocab ids (1 means it needs to be masked).
            labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
            ignored_mask = labels == ignored_index
            labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)

            # For tensor parallel cross entropy with smoothing, we want to pass in the total number
            # of classes so that smoothing can be applied correctly. If total_classes=-1, use the
            # last dimension of the input tensor.
            losses, lse_local = xentropy_cuda_lib.forward(
                logits, labels_local, smoothing, world_size * vocab_size
            )
            assert lse_local.shape == (batch,)
            assert losses.shape == (batch,)
            losses.masked_fill_(ignored_mask, 0)
            # For labels == ignored_index, the loss is always 0.
            # If there's no smoothing, if labels are in the vocab of this partition, losses contains
            # lse_local - predicted logit, and 0 otherwise.
            # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
            # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
            # For labels not in the vocab of this partition, losses contains
            # 0.1 * (lse_local - sum logit / total_classes).

            lse_allgather = torch.empty(
                world_size, batch, dtype=lse_local.dtype, device=lse_local.device
            )
            torch.distributed.all_gather_into_tensor(
                lse_allgather, lse_local.contiguous(), group=process_group
            )
            handle_losses = torch.distributed.all_reduce(
                losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
            )
            lse = torch.logsumexp(lse_allgather, dim=0)
            # If there's no smoothing, the total losses are lse_local - predicted_logit,
            # we just have to subtract the lse_local and add the lse (global).
            # If there's smoothing=0.1, the total losses are
            # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
            # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
            rank_per_sample = torch.div(labels, vocab_size, rounding_mode="floor")
            lse_local = lse_allgather[
                rank_per_sample, torch.arange(batch, device=lse_allgather.device)
            ]

            handle_losses.wait()
            if smoothing == 0.0:
                losses += lse - lse_local
            else:
                losses += (1 - smoothing) * (lse - lse_local) + smoothing * (
                    lse - lse_allgather.sum(dim=0)
                )
            losses.masked_fill_(ignored_mask, 0)

        ctx.save_for_backward(logits, lse, labels_local)
        ctx.smoothing = smoothing
        ctx.ignored_index = ignored_index
        ctx.inplace_backward = inplace_backward
        return losses

    @staticmethod
    def backward(ctx, grad_loss):
        logits, lse, labels = ctx.saved_tensors
        grad_loss = grad_loss.contiguous()
        grad_loss.masked_fill_(labels == ctx.ignored_index, 0)
        grad_logits = xentropy_cuda_lib.backward(
            grad_loss, logits, lse, labels, ctx.smoothing, ctx.inplace_backward, ctx.total_classes
        )
        return grad_logits, None, None, None, None, None, None


class FusedCrossEntropyLoss(nn.Module):
    def __init__(
        self,
        ignore_index=-100,
        reduction="mean",
        label_smoothing=0.0,
        inplace_backward=True,
        process_group=None,
    ):
        super().__init__()
        if reduction not in ["mean", "none"]:
            raise NotImplementedError("Only support reduction = 'mean' or 'none'")
        self.ignore_index = ignore_index
        self.reduction = reduction
        self.label_smoothing = label_smoothing
        self.inplace_backward = inplace_backward
        self.process_group = process_group

    def forward(self, input, target):
        assert input.is_cuda and target.is_cuda
        # SoftmaxCrossEntropyLoss implicitly casts to float
        if len(input.shape) == 3:
            input = input.view(-1, input.size(-1))
            target = target.view(-1)
        loss = SoftmaxCrossEntropyLossFn.apply(
            input,
            target,
            self.label_smoothing,
            self.ignore_index,
            self.inplace_backward,
            self.process_group,
        )
        if self.reduction == "mean":
            return loss.sum() / (target != self.ignore_index).sum()
        else:
            return loss