mha.py 27.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright (c) 2022, Tri Dao.

import math
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange

try:
    from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
    from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
except ImportError:
    flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func = None, None

try:
    from flash_attn.ops.flash_attn_triton import flash_attn_qkvpacked_func, flash_attn_kvpacked_func
except ImportError:
    flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None

try:
Tri Dao's avatar
Tri Dao committed
24
    from flash_attn.ops.fused_dense import FusedDense, ColumnParallelLinear, RowParallelLinear
25
except ImportError:
Tri Dao's avatar
Tri Dao committed
26
    FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
27
28
29
30
31
32

try:
    from flash_attn.layers.rotary import RotaryEmbedding
except ImportError:
    RotaryEmbedding = None

33
34
35
36
37
try:
    import ft_attention
except ImportError:
    ft_attention = None

38
39
40
41
42
43
44
45
46
47
48
49

class FlashSelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
Tri Dao's avatar
Tri Dao committed
50
                 triton=False):
51
52
53
54
55
56
57
58
59
60
        super().__init__()
        if attention_dropout != 0.0 or not triton:
            assert flash_attn_unpadded_qkvpacked_func is not None, 'FlashAttention is not installed'
        if attention_dropout == 0.0 and triton:
            assert flash_attn_qkvpacked_func is not None, 'FlashAttention Triton is not installed'
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout
        self.triton = triton

Tri Dao's avatar
Tri Dao committed
61
    def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
62
63
64
        """Implements the multihead softmax attention.
        Arguments
        ---------
Tri Dao's avatar
Tri Dao committed
65
66
67
68
            qkv: The tensor containing the query, key, and value.
                If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
                If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
                (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
Tri Dao's avatar
Tri Dao committed
69
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
70
71
72
73
74
75
76
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into qkv.
            max_seqlen: int. Maximum sequence length in the batch.
        Returns:
        --------
            out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
                else (B, S, H, D).
77
78
79
        """
        assert qkv.dtype in [torch.float16, torch.bfloat16]
        assert qkv.is_cuda
Tri Dao's avatar
Tri Dao committed
80
        causal = self.causal if causal is None else causal
Tri Dao's avatar
Tri Dao committed
81
82
83
84
85
86
87
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
            return flash_attn_unpadded_qkvpacked_func(
                qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
88
                softmax_scale=self.softmax_scale, causal=causal
89
            )
Tri Dao's avatar
Tri Dao committed
90
91
92
93
        else:
            batch_size, seqlen = qkv.shape[0], qkv.shape[1]
            # Triton version doesn't support dropout
            if self.triton and (self.dropout_p == 0 or not self.training):
Tri Dao's avatar
Tri Dao committed
94
                output = flash_attn_qkvpacked_func(qkv, None, causal, self.softmax_scale)
Tri Dao's avatar
Tri Dao committed
95
96
97
98
99
100
101
            else:
                qkv = rearrange(qkv, 'b s ... -> (b s) ...')
                max_seqlen = seqlen
                cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
                                        device=qkv.device)
                output = flash_attn_unpadded_qkvpacked_func(
                    qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
102
                    softmax_scale=self.softmax_scale, causal=causal
Tri Dao's avatar
Tri Dao committed
103
104
105
                )
                output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
            return output
106
107
108
109
110
111
112
113
114
115
116
117
118


class FlashCrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
Tri Dao's avatar
Tri Dao committed
119
                 triton=False):
120
121
122
123
124
125
126
127
128
129
        super().__init__()
        if attention_dropout != 0.0 or not triton:
            assert flash_attn_unpadded_kvpacked_func is not None, 'FlashAttention is not installed'
        if attention_dropout == 0.0 and triton:
            assert flash_attn_kvpacked_func is not None, 'FlashAttention Triton is not installed'
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout
        self.triton = triton

Tri Dao's avatar
Tri Dao committed
130
131
    def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None,
                cu_seqlens_k=None, max_seqlen_k=None):
132
133
134
135
136
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
            kv: The tensor containing the key and value. (B, Sk, 2, H, D)
Tri Dao's avatar
Tri Dao committed
137
            causal: if passed, will override self.causal
138
139
140
141
142
143
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into q.
            max_seqlen: int. Maximum sequence length in the batch of q.
            cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into kv.
            max_seqlen_k: int. Maximum sequence length in the batch of k and v.
144
145
146
        """
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.is_cuda and kv.is_cuda
Tri Dao's avatar
Tri Dao committed
147
        causal = self.causal if causal is None else causal
148
149
150
151
152
153
154
155
156
157
158
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
            assert cu_seqlens_k is not None
            assert cu_seqlens_k.dtype == torch.int32
            assert max_seqlen_k is not None
            assert isinstance(max_seqlen, int)
            return flash_attn_unpadded_kvpacked_func(
                q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k,
159
                self.dropout_p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
160
                softmax_scale=self.softmax_scale, causal=causal
161
            )
162
163
164
165
166
        else:
            batch_size, seqlen_q = q.shape[0], q.shape[1]
            seqlen_k = kv.shape[1]
            assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
            if self.triton and (self.dropout_p == 0.0 or not self.training):  # Triton version doesn't support dropout
Tri Dao's avatar
Tri Dao committed
167
                output = flash_attn_kvpacked_func(q, kv, None, causal, self.softmax_scale)
168
169
170
171
172
173
174
175
176
177
            else:
                q = rearrange(q, 'b s ... -> (b s) ...')
                kv = rearrange(kv, 'b s ... -> (b s) ...')
                cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q,
                                            dtype=torch.int32, device=q.device)
                cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k,
                                            dtype=torch.int32, device=kv.device)
                output = flash_attn_unpadded_kvpacked_func(
                    q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
                    self.dropout_p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
178
                    softmax_scale=self.softmax_scale, causal=causal
179
180
181
                )
                output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
            return output
182
183
184
185
186
187
188
189
190
191
192
193


class SelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
194
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
195
196
197
198
199
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout

Tri Dao's avatar
Tri Dao committed
200
    def forward(self, qkv, causal=None, key_padding_mask=None):
201
202
203
204
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
Tri Dao's avatar
Tri Dao committed
205
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
206
207
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, S)
208
209
        """
        batch_size, seqlen = qkv.shape[0], qkv.shape[1]
Tri Dao's avatar
Tri Dao committed
210
        causal = self.causal if causal is None else causal
211
212
213
        q, k, v = qkv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
        scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
Tri Dao's avatar
Tri Dao committed
214
215
216
217
218
219
        if key_padding_mask is not None:
            padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype,
                                      device=scores.device)
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
Tri Dao's avatar
Tri Dao committed
220
        if causal:
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
            causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + causal_mask.to(dtype=scores.dtype)
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
        attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0)
        output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
        return output


class CrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
242
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
243
244
245
246
247
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout

Tri Dao's avatar
Tri Dao committed
248
    def forward(self, q, kv, causal=None, key_padding_mask=None):
249
250
251
252
253
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
            kv: The tensor containing the key and value. (B, Sk, 2, H, D)
Tri Dao's avatar
Tri Dao committed
254
            causal: if passed, will override self.causal
255
256
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, Sk)
257
258
        """
        batch_size, seqlen_q = q.shape[0], q.shape[1]
Tri Dao's avatar
Tri Dao committed
259
        causal = self.causal if causal is None else causal
260
261
262
263
264
        seqlen_k = kv.shape[1]
        assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
        k, v = kv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
        scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
265
266
267
268
269
270
        if key_padding_mask is not None:
            padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype,
                                      device=scores.device)
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
Tri Dao's avatar
Tri Dao committed
271
        if causal:
272
273
274
275
276
277
278
279
280
281
282
283
284
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
            causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0,
                                                device=scores.device), 1)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + causal_mask.to(dtype=scores.dtype)
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
        attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0)
        output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
        return output


class LinearResidual(nn.Linear):
Tri Dao's avatar
Tri Dao committed
285
    """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.
286
287
288
289
290
291
292
293
294
295
296
    """

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return super().forward(input), input


class MHA(nn.Module):
    """Multi-head self-attention and cross-attention
    """

    def __init__(self, embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.0,
Tri Dao's avatar
Tri Dao committed
297
                 softmax_scale=None, causal=False, layer_idx=None, dwconv=False, rotary_emb_dim=0,
Tri Dao's avatar
Tri Dao committed
298
                 rotary_emb_scale_base=0,
299
300
301
302
303
304
305
306
307
308
309
310
                 fused_bias_fc=False, use_flash_attn=False, return_residual=False,
                 checkpointing=False, device=None, dtype=None) -> None:
        """
            return_residual: whether to return the input x along with the output. This is for
                performance reason: for post-norm architecture, returning the input allows us
                to fuse the backward of nn.Linear with the residual connection.
        """
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.embed_dim = embed_dim
        self.cross_attn = cross_attn
        self.causal = causal
Tri Dao's avatar
Tri Dao committed
311
        self.layer_idx = layer_idx
312
313
        self.dwconv = dwconv
        self.rotary_emb_dim = rotary_emb_dim
Tri Dao's avatar
Tri Dao committed
314
        self.use_flash_attn = use_flash_attn
315
316
317
318
319
320
321
322
323
324
        self.return_residual = return_residual
        self.checkpointing = checkpointing

        self.num_heads = num_heads
        assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
        self.head_dim = self.embed_dim // num_heads

        if self.rotary_emb_dim > 0:
            assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet'
            assert RotaryEmbedding is not None, 'rotary_emb is not installed'
Tri Dao's avatar
Tri Dao committed
325
326
            self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base,
                                              device=device)
327

Tri Dao's avatar
Tri Dao committed
328
        if fused_bias_fc and FusedDense is None:
329
            raise ImportError('fused_dense is not installed')
Tri Dao's avatar
Tri Dao committed
330
331
332
        linear_cls = nn.Linear if not fused_bias_fc else FusedDense
        linear_resid_cls = (LinearResidual if not fused_bias_fc
                            else partial(FusedDense, return_residual=True))
Tri Dao's avatar
Tri Dao committed
333
334
        inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
        inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
335
336
337
338
339
340
341
342
343
344
        if not self.cross_attn:
            if not self.return_residual:
                self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
            else:
                self.Wqkv = linear_resid_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
            if self.dwconv:
                self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
                                            groups=3 * embed_dim)
        else:
            self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
345
346
347
348
            if not self.return_residual:
                self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
            else:
                self.Wkv = linear_resid_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
349
350
351
352
353
354
            if self.dwconv:
                self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2,
                                        groups=embed_dim)
                self.dwconv_kv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, kernel_size=3, padding=2,
                                        groups=2 * embed_dim)
        self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
Tri Dao's avatar
Tri Dao committed
355
                                         attention_dropout=dropout)
Tri Dao's avatar
Tri Dao committed
356
357
        self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
                                                     attention_dropout=dropout)
358
359
360
        # output projection always have the bias (for now)
        self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)

Tri Dao's avatar
Tri Dao committed
361
362
363
364
365
366
367
    def _update_kv_cache(self, kv, inference_params):
        """kv: (batch_size, 1, nheads, head_dim)
        """
        assert not self.dwconv, 'Generation does not support dwconv yet'
        assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
        # Pre-allocate memory for key-values for inference.
        if self.layer_idx not in inference_params.key_value_memory_dict:
368
            kv_cache = torch.empty(
Tri Dao's avatar
Tri Dao committed
369
370
371
                inference_params.max_batch_size, inference_params.max_sequence_len, 2,
                self.num_heads, self.head_dim, dtype=kv.dtype, device=kv.device
            )
372
            inference_params.key_value_memory_dict[self.layer_idx] = kv_cache
Tri Dao's avatar
Tri Dao committed
373
        else:
374
375
            assert not inference_params.fused_ft_kernel, 'fused_ft_kernel should not take this path'
            kv_cache = inference_params.key_value_memory_dict[self.layer_idx]
Tri Dao's avatar
Tri Dao committed
376
377
378
        # Adjust key and value for inference
        batch_start = inference_params.batch_size_offset
        batch_end = batch_start + kv.shape[0]
379
        assert batch_end <= kv_cache.shape[0]
Tri Dao's avatar
Tri Dao committed
380
381
        sequence_start = inference_params.sequence_len_offset
        sequence_end = sequence_start + kv.shape[1]
382
        assert sequence_end <= kv_cache.shape[1]
Tri Dao's avatar
Tri Dao committed
383
        # Copy key and values.
384
385
386
387
388
389
390
391
392
393
        kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
        kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
        if inference_params.fused_ft_kernel:
            # FT kernel requires different layouts for the k_cache and v_cache.
            assert kv_cache.dtype in [torch.float16, torch.bfloat16, torch.float32]
            packsize = 4 if kv_cache.dtype == torch.float32 else 8
            k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize',
                                packsize=packsize).contiguous()
            v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous()
            inference_params.key_value_memory_dict[self.layer_idx] = (k_cache, v_cache)
Tri Dao's avatar
Tri Dao committed
394
395
        return kv

396
    def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
Tri Dao's avatar
Tri Dao committed
397
                inference_params=None, **kwargs):
398
399
        """
        Arguments:
Tri Dao's avatar
Tri Dao committed
400
401
402
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
                cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
                is the is the sum of the sequence lengths in the batch.
403
            x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
Tri Dao's avatar
Tri Dao committed
404
405
406
407
408
409
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into x. Only applicable when using
                FlashAttention.
            max_seqlen: int. Maximum sequence length in the batch.
            key_padding_mask: boolean mask, True means to keep, False means to mask out.
                (batch, seqlen). Only applicable when not using FlashAttention.
Tri Dao's avatar
Tri Dao committed
410
411
            inference_params: for generation. Adapted from Megatron-LM (and Apex)
            https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
412
        """
Tri Dao's avatar
Tri Dao committed
413
414
415
416
417
418
419
420
421
422
        if cu_seqlens is not None:
            assert max_seqlen is not None
            assert key_padding_mask is None
            assert self.use_flash_attn
            assert not self.dwconv
            assert self.rotary_emb_dim == 0
        if key_padding_mask is not None:
            assert cu_seqlens is None
            assert max_seqlen is None
            assert not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
423
424
425
426
        if inference_params is not None:
            assert key_padding_mask is None
            assert cu_seqlens is None and max_seqlen is None
            assert not self.dwconv
Tri Dao's avatar
Tri Dao committed
427

428
429
        kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs}
                  if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs})
430
431
432
433
434
435
436
437
        if not self.cross_attn:
            if not self.return_residual:
                qkv = self.Wqkv(x)
            else:
                qkv, x = self.Wqkv(x)
            if self.dwconv:
                qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2],
                                'b d s -> b s d').contiguous()
Tri Dao's avatar
Tri Dao committed
438
            qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim)
Tri Dao's avatar
Tri Dao committed
439
440
441
442
443
444
445
            if inference_params is None:
                if self.rotary_emb_dim > 0:
                    qkv = self.rotary_emb(qkv)
                if not self.checkpointing:
                    context = self.inner_attn(qkv, **kwargs)
                else:
                    context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
446
            else:
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
                if (not inference_params.fused_ft_kernel) or inference_params.sequence_len_offset == 0:
                    if self.rotary_emb_dim > 0:
                        qkv = self.rotary_emb(qkv, seqlen_offset=inference_params.sequence_len_offset)
                    q = qkv[:, :, 0]
                    kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
                    # If we're processing the prompt, causal=None (use self.causal).
                    # If we're decoding, then causal=False.
                    causal = None if inference_params.sequence_len_offset == 0 else False
                    context = self.inner_cross_attn(q, kv, causal=causal)
                else:
                    assert ft_attention is not None
                    context = ft_attention.single_query_attention(
                        *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
                        *inference_params.key_value_memory_dict[self.layer_idx],
                        inference_params.lengths_per_sample, inference_params.sequence_len_offset,
                        self.rotary_emb_dim
                    )
                    context = rearrange(context, 'b h d -> b 1 h d')
465
        else:
466
467
468
469
470
471
472
473
474
            if not self.return_residual:
                q = self.Wq(x)
                kv = self.Wkv(x_kv if x_kv is not None else x)
            else:
                if x_kv is not None:
                    kv, x_kv = self.Wkv(x_kv)
                else:
                    kv, x = self.Wkv(x)
                q = self.Wq(x)
Tri Dao's avatar
Tri Dao committed
475
476
            q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
            kv = rearrange(kv, '... (two h d) -> ... two h d', two=2, d=self.head_dim)
477
478
479
480
481
            if self.dwconv:
                q = rearrange(self.dwconv_q(rearrange(q, 'b s d -> b d s'))[..., :-2],
                              'b d s -> b s d').contiguous()
                kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2],
                               'b d s -> b s d').contiguous()
Tri Dao's avatar
Tri Dao committed
482
483
484
485
486
            if inference_params is None:
                if not self.checkpointing:
                    context = self.inner_attn(q, kv, **kwargs)
                else:
                    context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv, **kwargs)
487
            else:
Tri Dao's avatar
Tri Dao committed
488
489
                kv = self._update_kv_cache(kv)
                context = self.inner_cross_attn(q, kv, causal=False)
Tri Dao's avatar
Tri Dao committed
490
        out = self.out_proj(rearrange(context, '... h d -> ... (h d)'))
491
        return out if not self.return_residual else (out, x)
Tri Dao's avatar
Tri Dao committed
492
493
494
495
496
497
498


class ParallelMHA(nn.Module):
    """Multi-head self-attention and cross-attention
    """

    def __init__(self, embed_dim, num_heads, process_group, bias=True, dropout=0.0,
499
500
                 softmax_scale=None, causal=False, layer_idx=None, rotary_emb_dim=0,
                 rotary_emb_scale_base=0,
Tri Dao's avatar
Tri Dao committed
501
502
503
504
505
506
                 use_flash_attn=False, checkpointing=False, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.process_group = process_group
        self.embed_dim = embed_dim
        self.causal = causal
507
        self.layer_idx = layer_idx
Tri Dao's avatar
Tri Dao committed
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
        self.rotary_emb_dim = rotary_emb_dim
        self.use_flash_attn = use_flash_attn
        self.checkpointing = checkpointing

        self.num_heads = num_heads
        assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
        self.head_dim = self.embed_dim // num_heads

        if self.rotary_emb_dim > 0:
            assert RotaryEmbedding is not None, 'rotary_emb is not installed'
            self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base,
                                              device=device)

        if ColumnParallelLinear is None or RowParallelLinear is None:
            raise ImportError('fused_dense is not installed')
        self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group, bias=bias,
                                         **factory_kwargs)
        inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
        self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
                                         attention_dropout=dropout)
        # output projection always have the bias (for now)
        self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group, **factory_kwargs)

    def forward(self, x, seqlen=None, **kwargs):
        """
        Arguments:
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
                If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
                split x during sequence parallel, we split the batch * seqlen dimension
                (in case batch is small).
        """
        qkv = self.Wqkv(x)
        if seqlen is None:
            qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, d=self.head_dim)
        else:
            qkv = rearrange(qkv, '(b s) (three h d) -> b s three h d', s=seqlen, three=3,
                            d=self.head_dim)
        if self.rotary_emb_dim > 0:
            qkv = self.rotary_emb(qkv)
        if not self.checkpointing:
            context = self.inner_attn(qkv, **kwargs)
        else:
            context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
        if seqlen is None:
            context = rearrange(context, 'b s h d -> b s (h d)')
        else:
            context = rearrange(context, 'b s h d -> (b s) (h d)')
        out = self.out_proj(context)
        return out