mha.py 37.6 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2023, Tri Dao.
2
3
4
5
6
7

import math
from functools import partial

import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
8
from einops import rearrange, repeat
9

10
11
from flash_attn.utils.distributed import get_dim_for_local_rank

12
try:
Tri Dao's avatar
Tri Dao committed
13
14
15
16
17
18
    from flash_attn import (
        flash_attn_kvpacked_func,
        flash_attn_qkvpacked_func,
        flash_attn_varlen_kvpacked_func,
        flash_attn_varlen_qkvpacked_func,
    )
19
except ImportError:
Tri Dao's avatar
Tri Dao committed
20
    flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
21
22
23
    flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None

try:
Tri Dao's avatar
Tri Dao committed
24
    from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
25
except ImportError:
Tri Dao's avatar
Tri Dao committed
26
    FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
27
28
29
30
31
32

try:
    from flash_attn.layers.rotary import RotaryEmbedding
except ImportError:
    RotaryEmbedding = None

33
34
35
36
37
try:
    import ft_attention
except ImportError:
    ft_attention = None

38
39
40
41
42
43
44
45
46
47
48

class FlashSelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
49

Tri Dao's avatar
Tri Dao committed
50
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
51
        super().__init__()
Tri Dao's avatar
Tri Dao committed
52
53
        assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
        assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
54
55
        self.causal = causal
        self.softmax_scale = softmax_scale
56
        self.drop = nn.Dropout(attention_dropout)
57

Tri Dao's avatar
Tri Dao committed
58
    def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
59
60
61
        """Implements the multihead softmax attention.
        Arguments
        ---------
Tri Dao's avatar
Tri Dao committed
62
63
64
65
            qkv: The tensor containing the query, key, and value.
                If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
                If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
                (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
Tri Dao's avatar
Tri Dao committed
66
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
67
68
69
70
71
72
73
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into qkv.
            max_seqlen: int. Maximum sequence length in the batch.
        Returns:
        --------
            out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
                else (B, S, H, D).
74
75
76
        """
        assert qkv.dtype in [torch.float16, torch.bfloat16]
        assert qkv.is_cuda
Tri Dao's avatar
Tri Dao committed
77
        causal = self.causal if causal is None else causal
Tri Dao's avatar
Tri Dao committed
78
79
80
81
82
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
Tri Dao's avatar
Tri Dao committed
83
            return flash_attn_varlen_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
84
85
86
87
88
89
                qkv,
                cu_seqlens,
                max_seqlen,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
90
            )
Tri Dao's avatar
Tri Dao committed
91
        else:
Tri Dao's avatar
Tri Dao committed
92
93
94
95
96
97
            return flash_attn_qkvpacked_func(
                qkv,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
            )
98
99
100
101
102
103
104
105
106
107
108
109


class FlashCrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
110

Tri Dao's avatar
Tri Dao committed
111
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
112
        super().__init__()
Tri Dao's avatar
Tri Dao committed
113
114
        assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
        assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
115
116
        self.causal = causal
        self.softmax_scale = softmax_scale
117
        self.drop = nn.Dropout(attention_dropout)
118

Tri Dao's avatar
Tri Dao committed
119
120
121
122
123
124
125
126
127
128
    def forward(
        self,
        q,
        kv,
        causal=None,
        cu_seqlens=None,
        max_seqlen=None,
        cu_seqlens_k=None,
        max_seqlen_k=None,
    ):
129
130
131
132
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
Tri Dao's avatar
Tri Dao committed
133
            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
Tri Dao's avatar
Tri Dao committed
134
            causal: if passed, will override self.causal
135
136
137
138
139
140
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into q.
            max_seqlen: int. Maximum sequence length in the batch of q.
            cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into kv.
            max_seqlen_k: int. Maximum sequence length in the batch of k and v.
141
142
143
        """
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.is_cuda and kv.is_cuda
Tri Dao's avatar
Tri Dao committed
144
        causal = self.causal if causal is None else causal
145
146
147
148
149
150
151
152
153
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
            assert cu_seqlens_k is not None
            assert cu_seqlens_k.dtype == torch.int32
            assert max_seqlen_k is not None
            assert isinstance(max_seqlen, int)
Tri Dao's avatar
Tri Dao committed
154
            return flash_attn_varlen_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
155
156
157
158
159
160
                q,
                kv,
                cu_seqlens,
                cu_seqlens_k,
                max_seqlen,
                max_seqlen_k,
161
                self.drop.p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
162
163
                softmax_scale=self.softmax_scale,
                causal=causal,
164
            )
165
166
167
        else:
            batch_size, seqlen_q = q.shape[0], q.shape[1]
            seqlen_k = kv.shape[1]
Tri Dao's avatar
Tri Dao committed
168
            assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
Tri Dao's avatar
Tri Dao committed
169
170
171
172
173
174
175
            return flash_attn_kvpacked_func(
                q,
                kv,
                self.drop.p if self.training else 0.0,
                causal=causal,
                softmax_scale=self.softmax_scale,
            )
176
177
178
179
180
181
182
183
184
185
186
187


class SelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
188

Tri Dao's avatar
Tri Dao committed
189
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
190
191
192
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
193
        self.drop = nn.Dropout(attention_dropout)
194

Tri Dao's avatar
Tri Dao committed
195
    def forward(self, qkv, causal=None, key_padding_mask=None):
196
197
198
199
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
Tri Dao's avatar
Tri Dao committed
200
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
201
202
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, S)
203
204
        """
        batch_size, seqlen = qkv.shape[0], qkv.shape[1]
Tri Dao's avatar
Tri Dao committed
205
        causal = self.causal if causal is None else causal
206
207
        q, k, v = qkv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
Tri Dao's avatar
Tri Dao committed
208
        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
Tri Dao's avatar
Tri Dao committed
209
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
210
211
212
            padding_mask = torch.full(
                (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
            )
Tri Dao's avatar
Tri Dao committed
213
214
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
Tri Dao's avatar
Tri Dao committed
215
            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
Tri Dao's avatar
Tri Dao committed
216
        if causal:
217
218
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
Tri Dao's avatar
Tri Dao committed
219
220
221
            causal_mask = torch.triu(
                torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
            )
222
223
224
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + causal_mask.to(dtype=scores.dtype)
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
225
        attention_drop = self.drop(attention)
Tri Dao's avatar
Tri Dao committed
226
        output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
227
228
229
230
231
232
233
234
235
236
237
238
239
        return output


class CrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
240

Tri Dao's avatar
Tri Dao committed
241
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
242
243
244
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
245
        self.drop = nn.Dropout(attention_dropout)
246

Tri Dao's avatar
Tri Dao committed
247
    def forward(self, q, kv, causal=None, key_padding_mask=None):
248
249
250
251
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
Tri Dao's avatar
Tri Dao committed
252
            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
Tri Dao's avatar
Tri Dao committed
253
            causal: if passed, will override self.causal
254
255
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, Sk)
256
257
        """
        batch_size, seqlen_q = q.shape[0], q.shape[1]
Tri Dao's avatar
Tri Dao committed
258
        causal = self.causal if causal is None else causal
259
        seqlen_k = kv.shape[1]
Tri Dao's avatar
Tri Dao committed
260
261
262
        assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
        if kv.shape[3] != q.shape[2]:  # MQA/GQA
            kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
263
264
        k, v = kv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
Tri Dao's avatar
Tri Dao committed
265
        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
266
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
267
268
269
            padding_mask = torch.full(
                (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
            )
270
271
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
Tri Dao's avatar
Tri Dao committed
272
            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
Tri Dao's avatar
Tri Dao committed
273
        if causal:
274
275
276
            # causal mask needs to take into account the difference between seqlen_q and seqlen_k
            row_idx = rearrange(
                torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
Tri Dao's avatar
Tri Dao committed
277
            )
278
279
280
281
282
283
284
285
            col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
            sk = (
                seqlen_k
                if key_padding_mask is None
                else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
            )
            causal_mask = col_idx > row_idx + sk - seqlen_q
            scores = scores.masked_fill(causal_mask, -10000.0)
286
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
287
        attention_drop = self.drop(attention)
Tri Dao's avatar
Tri Dao committed
288
        output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
289
290
291
292
        return output


class LinearResidual(nn.Linear):
Tri Dao's avatar
Tri Dao committed
293
    """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
294
295
296
297
298

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return super().forward(input), input


299
def _update_kv_cache(kv, inference_params, layer_idx):
Tri Dao's avatar
Tri Dao committed
300
    """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
301
302
303
304
    # Pre-allocate memory for key-values for inference.
    num_heads, head_dim = kv.shape[-2:]
    if layer_idx not in inference_params.key_value_memory_dict:
        kv_cache = torch.empty(
Tri Dao's avatar
Tri Dao committed
305
306
307
308
309
310
311
            inference_params.max_batch_size,
            inference_params.max_sequence_len,
            2,
            num_heads,
            head_dim,
            dtype=kv.dtype,
            device=kv.device,
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        )
        inference_params.key_value_memory_dict[layer_idx] = kv_cache
    else:
        if not inference_params.fused_ft_kernel:
            kv_cache = inference_params.key_value_memory_dict[layer_idx]
        else:
            # For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
            # where packsize = 4 if fp32, 8 if fp16 or bf16.
            # v_cache has shape (b, h, s, headdim)
            k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
            kv_cache = None
    # Adjust key and value for inference
    batch_start = inference_params.batch_size_offset
    batch_end = batch_start + kv.shape[0]
    sequence_start = inference_params.sequence_len_offset
    sequence_end = sequence_start + kv.shape[1]
    assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
    assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
    # Copy key and values.
    if not inference_params.fused_ft_kernel:
        assert kv_cache is not None
        kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
        kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
        return kv
    else:
        assert inference_params.sequence_len_offset == 0
        # FT kernel requires different layouts for the k_cache and v_cache.
        assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
        packsize = 4 if kv.dtype == torch.float32 else 8
        if kv_cache is not None:
            kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
Tri Dao's avatar
Tri Dao committed
343
344
345
346
            k_cache = rearrange(
                kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
            ).contiguous()
            v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous()
347
348
349
            inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
        else:
            k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
Tri Dao's avatar
Tri Dao committed
350
                kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
351
352
            )
            v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(
Tri Dao's avatar
Tri Dao committed
353
                kv[:, :, 1], "b s h d -> b h s d"
354
355
356
357
            )
        return kv


Tri Dao's avatar
Tri Dao committed
358
359
360
361
362
363
364
365
366
def _apply_rotary_single_query_attention(
    qkv,
    inference_params,
    layer_idx,
    rotary_emb_dim,
    rotary_emb_base,
    kv=None,
    rotary_emb_interleaved=False,
):
Tri Dao's avatar
Tri Dao committed
367
368
369
370
371
372
373
374
    """
    qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
            q of shape (batch_size, 1, nheads, head_dim)
    kv: (batch_size, 1, 2, nheads_kv, head_dim)
    """
    assert inference_params.fused_ft_kernel
    assert ft_attention is not None
    if kv is None:
Tri Dao's avatar
Tri Dao committed
375
        q, k, v = rearrange(qkv, "b 1 three h d -> b three h d").unbind(dim=1)
Tri Dao's avatar
Tri Dao committed
376
    else:
Tri Dao's avatar
Tri Dao committed
377
378
        q = rearrange(qkv, "b 1 h d -> b h d")
        k, v = rearrange(kv, "b 1 two h d -> b two h d").unbind(dim=1)
Tri Dao's avatar
Tri Dao committed
379
380
381
    batch_start = inference_params.batch_size_offset
    batch_end = batch_start + q.shape[0]
    k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
Tri Dao's avatar
Tri Dao committed
382
383
384
385
386
    lengths_per_sample = (
        inference_params.lengths_per_sample[batch_start:batch_end]
        if inference_params.lengths_per_sample is not None
        else None
    )
Tri Dao's avatar
Tri Dao committed
387
    context = ft_attention.single_query_attention(
Tri Dao's avatar
Tri Dao committed
388
389
390
        q,
        k,
        v,
Tri Dao's avatar
Tri Dao committed
391
392
393
394
395
396
397
        k_cache[batch_start:batch_end],
        v_cache[batch_start:batch_end],
        lengths_per_sample,
        None,  # rotary_cos_
        None,  # rotary_sin_
        None,  # nnz_head_idx
        inference_params.sequence_len_offset,
Tri Dao's avatar
Tri Dao committed
398
399
400
        rotary_emb_dim,
        rotary_emb_base,
        not rotary_emb_interleaved,  # neox_rotary_style
Tri Dao's avatar
Tri Dao committed
401
    )
Tri Dao's avatar
Tri Dao committed
402
    return rearrange(context, "b h d -> b 1 h d")
Tri Dao's avatar
Tri Dao committed
403
404


405
class MHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    """Multi-head self-attention and cross-attention"""

    def __init__(
        self,
        embed_dim,
        num_heads,
        num_heads_kv=None,
        cross_attn=False,
        qkv_proj_bias=True,
        out_proj_bias=True,
        dropout=0.0,
        softmax_scale=None,
        causal=False,
        layer_idx=None,
        dwconv=False,
        rotary_emb_dim=0,
        rotary_emb_base=10000.0,
        rotary_emb_scale_base=None,
        rotary_emb_interleaved=False,
        fused_bias_fc=False,
        use_flash_attn=False,
        return_residual=False,
        checkpointing=False,
        device=None,
        dtype=None,
    ) -> None:
432
        """
Tri Dao's avatar
Tri Dao committed
433
434
435
436
        num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
        return_residual: whether to return the input x along with the output. This is for
            performance reason: for post-norm architecture, returning the input allows us
            to fuse the backward of nn.Linear with the residual connection.
437
        """
Tri Dao's avatar
Tri Dao committed
438
        factory_kwargs = {"device": device, "dtype": dtype}
439
440
441
442
        super().__init__()
        self.embed_dim = embed_dim
        self.cross_attn = cross_attn
        self.causal = causal
Tri Dao's avatar
Tri Dao committed
443
        self.layer_idx = layer_idx
444
445
        self.dwconv = dwconv
        self.rotary_emb_dim = rotary_emb_dim
Tri Dao's avatar
Tri Dao committed
446
        self.use_flash_attn = use_flash_attn
447
448
449
450
        self.return_residual = return_residual
        self.checkpointing = checkpointing

        self.num_heads = num_heads
Tri Dao's avatar
Tri Dao committed
451
        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
Tri Dao's avatar
Tri Dao committed
452
453
454
        assert (
            self.num_heads % self.num_heads_kv == 0
        ), "num_heads must be divisible by num_heads_kv"
Tri Dao's avatar
Tri Dao committed
455
        assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
456
        self.head_dim = self.embed_dim // num_heads
Tri Dao's avatar
Tri Dao committed
457
458
        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
        kv_dim = 2 * self.head_dim * self.num_heads_kv
459
460

        if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
461
462
463
464
465
466
467
468
469
            assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
            assert RotaryEmbedding is not None, "rotary_emb is not installed"
            self.rotary_emb = RotaryEmbedding(
                self.rotary_emb_dim,
                base=rotary_emb_base,
                scale_base=rotary_emb_scale_base,
                interleaved=rotary_emb_interleaved,
                device=device,
            )
470

Tri Dao's avatar
Tri Dao committed
471
        if fused_bias_fc and FusedDense is None:
Tri Dao's avatar
Tri Dao committed
472
            raise ImportError("fused_dense is not installed")
Tri Dao's avatar
Tri Dao committed
473
        linear_cls = nn.Linear if not fused_bias_fc else FusedDense
Tri Dao's avatar
Tri Dao committed
474
475
476
        linear_resid_cls = (
            LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
        )
Tri Dao's avatar
Tri Dao committed
477
        wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
Tri Dao's avatar
Tri Dao committed
478
479
        inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
        inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
480
        if not self.cross_attn:
Tri Dao's avatar
Tri Dao committed
481
            self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
482
        else:
Tri Dao's avatar
Tri Dao committed
483
            self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
484
485
486
            self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
        if self.dwconv:
            if self.num_heads_kv == self.num_heads:
Tri Dao's avatar
Tri Dao committed
487
488
489
                self.dwconv_qkv = nn.Conv1d(
                    qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
                )
490
            else:
Tri Dao's avatar
Tri Dao committed
491
492
493
494
495
496
497
498
499
500
                self.dwconv_q = nn.Conv1d(
                    embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
                )
                self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
        self.inner_attn = inner_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.inner_cross_attn = inner_cross_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
Tri Dao's avatar
Tri Dao committed
501
        self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
502

503
504
505
506
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
        dtype = self.out_proj.weight.dtype if dtype is None else dtype
        device = self.out_proj.weight.device
        if not fused_ft_kernel:
Tri Dao's avatar
Tri Dao committed
507
508
509
510
511
512
513
514
515
            return torch.empty(
                batch_size,
                max_seqlen,
                2,
                self.num_heads_kv,
                self.head_dim,
                dtype=dtype,
                device=device,
            )
516
517
518
519
        else:
            assert dtype in [torch.float16, torch.bfloat16, torch.float32]
            packsize = 4 if dtype == torch.float32 else 8
            assert self.head_dim % packsize == 0
Tri Dao's avatar
Tri Dao committed
520
521
522
523
524
525
526
527
528
529
530
531
            k_cache = torch.empty(
                batch_size,
                self.num_heads_kv,
                self.head_dim // packsize,
                max_seqlen,
                packsize,
                dtype=dtype,
                device=device,
            )
            v_cache = torch.empty(
                batch_size, self.num_heads_kv, max_seqlen, self.head_dim, dtype=dtype, device=device
            )
532
533
            return k_cache, v_cache

Tri Dao's avatar
Tri Dao committed
534
    def _update_kv_cache(self, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
535
536
537
        """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
        assert not self.dwconv, "Generation does not support dwconv yet"
        assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
538
        return _update_kv_cache(kv, inference_params, self.layer_idx)
Tri Dao's avatar
Tri Dao committed
539

Tri Dao's avatar
Tri Dao committed
540
541
542
543
544
545
546
547
    def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
        """
        qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
              q of shape (batch_size, 1, nheads, head_dim)
        kv: (batch_size, 1, 2, nheads_kv, head_dim)
        """
        rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
        return _apply_rotary_single_query_attention(
Tri Dao's avatar
Tri Dao committed
548
549
550
551
552
553
554
555
556
            qkv,
            inference_params,
            self.layer_idx,
            self.rotary_emb_dim,
            rotary_emb_base,
            kv=kv,
            rotary_emb_interleaved=self.rotary_emb.interleaved
            if self.rotary_emb_dim > 0
            else False,
Tri Dao's avatar
Tri Dao committed
557
558
        )

Tri Dao's avatar
Tri Dao committed
559
560
561
562
563
564
565
566
567
568
569
    def forward(
        self,
        x,
        x_kv=None,
        key_padding_mask=None,
        cu_seqlens=None,
        max_seqlen=None,
        mixer_subset=None,
        inference_params=None,
        **kwargs,
    ):
570
571
        """
        Arguments:
Tri Dao's avatar
Tri Dao committed
572
573
574
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
                cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
                is the is the sum of the sequence lengths in the batch.
575
            x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
Tri Dao's avatar
Tri Dao committed
576
577
578
579
580
581
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into x. Only applicable when using
                FlashAttention.
            max_seqlen: int. Maximum sequence length in the batch.
            key_padding_mask: boolean mask, True means to keep, False means to mask out.
                (batch, seqlen). Only applicable when not using FlashAttention.
582
583
584
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
Tri Dao's avatar
Tri Dao committed
585
586
            inference_params: for generation. Adapted from Megatron-LM (and Apex)
            https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
587
        """
Tri Dao's avatar
Tri Dao committed
588
589
590
591
592
593
594
595
596
597
        if cu_seqlens is not None:
            assert max_seqlen is not None
            assert key_padding_mask is None
            assert self.use_flash_attn
            assert not self.dwconv
            assert self.rotary_emb_dim == 0
        if key_padding_mask is not None:
            assert cu_seqlens is None
            assert max_seqlen is None
            assert not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
598
599
600
601
        if inference_params is not None:
            assert key_padding_mask is None
            assert cu_seqlens is None and max_seqlen is None
            assert not self.dwconv
Tri Dao's avatar
Tri Dao committed
602

Tri Dao's avatar
Tri Dao committed
603
604
605
606
607
        kwargs = (
            {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
            if self.use_flash_attn
            else {"key_padding_mask": key_padding_mask, **kwargs}
        )
Tri Dao's avatar
Tri Dao committed
608
609
        seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
        if not self.cross_attn and self.num_heads_kv == self.num_heads:
610
            assert x_kv is None and mixer_subset is None
611
612
613
614
615
            if not self.return_residual:
                qkv = self.Wqkv(x)
            else:
                qkv, x = self.Wqkv(x)
            if self.dwconv:
Tri Dao's avatar
Tri Dao committed
616
617
618
619
620
621
622
623
624
                qkv = rearrange(
                    self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
            qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
            if (
                inference_params is None
                or inference_params.sequence_len_offset == 0
                or not inference_params.fused_ft_kernel
            ):
Tri Dao's avatar
Tri Dao committed
625
                if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
626
627
628
629
630
                    qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_attn(qkv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
631
                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
Tri Dao's avatar
Tri Dao committed
632
                else:
633
634
                    q = qkv[:, :, 0]
                    kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
635
                    context = self.inner_cross_attn(q, kv)
Tri Dao's avatar
Tri Dao committed
636
637
            else:
                context = self._apply_rotary_single_query_attention(qkv, inference_params)
638
        else:
Tri Dao's avatar
Tri Dao committed
639
640
641
642
643
644
645
646
647
648
            if self.cross_attn:
                if not self.return_residual:
                    q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
                    kv = self.Wkv(x_kv if x_kv is not None else x)
                else:
                    if x_kv is not None:
                        kv, x_kv = self.Wkv(x_kv)
                    else:
                        kv, x = self.Wkv(x)
                    q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
649
            else:
Tri Dao's avatar
Tri Dao committed
650
651
652
                assert self.num_heads_kv != self.num_heads
                if not self.return_residual:
                    qkv = self.Wqkv(x)
653
                else:
Tri Dao's avatar
Tri Dao committed
654
                    qkv, x = self.Wqkv(x)
Tri Dao's avatar
Tri Dao committed
655
656
657
658
                q = qkv[..., : self.num_heads * self.head_dim]
                kv = qkv[..., self.num_heads * self.head_dim :]
            q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
            kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
659
            if self.dwconv:
Tri Dao's avatar
Tri Dao committed
660
661
662
663
664
665
666
667
668
669
670
                q = rearrange(
                    self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
                kv = rearrange(
                    self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
            if (
                inference_params is None
                or inference_params.sequence_len_offset == 0
                or not inference_params.fused_ft_kernel
            ):
Tri Dao's avatar
Tri Dao committed
671
672
673
674
675
676
                if self.rotary_emb_dim > 0:
                    q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_cross_attn(q, kv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
677
678
679
                        context = torch.utils.checkpoint.checkpoint(
                            self.inner_cross_attn, q, kv, **kwargs
                        )
Tri Dao's avatar
Tri Dao committed
680
                else:
Tri Dao's avatar
Tri Dao committed
681
                    kv = self._update_kv_cache(kv, inference_params)
682
                    context = self.inner_cross_attn(q, kv)
683
            else:
Tri Dao's avatar
Tri Dao committed
684
                context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
Tri Dao's avatar
Tri Dao committed
685
        out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
686
        return out if not self.return_residual else (out, x)
Tri Dao's avatar
Tri Dao committed
687
688
689


class ParallelMHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
    """Multi-head self-attention and cross-attention"""

    def __init__(
        self,
        embed_dim,
        num_heads,
        process_group,
        num_heads_kv=None,
        qkv_proj_bias=True,
        out_proj_bias=True,
        dropout=0.0,
        softmax_scale=None,
        causal=False,
        layer_idx=None,
        rotary_emb_dim=0,
        rotary_emb_base=10000.0,
        rotary_emb_scale_base=None,
        rotary_emb_interleaved=False,
        use_flash_attn=False,
        checkpointing=False,
        sequence_parallel=True,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
Tri Dao's avatar
Tri Dao committed
715
716
717
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal
718
        self.layer_idx = layer_idx
Tri Dao's avatar
Tri Dao committed
719
720
721
        self.rotary_emb_dim = rotary_emb_dim
        self.use_flash_attn = use_flash_attn
        self.checkpointing = checkpointing
Tri Dao's avatar
Tri Dao committed
722
        self.process_group = process_group
723
724
        self.world_size = process_group.size()
        self.local_rank = torch.distributed.get_rank(process_group)
Tri Dao's avatar
Tri Dao committed
725
726

        self.num_heads = num_heads
727
728
        assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"

Tri Dao's avatar
Tri Dao committed
729
        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
Tri Dao's avatar
Tri Dao committed
730
731
732
        assert (
            self.num_heads % self.num_heads_kv == 0
        ), "num_heads must be divisible by num_heads_kv"
733

Tri Dao's avatar
Tri Dao committed
734
735
736
737
738
739
        self.num_heads_per_rank = get_dim_for_local_rank(
            self.num_heads, self.world_size, self.local_rank
        )
        self.num_heads_kv_per_rank = get_dim_for_local_rank(
            self.num_heads, self.world_size, self.local_rank
        )
Tri Dao's avatar
Tri Dao committed
740
        self.head_dim = self.embed_dim // num_heads
Tri Dao's avatar
Tri Dao committed
741
        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
Tri Dao's avatar
Tri Dao committed
742
743

        if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
744
745
746
747
748
749
750
751
            assert RotaryEmbedding is not None, "rotary_emb is not installed"
            self.rotary_emb = RotaryEmbedding(
                self.rotary_emb_dim,
                base=rotary_emb_base,
                scale_base=rotary_emb_scale_base,
                interleaved=rotary_emb_interleaved,
                device=device,
            )
Tri Dao's avatar
Tri Dao committed
752
753

        if ColumnParallelLinear is None or RowParallelLinear is None:
Tri Dao's avatar
Tri Dao committed
754
755
756
757
758
759
760
            raise ImportError("fused_dense is not installed")
        self.Wqkv = ColumnParallelLinear(
            embed_dim,
            qkv_dim,
            process_group,
            bias=qkv_proj_bias,
            sequence_parallel=sequence_parallel,
761
            multiple_of=self.head_dim * 3,
Tri Dao's avatar
Tri Dao committed
762
763
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
764
        inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
765
        inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
Tri Dao's avatar
Tri Dao committed
766
767
768
769
770
771
772
773
774
775
776
777
        self.inner_attn = inner_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.inner_cross_attn = inner_cross_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            process_group,
            bias=out_proj_bias,
            sequence_parallel=sequence_parallel,
778
            multiple_of=self.head_dim,
Tri Dao's avatar
Tri Dao committed
779
780
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
781

Tri Dao's avatar
Tri Dao committed
782
783
784
785
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
        dtype = self.out_proj.weight.dtype if dtype is None else dtype
        device = self.out_proj.weight.device
        if not fused_ft_kernel:
Tri Dao's avatar
Tri Dao committed
786
787
788
789
790
791
792
793
794
            return torch.empty(
                batch_size,
                max_seqlen,
                2,
                self.num_heads_kv_per_rank,
                self.head_dim,
                dtype=dtype,
                device=device,
            )
Tri Dao's avatar
Tri Dao committed
795
796
797
798
        else:
            assert dtype in [torch.float16, torch.bfloat16, torch.float32]
            packsize = 4 if dtype == torch.float32 else 8
            assert self.head_dim % packsize == 0
Tri Dao's avatar
Tri Dao committed
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
            k_cache = torch.empty(
                batch_size,
                self.num_heads_kv_per_rank,
                self.head_dim // packsize,
                max_seqlen,
                packsize,
                dtype=dtype,
                device=device,
            )
            v_cache = torch.empty(
                batch_size,
                self.num_heads_kv_per_rank,
                max_seqlen,
                self.head_dim,
                dtype=dtype,
                device=device,
            )
Tri Dao's avatar
Tri Dao committed
816
817
818
            return k_cache, v_cache

    def _update_kv_cache(self, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
819
820
        """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
        assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
Tri Dao's avatar
Tri Dao committed
821
822
823
824
825
826
827
828
829
830
        return _update_kv_cache(kv, inference_params, self.layer_idx)

    def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
        """
        qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
              q of shape (batch_size, 1, nheads, head_dim)
        kv: (batch_size, 1, 2, nheads_kv, head_dim)
        """
        rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
        return _apply_rotary_single_query_attention(
Tri Dao's avatar
Tri Dao committed
831
832
833
834
835
836
837
838
839
            qkv,
            inference_params,
            self.layer_idx,
            self.rotary_emb_dim,
            rotary_emb_base,
            kv=kv,
            rotary_emb_interleaved=self.rotary_emb.interleaved
            if self.rotary_emb_dim > 0
            else False,
Tri Dao's avatar
Tri Dao committed
840
841
        )

842
    def forward(self, x, seqlen=None, inference_params=None, **kwargs):
Tri Dao's avatar
Tri Dao committed
843
844
845
846
847
848
849
850
        """
        Arguments:
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
                If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
                split x during sequence parallel, we split the batch * seqlen dimension
                (in case batch is small).
        """
        qkv = self.Wqkv(x)
Tri Dao's avatar
Tri Dao committed
851
852
853
854
        if seqlen is not None:
            qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
        seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
        if self.num_heads_kv == self.num_heads:
Tri Dao's avatar
Tri Dao committed
855
856
857
858
859
860
            qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
            if (
                inference_params is None
                or inference_params.sequence_len_offset == 0
                or not inference_params.fused_ft_kernel
            ):
Tri Dao's avatar
Tri Dao committed
861
862
863
864
865
866
867
868
869
870
                if self.rotary_emb_dim > 0:
                    qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_attn(qkv, **kwargs)
                    else:
                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
                else:
                    q = qkv[:, :, 0]
                    kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
871
                    context = self.inner_cross_attn(q, kv)
872
            else:
Tri Dao's avatar
Tri Dao committed
873
                context = self._apply_rotary_single_query_attention(qkv, inference_params)
Tri Dao's avatar
Tri Dao committed
874
        else:
Tri Dao's avatar
Tri Dao committed
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
            q = rearrange(
                qkv[..., : self.num_heads_per_rank * self.head_dim],
                "... (h d) -> ... h d",
                d=self.head_dim,
            )
            kv = rearrange(
                qkv[..., self.num_heads_per_rank * self.head_dim :],
                "... (two hkv d) -> ... two hkv d",
                two=2,
                d=self.head_dim,
            )
            if (
                inference_params is None
                or inference_params.sequence_len_offset == 0
                or not inference_params.fused_ft_kernel
            ):
891
                if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
892
893
894
895
896
                    q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_cross_attn(q, kv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
897
898
899
                        context = torch.utils.checkpoint.checkpoint(
                            self.inner_cross_attn, q, kv, **kwargs
                        )
Tri Dao's avatar
Tri Dao committed
900
901
                else:
                    kv = self._update_kv_cache(kv, inference_params)
902
                    context = self.inner_cross_attn(q, kv)
903
            else:
Tri Dao's avatar
Tri Dao committed
904
                context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
Tri Dao's avatar
Tri Dao committed
905
        context = rearrange(context, "b s h d -> b s (h d)")
Tri Dao's avatar
Tri Dao committed
906
        if seqlen is not None:
Tri Dao's avatar
Tri Dao committed
907
            context = rearrange(context, "b s d -> (b s) d")
Tri Dao's avatar
Tri Dao committed
908
909
        out = self.out_proj(context)
        return out