cross_entropy_parallel.py 5.69 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
# the losses we can get the global loss. There's no need to do it step by step
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
import torch
import torch.nn as nn

import xentropy_cuda_lib

from apex.transformer.parallel_state import get_tensor_model_parallel_group
from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
from apex.transformer.tensor_parallel.utils import VocabUtility

# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward comparability with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
    torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
if "reduce_scatter_tensor" not in dir(torch.distributed):
    torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base


class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function):

    @staticmethod
    def forward(ctx, logits_parallel, labels, smoothing=0.0, ignored_index=-100,
                inplace_backward=False):
        """
        logits_parallel: (batch, vocab_size / world_size)
        labels: (batch,)
        """
        assert smoothing == 0.0, 'smoothing != 0.0 is not yet implemented, file an issue if you need it'
        batch, partition_vocab_size = logits_parallel.shape
        assert labels.shape == (batch,)
        rank = get_tensor_model_parallel_rank()
        world_size = get_tensor_model_parallel_world_size()

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        if world_size == 1:
            losses, lse = xentropy_cuda_lib.forward(logits_parallel, labels, smoothing)
            losses.masked_fill_(labels==ignored_index, 0)
            labels_local = labels
        else:
            vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size(
                partition_vocab_size, get_tensor_model_parallel_rank(),
                get_tensor_model_parallel_world_size()
            )

            # Create a mask of valid vocab ids (1 means it needs to be masked).
            labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
            ignored_mask = labels == ignored_index
            labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
            masked_labels = labels_local.clone()
            masked_labels[labels_mask] = ignored_index
Tri Dao's avatar
Tri Dao committed
56

57
58
59
60
            losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing)
            assert lse_local.shape == (batch,)
            assert losses.shape == (batch,)
            losses.masked_fill_(masked_labels==ignored_index, 0)
Tri Dao's avatar
Tri Dao committed
61
62
63

            lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
                                        device=lse_local.device)
64
65
66
67
68
69
70
71
72
            handle_lse = torch.distributed.all_gather_into_tensor(
                lse_allgather, lse_local.contiguous(),
                group=get_tensor_model_parallel_group(), async_op=True
            )
            handle_losses = torch.distributed.all_reduce(
                losses, op=torch.distributed.ReduceOp.SUM,
                group=get_tensor_model_parallel_group(), async_op=True
            )
            handle_lse.wait()
Tri Dao's avatar
Tri Dao committed
73
            lse = torch.logsumexp(lse_allgather, dim=0)
74
75
76
            # The losses are going to be lse_local - predicted_logit, we just have to subtract
            # the lse_local and add the lse (global).
            rank_per_sample = torch.div(labels, partition_vocab_size, rounding_mode='floor')
Tri Dao's avatar
Tri Dao committed
77
78
            lse_local = lse_allgather[rank_per_sample,
                                      torch.arange(batch, device=lse_allgather.device)]
79
80

            handle_losses.wait()
Tri Dao's avatar
Tri Dao committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
            losses += lse - lse_local
            losses.masked_fill_(ignored_mask, 0)

        ctx.save_for_backward(logits_parallel, lse, labels_local)
        ctx.smoothing = smoothing
        ctx.ignored_index = ignored_index
        ctx.inplace_backward = inplace_backward
        return losses

    @staticmethod
    def backward(ctx, grad_loss):
        logits_parallel, lse, labels = ctx.saved_tensors
        if not grad_loss.is_contiguous():
            grad_loss = grad_loss.contiguous()
        grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
        grad_logits = xentropy_cuda_lib.backward(grad_loss, logits_parallel, lse, labels,
                                                 ctx.smoothing, ctx.inplace_backward)
        return grad_logits, None, None, None, None, None


class CrossEntropyLossParallel(nn.Module):

    def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
                 inplace_backward=False):
        super().__init__()
        if reduction not in ['mean', 'none']:
            raise NotImplementedError("Only support reduction = 'mean' or 'none'")
        self.ignore_index = ignore_index
        self.reduction = reduction
        self.label_smoothing = label_smoothing
        self.inplace_backward = inplace_backward

    def forward(self, input, target):
        assert input.is_cuda and target.is_cuda
        # SoftmaxCrossEntropyLoss implicitly casts to float
        loss = SoftmaxCrossEntropyLossParallelFn.apply(
            input, target, self.label_smoothing, self.ignore_index, self.inplace_backward
        )
        if self.reduction == 'mean':
            return loss.sum() / (target != self.ignore_index).sum()
        else:
            return loss