cross_entropy_parallel.py 5.33 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
# the losses we can get the global loss. There's no need to do it step by step
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
import torch
import torch.nn as nn

import xentropy_cuda_lib

from apex.transformer.parallel_state import get_tensor_model_parallel_group
from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
from apex.transformer.tensor_parallel.utils import VocabUtility

# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward comparability with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
    torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
if "reduce_scatter_tensor" not in dir(torch.distributed):
    torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base


class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function):

    @staticmethod
    def forward(ctx, logits_parallel, labels, smoothing=0.0, ignored_index=-100,
                inplace_backward=False):
        """
        logits_parallel: (batch, vocab_size / world_size)
        labels: (batch,)
        """
        assert smoothing == 0.0, 'smoothing != 0.0 is not yet implemented, file an issue if you need it'
        batch, partition_vocab_size = logits_parallel.shape
        assert labels.shape == (batch,)
        rank = get_tensor_model_parallel_rank()
        world_size = get_tensor_model_parallel_world_size()
        vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size(
            partition_vocab_size, get_tensor_model_parallel_rank(),
            get_tensor_model_parallel_world_size()
        )

        # Create a mask of valid vocab ids (1 means it needs to be masked).
        labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
        ignored_mask = labels == ignored_index
        labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
        masked_labels = labels_local.clone()
        masked_labels[labels_mask] = ignored_index

        losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing)
        assert lse_local.shape == (batch,)
        assert losses.shape == (batch,)
        losses.masked_fill_(masked_labels==ignored_index, 0)

        if world_size > 1:
            lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
                                        device=lse_local.device)
            torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(),
                                                     group=get_tensor_model_parallel_group())
            lse = torch.logsumexp(lse_allgather, dim=0)
            torch.distributed.all_reduce(losses, op=torch.distributed.ReduceOp.SUM,
                                         group=get_tensor_model_parallel_group())
            # The losses are currently lse_local - predicted_logit, we just have to subtract the
            # lse_local and add the lse (global).
            rank_per_sample = labels // partition_vocab_size
            lse_local = lse_allgather[rank_per_sample,
                                      torch.arange(batch, device=lse_allgather.device)]
            losses += lse - lse_local
            losses.masked_fill_(ignored_mask, 0)
        else:
            lse = lse_local

        ctx.save_for_backward(logits_parallel, lse, labels_local)
        ctx.smoothing = smoothing
        ctx.ignored_index = ignored_index
        ctx.inplace_backward = inplace_backward
        return losses

    @staticmethod
    def backward(ctx, grad_loss):
        logits_parallel, lse, labels = ctx.saved_tensors
        if not grad_loss.is_contiguous():
            grad_loss = grad_loss.contiguous()
        grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
        grad_logits = xentropy_cuda_lib.backward(grad_loss, logits_parallel, lse, labels,
                                                 ctx.smoothing, ctx.inplace_backward)
        return grad_logits, None, None, None, None, None


class CrossEntropyLossParallel(nn.Module):

    def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
                 inplace_backward=False):
        super().__init__()
        if reduction not in ['mean', 'none']:
            raise NotImplementedError("Only support reduction = 'mean' or 'none'")
        self.ignore_index = ignore_index
        self.reduction = reduction
        self.label_smoothing = label_smoothing
        self.inplace_backward = inplace_backward

    def forward(self, input, target):
        assert input.is_cuda and target.is_cuda
        # SoftmaxCrossEntropyLoss implicitly casts to float
        loss = SoftmaxCrossEntropyLossParallelFn.apply(
            input, target, self.label_smoothing, self.ignore_index, self.inplace_backward
        )
        if self.reduction == 'mean':
            return loss.sum() / (target != self.ignore_index).sum()
        else:
            return loss