mha.py 41 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2023, Tri Dao.
2
3
4
5
6
7

import math
from functools import partial

import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
8
from einops import rearrange, repeat
9

10
11
from flash_attn.utils.distributed import get_dim_for_local_rank

12
try:
Tri Dao's avatar
Tri Dao committed
13
14
15
16
17
    from flash_attn import (
        flash_attn_kvpacked_func,
        flash_attn_qkvpacked_func,
        flash_attn_varlen_kvpacked_func,
        flash_attn_varlen_qkvpacked_func,
18
        flash_attn_with_kvcache,
Tri Dao's avatar
Tri Dao committed
19
    )
20
except ImportError:
Tri Dao's avatar
Tri Dao committed
21
    flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
22
    flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
23
    flash_attn_with_kvcache = None
24
25

try:
Tri Dao's avatar
Tri Dao committed
26
    from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
27
except ImportError:
Tri Dao's avatar
Tri Dao committed
28
    FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
29
30
31
32
33
34

try:
    from flash_attn.layers.rotary import RotaryEmbedding
except ImportError:
    RotaryEmbedding = None

35
36
37
38
39
try:
    import ft_attention
except ImportError:
    ft_attention = None

40
41
42
43
44
45
46
47
48
49
50

class FlashSelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
51

Tri Dao's avatar
Tri Dao committed
52
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
53
        super().__init__()
Tri Dao's avatar
Tri Dao committed
54
55
        assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
        assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
56
57
        self.causal = causal
        self.softmax_scale = softmax_scale
58
        self.drop = nn.Dropout(attention_dropout)
59

Tri Dao's avatar
Tri Dao committed
60
    def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
61
62
63
        """Implements the multihead softmax attention.
        Arguments
        ---------
Tri Dao's avatar
Tri Dao committed
64
65
66
67
            qkv: The tensor containing the query, key, and value.
                If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
                If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
                (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
Tri Dao's avatar
Tri Dao committed
68
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
69
70
71
72
73
74
75
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into qkv.
            max_seqlen: int. Maximum sequence length in the batch.
        Returns:
        --------
            out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
                else (B, S, H, D).
76
77
78
        """
        assert qkv.dtype in [torch.float16, torch.bfloat16]
        assert qkv.is_cuda
Tri Dao's avatar
Tri Dao committed
79
        causal = self.causal if causal is None else causal
Tri Dao's avatar
Tri Dao committed
80
81
82
83
84
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
Tri Dao's avatar
Tri Dao committed
85
            return flash_attn_varlen_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
86
87
88
89
90
91
                qkv,
                cu_seqlens,
                max_seqlen,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
92
            )
Tri Dao's avatar
Tri Dao committed
93
        else:
Tri Dao's avatar
Tri Dao committed
94
95
96
97
98
99
            return flash_attn_qkvpacked_func(
                qkv,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
            )
100
101
102
103
104
105
106
107
108
109
110
111


class FlashCrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
112

Tri Dao's avatar
Tri Dao committed
113
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
114
        super().__init__()
Tri Dao's avatar
Tri Dao committed
115
116
        assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
        assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
117
118
        self.causal = causal
        self.softmax_scale = softmax_scale
119
        self.drop = nn.Dropout(attention_dropout)
120

Tri Dao's avatar
Tri Dao committed
121
122
123
124
125
126
127
128
129
130
    def forward(
        self,
        q,
        kv,
        causal=None,
        cu_seqlens=None,
        max_seqlen=None,
        cu_seqlens_k=None,
        max_seqlen_k=None,
    ):
131
132
133
134
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
Tri Dao's avatar
Tri Dao committed
135
            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
Tri Dao's avatar
Tri Dao committed
136
            causal: if passed, will override self.causal
137
138
139
140
141
142
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into q.
            max_seqlen: int. Maximum sequence length in the batch of q.
            cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into kv.
            max_seqlen_k: int. Maximum sequence length in the batch of k and v.
143
144
145
        """
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.is_cuda and kv.is_cuda
Tri Dao's avatar
Tri Dao committed
146
        causal = self.causal if causal is None else causal
147
148
149
150
151
152
153
154
155
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
            assert cu_seqlens_k is not None
            assert cu_seqlens_k.dtype == torch.int32
            assert max_seqlen_k is not None
            assert isinstance(max_seqlen, int)
Tri Dao's avatar
Tri Dao committed
156
            return flash_attn_varlen_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
157
158
159
160
161
162
                q,
                kv,
                cu_seqlens,
                cu_seqlens_k,
                max_seqlen,
                max_seqlen_k,
163
                self.drop.p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
164
165
                softmax_scale=self.softmax_scale,
                causal=causal,
166
            )
167
168
169
        else:
            batch_size, seqlen_q = q.shape[0], q.shape[1]
            seqlen_k = kv.shape[1]
Tri Dao's avatar
Tri Dao committed
170
            assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
Tri Dao's avatar
Tri Dao committed
171
172
173
174
175
176
177
            return flash_attn_kvpacked_func(
                q,
                kv,
                self.drop.p if self.training else 0.0,
                causal=causal,
                softmax_scale=self.softmax_scale,
            )
178
179
180
181
182
183
184
185
186
187
188
189


class SelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
190

Tri Dao's avatar
Tri Dao committed
191
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
192
193
194
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
195
        self.drop = nn.Dropout(attention_dropout)
196

Tri Dao's avatar
Tri Dao committed
197
    def forward(self, qkv, causal=None, key_padding_mask=None):
198
199
200
201
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
Tri Dao's avatar
Tri Dao committed
202
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
203
204
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, S)
205
206
        """
        batch_size, seqlen = qkv.shape[0], qkv.shape[1]
Tri Dao's avatar
Tri Dao committed
207
        causal = self.causal if causal is None else causal
208
209
        q, k, v = qkv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
Tri Dao's avatar
Tri Dao committed
210
        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
Tri Dao's avatar
Tri Dao committed
211
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
212
213
214
            padding_mask = torch.full(
                (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
            )
Tri Dao's avatar
Tri Dao committed
215
216
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
Tri Dao's avatar
Tri Dao committed
217
            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
Tri Dao's avatar
Tri Dao committed
218
        if causal:
219
220
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
Tri Dao's avatar
Tri Dao committed
221
222
223
            causal_mask = torch.triu(
                torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
            )
224
225
226
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + causal_mask.to(dtype=scores.dtype)
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
227
        attention_drop = self.drop(attention)
Tri Dao's avatar
Tri Dao committed
228
        output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
229
230
231
232
233
234
235
236
237
238
239
240
241
        return output


class CrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
242

Tri Dao's avatar
Tri Dao committed
243
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
244
245
246
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
247
        self.drop = nn.Dropout(attention_dropout)
248

Tri Dao's avatar
Tri Dao committed
249
    def forward(self, q, kv, causal=None, key_padding_mask=None):
250
251
252
253
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
Tri Dao's avatar
Tri Dao committed
254
            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
Tri Dao's avatar
Tri Dao committed
255
            causal: if passed, will override self.causal
256
257
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, Sk)
258
259
        """
        batch_size, seqlen_q = q.shape[0], q.shape[1]
Tri Dao's avatar
Tri Dao committed
260
        causal = self.causal if causal is None else causal
261
        seqlen_k = kv.shape[1]
Tri Dao's avatar
Tri Dao committed
262
263
264
        assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
        if kv.shape[3] != q.shape[2]:  # MQA/GQA
            kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
265
266
        k, v = kv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
Tri Dao's avatar
Tri Dao committed
267
        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
268
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
269
270
271
            padding_mask = torch.full(
                (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
            )
272
273
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
Tri Dao's avatar
Tri Dao committed
274
            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
Tri Dao's avatar
Tri Dao committed
275
        if causal:
276
277
278
            # causal mask needs to take into account the difference between seqlen_q and seqlen_k
            row_idx = rearrange(
                torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
Tri Dao's avatar
Tri Dao committed
279
            )
280
281
282
283
284
285
286
287
            col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
            sk = (
                seqlen_k
                if key_padding_mask is None
                else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
            )
            causal_mask = col_idx > row_idx + sk - seqlen_q
            scores = scores.masked_fill(causal_mask, -10000.0)
288
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
289
        attention_drop = self.drop(attention)
Tri Dao's avatar
Tri Dao committed
290
        output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
291
292
293
294
        return output


class LinearResidual(nn.Linear):
Tri Dao's avatar
Tri Dao committed
295
    """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
296
297
298
299
300

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return super().forward(input), input


301
def _update_kv_cache(kv, inference_params, layer_idx):
Tri Dao's avatar
Tri Dao committed
302
    """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
303
304
305
306
    # Pre-allocate memory for key-values for inference.
    num_heads, head_dim = kv.shape[-2:]
    if layer_idx not in inference_params.key_value_memory_dict:
        kv_cache = torch.empty(
Tri Dao's avatar
Tri Dao committed
307
308
309
310
311
312
313
            inference_params.max_batch_size,
            inference_params.max_sequence_len,
            2,
            num_heads,
            head_dim,
            dtype=kv.dtype,
            device=kv.device,
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
        )
        inference_params.key_value_memory_dict[layer_idx] = kv_cache
    else:
        if not inference_params.fused_ft_kernel:
            kv_cache = inference_params.key_value_memory_dict[layer_idx]
        else:
            # For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
            # where packsize = 4 if fp32, 8 if fp16 or bf16.
            # v_cache has shape (b, h, s, headdim)
            k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
            kv_cache = None
    # Adjust key and value for inference
    batch_start = inference_params.batch_size_offset
    batch_end = batch_start + kv.shape[0]
    sequence_start = inference_params.sequence_len_offset
    sequence_end = sequence_start + kv.shape[1]
    assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
    assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
    # Copy key and values.
    if not inference_params.fused_ft_kernel:
        assert kv_cache is not None
        kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
        kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
        return kv
    else:
        assert inference_params.sequence_len_offset == 0
        # FT kernel requires different layouts for the k_cache and v_cache.
        assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
        packsize = 4 if kv.dtype == torch.float32 else 8
        if kv_cache is not None:
            kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
Tri Dao's avatar
Tri Dao committed
345
346
347
348
            k_cache = rearrange(
                kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
            ).contiguous()
            v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous()
349
350
351
            inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
        else:
            k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
Tri Dao's avatar
Tri Dao committed
352
                kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
353
354
            )
            v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(
Tri Dao's avatar
Tri Dao committed
355
                kv[:, :, 1], "b s h d -> b h s d"
356
357
358
359
            )
        return kv


Tri Dao's avatar
Tri Dao committed
360
361
362
363
364
365
366
367
368
def _apply_rotary_single_query_attention(
    qkv,
    inference_params,
    layer_idx,
    rotary_emb_dim,
    rotary_emb_base,
    kv=None,
    rotary_emb_interleaved=False,
):
Tri Dao's avatar
Tri Dao committed
369
370
371
372
373
374
375
376
    """
    qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
            q of shape (batch_size, 1, nheads, head_dim)
    kv: (batch_size, 1, 2, nheads_kv, head_dim)
    """
    assert inference_params.fused_ft_kernel
    assert ft_attention is not None
    if kv is None:
Tri Dao's avatar
Tri Dao committed
377
        q, k, v = rearrange(qkv, "b 1 three h d -> b three h d").unbind(dim=1)
Tri Dao's avatar
Tri Dao committed
378
    else:
Tri Dao's avatar
Tri Dao committed
379
380
        q = rearrange(qkv, "b 1 h d -> b h d")
        k, v = rearrange(kv, "b 1 two h d -> b two h d").unbind(dim=1)
Tri Dao's avatar
Tri Dao committed
381
382
383
    batch_start = inference_params.batch_size_offset
    batch_end = batch_start + q.shape[0]
    k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
Tri Dao's avatar
Tri Dao committed
384
385
386
387
388
    lengths_per_sample = (
        inference_params.lengths_per_sample[batch_start:batch_end]
        if inference_params.lengths_per_sample is not None
        else None
    )
Tri Dao's avatar
Tri Dao committed
389
    context = ft_attention.single_query_attention(
Tri Dao's avatar
Tri Dao committed
390
391
392
        q,
        k,
        v,
Tri Dao's avatar
Tri Dao committed
393
394
395
396
397
398
399
        k_cache[batch_start:batch_end],
        v_cache[batch_start:batch_end],
        lengths_per_sample,
        None,  # rotary_cos_
        None,  # rotary_sin_
        None,  # nnz_head_idx
        inference_params.sequence_len_offset,
Tri Dao's avatar
Tri Dao committed
400
401
402
        rotary_emb_dim,
        rotary_emb_base,
        not rotary_emb_interleaved,  # neox_rotary_style
Tri Dao's avatar
Tri Dao committed
403
    )
Tri Dao's avatar
Tri Dao committed
404
    return rearrange(context, "b h d -> b 1 h d")
Tri Dao's avatar
Tri Dao committed
405
406


407
class MHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    """Multi-head self-attention and cross-attention"""

    def __init__(
        self,
        embed_dim,
        num_heads,
        num_heads_kv=None,
        cross_attn=False,
        qkv_proj_bias=True,
        out_proj_bias=True,
        dropout=0.0,
        softmax_scale=None,
        causal=False,
        layer_idx=None,
        dwconv=False,
        rotary_emb_dim=0,
        rotary_emb_base=10000.0,
        rotary_emb_scale_base=None,
        rotary_emb_interleaved=False,
        fused_bias_fc=False,
        use_flash_attn=False,
        return_residual=False,
        checkpointing=False,
        device=None,
        dtype=None,
    ) -> None:
434
        """
Tri Dao's avatar
Tri Dao committed
435
436
437
438
        num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
        return_residual: whether to return the input x along with the output. This is for
            performance reason: for post-norm architecture, returning the input allows us
            to fuse the backward of nn.Linear with the residual connection.
439
        """
Tri Dao's avatar
Tri Dao committed
440
        factory_kwargs = {"device": device, "dtype": dtype}
441
442
443
444
        super().__init__()
        self.embed_dim = embed_dim
        self.cross_attn = cross_attn
        self.causal = causal
Tri Dao's avatar
Tri Dao committed
445
        self.layer_idx = layer_idx
446
447
        self.dwconv = dwconv
        self.rotary_emb_dim = rotary_emb_dim
Tri Dao's avatar
Tri Dao committed
448
        self.use_flash_attn = use_flash_attn
449
450
451
452
        self.return_residual = return_residual
        self.checkpointing = checkpointing

        self.num_heads = num_heads
Tri Dao's avatar
Tri Dao committed
453
        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
Tri Dao's avatar
Tri Dao committed
454
455
456
        assert (
            self.num_heads % self.num_heads_kv == 0
        ), "num_heads must be divisible by num_heads_kv"
Tri Dao's avatar
Tri Dao committed
457
        assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
458
        self.head_dim = self.embed_dim // num_heads
Tri Dao's avatar
Tri Dao committed
459
460
        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
        kv_dim = 2 * self.head_dim * self.num_heads_kv
461
462

        if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
463
464
465
466
467
468
469
470
471
            assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
            assert RotaryEmbedding is not None, "rotary_emb is not installed"
            self.rotary_emb = RotaryEmbedding(
                self.rotary_emb_dim,
                base=rotary_emb_base,
                scale_base=rotary_emb_scale_base,
                interleaved=rotary_emb_interleaved,
                device=device,
            )
472

Tri Dao's avatar
Tri Dao committed
473
        if fused_bias_fc and FusedDense is None:
Tri Dao's avatar
Tri Dao committed
474
            raise ImportError("fused_dense is not installed")
Tri Dao's avatar
Tri Dao committed
475
        linear_cls = nn.Linear if not fused_bias_fc else FusedDense
Tri Dao's avatar
Tri Dao committed
476
477
478
        linear_resid_cls = (
            LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
        )
Tri Dao's avatar
Tri Dao committed
479
        wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
Tri Dao's avatar
Tri Dao committed
480
481
        inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
        inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
482
        if not self.cross_attn:
Tri Dao's avatar
Tri Dao committed
483
            self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
484
        else:
Tri Dao's avatar
Tri Dao committed
485
            self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
486
487
488
            self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
        if self.dwconv:
            if self.num_heads_kv == self.num_heads:
Tri Dao's avatar
Tri Dao committed
489
490
491
                self.dwconv_qkv = nn.Conv1d(
                    qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
                )
492
            else:
Tri Dao's avatar
Tri Dao committed
493
494
495
496
497
498
499
500
501
502
                self.dwconv_q = nn.Conv1d(
                    embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
                )
                self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
        self.inner_attn = inner_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.inner_cross_attn = inner_cross_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
Tri Dao's avatar
Tri Dao committed
503
        self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
504

505
506
507
508
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
        dtype = self.out_proj.weight.dtype if dtype is None else dtype
        device = self.out_proj.weight.device
        if not fused_ft_kernel:
Tri Dao's avatar
Tri Dao committed
509
510
511
512
513
514
515
516
517
            return torch.empty(
                batch_size,
                max_seqlen,
                2,
                self.num_heads_kv,
                self.head_dim,
                dtype=dtype,
                device=device,
            )
518
519
520
521
        else:
            assert dtype in [torch.float16, torch.bfloat16, torch.float32]
            packsize = 4 if dtype == torch.float32 else 8
            assert self.head_dim % packsize == 0
Tri Dao's avatar
Tri Dao committed
522
523
524
525
526
527
528
529
530
531
532
533
            k_cache = torch.empty(
                batch_size,
                self.num_heads_kv,
                self.head_dim // packsize,
                max_seqlen,
                packsize,
                dtype=dtype,
                device=device,
            )
            v_cache = torch.empty(
                batch_size, self.num_heads_kv, max_seqlen, self.head_dim, dtype=dtype, device=device
            )
534
535
            return k_cache, v_cache

Tri Dao's avatar
Tri Dao committed
536
    def _update_kv_cache(self, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
537
538
539
        """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
        assert not self.dwconv, "Generation does not support dwconv yet"
        assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
540
        return _update_kv_cache(kv, inference_params, self.layer_idx)
Tri Dao's avatar
Tri Dao committed
541

Tri Dao's avatar
Tri Dao committed
542
543
544
545
546
547
548
549
    def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
        """
        qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
              q of shape (batch_size, 1, nheads, head_dim)
        kv: (batch_size, 1, 2, nheads_kv, head_dim)
        """
        rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
        return _apply_rotary_single_query_attention(
Tri Dao's avatar
Tri Dao committed
550
551
552
553
554
555
556
557
558
            qkv,
            inference_params,
            self.layer_idx,
            self.rotary_emb_dim,
            rotary_emb_base,
            kv=kv,
            rotary_emb_interleaved=self.rotary_emb.interleaved
            if self.rotary_emb_dim > 0
            else False,
Tri Dao's avatar
Tri Dao committed
559
560
        )

561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
    def _update_kvcache_attention(self, q, kv, inference_params):
        """Write kv to inference_params, then do attention """
        if (
            inference_params.sequence_len_offset == 0
            or flash_attn_with_kvcache is None
            or not self.use_flash_attn
        ):
            # TODO: this only uses sequence_len_offset and not lengths_per_sample.
            kv = self._update_kv_cache(kv, inference_params)
            return self.inner_cross_attn(q, kv)
        else:
            batch = q.shape[0]
            kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
            cache_seqlens = (
                inference_params.lengths_per_sample[:batch]
                if inference_params.lengths_per_sample is not None
                else inference_params.sequence_len_offset
            )
            return flash_attn_with_kvcache(
                q,
                kv_cache[:, :, 0],
                kv_cache[:, :, 1],
                kv[:, :, 0],
                kv[:, :, 1],
                cache_seqlens=cache_seqlens,
                softmax_scale=self.inner_cross_attn.softmax_scale,
                causal=self.inner_cross_attn.causal,
            )

Tri Dao's avatar
Tri Dao committed
590
591
592
593
594
595
596
597
598
599
600
    def forward(
        self,
        x,
        x_kv=None,
        key_padding_mask=None,
        cu_seqlens=None,
        max_seqlen=None,
        mixer_subset=None,
        inference_params=None,
        **kwargs,
    ):
601
602
        """
        Arguments:
Tri Dao's avatar
Tri Dao committed
603
604
605
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
                cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
                is the is the sum of the sequence lengths in the batch.
606
            x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
Tri Dao's avatar
Tri Dao committed
607
608
609
610
611
612
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into x. Only applicable when using
                FlashAttention.
            max_seqlen: int. Maximum sequence length in the batch.
            key_padding_mask: boolean mask, True means to keep, False means to mask out.
                (batch, seqlen). Only applicable when not using FlashAttention.
613
614
615
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
Tri Dao's avatar
Tri Dao committed
616
617
            inference_params: for generation. Adapted from Megatron-LM (and Apex)
            https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
618
        """
Tri Dao's avatar
Tri Dao committed
619
620
621
622
623
624
625
626
627
628
        if cu_seqlens is not None:
            assert max_seqlen is not None
            assert key_padding_mask is None
            assert self.use_flash_attn
            assert not self.dwconv
            assert self.rotary_emb_dim == 0
        if key_padding_mask is not None:
            assert cu_seqlens is None
            assert max_seqlen is None
            assert not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
629
630
631
632
        if inference_params is not None:
            assert key_padding_mask is None
            assert cu_seqlens is None and max_seqlen is None
            assert not self.dwconv
Tri Dao's avatar
Tri Dao committed
633

Tri Dao's avatar
Tri Dao committed
634
635
636
637
638
        kwargs = (
            {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
            if self.use_flash_attn
            else {"key_padding_mask": key_padding_mask, **kwargs}
        )
639
640
641
642
643
644
645
646
647
        seqlen_offset = (
            0
            if inference_params is None
            else (
                inference_params.lengths_per_sample
                if inference_params.lengths_per_sample is not None
                else inference_params.sequence_len_offset
            )
        )
648
        rotary_max_seqlen = (
Tri Dao's avatar
Tri Dao committed
649
            inference_params.max_sequence_len if inference_params is not None else None
650
        )
651
        batch, seqlen = x.shape[:2]
Tri Dao's avatar
Tri Dao committed
652
        if not self.cross_attn and self.num_heads_kv == self.num_heads:
653
            assert x_kv is None and mixer_subset is None
654
655
656
657
658
            if not self.return_residual:
                qkv = self.Wqkv(x)
            else:
                qkv, x = self.Wqkv(x)
            if self.dwconv:
Tri Dao's avatar
Tri Dao committed
659
660
661
                qkv = rearrange(
                    self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
662
            qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
Tri Dao's avatar
Tri Dao committed
663
664
665
666
667
            if (
                inference_params is None
                or inference_params.sequence_len_offset == 0
                or not inference_params.fused_ft_kernel
            ):
Tri Dao's avatar
Tri Dao committed
668
                if self.rotary_emb_dim > 0:
669
670
671
                    qkv = self.rotary_emb(
                        qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
672
673
674
675
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_attn(qkv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
676
                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
Tri Dao's avatar
Tri Dao committed
677
                else:
678
679
680
                    context = self._update_kvcache_attention(
                        qkv[:, :, 0], qkv[:, :, 1:], inference_params
                    )
Tri Dao's avatar
Tri Dao committed
681
682
            else:
                context = self._apply_rotary_single_query_attention(qkv, inference_params)
683
        else:
Tri Dao's avatar
Tri Dao committed
684
685
686
687
688
689
690
691
692
693
            if self.cross_attn:
                if not self.return_residual:
                    q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
                    kv = self.Wkv(x_kv if x_kv is not None else x)
                else:
                    if x_kv is not None:
                        kv, x_kv = self.Wkv(x_kv)
                    else:
                        kv, x = self.Wkv(x)
                    q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
694
            else:
Tri Dao's avatar
Tri Dao committed
695
696
697
                assert self.num_heads_kv != self.num_heads
                if not self.return_residual:
                    qkv = self.Wqkv(x)
698
                else:
Tri Dao's avatar
Tri Dao committed
699
                    qkv, x = self.Wqkv(x)
Tri Dao's avatar
Tri Dao committed
700
701
                q = qkv[..., : self.num_heads * self.head_dim]
                kv = qkv[..., self.num_heads * self.head_dim :]
702
703
            q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
            kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
704
            if self.dwconv:
Tri Dao's avatar
Tri Dao committed
705
706
707
708
709
710
711
712
713
714
715
                q = rearrange(
                    self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
                kv = rearrange(
                    self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
            if (
                inference_params is None
                or inference_params.sequence_len_offset == 0
                or not inference_params.fused_ft_kernel
            ):
Tri Dao's avatar
Tri Dao committed
716
                if self.rotary_emb_dim > 0:
717
718
719
                    q, kv = self.rotary_emb(
                        q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
720
721
722
723
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_cross_attn(q, kv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
724
725
726
                        context = torch.utils.checkpoint.checkpoint(
                            self.inner_cross_attn, q, kv, **kwargs
                        )
Tri Dao's avatar
Tri Dao committed
727
                else:
728
                    context = self._update_kvcache_attention(q, kv, inference_params)
729
            else:
Tri Dao's avatar
Tri Dao committed
730
                context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
731
        out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
732
        return out if not self.return_residual else (out, x)
Tri Dao's avatar
Tri Dao committed
733
734
735


class ParallelMHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
    """Multi-head self-attention and cross-attention"""

    def __init__(
        self,
        embed_dim,
        num_heads,
        process_group,
        num_heads_kv=None,
        qkv_proj_bias=True,
        out_proj_bias=True,
        dropout=0.0,
        softmax_scale=None,
        causal=False,
        layer_idx=None,
        rotary_emb_dim=0,
        rotary_emb_base=10000.0,
        rotary_emb_scale_base=None,
        rotary_emb_interleaved=False,
        use_flash_attn=False,
        checkpointing=False,
        sequence_parallel=True,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
Tri Dao's avatar
Tri Dao committed
761
762
763
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal
764
        self.layer_idx = layer_idx
Tri Dao's avatar
Tri Dao committed
765
766
767
        self.rotary_emb_dim = rotary_emb_dim
        self.use_flash_attn = use_flash_attn
        self.checkpointing = checkpointing
Tri Dao's avatar
Tri Dao committed
768
        self.process_group = process_group
769
770
        self.world_size = process_group.size()
        self.local_rank = torch.distributed.get_rank(process_group)
Tri Dao's avatar
Tri Dao committed
771
772

        self.num_heads = num_heads
773
774
        assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"

Tri Dao's avatar
Tri Dao committed
775
        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
Tri Dao's avatar
Tri Dao committed
776
777
778
        assert (
            self.num_heads % self.num_heads_kv == 0
        ), "num_heads must be divisible by num_heads_kv"
779

Tri Dao's avatar
Tri Dao committed
780
781
782
783
        self.num_heads_per_rank = get_dim_for_local_rank(
            self.num_heads, self.world_size, self.local_rank
        )
        self.num_heads_kv_per_rank = get_dim_for_local_rank(
784
            self.num_heads_kv, self.world_size, self.local_rank
Tri Dao's avatar
Tri Dao committed
785
        )
Tri Dao's avatar
Tri Dao committed
786
        self.head_dim = self.embed_dim // num_heads
Tri Dao's avatar
Tri Dao committed
787
        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
Tri Dao's avatar
Tri Dao committed
788
789

        if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
790
791
792
793
794
795
796
797
            assert RotaryEmbedding is not None, "rotary_emb is not installed"
            self.rotary_emb = RotaryEmbedding(
                self.rotary_emb_dim,
                base=rotary_emb_base,
                scale_base=rotary_emb_scale_base,
                interleaved=rotary_emb_interleaved,
                device=device,
            )
Tri Dao's avatar
Tri Dao committed
798
799

        if ColumnParallelLinear is None or RowParallelLinear is None:
Tri Dao's avatar
Tri Dao committed
800
801
802
803
804
805
806
            raise ImportError("fused_dense is not installed")
        self.Wqkv = ColumnParallelLinear(
            embed_dim,
            qkv_dim,
            process_group,
            bias=qkv_proj_bias,
            sequence_parallel=sequence_parallel,
807
            multiple_of=self.head_dim * (self.num_heads_per_rank + 2 * self.num_heads_kv_per_rank),
Tri Dao's avatar
Tri Dao committed
808
809
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
810
        inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
811
        inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
Tri Dao's avatar
Tri Dao committed
812
813
814
815
816
817
818
819
820
821
822
823
        self.inner_attn = inner_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.inner_cross_attn = inner_cross_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            process_group,
            bias=out_proj_bias,
            sequence_parallel=sequence_parallel,
824
            multiple_of=self.head_dim,
Tri Dao's avatar
Tri Dao committed
825
826
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
827

Tri Dao's avatar
Tri Dao committed
828
829
830
831
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
        dtype = self.out_proj.weight.dtype if dtype is None else dtype
        device = self.out_proj.weight.device
        if not fused_ft_kernel:
Tri Dao's avatar
Tri Dao committed
832
833
834
835
836
837
838
839
840
            return torch.empty(
                batch_size,
                max_seqlen,
                2,
                self.num_heads_kv_per_rank,
                self.head_dim,
                dtype=dtype,
                device=device,
            )
Tri Dao's avatar
Tri Dao committed
841
842
843
844
        else:
            assert dtype in [torch.float16, torch.bfloat16, torch.float32]
            packsize = 4 if dtype == torch.float32 else 8
            assert self.head_dim % packsize == 0
Tri Dao's avatar
Tri Dao committed
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
            k_cache = torch.empty(
                batch_size,
                self.num_heads_kv_per_rank,
                self.head_dim // packsize,
                max_seqlen,
                packsize,
                dtype=dtype,
                device=device,
            )
            v_cache = torch.empty(
                batch_size,
                self.num_heads_kv_per_rank,
                max_seqlen,
                self.head_dim,
                dtype=dtype,
                device=device,
            )
Tri Dao's avatar
Tri Dao committed
862
863
864
            return k_cache, v_cache

    def _update_kv_cache(self, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
865
866
        """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
        assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
Tri Dao's avatar
Tri Dao committed
867
868
869
870
871
872
873
874
875
876
        return _update_kv_cache(kv, inference_params, self.layer_idx)

    def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
        """
        qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
              q of shape (batch_size, 1, nheads, head_dim)
        kv: (batch_size, 1, 2, nheads_kv, head_dim)
        """
        rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
        return _apply_rotary_single_query_attention(
Tri Dao's avatar
Tri Dao committed
877
878
879
880
881
882
883
884
885
            qkv,
            inference_params,
            self.layer_idx,
            self.rotary_emb_dim,
            rotary_emb_base,
            kv=kv,
            rotary_emb_interleaved=self.rotary_emb.interleaved
            if self.rotary_emb_dim > 0
            else False,
Tri Dao's avatar
Tri Dao committed
886
887
        )

888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
    def _update_kvcache_attention(self, q, kv, inference_params):
        """Write kv to inference_params, then do attention """
        if (
            inference_params.sequence_len_offset == 0
            or flash_attn_with_kvcache is None
            or not self.use_flash_attn
        ):
            # TODO: this only uses sequence_len_offset and not lengths_per_sample.
            kv = self._update_kv_cache(kv, inference_params)
            return self.inner_cross_attn(q, kv)
        else:
            batch = q.shape[0]
            kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
            cache_seqlens = (
                inference_params.lengths_per_sample[:batch]
                if inference_params.lengths_per_sample is not None
                else inference_params.sequence_len_offset
            )
            context = flash_attn_with_kvcache(
                q,
                kv_cache[:, :, 0],
                kv_cache[:, :, 1],
                kv[:, :, 0],
                kv[:, :, 1],
                cache_seqlens=cache_seqlens,
                softmax_scale=self.inner_cross_attn.softmax_scale,
                causal=self.inner_cross_attn.causal,
            )
            return context

918
    def forward(self, x, seqlen=None, inference_params=None, **kwargs):
Tri Dao's avatar
Tri Dao committed
919
920
921
922
923
924
925
926
        """
        Arguments:
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
                If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
                split x during sequence parallel, we split the batch * seqlen dimension
                (in case batch is small).
        """
        qkv = self.Wqkv(x)
Tri Dao's avatar
Tri Dao committed
927
928
        if seqlen is not None:
            qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
929
930
931
932
933
934
935
936
937
        seqlen_offset = (
            0
            if inference_params is None
            else (
                inference_params.lengths_per_sample
                if inference_params.lengths_per_sample is not None
                else inference_params.sequence_len_offset
            )
        )
938
        rotary_max_seqlen = (
Tri Dao's avatar
Tri Dao committed
939
            inference_params.max_sequence_len if inference_params is not None else None
940
        )
Tri Dao's avatar
Tri Dao committed
941
        if self.num_heads_kv == self.num_heads:
Tri Dao's avatar
Tri Dao committed
942
943
944
945
946
947
            qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
            if (
                inference_params is None
                or inference_params.sequence_len_offset == 0
                or not inference_params.fused_ft_kernel
            ):
Tri Dao's avatar
Tri Dao committed
948
                if self.rotary_emb_dim > 0:
949
950
951
                    qkv = self.rotary_emb(
                        qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
952
953
954
955
956
957
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_attn(qkv, **kwargs)
                    else:
                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
                else:
958
959
960
                    context = self._update_kvcache_attention(
                        qkv[:, :, 0], qkv[:, :, 1:], inference_params
                    )
961
            else:
Tri Dao's avatar
Tri Dao committed
962
                context = self._apply_rotary_single_query_attention(qkv, inference_params)
Tri Dao's avatar
Tri Dao committed
963
        else:
Tri Dao's avatar
Tri Dao committed
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
            q = rearrange(
                qkv[..., : self.num_heads_per_rank * self.head_dim],
                "... (h d) -> ... h d",
                d=self.head_dim,
            )
            kv = rearrange(
                qkv[..., self.num_heads_per_rank * self.head_dim :],
                "... (two hkv d) -> ... two hkv d",
                two=2,
                d=self.head_dim,
            )
            if (
                inference_params is None
                or inference_params.sequence_len_offset == 0
                or not inference_params.fused_ft_kernel
            ):
980
                if self.rotary_emb_dim > 0:
981
982
983
                    q, kv = self.rotary_emb(
                        q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
984
985
986
987
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_cross_attn(q, kv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
988
989
990
                        context = torch.utils.checkpoint.checkpoint(
                            self.inner_cross_attn, q, kv, **kwargs
                        )
Tri Dao's avatar
Tri Dao committed
991
                else:
992
                    context = self._update_kvcache_attention(q, kv, inference_params)
993
            else:
Tri Dao's avatar
Tri Dao committed
994
                context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
Tri Dao's avatar
Tri Dao committed
995
        context = rearrange(context, "b s h d -> b s (h d)")
Tri Dao's avatar
Tri Dao committed
996
        if seqlen is not None:
Tri Dao's avatar
Tri Dao committed
997
            context = rearrange(context, "b s d -> (b s) d")
Tri Dao's avatar
Tri Dao committed
998
999
        out = self.out_proj(context)
        return out