sparse_self_attention.py 6.89 KB
Newer Older
1
2
3
4
5
6
7
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""

import torch.nn as nn
from torch.nn.functional import *
import torch
8
from torch import distributed as dist
9
from collections import namedtuple
aiss's avatar
aiss committed
10
from deepspeed.ops.sparse_attention import SparsityConfig
11
12
13
14
15
16
17
18
19
20


class SparseSelfAttention(nn.Module):
    """Implements an efficient Sparse Self Attention of Transformer layer based on `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509

    For more information please see, TODO DeepSpeed Sparse Transformer.

    For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial.
    """
    def __init__(
aiss's avatar
aiss committed
21
22
23
24
25
26
            self,
            # SparsityConfig parameters needs to be set accordingly
            sparsity_config=SparsityConfig(num_heads=4),
            key_padding_mask_mode='add',
            attn_mask_mode='mul',
            max_seq_length=2048):
27
28
        """Initialize the sparse self attention layer.
        Arguments:
aiss's avatar
aiss committed
29
            sparsity_config: optional: this parameter determines sparsity pattern configuration; it is based on SparsityConfig class.
30
31
            key_padding_mask_mode: optional: a string determining if key padding mask needs to be added, `add`, or be multiplied, `mul`.
            attn_mask_mode: optional: a string determining if attention mask needs to be added, `add`, or be multiplied, `mul`.
32
            max_seq_length: optional: the maximum sequence length this sparse attention module will be applied to; it controls the size of the master_layout.
33
34
35
36
37
38
        """
        super().__init__()

        # sparsity information
        self.sparsity_config = sparsity_config

39
40
41
42
43
        # initialize sparse layout and register as buffer
        master_layout = self.sparsity_config.make_layout(max_seq_length)
        self.register_buffer("master_layout", master_layout)
        self._need_layout_synchronization = True

44
45
46
47
48
49
        # mask modes
        self.key_padding_mask_mode = key_padding_mask_mode
        self.attn_mask_mode = attn_mask_mode

    ops = dict()

50
51
52
53
54
55
56
57
58
59
60
61
62
63
    def get_layout(self, L):
        # if layout is never synchronized across GPUs, broadcast the layout from global rank 0
        if self._need_layout_synchronization and dist.is_initialized():
            dist.broadcast(self.master_layout, src=0)
            self._need_layout_synchronization = False

        if (L % self.sparsity_config.block != 0):
            raise ValueError(
                f'Sequence Length, {L}, needs to be dividable by Block size {self.sparsity_config.block}!'
            )

        num_blocks = L // self.sparsity_config.block
        return self.master_layout[..., :num_blocks, :num_blocks].cpu()  # layout needs to be a CPU tensor

64
65
    # add to cache
    def get_ops(self, H, L):
aiss's avatar
aiss committed
66
67
        from deepspeed.ops.sparse_attention.matmul import MatMul
        from deepspeed.ops.sparse_attention.softmax import Softmax
68
        if L not in SparseSelfAttention.ops:
69
            sparsity_layout = self.get_layout(L)
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
            sparse_dot_sdd_nt = MatMul(sparsity_layout,
                                       self.sparsity_config.block,
                                       'sdd',
                                       trans_a=False,
                                       trans_b=True)

            sparse_dot_dsd_nn = MatMul(sparsity_layout,
                                       self.sparsity_config.block,
                                       'dsd',
                                       trans_a=False,
                                       trans_b=False)

            sparse_softmax = Softmax(sparsity_layout, self.sparsity_config.block)

            SparseSelfAttention.ops[L] = (sparse_dot_sdd_nt,
                                          sparse_dot_dsd_nn,
                                          sparse_softmax)
        return SparseSelfAttention.ops[L]

    def transpose_key_for_scores(self, x, L):
        bsz, num_heads, seq_len, head_dim = x.size()
        if seq_len != L:
            return x.permute(0, 1, 3, 2)
        return x

    def transpose_mask_for_sparse(self, qtype, x, is_key_padding_mask=False):
        x = x.type(qtype)
        if is_key_padding_mask:
            xdim = x.dim()
            for d in range(xdim - 1, 0, -1):
                x = x.squeeze(dim=d)
            return x
        return x.squeeze()

    # forward pass
    def forward(self,
                query,
                key,
                value,
                rpe=None,
                key_padding_mask=None,
                attn_mask=None):
        """Applies forward phase of sparse self attention

        Arguments:
            query: required: query tensor
            key: required: key tensor
            value: required: value tensor
            rpe: optional: a tensor same dimension as x that is used as relative position embedding
            key_padding_mask: optional: a mask tensor of size (BatchSize X SequenceLength)
            attn_mask: optional: a mask tensor of size (SequenceLength X SequenceLength); currently only 2D is supported
            key_padding_mask_mode: optional: a boolean determining if key_padding_mask needs to be added or multiplied
            attn_mask_mode: optional: a boolean determining if attn_mask needs to be added or multiplied

        Return:
aiss's avatar
aiss committed
125
             attn_output: a dense tensor containing attention context
126
        """
aiss's avatar
aiss committed
127
        assert query.dtype == torch.half, "sparse attention only supports training in fp16 currently, please file a github issue if you need fp32 support"
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        bsz, num_heads, tgt_len, head_dim = query.size()

        # transpose back key if it is already transposed
        key = self.transpose_key_for_scores(key, tgt_len)

        # check that operation is supported
        if query.shape != key.shape or key.shape != value.shape:
            raise NotImplementedError('only self-attention is supported for now')

        # squeeze key_padding_mask if it is given
        if key_padding_mask is not None:
            key_padding_mask = self.transpose_mask_for_sparse(query.dtype,
                                                              key_padding_mask,
                                                              is_key_padding_mask=True)

        # squeeze attn_mask if it is given
        if attn_mask is not None:
            attn_mask = self.transpose_mask_for_sparse(query.dtype, attn_mask)

        # cache look-up table computations etc
        sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax = self.get_ops(num_heads, tgt_len)

        scaling = float(head_dim)**-0.5

        # attention scores
        attn_output_weights = sparse_dot_sdd_nt(query, key)
        attn_output_weights = sparse_softmax(
            attn_output_weights,
            scale=scaling,
            rpe=rpe,
            key_padding_mask=key_padding_mask,
            attn_mask=attn_mask,
            key_padding_mask_mode=self.key_padding_mask_mode,
            attn_mask_mode=self.attn_mask_mode)

        # outputs
        attn_output = sparse_dot_dsd_nn(attn_output_weights, value)
        return attn_output