bert_padding.py 5.79 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py

import torch
import torch.nn.functional as F
from einops import rearrange, repeat


class IndexFirstAxis(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, indices):
        ctx.save_for_backward(indices)
12
13
        assert input.ndim >= 2
        ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
Antoine Adam's avatar
Antoine Adam committed
14
        second_dim = other_shape.numel()
Tri Dao's avatar
Tri Dao committed
15
16
        # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
        # return input[indices]
Tri Dao's avatar
Tri Dao committed
17
18
19
        return torch.gather(
            rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
        ).reshape(-1, *other_shape)
Tri Dao's avatar
Tri Dao committed
20
21
22

    @staticmethod
    def backward(ctx, grad_output):
Tri Dao's avatar
Tri Dao committed
23
        (indices,) = ctx.saved_tensors
24
25
        assert grad_output.ndim >= 2
        other_shape = grad_output.shape[1:]
Tri Dao's avatar
Tri Dao committed
26
27
28
29
30
31
        grad_output = rearrange(grad_output, "b ... -> b (...)")
        grad_input = torch.zeros(
            [ctx.first_axis_dim, grad_output.shape[1]],
            device=grad_output.device,
            dtype=grad_output.dtype,
        )
Tri Dao's avatar
Tri Dao committed
32
33
        # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
        # grad_input[indices] = grad_output
Tri Dao's avatar
Tri Dao committed
34
        grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
35
        return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
Tri Dao's avatar
Tri Dao committed
36
37
38
39
40
41
42
43
44
45


index_first_axis = IndexFirstAxis.apply


class IndexPutFirstAxis(torch.autograd.Function):
    @staticmethod
    def forward(ctx, values, indices, first_axis_dim):
        ctx.save_for_backward(indices)
        assert indices.ndim == 1
46
        assert values.ndim >= 2
Tri Dao's avatar
Tri Dao committed
47
48
49
        output = torch.zeros(
            first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
        )
Tri Dao's avatar
Tri Dao committed
50
51
52
53
54
55
56
        # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
        output[indices] = values
        # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
        return output

    @staticmethod
    def backward(ctx, grad_output):
Tri Dao's avatar
Tri Dao committed
57
        (indices,) = ctx.saved_tensors
Tri Dao's avatar
Tri Dao committed
58
59
60
61
62
63
64
65
66
        # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
        grad_values = grad_output[indices]
        # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
        return grad_values, None, None


index_put_first_axis = IndexPutFirstAxis.apply


67
68
69
70
71
72
class IndexFirstAxisResidual(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, indices):
        ctx.save_for_backward(indices)
        assert input.ndim >= 2
        ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
Antoine Adam's avatar
Antoine Adam committed
73
        second_dim = other_shape.numel()
74
75
76
77
78
79
80
81
82
        # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
        output = input[indices]
        # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
        # memory format to channel_first. In other words, input might not be contiguous.
        # If we don't detach, Pytorch complains about output being a view and is being modified inplace
        return output, input.detach()

    @staticmethod
    def backward(ctx, grad_output, grad_residual):
Tri Dao's avatar
Tri Dao committed
83
        (indices,) = ctx.saved_tensors
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        assert grad_output.ndim >= 2
        other_shape = grad_output.shape[1:]
        assert grad_residual.shape[1:] == other_shape
        grad_input = grad_residual
        # grad_input[indices] += grad_output
        indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
        indices = indices.expand_as(grad_output)
        grad_input.scatter_add_(0, indices, grad_output)
        return grad_input.reshape(ctx.first_axis_dim, *other_shape), None


index_first_axis_residual = IndexFirstAxisResidual.apply


Tri Dao's avatar
Tri Dao committed
98
99
100
def unpad_input(hidden_states, attention_mask):
    """
    Arguments:
101
        hidden_states: (batch, seqlen, ...)
Tri Dao's avatar
Tri Dao committed
102
103
        attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
    Return:
104
        hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
Tri Dao's avatar
Tri Dao committed
105
106
107
108
109
110
111
112
113
114
115
116
        cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
        max_seqlen_in_batch: int
    """
    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
    # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
    # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
    # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
    # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
    # so we write custom forward and backward to make it a bit faster.
Tri Dao's avatar
Tri Dao committed
117
118
119
120
121
122
    return (
        index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
        indices,
        cu_seqlens,
        max_seqlen_in_batch,
    )
Tri Dao's avatar
Tri Dao committed
123
124
125
126
127


def pad_input(hidden_states, indices, batch, seqlen):
    """
    Arguments:
128
        hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
Tri Dao's avatar
Tri Dao committed
129
130
        indices: (total_nnz)
    Return:
131
        hidden_states: (batch, seqlen, ...)
Tri Dao's avatar
Tri Dao committed
132
133
134
135
136
    """
    dim = hidden_states.shape[-1]
    # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
    # output[indices] = hidden_states
    output = index_put_first_axis(hidden_states, indices, batch * seqlen)
Tri Dao's avatar
Tri Dao committed
137
    return rearrange(output, "(b s) ... -> b s ...", b=batch)