mha.py 38.3 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2023, Tri Dao.
2
3
4
5
6
7
8

import math
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
Tri Dao's avatar
Tri Dao committed
9
from einops import rearrange, repeat
10
11

try:
Tri Dao's avatar
Tri Dao committed
12
13
14
15
16
17
    from flash_attn import (
        flash_attn_kvpacked_func,
        flash_attn_qkvpacked_func,
        flash_attn_varlen_kvpacked_func,
        flash_attn_varlen_qkvpacked_func,
    )
18
except ImportError:
Tri Dao's avatar
Tri Dao committed
19
    flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
20
21
22
    flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None

try:
Tri Dao's avatar
Tri Dao committed
23
    from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
24
except ImportError:
Tri Dao's avatar
Tri Dao committed
25
    FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
26
27
28
29
30
31

try:
    from flash_attn.layers.rotary import RotaryEmbedding
except ImportError:
    RotaryEmbedding = None

32
33
34
35
36
try:
    import ft_attention
except ImportError:
    ft_attention = None

37
38
39
40
41
42
43
44
45
46
47

class FlashSelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
48

Tri Dao's avatar
Tri Dao committed
49
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
50
        super().__init__()
Tri Dao's avatar
Tri Dao committed
51
52
        assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
        assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
53
54
        self.causal = causal
        self.softmax_scale = softmax_scale
55
        self.drop = nn.Dropout(attention_dropout)
56

Tri Dao's avatar
Tri Dao committed
57
    def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
58
59
60
        """Implements the multihead softmax attention.
        Arguments
        ---------
Tri Dao's avatar
Tri Dao committed
61
62
63
64
            qkv: The tensor containing the query, key, and value.
                If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
                If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
                (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
Tri Dao's avatar
Tri Dao committed
65
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
66
67
68
69
70
71
72
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into qkv.
            max_seqlen: int. Maximum sequence length in the batch.
        Returns:
        --------
            out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
                else (B, S, H, D).
73
74
75
        """
        assert qkv.dtype in [torch.float16, torch.bfloat16]
        assert qkv.is_cuda
Tri Dao's avatar
Tri Dao committed
76
        causal = self.causal if causal is None else causal
Tri Dao's avatar
Tri Dao committed
77
78
79
80
81
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
Tri Dao's avatar
Tri Dao committed
82
            return flash_attn_varlen_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
83
84
85
86
87
88
                qkv,
                cu_seqlens,
                max_seqlen,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
89
            )
Tri Dao's avatar
Tri Dao committed
90
        else:
Tri Dao's avatar
Tri Dao committed
91
92
93
94
95
96
            return flash_attn_qkvpacked_func(
                qkv,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
            )
97
98
99
100
101
102
103
104
105
106
107
108


class FlashCrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
109

Tri Dao's avatar
Tri Dao committed
110
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
111
        super().__init__()
Tri Dao's avatar
Tri Dao committed
112
113
        assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
        assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
114
115
        self.causal = causal
        self.softmax_scale = softmax_scale
116
        self.drop = nn.Dropout(attention_dropout)
117

Tri Dao's avatar
Tri Dao committed
118
119
120
121
122
123
124
125
126
127
    def forward(
        self,
        q,
        kv,
        causal=None,
        cu_seqlens=None,
        max_seqlen=None,
        cu_seqlens_k=None,
        max_seqlen_k=None,
    ):
128
129
130
131
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
Tri Dao's avatar
Tri Dao committed
132
            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
Tri Dao's avatar
Tri Dao committed
133
            causal: if passed, will override self.causal
134
135
136
137
138
139
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into q.
            max_seqlen: int. Maximum sequence length in the batch of q.
            cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into kv.
            max_seqlen_k: int. Maximum sequence length in the batch of k and v.
140
141
142
        """
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.is_cuda and kv.is_cuda
Tri Dao's avatar
Tri Dao committed
143
        causal = self.causal if causal is None else causal
144
145
146
147
148
149
150
151
152
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
            assert cu_seqlens_k is not None
            assert cu_seqlens_k.dtype == torch.int32
            assert max_seqlen_k is not None
            assert isinstance(max_seqlen, int)
Tri Dao's avatar
Tri Dao committed
153
            return flash_attn_varlen_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
154
155
156
157
158
159
                q,
                kv,
                cu_seqlens,
                cu_seqlens_k,
                max_seqlen,
                max_seqlen_k,
160
                self.drop.p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
161
162
                softmax_scale=self.softmax_scale,
                causal=causal,
163
            )
164
165
166
        else:
            batch_size, seqlen_q = q.shape[0], q.shape[1]
            seqlen_k = kv.shape[1]
Tri Dao's avatar
Tri Dao committed
167
            assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
Tri Dao's avatar
Tri Dao committed
168
169
170
171
172
173
174
            return flash_attn_kvpacked_func(
                q,
                kv,
                self.drop.p if self.training else 0.0,
                causal=causal,
                softmax_scale=self.softmax_scale,
            )
175
176
177
178
179
180
181
182
183
184
185
186


class SelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
187

Tri Dao's avatar
Tri Dao committed
188
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
189
190
191
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
192
        self.drop = nn.Dropout(attention_dropout)
193

Tri Dao's avatar
Tri Dao committed
194
    def forward(self, qkv, causal=None, key_padding_mask=None):
195
196
197
198
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
Tri Dao's avatar
Tri Dao committed
199
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
200
201
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, S)
202
203
        """
        batch_size, seqlen = qkv.shape[0], qkv.shape[1]
Tri Dao's avatar
Tri Dao committed
204
        causal = self.causal if causal is None else causal
205
206
        q, k, v = qkv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
Tri Dao's avatar
Tri Dao committed
207
        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
Tri Dao's avatar
Tri Dao committed
208
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
209
210
211
            padding_mask = torch.full(
                (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
            )
Tri Dao's avatar
Tri Dao committed
212
213
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
Tri Dao's avatar
Tri Dao committed
214
            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
Tri Dao's avatar
Tri Dao committed
215
        if causal:
216
217
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
Tri Dao's avatar
Tri Dao committed
218
219
220
            causal_mask = torch.triu(
                torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
            )
221
222
223
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + causal_mask.to(dtype=scores.dtype)
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
224
        attention_drop = self.drop(attention)
Tri Dao's avatar
Tri Dao committed
225
        output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
226
227
228
229
230
231
232
233
234
235
236
237
238
        return output


class CrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
239

Tri Dao's avatar
Tri Dao committed
240
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
241
242
243
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
244
        self.drop = nn.Dropout(attention_dropout)
245

Tri Dao's avatar
Tri Dao committed
246
    def forward(self, q, kv, causal=None, key_padding_mask=None):
247
248
249
250
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
Tri Dao's avatar
Tri Dao committed
251
            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
Tri Dao's avatar
Tri Dao committed
252
            causal: if passed, will override self.causal
253
254
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, Sk)
255
256
        """
        batch_size, seqlen_q = q.shape[0], q.shape[1]
Tri Dao's avatar
Tri Dao committed
257
        causal = self.causal if causal is None else causal
258
        seqlen_k = kv.shape[1]
Tri Dao's avatar
Tri Dao committed
259
260
261
        assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
        if kv.shape[3] != q.shape[2]:  # MQA/GQA
            kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
262
263
        k, v = kv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
Tri Dao's avatar
Tri Dao committed
264
        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
265
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
266
267
268
            padding_mask = torch.full(
                (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
            )
269
270
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
Tri Dao's avatar
Tri Dao committed
271
            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
Tri Dao's avatar
Tri Dao committed
272
        if causal:
273
274
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
Tri Dao's avatar
Tri Dao committed
275
276
277
            causal_mask = torch.triu(
                torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1
            )
278
279
280
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + causal_mask.to(dtype=scores.dtype)
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
281
        attention_drop = self.drop(attention)
Tri Dao's avatar
Tri Dao committed
282
        output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
283
284
285
286
        return output


class LinearResidual(nn.Linear):
Tri Dao's avatar
Tri Dao committed
287
    """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
288
289
290
291
292

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return super().forward(input), input


293
def _update_kv_cache(kv, inference_params, layer_idx):
Tri Dao's avatar
Tri Dao committed
294
    """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
295
296
297
298
    # Pre-allocate memory for key-values for inference.
    num_heads, head_dim = kv.shape[-2:]
    if layer_idx not in inference_params.key_value_memory_dict:
        kv_cache = torch.empty(
Tri Dao's avatar
Tri Dao committed
299
300
301
302
303
304
305
            inference_params.max_batch_size,
            inference_params.max_sequence_len,
            2,
            num_heads,
            head_dim,
            dtype=kv.dtype,
            device=kv.device,
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        )
        inference_params.key_value_memory_dict[layer_idx] = kv_cache
    else:
        if not inference_params.fused_ft_kernel:
            kv_cache = inference_params.key_value_memory_dict[layer_idx]
        else:
            # For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
            # where packsize = 4 if fp32, 8 if fp16 or bf16.
            # v_cache has shape (b, h, s, headdim)
            k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
            kv_cache = None
    # Adjust key and value for inference
    batch_start = inference_params.batch_size_offset
    batch_end = batch_start + kv.shape[0]
    sequence_start = inference_params.sequence_len_offset
    sequence_end = sequence_start + kv.shape[1]
    assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
    assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
    # Copy key and values.
    if not inference_params.fused_ft_kernel:
        assert kv_cache is not None
        kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
        kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
        return kv
    else:
        assert inference_params.sequence_len_offset == 0
        # FT kernel requires different layouts for the k_cache and v_cache.
        assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
        packsize = 4 if kv.dtype == torch.float32 else 8
        if kv_cache is not None:
            kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
Tri Dao's avatar
Tri Dao committed
337
338
339
340
            k_cache = rearrange(
                kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
            ).contiguous()
            v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous()
341
342
343
            inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
        else:
            k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
Tri Dao's avatar
Tri Dao committed
344
                kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
345
346
            )
            v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(
Tri Dao's avatar
Tri Dao committed
347
                kv[:, :, 1], "b s h d -> b h s d"
348
349
350
351
            )
        return kv


Tri Dao's avatar
Tri Dao committed
352
353
354
355
356
357
358
359
360
def _apply_rotary_single_query_attention(
    qkv,
    inference_params,
    layer_idx,
    rotary_emb_dim,
    rotary_emb_base,
    kv=None,
    rotary_emb_interleaved=False,
):
Tri Dao's avatar
Tri Dao committed
361
362
363
364
365
366
367
368
    """
    qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
            q of shape (batch_size, 1, nheads, head_dim)
    kv: (batch_size, 1, 2, nheads_kv, head_dim)
    """
    assert inference_params.fused_ft_kernel
    assert ft_attention is not None
    if kv is None:
Tri Dao's avatar
Tri Dao committed
369
        q, k, v = rearrange(qkv, "b 1 three h d -> b three h d").unbind(dim=1)
Tri Dao's avatar
Tri Dao committed
370
    else:
Tri Dao's avatar
Tri Dao committed
371
372
        q = rearrange(qkv, "b 1 h d -> b h d")
        k, v = rearrange(kv, "b 1 two h d -> b two h d").unbind(dim=1)
Tri Dao's avatar
Tri Dao committed
373
374
375
    batch_start = inference_params.batch_size_offset
    batch_end = batch_start + q.shape[0]
    k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
Tri Dao's avatar
Tri Dao committed
376
377
378
379
380
    lengths_per_sample = (
        inference_params.lengths_per_sample[batch_start:batch_end]
        if inference_params.lengths_per_sample is not None
        else None
    )
Tri Dao's avatar
Tri Dao committed
381
    context = ft_attention.single_query_attention(
Tri Dao's avatar
Tri Dao committed
382
383
384
        q,
        k,
        v,
Tri Dao's avatar
Tri Dao committed
385
386
387
388
389
390
391
        k_cache[batch_start:batch_end],
        v_cache[batch_start:batch_end],
        lengths_per_sample,
        None,  # rotary_cos_
        None,  # rotary_sin_
        None,  # nnz_head_idx
        inference_params.sequence_len_offset,
Tri Dao's avatar
Tri Dao committed
392
393
394
        rotary_emb_dim,
        rotary_emb_base,
        not rotary_emb_interleaved,  # neox_rotary_style
Tri Dao's avatar
Tri Dao committed
395
    )
Tri Dao's avatar
Tri Dao committed
396
    return rearrange(context, "b h d -> b 1 h d")
Tri Dao's avatar
Tri Dao committed
397
398


399
class MHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
    """Multi-head self-attention and cross-attention"""

    def __init__(
        self,
        embed_dim,
        num_heads,
        num_heads_kv=None,
        cross_attn=False,
        qkv_proj_bias=True,
        out_proj_bias=True,
        dropout=0.0,
        softmax_scale=None,
        causal=False,
        layer_idx=None,
        dwconv=False,
        rotary_emb_dim=0,
        rotary_emb_base=10000.0,
        rotary_emb_scale_base=None,
        rotary_emb_interleaved=False,
        fused_bias_fc=False,
        use_flash_attn=False,
        return_residual=False,
        checkpointing=False,
        device=None,
        dtype=None,
    ) -> None:
426
        """
Tri Dao's avatar
Tri Dao committed
427
428
429
430
        num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
        return_residual: whether to return the input x along with the output. This is for
            performance reason: for post-norm architecture, returning the input allows us
            to fuse the backward of nn.Linear with the residual connection.
431
        """
Tri Dao's avatar
Tri Dao committed
432
        factory_kwargs = {"device": device, "dtype": dtype}
433
434
435
436
        super().__init__()
        self.embed_dim = embed_dim
        self.cross_attn = cross_attn
        self.causal = causal
Tri Dao's avatar
Tri Dao committed
437
        self.layer_idx = layer_idx
438
439
        self.dwconv = dwconv
        self.rotary_emb_dim = rotary_emb_dim
Tri Dao's avatar
Tri Dao committed
440
        self.use_flash_attn = use_flash_attn
441
442
443
444
        self.return_residual = return_residual
        self.checkpointing = checkpointing

        self.num_heads = num_heads
Tri Dao's avatar
Tri Dao committed
445
        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
Tri Dao's avatar
Tri Dao committed
446
447
448
        assert (
            self.num_heads % self.num_heads_kv == 0
        ), "num_heads must be divisible by num_heads_kv"
Tri Dao's avatar
Tri Dao committed
449
        assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
450
        self.head_dim = self.embed_dim // num_heads
Tri Dao's avatar
Tri Dao committed
451
452
        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
        kv_dim = 2 * self.head_dim * self.num_heads_kv
453
454

        if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
455
456
457
458
459
460
461
462
463
            assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
            assert RotaryEmbedding is not None, "rotary_emb is not installed"
            self.rotary_emb = RotaryEmbedding(
                self.rotary_emb_dim,
                base=rotary_emb_base,
                scale_base=rotary_emb_scale_base,
                interleaved=rotary_emb_interleaved,
                device=device,
            )
464

Tri Dao's avatar
Tri Dao committed
465
        if fused_bias_fc and FusedDense is None:
Tri Dao's avatar
Tri Dao committed
466
            raise ImportError("fused_dense is not installed")
Tri Dao's avatar
Tri Dao committed
467
        linear_cls = nn.Linear if not fused_bias_fc else FusedDense
Tri Dao's avatar
Tri Dao committed
468
469
470
        linear_resid_cls = (
            LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
        )
Tri Dao's avatar
Tri Dao committed
471
        wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
Tri Dao's avatar
Tri Dao committed
472
473
        inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
        inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
474
        if not self.cross_attn:
Tri Dao's avatar
Tri Dao committed
475
            self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
476
        else:
Tri Dao's avatar
Tri Dao committed
477
            self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
478
479
480
            self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
        if self.dwconv:
            if self.num_heads_kv == self.num_heads:
Tri Dao's avatar
Tri Dao committed
481
482
483
                self.dwconv_qkv = nn.Conv1d(
                    qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
                )
484
            else:
Tri Dao's avatar
Tri Dao committed
485
486
487
488
489
490
491
492
493
494
                self.dwconv_q = nn.Conv1d(
                    embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
                )
                self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
        self.inner_attn = inner_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.inner_cross_attn = inner_cross_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
Tri Dao's avatar
Tri Dao committed
495
        self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
496

497
498
499
500
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
        dtype = self.out_proj.weight.dtype if dtype is None else dtype
        device = self.out_proj.weight.device
        if not fused_ft_kernel:
Tri Dao's avatar
Tri Dao committed
501
502
503
504
505
506
507
508
509
            return torch.empty(
                batch_size,
                max_seqlen,
                2,
                self.num_heads_kv,
                self.head_dim,
                dtype=dtype,
                device=device,
            )
510
511
512
513
        else:
            assert dtype in [torch.float16, torch.bfloat16, torch.float32]
            packsize = 4 if dtype == torch.float32 else 8
            assert self.head_dim % packsize == 0
Tri Dao's avatar
Tri Dao committed
514
515
516
517
518
519
520
521
522
523
524
525
            k_cache = torch.empty(
                batch_size,
                self.num_heads_kv,
                self.head_dim // packsize,
                max_seqlen,
                packsize,
                dtype=dtype,
                device=device,
            )
            v_cache = torch.empty(
                batch_size, self.num_heads_kv, max_seqlen, self.head_dim, dtype=dtype, device=device
            )
526
527
            return k_cache, v_cache

Tri Dao's avatar
Tri Dao committed
528
    def _update_kv_cache(self, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
529
530
531
        """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
        assert not self.dwconv, "Generation does not support dwconv yet"
        assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
532
        return _update_kv_cache(kv, inference_params, self.layer_idx)
Tri Dao's avatar
Tri Dao committed
533

Tri Dao's avatar
Tri Dao committed
534
535
536
537
538
539
540
541
    def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
        """
        qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
              q of shape (batch_size, 1, nheads, head_dim)
        kv: (batch_size, 1, 2, nheads_kv, head_dim)
        """
        rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
        return _apply_rotary_single_query_attention(
Tri Dao's avatar
Tri Dao committed
542
543
544
545
546
547
548
549
550
            qkv,
            inference_params,
            self.layer_idx,
            self.rotary_emb_dim,
            rotary_emb_base,
            kv=kv,
            rotary_emb_interleaved=self.rotary_emb.interleaved
            if self.rotary_emb_dim > 0
            else False,
Tri Dao's avatar
Tri Dao committed
551
552
        )

Tri Dao's avatar
Tri Dao committed
553
554
555
556
557
558
559
560
561
562
563
    def forward(
        self,
        x,
        x_kv=None,
        key_padding_mask=None,
        cu_seqlens=None,
        max_seqlen=None,
        mixer_subset=None,
        inference_params=None,
        **kwargs,
    ):
564
565
        """
        Arguments:
Tri Dao's avatar
Tri Dao committed
566
567
568
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
                cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
                is the is the sum of the sequence lengths in the batch.
569
            x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
Tri Dao's avatar
Tri Dao committed
570
571
572
573
574
575
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into x. Only applicable when using
                FlashAttention.
            max_seqlen: int. Maximum sequence length in the batch.
            key_padding_mask: boolean mask, True means to keep, False means to mask out.
                (batch, seqlen). Only applicable when not using FlashAttention.
576
577
578
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
Tri Dao's avatar
Tri Dao committed
579
580
            inference_params: for generation. Adapted from Megatron-LM (and Apex)
            https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
581
        """
Tri Dao's avatar
Tri Dao committed
582
583
584
585
586
587
588
589
590
591
        if cu_seqlens is not None:
            assert max_seqlen is not None
            assert key_padding_mask is None
            assert self.use_flash_attn
            assert not self.dwconv
            assert self.rotary_emb_dim == 0
        if key_padding_mask is not None:
            assert cu_seqlens is None
            assert max_seqlen is None
            assert not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
592
593
594
595
        if inference_params is not None:
            assert key_padding_mask is None
            assert cu_seqlens is None and max_seqlen is None
            assert not self.dwconv
Tri Dao's avatar
Tri Dao committed
596

Tri Dao's avatar
Tri Dao committed
597
598
599
600
601
        kwargs = (
            {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
            if self.use_flash_attn
            else {"key_padding_mask": key_padding_mask, **kwargs}
        )
Tri Dao's avatar
Tri Dao committed
602
603
        seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
        if not self.cross_attn and self.num_heads_kv == self.num_heads:
604
            assert x_kv is None and mixer_subset is None
605
606
607
608
609
            if not self.return_residual:
                qkv = self.Wqkv(x)
            else:
                qkv, x = self.Wqkv(x)
            if self.dwconv:
Tri Dao's avatar
Tri Dao committed
610
611
612
613
614
615
616
617
618
                qkv = rearrange(
                    self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
            qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
            if (
                inference_params is None
                or inference_params.sequence_len_offset == 0
                or not inference_params.fused_ft_kernel
            ):
Tri Dao's avatar
Tri Dao committed
619
                if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
620
621
622
623
624
                    qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_attn(qkv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
625
                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
Tri Dao's avatar
Tri Dao committed
626
                else:
627
628
629
630
631
632
                    q = qkv[:, :, 0]
                    kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
                    # If we're processing the prompt, causal=None (use self.causal).
                    # If we're decoding, then causal=False.
                    causal = None if inference_params.sequence_len_offset == 0 else False
                    context = self.inner_cross_attn(q, kv, causal=causal)
Tri Dao's avatar
Tri Dao committed
633
634
            else:
                context = self._apply_rotary_single_query_attention(qkv, inference_params)
635
        else:
Tri Dao's avatar
Tri Dao committed
636
637
638
639
640
641
642
643
644
645
            if self.cross_attn:
                if not self.return_residual:
                    q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
                    kv = self.Wkv(x_kv if x_kv is not None else x)
                else:
                    if x_kv is not None:
                        kv, x_kv = self.Wkv(x_kv)
                    else:
                        kv, x = self.Wkv(x)
                    q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
646
            else:
Tri Dao's avatar
Tri Dao committed
647
648
649
                assert self.num_heads_kv != self.num_heads
                if not self.return_residual:
                    qkv = self.Wqkv(x)
650
                else:
Tri Dao's avatar
Tri Dao committed
651
                    qkv, x = self.Wqkv(x)
Tri Dao's avatar
Tri Dao committed
652
653
654
655
                q = qkv[..., : self.num_heads * self.head_dim]
                kv = qkv[..., self.num_heads * self.head_dim :]
            q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
            kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
656
            if self.dwconv:
Tri Dao's avatar
Tri Dao committed
657
658
659
660
661
662
663
664
665
666
667
                q = rearrange(
                    self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
                kv = rearrange(
                    self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
            if (
                inference_params is None
                or inference_params.sequence_len_offset == 0
                or not inference_params.fused_ft_kernel
            ):
Tri Dao's avatar
Tri Dao committed
668
669
670
671
672
673
                if self.rotary_emb_dim > 0:
                    q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_cross_attn(q, kv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
674
675
676
                        context = torch.utils.checkpoint.checkpoint(
                            self.inner_cross_attn, q, kv, **kwargs
                        )
Tri Dao's avatar
Tri Dao committed
677
                else:
Tri Dao's avatar
Tri Dao committed
678
679
680
681
682
                    kv = self._update_kv_cache(kv, inference_params)
                    # If we're processing the prompt, causal=None (use self.causal).
                    # If we're decoding, then causal=False.
                    causal = None if inference_params.sequence_len_offset == 0 else False
                    context = self.inner_cross_attn(q, kv, causal=causal)
683
            else:
Tri Dao's avatar
Tri Dao committed
684
                context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
Tri Dao's avatar
Tri Dao committed
685
        out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
686
        return out if not self.return_residual else (out, x)
Tri Dao's avatar
Tri Dao committed
687
688
689


class ParallelMHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
    """Multi-head self-attention and cross-attention"""

    def __init__(
        self,
        embed_dim,
        num_heads,
        process_group,
        num_heads_kv=None,
        qkv_proj_bias=True,
        out_proj_bias=True,
        dropout=0.0,
        softmax_scale=None,
        causal=False,
        layer_idx=None,
        rotary_emb_dim=0,
        rotary_emb_base=10000.0,
        rotary_emb_scale_base=None,
        rotary_emb_interleaved=False,
        use_flash_attn=False,
        checkpointing=False,
        sequence_parallel=True,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
Tri Dao's avatar
Tri Dao committed
715
716
717
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal
718
        self.layer_idx = layer_idx
Tri Dao's avatar
Tri Dao committed
719
720
721
        self.rotary_emb_dim = rotary_emb_dim
        self.use_flash_attn = use_flash_attn
        self.checkpointing = checkpointing
Tri Dao's avatar
Tri Dao committed
722
723
        self.process_group = process_group
        self.world_size = process_group.size() if process_group is not None else 1
Tri Dao's avatar
Tri Dao committed
724
725

        self.num_heads = num_heads
Tri Dao's avatar
Tri Dao committed
726
727
728
        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
        self.num_heads_per_rank = num_heads // self.world_size
        self.num_heads_kv_per_rank = self.num_heads_kv // self.world_size
Tri Dao's avatar
Tri Dao committed
729
730
731
        assert (
            self.num_heads % self.num_heads_kv == 0
        ), "num_heads must be divisible by num_heads_kv"
Tri Dao's avatar
Tri Dao committed
732
        assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
Tri Dao's avatar
Tri Dao committed
733
734
735
        assert (
            self.num_heads_kv % self.world_size == 0
        ), "num_heads_kv must be divisible by world_size"
Tri Dao's avatar
Tri Dao committed
736
        self.head_dim = self.embed_dim // num_heads
Tri Dao's avatar
Tri Dao committed
737
738
        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
        kv_dim = 2 * self.head_dim * self.num_heads_kv
Tri Dao's avatar
Tri Dao committed
739
740

        if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
741
742
743
744
745
746
747
748
            assert RotaryEmbedding is not None, "rotary_emb is not installed"
            self.rotary_emb = RotaryEmbedding(
                self.rotary_emb_dim,
                base=rotary_emb_base,
                scale_base=rotary_emb_scale_base,
                interleaved=rotary_emb_interleaved,
                device=device,
            )
Tri Dao's avatar
Tri Dao committed
749
750

        if ColumnParallelLinear is None or RowParallelLinear is None:
Tri Dao's avatar
Tri Dao committed
751
752
753
754
755
756
757
758
759
            raise ImportError("fused_dense is not installed")
        self.Wqkv = ColumnParallelLinear(
            embed_dim,
            qkv_dim,
            process_group,
            bias=qkv_proj_bias,
            sequence_parallel=sequence_parallel,
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
760
        inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
761
        inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
Tri Dao's avatar
Tri Dao committed
762
763
764
765
766
767
768
769
770
771
772
773
774
775
        self.inner_attn = inner_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.inner_cross_attn = inner_cross_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            process_group,
            bias=out_proj_bias,
            sequence_parallel=sequence_parallel,
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
776

Tri Dao's avatar
Tri Dao committed
777
778
779
780
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
        dtype = self.out_proj.weight.dtype if dtype is None else dtype
        device = self.out_proj.weight.device
        if not fused_ft_kernel:
Tri Dao's avatar
Tri Dao committed
781
782
783
784
785
786
787
788
789
            return torch.empty(
                batch_size,
                max_seqlen,
                2,
                self.num_heads_kv_per_rank,
                self.head_dim,
                dtype=dtype,
                device=device,
            )
Tri Dao's avatar
Tri Dao committed
790
791
792
793
        else:
            assert dtype in [torch.float16, torch.bfloat16, torch.float32]
            packsize = 4 if dtype == torch.float32 else 8
            assert self.head_dim % packsize == 0
Tri Dao's avatar
Tri Dao committed
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
            k_cache = torch.empty(
                batch_size,
                self.num_heads_kv_per_rank,
                self.head_dim // packsize,
                max_seqlen,
                packsize,
                dtype=dtype,
                device=device,
            )
            v_cache = torch.empty(
                batch_size,
                self.num_heads_kv_per_rank,
                max_seqlen,
                self.head_dim,
                dtype=dtype,
                device=device,
            )
Tri Dao's avatar
Tri Dao committed
811
812
813
            return k_cache, v_cache

    def _update_kv_cache(self, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
814
815
        """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
        assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
Tri Dao's avatar
Tri Dao committed
816
817
818
819
820
821
822
823
824
825
        return _update_kv_cache(kv, inference_params, self.layer_idx)

    def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
        """
        qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
              q of shape (batch_size, 1, nheads, head_dim)
        kv: (batch_size, 1, 2, nheads_kv, head_dim)
        """
        rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
        return _apply_rotary_single_query_attention(
Tri Dao's avatar
Tri Dao committed
826
827
828
829
830
831
832
833
834
            qkv,
            inference_params,
            self.layer_idx,
            self.rotary_emb_dim,
            rotary_emb_base,
            kv=kv,
            rotary_emb_interleaved=self.rotary_emb.interleaved
            if self.rotary_emb_dim > 0
            else False,
Tri Dao's avatar
Tri Dao committed
835
836
        )

837
    def forward(self, x, seqlen=None, inference_params=None, **kwargs):
Tri Dao's avatar
Tri Dao committed
838
839
840
841
842
843
844
845
        """
        Arguments:
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
                If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
                split x during sequence parallel, we split the batch * seqlen dimension
                (in case batch is small).
        """
        qkv = self.Wqkv(x)
Tri Dao's avatar
Tri Dao committed
846
847
848
849
        if seqlen is not None:
            qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
        seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
        if self.num_heads_kv == self.num_heads:
Tri Dao's avatar
Tri Dao committed
850
851
852
853
854
855
            qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
            if (
                inference_params is None
                or inference_params.sequence_len_offset == 0
                or not inference_params.fused_ft_kernel
            ):
Tri Dao's avatar
Tri Dao committed
856
857
858
859
860
861
862
863
864
865
866
867
868
869
                if self.rotary_emb_dim > 0:
                    qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_attn(qkv, **kwargs)
                    else:
                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
                else:
                    q = qkv[:, :, 0]
                    kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
                    # If we're processing the prompt, causal=None (use self.causal).
                    # If we're decoding, then causal=False.
                    causal = None if inference_params.sequence_len_offset == 0 else False
                    context = self.inner_cross_attn(q, kv, causal=causal)
870
            else:
Tri Dao's avatar
Tri Dao committed
871
                context = self._apply_rotary_single_query_attention(qkv, inference_params)
Tri Dao's avatar
Tri Dao committed
872
        else:
Tri Dao's avatar
Tri Dao committed
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
            q = rearrange(
                qkv[..., : self.num_heads_per_rank * self.head_dim],
                "... (h d) -> ... h d",
                d=self.head_dim,
            )
            kv = rearrange(
                qkv[..., self.num_heads_per_rank * self.head_dim :],
                "... (two hkv d) -> ... two hkv d",
                two=2,
                d=self.head_dim,
            )
            if (
                inference_params is None
                or inference_params.sequence_len_offset == 0
                or not inference_params.fused_ft_kernel
            ):
889
                if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
890
891
892
893
894
                    q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_cross_attn(q, kv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
895
896
897
                        context = torch.utils.checkpoint.checkpoint(
                            self.inner_cross_attn, q, kv, **kwargs
                        )
Tri Dao's avatar
Tri Dao committed
898
899
900
901
902
903
                else:
                    kv = self._update_kv_cache(kv, inference_params)
                    # If we're processing the prompt, causal=None (use self.causal).
                    # If we're decoding, then causal=False.
                    causal = None if inference_params.sequence_len_offset == 0 else False
                    context = self.inner_cross_attn(q, kv, causal=causal)
904
            else:
Tri Dao's avatar
Tri Dao committed
905
                context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
Tri Dao's avatar
Tri Dao committed
906
        context = rearrange(context, "b s h d -> b s (h d)")
Tri Dao's avatar
Tri Dao committed
907
        if seqlen is not None:
Tri Dao's avatar
Tri Dao committed
908
            context = rearrange(context, "b s d -> (b s) d")
Tri Dao's avatar
Tri Dao committed
909
910
        out = self.out_proj(context)
        return out