flash_blocksparse_attention.py 7.29 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
import math
Tri Dao's avatar
Tri Dao committed
2
3

import hydra
Tri Dao's avatar
Tri Dao committed
4
5
6
7
import torch
import torch.nn as nn
from einops import rearrange

Tri Dao's avatar
Tri Dao committed
8
9
10
11
12
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
from flash_attn.flash_blocksparse_attn_interface import (
    convert_blockmask,
    flash_blocksparse_attn_func,
)
13

Tri Dao's avatar
Tri Dao committed
14

Tri Dao's avatar
Tri Dao committed
15
class FlashBlocksparseAttention(nn.Module):
Tri Dao's avatar
Tri Dao committed
16
17
18
19
20
21
22
23
24
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_temp: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.1)
    """
Tri Dao's avatar
Tri Dao committed
25
26
27
28
29
30
31
32
33
34

    def __init__(
        self,
        sparsity_config,
        softmax_temp=None,
        attention_dropout=0.0,
        max_seq_length=2048,
        device=None,
        dtype=None,
    ):
Tri Dao's avatar
Tri Dao committed
35
36
37
38
39
40
41
42
43
44
45
46
47
        super().__init__()
        self.sparsity_config = hydra.utils.instantiate(sparsity_config)
        self.softmax_temp = softmax_temp
        self.dropout_p = attention_dropout

        # initialize sparse layout and register as buffer
        max_seq_length = ((max_seq_length + 256 - 1) // 256) * 256
        layout = self.sparsity_config.make_layout(max_seq_length)
        self.register_buffer("layout", layout)
        blockmask_converted = convert_blockmask(self.layout, causal=False)
        self.register_buffer("blockmask_converted", blockmask_converted)
        # logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}')

Tri Dao's avatar
Tri Dao committed
48
49
50
51
52
53
54
55
56
57
58
    def forward(
        self,
        qkv,
        attn_mask=None,
        key_padding_mask=None,
        causal=False,
        cu_seqlens=None,
        max_s=None,
        need_weights=False,
        convert_mask=True,
    ):
Tri Dao's avatar
Tri Dao committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
            attn_mask: An implementation of BaseMask that encodes where each
                       query can attend to
            key_padding_mask: An implementation of BaseMask that encodes how
                         many query each sequence in the batch consists of
        """
        assert not need_weights
        assert attn_mask is None
        assert qkv.dtype == torch.float16
        assert qkv.is_cuda

        if cu_seqlens is None:
            batch_size = qkv.shape[0]
            seqlen = qkv.shape[1]
            # Convert mask to take a subset
            seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
Tri Dao's avatar
Tri Dao committed
78
79
80
81
            assert seqlen_rounded // 16 <= self.layout.shape[0], (
                seqlen_rounded // 256 <= self.layout.shape[1]
            )
            blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256]
Tri Dao's avatar
Tri Dao committed
82
            if key_padding_mask is None:
Tri Dao's avatar
Tri Dao committed
83
                qkv = rearrange(qkv, "b s ... -> (b s) ...")
Tri Dao's avatar
Tri Dao committed
84
                max_s = seqlen
Tri Dao's avatar
Tri Dao committed
85
86
87
                cu_seqlens = torch.arange(
                    0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device
                )
Tri Dao's avatar
Tri Dao committed
88
                output = flash_blocksparse_attn_func(
Tri Dao's avatar
Tri Dao committed
89
90
91
92
93
94
95
                    qkv,
                    cu_seqlens,
                    blockmask,
                    self.dropout_p if self.training else 0.0,
                    max_s,
                    softmax_scale=self.softmax_temp,
                    causal=causal,
Tri Dao's avatar
Tri Dao committed
96
                )
Tri Dao's avatar
Tri Dao committed
97
                output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
Tri Dao's avatar
Tri Dao committed
98
99
100
            else:
                key_padding_mask_bool = key_padding_mask.bool_matrix
                nheads = qkv.shape[-2]
Tri Dao's avatar
Tri Dao committed
101
                x = rearrange(qkv, "b s three h d -> b s (three h d)")
Tri Dao's avatar
Tri Dao committed
102
                x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
Tri Dao's avatar
Tri Dao committed
103
                x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
Tri Dao's avatar
Tri Dao committed
104
                output_unpad = flash_blocksparse_attn_func(
Tri Dao's avatar
Tri Dao committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
                    x_unpad,
                    cu_seqlens,
                    blockmask,
                    self.dropout_p if self.training else 0.0,
                    max_s,
                    softmax_scale=self.softmax_temp,
                    causal=causal,
                )
                output = rearrange(
                    pad_input(
                        rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen
                    ),
                    "b s (h d) -> b s h d",
                    h=nheads,
Tri Dao's avatar
Tri Dao committed
119
120
121
122
123
124
                )
        else:
            assert max_s is not None
            seqlen = max_s
            # Convert mask to take a subset
            seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
Tri Dao's avatar
Tri Dao committed
125
126
127
128
            assert seqlen_rounded // 16 <= self.layout.shape[0], (
                seqlen_rounded // 256 <= self.layout.shape[1]
            )
            blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256]
Tri Dao's avatar
Tri Dao committed
129
            if convert_mask:
Tri Dao's avatar
Tri Dao committed
130
                output = flash_blocksparse_attn_func(
Tri Dao's avatar
Tri Dao committed
131
132
133
134
135
136
137
                    qkv,
                    cu_seqlens,
                    blockmask,
                    self.dropout_p if self.training else 0.0,
                    max_s,
                    softmax_scale=self.softmax_temp,
                    causal=causal,
Tri Dao's avatar
Tri Dao committed
138
139
                )
            else:
Tri Dao's avatar
Tri Dao committed
140
                output = flash_blocksparse_attn_func(
Tri Dao's avatar
Tri Dao committed
141
142
143
144
145
146
147
                    qkv,
                    cu_seqlens,
                    self.blockmask_converted,
                    self.dropout_p if self.training else 0.0,
                    max_s,
                    softmax_scale=self.softmax_temp,
                    causal=causal,
Tri Dao's avatar
Tri Dao committed
148
149
150
151
152
153
                    convert_mask=False,
                )

        return output, None


Tri Dao's avatar
Tri Dao committed
154
class FlashBlocksparseMHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    def __init__(
        self,
        embed_dim,
        num_heads,
        sparsity_config,
        bias=True,
        batch_first=True,
        attention_dropout=0.0,
        causal=False,
        max_seq_length=2048,
        device=None,
        dtype=None,
        **kwargs,
    ) -> None:
Tri Dao's avatar
Tri Dao committed
169
        assert batch_first
Tri Dao's avatar
Tri Dao committed
170
        factory_kwargs = {"device": device, "dtype": dtype}
Tri Dao's avatar
Tri Dao committed
171
172
173
174
175
176
177
178
179
180
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal

        self.num_heads = num_heads
        assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
        self.head_dim = self.embed_dim // num_heads
        assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64"

        self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
181
        self.inner_attn = FlashBlocksparseAttention(
Tri Dao's avatar
Tri Dao committed
182
183
184
185
            sparsity_config,
            attention_dropout=attention_dropout,
            max_seq_length=max_seq_length,
            **factory_kwargs,
Tri Dao's avatar
Tri Dao committed
186
187
188
        )
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)

Tri Dao's avatar
Tri Dao committed
189
190
191
    def forward(
        self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None, need_weights=False
    ):
Tri Dao's avatar
Tri Dao committed
192
        qkv = self.Wqkv(x)
Tri Dao's avatar
Tri Dao committed
193
194
195
196
197
        qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
        context, attn_weights = self.inner_attn(
            qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
        )
        return self.out_proj(rearrange(context, "b s h d -> b s (h d)")), attn_weights