softmax_dropout.py 4.4 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import unicore_fused_softmax_dropout
import torch.nn.functional as F

9

Guolin Ke's avatar
Guolin Ke committed
10
11
class SoftmaxDropoutFast(torch.autograd.Function):
    @staticmethod
12
13
14
15
16
17
18
19
    def forward(ctx, is_training, inputs, mask, bias, dropout_prob):
        (
            dropout_results,
            dropout_mask,
            softmax_results,
        ) = unicore_fused_softmax_dropout.forward(
            is_training, inputs, mask, bias, dropout_prob, None
        )
Guolin Ke's avatar
Guolin Ke committed
20
21
22
        if is_training:
            ctx.dropout_prob = dropout_prob
            ctx.save_for_backward(softmax_results, dropout_mask)
23
24
25
            ctx.has_bias = bias is not None and bias.requires_grad
            if ctx.has_bias:
                ctx.bias_batch_dim = bias.shape[0]
Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
31
32
        return dropout_results

    @staticmethod
    def backward(ctx, grad_output):
        softmax_results, dropout_mask = ctx.saved_tensors
        dropout_prob = ctx.dropout_prob
        grad_output = grad_output.contiguous()
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        grad_input = unicore_fused_softmax_dropout.backward(
            grad_output, softmax_results, dropout_mask, dropout_prob
        )
        if ctx.has_bias:
            grad_bias = grad_input.view(
                -1, ctx.bias_batch_dim, grad_input.shape[-2], grad_input.shape[-1]
            ).sum(dim=0)
        else:
            grad_bias = None
        return None, grad_input, None, grad_bias, None


def _check_mask(mask, input):
    assert mask.dtype == input.dtype, "mask and input must have the same dtype"
    assert len(mask.shape) == len(input.shape), "wrong length of mask.shape"
    assert (
        mask.shape[-3] == 1 or mask.shape[-3] == input.shape[-3]
    ), "mask.shape[-3] must be 1 or input.shape[-3]"
    if mask.shape[-3] == 1:
        assert mask.shape[-2] == 1, "when mask.shape[-3] == 1, mask.shape[-2] must be 1"
    else:
        assert (
            mask.shape[-2] == 1 or mask.shape[-2] == input.shape[-2]
        ), "mask.shape[-2] must be 1 or input.shape[-2]"


def _check_bias(bias, input):
    assert bias.dtype == input.dtype, "bias and input must have the same dtype"
    assert len(bias.shape) == len(input.shape), "wrong length of bias.shape"
    assert bias.shape[-1] == input.shape[-1], "bias.shape[-1] must be input.shape[-1]"
    assert bias.shape[-2] == input.shape[-2], "bias.shape[-2] must be input.shape[-2]"
    len_shape = len(input.shape)
    if len_shape > 3:
        # head dim should be the same
        assert (
            bias.shape[-3] == input.shape[-3]
        ), "bias.shape[-3] must be input.shape[-3]"
        offset = 3
    else:
        offset = 2
    prev_non_one = True
    for i in range(len_shape - offset - 1, -1, -1):
        if prev_non_one:
            assert (
                bias.shape[i] == input.shape[i] or bias.shape[i] == 1
            ), "bias.shape[{}] must be input.shape[{}] or 1".format(i, i)
        else:
            assert bias.shape[i] == 1, "bias.shape[{}] must be 1".format(i)
        prev_non_one = bias.shape[i] != 1


def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None):
    """softmax dropout, and mask, bias are optional.
    Args:
        input (torch.Tensor): input tensor
        dropout_prob (float): dropout probability
        is_training (bool, optional): is in training or not. Defaults to True.
        mask (torch.Tensor, optional): the mask tensor, use as input + mask . Defaults to None.
        bias (torch.Tensor, optional): the bias tensor, use as input + bias . Defaults to None.
Guolin Ke's avatar
Guolin Ke committed
92

93
94
95
    Returns:
        torch.Tensor: the result after softmax
    """
Guolin Ke's avatar
Guolin Ke committed
96
97
    input = input.contiguous()
    if input.is_cuda and input.shape[-1] <= 2048:
Guolin Ke's avatar
Guolin Ke committed
98
99
100
101
102
103
104
105
        input_size = input.size()
        if mask is not None:
            _check_mask(mask, input)
            mask = mask.contiguous().view(-1, mask.shape[-2], mask.shape[-1])
        if bias is not None:
            _check_bias(bias, input)
            bias = bias.contiguous().view(-1, input_size[-2], input_size[-1])
        input = input.view(-1, input_size[-2], input_size[-1])
106
107
108
        return SoftmaxDropoutFast.apply(
            is_training, input, mask, bias, dropout_prob
        ).view(*input_size)
Guolin Ke's avatar
Guolin Ke committed
109
    else:
Guolin Ke's avatar
Guolin Ke committed
110
        if mask is not None:
111
112
113
            input += mask
        if bias is not None:
            input += bias
Guolin Ke's avatar
Guolin Ke committed
114
        return F.dropout(F.softmax(input, dim=-1), p=dropout_prob, training=is_training)