mha.py 38.4 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2023, Tri Dao.
2
3
4
5
6
7

import math
from functools import partial

import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
8
from einops import rearrange, repeat
9

10
11
from flash_attn.utils.distributed import get_dim_for_local_rank

12
try:
Tri Dao's avatar
Tri Dao committed
13
14
15
16
17
    from flash_attn import (
        flash_attn_kvpacked_func,
        flash_attn_qkvpacked_func,
        flash_attn_varlen_kvpacked_func,
        flash_attn_varlen_qkvpacked_func,
18
        flash_attn_with_kvcache,
Tri Dao's avatar
Tri Dao committed
19
    )
20
except ImportError:
Tri Dao's avatar
Tri Dao committed
21
    flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
22
    flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
23
    flash_attn_with_kvcache = None
24
25

try:
Tri Dao's avatar
Tri Dao committed
26
    from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
27
except ImportError:
Tri Dao's avatar
Tri Dao committed
28
    FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

try:
    from flash_attn.layers.rotary import RotaryEmbedding
except ImportError:
    RotaryEmbedding = None


class FlashSelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
46

Tri Dao's avatar
Tri Dao committed
47
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
48
        super().__init__()
Tri Dao's avatar
Tri Dao committed
49
50
        assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
        assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
51
52
        self.causal = causal
        self.softmax_scale = softmax_scale
53
        self.drop = nn.Dropout(attention_dropout)
54

Tri Dao's avatar
Tri Dao committed
55
    def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
56
57
58
        """Implements the multihead softmax attention.
        Arguments
        ---------
Tri Dao's avatar
Tri Dao committed
59
60
61
62
            qkv: The tensor containing the query, key, and value.
                If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
                If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
                (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
Tri Dao's avatar
Tri Dao committed
63
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
64
65
66
67
68
69
70
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into qkv.
            max_seqlen: int. Maximum sequence length in the batch.
        Returns:
        --------
            out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
                else (B, S, H, D).
71
72
73
        """
        assert qkv.dtype in [torch.float16, torch.bfloat16]
        assert qkv.is_cuda
Tri Dao's avatar
Tri Dao committed
74
        causal = self.causal if causal is None else causal
Tri Dao's avatar
Tri Dao committed
75
76
77
78
79
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
Tri Dao's avatar
Tri Dao committed
80
            return flash_attn_varlen_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
81
82
83
84
85
86
                qkv,
                cu_seqlens,
                max_seqlen,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
87
            )
Tri Dao's avatar
Tri Dao committed
88
        else:
Tri Dao's avatar
Tri Dao committed
89
90
91
92
93
94
            return flash_attn_qkvpacked_func(
                qkv,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
            )
95
96
97
98
99
100
101
102
103
104
105
106


class FlashCrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
107

Tri Dao's avatar
Tri Dao committed
108
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
109
        super().__init__()
Tri Dao's avatar
Tri Dao committed
110
111
        assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
        assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
112
113
        self.causal = causal
        self.softmax_scale = softmax_scale
114
        self.drop = nn.Dropout(attention_dropout)
115

Tri Dao's avatar
Tri Dao committed
116
117
118
119
120
121
122
123
124
125
    def forward(
        self,
        q,
        kv,
        causal=None,
        cu_seqlens=None,
        max_seqlen=None,
        cu_seqlens_k=None,
        max_seqlen_k=None,
    ):
126
127
128
129
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
Tri Dao's avatar
Tri Dao committed
130
            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
Tri Dao's avatar
Tri Dao committed
131
            causal: if passed, will override self.causal
132
133
134
135
136
137
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into q.
            max_seqlen: int. Maximum sequence length in the batch of q.
            cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into kv.
            max_seqlen_k: int. Maximum sequence length in the batch of k and v.
138
139
140
        """
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.is_cuda and kv.is_cuda
Tri Dao's avatar
Tri Dao committed
141
        causal = self.causal if causal is None else causal
142
143
144
145
146
147
148
149
150
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
            assert cu_seqlens_k is not None
            assert cu_seqlens_k.dtype == torch.int32
            assert max_seqlen_k is not None
            assert isinstance(max_seqlen, int)
Tri Dao's avatar
Tri Dao committed
151
            return flash_attn_varlen_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
152
153
154
155
156
157
                q,
                kv,
                cu_seqlens,
                cu_seqlens_k,
                max_seqlen,
                max_seqlen_k,
158
                self.drop.p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
159
160
                softmax_scale=self.softmax_scale,
                causal=causal,
161
            )
162
163
164
        else:
            batch_size, seqlen_q = q.shape[0], q.shape[1]
            seqlen_k = kv.shape[1]
Tri Dao's avatar
Tri Dao committed
165
            assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
Tri Dao's avatar
Tri Dao committed
166
167
168
169
170
171
172
            return flash_attn_kvpacked_func(
                q,
                kv,
                self.drop.p if self.training else 0.0,
                causal=causal,
                softmax_scale=self.softmax_scale,
            )
173
174
175
176
177
178
179
180
181
182
183
184


class SelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
185

Tri Dao's avatar
Tri Dao committed
186
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
187
188
189
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
190
        self.drop = nn.Dropout(attention_dropout)
191

Tri Dao's avatar
Tri Dao committed
192
    def forward(self, qkv, causal=None, key_padding_mask=None):
193
194
195
196
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
Tri Dao's avatar
Tri Dao committed
197
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
198
199
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, S)
200
201
        """
        batch_size, seqlen = qkv.shape[0], qkv.shape[1]
Tri Dao's avatar
Tri Dao committed
202
        causal = self.causal if causal is None else causal
203
204
        q, k, v = qkv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
Tri Dao's avatar
Tri Dao committed
205
        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
Tri Dao's avatar
Tri Dao committed
206
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
207
208
209
            padding_mask = torch.full(
                (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
            )
Tri Dao's avatar
Tri Dao committed
210
211
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
Tri Dao's avatar
Tri Dao committed
212
            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
Tri Dao's avatar
Tri Dao committed
213
        if causal:
214
215
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
Tri Dao's avatar
Tri Dao committed
216
217
218
            causal_mask = torch.triu(
                torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
            )
219
220
221
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + causal_mask.to(dtype=scores.dtype)
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
222
        attention_drop = self.drop(attention)
Tri Dao's avatar
Tri Dao committed
223
        output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
224
225
226
227
228
229
230
231
232
233
234
235
236
        return output


class CrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
237

Tri Dao's avatar
Tri Dao committed
238
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
239
240
241
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
242
        self.drop = nn.Dropout(attention_dropout)
243

Tri Dao's avatar
Tri Dao committed
244
    def forward(self, q, kv, causal=None, key_padding_mask=None):
245
246
247
248
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
Tri Dao's avatar
Tri Dao committed
249
            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
Tri Dao's avatar
Tri Dao committed
250
            causal: if passed, will override self.causal
251
252
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, Sk)
253
254
        """
        batch_size, seqlen_q = q.shape[0], q.shape[1]
Tri Dao's avatar
Tri Dao committed
255
        causal = self.causal if causal is None else causal
256
        seqlen_k = kv.shape[1]
Tri Dao's avatar
Tri Dao committed
257
258
259
        assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
        if kv.shape[3] != q.shape[2]:  # MQA/GQA
            kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
260
261
        k, v = kv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
Tri Dao's avatar
Tri Dao committed
262
        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
263
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
264
265
266
            padding_mask = torch.full(
                (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
            )
267
268
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
Tri Dao's avatar
Tri Dao committed
269
            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
Tri Dao's avatar
Tri Dao committed
270
        if causal:
271
272
273
            # causal mask needs to take into account the difference between seqlen_q and seqlen_k
            row_idx = rearrange(
                torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
Tri Dao's avatar
Tri Dao committed
274
            )
275
276
277
278
279
280
281
282
            col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
            sk = (
                seqlen_k
                if key_padding_mask is None
                else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
            )
            causal_mask = col_idx > row_idx + sk - seqlen_q
            scores = scores.masked_fill(causal_mask, -10000.0)
283
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
284
        attention_drop = self.drop(attention)
Tri Dao's avatar
Tri Dao committed
285
        output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
286
287
288
289
        return output


class LinearResidual(nn.Linear):
Tri Dao's avatar
Tri Dao committed
290
    """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
291
292
293
294
295

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return super().forward(input), input


296
def _update_kv_cache(kv, inference_params, layer_idx):
Tri Dao's avatar
Tri Dao committed
297
    """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
298
299
300
301
    # Pre-allocate memory for key-values for inference.
    num_heads, head_dim = kv.shape[-2:]
    if layer_idx not in inference_params.key_value_memory_dict:
        kv_cache = torch.empty(
Tri Dao's avatar
Tri Dao committed
302
            inference_params.max_batch_size,
303
            inference_params.max_seqlen,
Tri Dao's avatar
Tri Dao committed
304
305
306
307
308
            2,
            num_heads,
            head_dim,
            dtype=kv.dtype,
            device=kv.device,
309
310
311
        )
        inference_params.key_value_memory_dict[layer_idx] = kv_cache
    else:
312
        kv_cache = inference_params.key_value_memory_dict[layer_idx]
313
314
315
    # Adjust key and value for inference
    batch_start = inference_params.batch_size_offset
    batch_end = batch_start + kv.shape[0]
316
    sequence_start = inference_params.seqlen_offset
317
318
319
    sequence_end = sequence_start + kv.shape[1]
    assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
    assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
320
321
322
    assert kv_cache is not None
    kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
    return kv_cache[batch_start:batch_end, :sequence_end, ...]
Tri Dao's avatar
Tri Dao committed
323
324


325
class MHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
    """Multi-head self-attention and cross-attention"""

    def __init__(
        self,
        embed_dim,
        num_heads,
        num_heads_kv=None,
        cross_attn=False,
        qkv_proj_bias=True,
        out_proj_bias=True,
        dropout=0.0,
        softmax_scale=None,
        causal=False,
        layer_idx=None,
        dwconv=False,
        rotary_emb_dim=0,
        rotary_emb_base=10000.0,
        rotary_emb_scale_base=None,
        rotary_emb_interleaved=False,
        fused_bias_fc=False,
        use_flash_attn=False,
        return_residual=False,
        checkpointing=False,
        device=None,
        dtype=None,
    ) -> None:
352
        """
Tri Dao's avatar
Tri Dao committed
353
354
355
356
        num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
        return_residual: whether to return the input x along with the output. This is for
            performance reason: for post-norm architecture, returning the input allows us
            to fuse the backward of nn.Linear with the residual connection.
357
        """
Tri Dao's avatar
Tri Dao committed
358
        factory_kwargs = {"device": device, "dtype": dtype}
359
360
361
362
        super().__init__()
        self.embed_dim = embed_dim
        self.cross_attn = cross_attn
        self.causal = causal
Tri Dao's avatar
Tri Dao committed
363
        self.layer_idx = layer_idx
364
365
        self.dwconv = dwconv
        self.rotary_emb_dim = rotary_emb_dim
Tri Dao's avatar
Tri Dao committed
366
        self.use_flash_attn = use_flash_attn
367
368
369
370
        self.return_residual = return_residual
        self.checkpointing = checkpointing

        self.num_heads = num_heads
Tri Dao's avatar
Tri Dao committed
371
        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
Tri Dao's avatar
Tri Dao committed
372
373
374
        assert (
            self.num_heads % self.num_heads_kv == 0
        ), "num_heads must be divisible by num_heads_kv"
Tri Dao's avatar
Tri Dao committed
375
        assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
376
        self.head_dim = self.embed_dim // num_heads
Tri Dao's avatar
Tri Dao committed
377
378
        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
        kv_dim = 2 * self.head_dim * self.num_heads_kv
379
380

        if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
381
382
383
384
385
386
387
388
389
            assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
            assert RotaryEmbedding is not None, "rotary_emb is not installed"
            self.rotary_emb = RotaryEmbedding(
                self.rotary_emb_dim,
                base=rotary_emb_base,
                scale_base=rotary_emb_scale_base,
                interleaved=rotary_emb_interleaved,
                device=device,
            )
390

Tri Dao's avatar
Tri Dao committed
391
        if fused_bias_fc and FusedDense is None:
Tri Dao's avatar
Tri Dao committed
392
            raise ImportError("fused_dense is not installed")
Tri Dao's avatar
Tri Dao committed
393
        linear_cls = nn.Linear if not fused_bias_fc else FusedDense
Tri Dao's avatar
Tri Dao committed
394
395
396
        linear_resid_cls = (
            LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
        )
Tri Dao's avatar
Tri Dao committed
397
        wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
Tri Dao's avatar
Tri Dao committed
398
399
        inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
        inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
400
        if not self.cross_attn:
Tri Dao's avatar
Tri Dao committed
401
            self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
402
        else:
Tri Dao's avatar
Tri Dao committed
403
            self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
404
405
406
            self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
        if self.dwconv:
            if self.num_heads_kv == self.num_heads:
Tri Dao's avatar
Tri Dao committed
407
408
409
                self.dwconv_qkv = nn.Conv1d(
                    qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
                )
410
            else:
Tri Dao's avatar
Tri Dao committed
411
412
413
414
415
416
417
418
419
420
                self.dwconv_q = nn.Conv1d(
                    embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
                )
                self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
        self.inner_attn = inner_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.inner_cross_attn = inner_cross_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
Tri Dao's avatar
Tri Dao committed
421
        self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
422

423
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
424
425
        dtype = self.out_proj.weight.dtype if dtype is None else dtype
        device = self.out_proj.weight.device
426
427
428
429
430
431
432
433
434
        return torch.empty(
            batch_size,
            max_seqlen,
            2,
            self.num_heads_kv,
            self.head_dim,
            dtype=dtype,
            device=device,
        )
435

Tri Dao's avatar
Tri Dao committed
436
    def _update_kv_cache(self, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
437
438
439
        """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
        assert not self.dwconv, "Generation does not support dwconv yet"
        assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
440
        return _update_kv_cache(kv, inference_params, self.layer_idx)
Tri Dao's avatar
Tri Dao committed
441

442
    def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
443
        """
444
445
446
        Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
        q: (batch_size, seqlen_q, nheads, head_dim)
        kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
Tri Dao's avatar
Tri Dao committed
447
        """
448
        assert inference_params is not None and inference_params.seqlen_offset > 0
449
450
451
452
        assert self.use_flash_attn
        if self.rotary_emb_dim > 0:
            assert self.rotary_emb.scale is None, "This code path does not support xPos"
            self.rotary_emb._update_cos_sin_cache(
453
                inference_params.max_seqlen, device=q.device, dtype=q.dtype
454
455
456
457
458
459
460
461
462
            )
            rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
        else:
            rotary_cos, rotary_sin = None, None
        batch = q.shape[0]
        kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
        cache_seqlens = (
            inference_params.lengths_per_sample[:batch]
            if inference_params.lengths_per_sample is not None
463
            else inference_params.seqlen_offset
Tri Dao's avatar
Tri Dao committed
464
        )
465
466
467
468
469
470
471
472
473
474
475
476
477
478
        context = flash_attn_with_kvcache(
            q,
            kv_cache[:, :, 0],
            kv_cache[:, :, 1],
            kv[:, :, 0],
            kv[:, :, 1],
            rotary_cos=rotary_cos,
            rotary_sin=rotary_sin,
            cache_seqlens=cache_seqlens,
            softmax_scale=self.inner_cross_attn.softmax_scale,
            causal=self.inner_cross_attn.causal,
            rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
        )
        return context
Tri Dao's avatar
Tri Dao committed
479

480
    def _update_kvcache_attention(self, q, kv, inference_params):
481
        """Write kv to inference_params, then do attention"""
482
        if (
483
            inference_params.seqlen_offset == 0
484
485
486
            or flash_attn_with_kvcache is None
            or not self.use_flash_attn
        ):
487
            # TODO: this only uses seqlen_offset and not lengths_per_sample.
488
489
490
491
492
493
494
495
            kv = self._update_kv_cache(kv, inference_params)
            return self.inner_cross_attn(q, kv)
        else:
            batch = q.shape[0]
            kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
            cache_seqlens = (
                inference_params.lengths_per_sample[:batch]
                if inference_params.lengths_per_sample is not None
496
                else inference_params.seqlen_offset
497
498
499
500
501
502
503
504
505
506
507
508
            )
            return flash_attn_with_kvcache(
                q,
                kv_cache[:, :, 0],
                kv_cache[:, :, 1],
                kv[:, :, 0],
                kv[:, :, 1],
                cache_seqlens=cache_seqlens,
                softmax_scale=self.inner_cross_attn.softmax_scale,
                causal=self.inner_cross_attn.causal,
            )

Tri Dao's avatar
Tri Dao committed
509
510
511
512
513
514
515
516
517
518
519
    def forward(
        self,
        x,
        x_kv=None,
        key_padding_mask=None,
        cu_seqlens=None,
        max_seqlen=None,
        mixer_subset=None,
        inference_params=None,
        **kwargs,
    ):
520
521
        """
        Arguments:
Tri Dao's avatar
Tri Dao committed
522
523
524
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
                cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
                is the is the sum of the sequence lengths in the batch.
525
            x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
Tri Dao's avatar
Tri Dao committed
526
527
528
529
530
531
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into x. Only applicable when using
                FlashAttention.
            max_seqlen: int. Maximum sequence length in the batch.
            key_padding_mask: boolean mask, True means to keep, False means to mask out.
                (batch, seqlen). Only applicable when not using FlashAttention.
532
533
534
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
Tri Dao's avatar
Tri Dao committed
535
536
            inference_params: for generation. Adapted from Megatron-LM (and Apex)
            https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
537
        """
Tri Dao's avatar
Tri Dao committed
538
539
540
541
542
543
544
545
546
547
        if cu_seqlens is not None:
            assert max_seqlen is not None
            assert key_padding_mask is None
            assert self.use_flash_attn
            assert not self.dwconv
            assert self.rotary_emb_dim == 0
        if key_padding_mask is not None:
            assert cu_seqlens is None
            assert max_seqlen is None
            assert not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
548
549
550
551
        if inference_params is not None:
            assert key_padding_mask is None
            assert cu_seqlens is None and max_seqlen is None
            assert not self.dwconv
Tri Dao's avatar
Tri Dao committed
552

Tri Dao's avatar
Tri Dao committed
553
554
555
556
557
        kwargs = (
            {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
            if self.use_flash_attn
            else {"key_padding_mask": key_padding_mask, **kwargs}
        )
558
559
560
561
562
563
        seqlen_offset = (
            0
            if inference_params is None
            else (
                inference_params.lengths_per_sample
                if inference_params.lengths_per_sample is not None
564
                else inference_params.seqlen_offset
565
566
            )
        )
567
        rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
568
        batch, seqlen = x.shape[:2]
Tri Dao's avatar
Tri Dao committed
569
        if not self.cross_attn and self.num_heads_kv == self.num_heads:
570
            assert x_kv is None and mixer_subset is None
571
572
573
574
575
            if not self.return_residual:
                qkv = self.Wqkv(x)
            else:
                qkv, x = self.Wqkv(x)
            if self.dwconv:
Tri Dao's avatar
Tri Dao committed
576
577
578
                qkv = rearrange(
                    self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
579
            qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
Tri Dao's avatar
Tri Dao committed
580
581
            if (
                inference_params is None
582
                or inference_params.seqlen_offset == 0
583
584
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
585
            ):
Tri Dao's avatar
Tri Dao committed
586
                if self.rotary_emb_dim > 0:
587
588
589
                    qkv = self.rotary_emb(
                        qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
590
591
592
593
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_attn(qkv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
594
                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
Tri Dao's avatar
Tri Dao committed
595
                else:
596
597
598
                    context = self._update_kvcache_attention(
                        qkv[:, :, 0], qkv[:, :, 1:], inference_params
                    )
Tri Dao's avatar
Tri Dao committed
599
            else:
600
601
602
                context = self._apply_rotary_update_kvcache_attention(
                    qkv[:, :, 0], qkv[:, :, 1:], inference_params
                )
603
        else:
Tri Dao's avatar
Tri Dao committed
604
605
606
607
608
609
610
611
612
613
            if self.cross_attn:
                if not self.return_residual:
                    q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
                    kv = self.Wkv(x_kv if x_kv is not None else x)
                else:
                    if x_kv is not None:
                        kv, x_kv = self.Wkv(x_kv)
                    else:
                        kv, x = self.Wkv(x)
                    q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
614
            else:
Tri Dao's avatar
Tri Dao committed
615
616
617
                assert self.num_heads_kv != self.num_heads
                if not self.return_residual:
                    qkv = self.Wqkv(x)
618
                else:
Tri Dao's avatar
Tri Dao committed
619
                    qkv, x = self.Wqkv(x)
Tri Dao's avatar
Tri Dao committed
620
621
                q = qkv[..., : self.num_heads * self.head_dim]
                kv = qkv[..., self.num_heads * self.head_dim :]
622
623
            q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
            kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
624
            if self.dwconv:
Tri Dao's avatar
Tri Dao committed
625
626
627
628
629
630
631
632
                q = rearrange(
                    self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
                kv = rearrange(
                    self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
            if (
                inference_params is None
633
                or inference_params.seqlen_offset == 0
634
635
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
636
            ):
Tri Dao's avatar
Tri Dao committed
637
                if self.rotary_emb_dim > 0:
638
639
640
                    q, kv = self.rotary_emb(
                        q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
641
642
643
644
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_cross_attn(q, kv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
645
646
647
                        context = torch.utils.checkpoint.checkpoint(
                            self.inner_cross_attn, q, kv, **kwargs
                        )
Tri Dao's avatar
Tri Dao committed
648
                else:
649
                    context = self._update_kvcache_attention(q, kv, inference_params)
650
            else:
651
                context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
652
        out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
653
        return out if not self.return_residual else (out, x)
Tri Dao's avatar
Tri Dao committed
654
655
656


class ParallelMHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
    """Multi-head self-attention and cross-attention"""

    def __init__(
        self,
        embed_dim,
        num_heads,
        process_group,
        num_heads_kv=None,
        qkv_proj_bias=True,
        out_proj_bias=True,
        dropout=0.0,
        softmax_scale=None,
        causal=False,
        layer_idx=None,
        rotary_emb_dim=0,
        rotary_emb_base=10000.0,
        rotary_emb_scale_base=None,
        rotary_emb_interleaved=False,
        use_flash_attn=False,
        checkpointing=False,
        sequence_parallel=True,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
Tri Dao's avatar
Tri Dao committed
682
683
684
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal
685
        self.layer_idx = layer_idx
Tri Dao's avatar
Tri Dao committed
686
687
688
        self.rotary_emb_dim = rotary_emb_dim
        self.use_flash_attn = use_flash_attn
        self.checkpointing = checkpointing
Tri Dao's avatar
Tri Dao committed
689
        self.process_group = process_group
690
691
        self.world_size = process_group.size()
        self.local_rank = torch.distributed.get_rank(process_group)
Tri Dao's avatar
Tri Dao committed
692
693

        self.num_heads = num_heads
694
695
        assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"

Tri Dao's avatar
Tri Dao committed
696
        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
Tri Dao's avatar
Tri Dao committed
697
698
699
        assert (
            self.num_heads % self.num_heads_kv == 0
        ), "num_heads must be divisible by num_heads_kv"
700

Tri Dao's avatar
Tri Dao committed
701
702
703
704
        self.num_heads_per_rank = get_dim_for_local_rank(
            self.num_heads, self.world_size, self.local_rank
        )
        self.num_heads_kv_per_rank = get_dim_for_local_rank(
705
            self.num_heads_kv, self.world_size, self.local_rank
Tri Dao's avatar
Tri Dao committed
706
        )
Tri Dao's avatar
Tri Dao committed
707
        self.head_dim = self.embed_dim // num_heads
Tri Dao's avatar
Tri Dao committed
708
        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
Tri Dao's avatar
Tri Dao committed
709
710

        if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
711
712
713
714
715
716
717
718
            assert RotaryEmbedding is not None, "rotary_emb is not installed"
            self.rotary_emb = RotaryEmbedding(
                self.rotary_emb_dim,
                base=rotary_emb_base,
                scale_base=rotary_emb_scale_base,
                interleaved=rotary_emb_interleaved,
                device=device,
            )
Tri Dao's avatar
Tri Dao committed
719
720

        if ColumnParallelLinear is None or RowParallelLinear is None:
Tri Dao's avatar
Tri Dao committed
721
722
723
724
725
726
727
            raise ImportError("fused_dense is not installed")
        self.Wqkv = ColumnParallelLinear(
            embed_dim,
            qkv_dim,
            process_group,
            bias=qkv_proj_bias,
            sequence_parallel=sequence_parallel,
728
            multiple_of=self.head_dim * (self.num_heads_per_rank + 2 * self.num_heads_kv_per_rank),
Tri Dao's avatar
Tri Dao committed
729
730
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
731
        inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
732
        inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
Tri Dao's avatar
Tri Dao committed
733
734
735
736
737
738
739
740
741
742
743
744
        self.inner_attn = inner_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.inner_cross_attn = inner_cross_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            process_group,
            bias=out_proj_bias,
            sequence_parallel=sequence_parallel,
745
            multiple_of=self.head_dim,
Tri Dao's avatar
Tri Dao committed
746
747
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
748

749
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
Tri Dao's avatar
Tri Dao committed
750
751
        dtype = self.out_proj.weight.dtype if dtype is None else dtype
        device = self.out_proj.weight.device
752
753
754
755
756
757
758
759
760
        return torch.empty(
            batch_size,
            max_seqlen,
            2,
            self.num_heads_kv_per_rank,
            self.head_dim,
            dtype=dtype,
            device=device,
        )
Tri Dao's avatar
Tri Dao committed
761
762

    def _update_kv_cache(self, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
763
764
        """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
        assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
Tri Dao's avatar
Tri Dao committed
765
766
        return _update_kv_cache(kv, inference_params, self.layer_idx)

767
    def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
768
        """
769
770
771
        Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
        q: (batch_size, seqlen_q, nheads, head_dim)
        kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
Tri Dao's avatar
Tri Dao committed
772
        """
773
        assert inference_params is not None and inference_params.seqlen_offset > 0
774
775
776
777
        assert self.use_flash_attn
        if self.rotary_emb_dim > 0:
            assert self.rotary_emb.scale is None, "This code path does not support xPos"
            self.rotary_emb._update_cos_sin_cache(
778
                inference_params.max_seqlen, device=q.device, dtype=q.dtype
779
780
781
782
783
784
785
786
787
            )
            rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
        else:
            rotary_cos, rotary_sin = None, None
        batch = q.shape[0]
        kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
        cache_seqlens = (
            inference_params.lengths_per_sample[:batch]
            if inference_params.lengths_per_sample is not None
788
            else inference_params.seqlen_offset
789
790
791
792
793
794
795
796
797
798
799
800
801
        )
        context = flash_attn_with_kvcache(
            q,
            kv_cache[:, :, 0],
            kv_cache[:, :, 1],
            kv[:, :, 0],
            kv[:, :, 1],
            rotary_cos=rotary_cos,
            rotary_sin=rotary_sin,
            cache_seqlens=cache_seqlens,
            softmax_scale=self.inner_cross_attn.softmax_scale,
            causal=self.inner_cross_attn.causal,
            rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
Tri Dao's avatar
Tri Dao committed
802
        )
803
        return context
Tri Dao's avatar
Tri Dao committed
804

805
    def _update_kvcache_attention(self, q, kv, inference_params):
806
        """Write kv to inference_params, then do attention"""
807
808
        if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
            # TODO: this only uses seqlen_offset and not lengths_per_sample.
809
810
811
812
813
814
815
816
            kv = self._update_kv_cache(kv, inference_params)
            return self.inner_cross_attn(q, kv)
        else:
            batch = q.shape[0]
            kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
            cache_seqlens = (
                inference_params.lengths_per_sample[:batch]
                if inference_params.lengths_per_sample is not None
817
                else inference_params.seqlen_offset
818
819
820
821
822
823
824
825
826
827
828
829
830
            )
            context = flash_attn_with_kvcache(
                q,
                kv_cache[:, :, 0],
                kv_cache[:, :, 1],
                kv[:, :, 0],
                kv[:, :, 1],
                cache_seqlens=cache_seqlens,
                softmax_scale=self.inner_cross_attn.softmax_scale,
                causal=self.inner_cross_attn.causal,
            )
            return context

831
    def forward(self, x, seqlen=None, inference_params=None, **kwargs):
Tri Dao's avatar
Tri Dao committed
832
833
834
835
836
837
838
839
        """
        Arguments:
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
                If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
                split x during sequence parallel, we split the batch * seqlen dimension
                (in case batch is small).
        """
        qkv = self.Wqkv(x)
Tri Dao's avatar
Tri Dao committed
840
841
        if seqlen is not None:
            qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
842
843
844
845
846
847
        seqlen_offset = (
            0
            if inference_params is None
            else (
                inference_params.lengths_per_sample
                if inference_params.lengths_per_sample is not None
848
                else inference_params.seqlen_offset
849
850
            )
        )
851
        rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
Tri Dao's avatar
Tri Dao committed
852
        if self.num_heads_kv == self.num_heads:
Tri Dao's avatar
Tri Dao committed
853
854
855
            qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
            if (
                inference_params is None
856
                or inference_params.seqlen_offset == 0
857
858
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
859
            ):
Tri Dao's avatar
Tri Dao committed
860
                if self.rotary_emb_dim > 0:
861
862
863
                    qkv = self.rotary_emb(
                        qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
864
865
866
867
868
869
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_attn(qkv, **kwargs)
                    else:
                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
                else:
870
871
872
                    context = self._update_kvcache_attention(
                        qkv[:, :, 0], qkv[:, :, 1:], inference_params
                    )
873
            else:
874
875
876
                context = self._apply_rotary_update_kvcache_attention(
                    qkv[:, :, 0], qkv[:, :, 1:], inference_params
                )
Tri Dao's avatar
Tri Dao committed
877
        else:
Tri Dao's avatar
Tri Dao committed
878
879
880
881
882
883
884
885
886
887
888
889
890
            q = rearrange(
                qkv[..., : self.num_heads_per_rank * self.head_dim],
                "... (h d) -> ... h d",
                d=self.head_dim,
            )
            kv = rearrange(
                qkv[..., self.num_heads_per_rank * self.head_dim :],
                "... (two hkv d) -> ... two hkv d",
                two=2,
                d=self.head_dim,
            )
            if (
                inference_params is None
891
                or inference_params.seqlen_offset == 0
892
893
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
894
            ):
895
                if self.rotary_emb_dim > 0:
896
897
898
                    q, kv = self.rotary_emb(
                        q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
899
900
901
902
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_cross_attn(q, kv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
903
904
905
                        context = torch.utils.checkpoint.checkpoint(
                            self.inner_cross_attn, q, kv, **kwargs
                        )
Tri Dao's avatar
Tri Dao committed
906
                else:
907
                    context = self._update_kvcache_attention(q, kv, inference_params)
908
            else:
909
                context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
Tri Dao's avatar
Tri Dao committed
910
        context = rearrange(context, "b s h d -> b s (h d)")
Tri Dao's avatar
Tri Dao committed
911
        if seqlen is not None:
Tri Dao's avatar
Tri Dao committed
912
            context = rearrange(context, "b s d -> (b s) d")
Tri Dao's avatar
Tri Dao committed
913
914
        out = self.out_proj(context)
        return out