mha.py 41.2 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2023, Tri Dao.
2
3
4
5
6
7

import math
from functools import partial

import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
8
from einops import rearrange, repeat
9

10
11
from flash_attn.utils.distributed import get_dim_for_local_rank

12
try:
Tri Dao's avatar
Tri Dao committed
13
14
15
16
17
    from flash_attn import (
        flash_attn_kvpacked_func,
        flash_attn_qkvpacked_func,
        flash_attn_varlen_kvpacked_func,
        flash_attn_varlen_qkvpacked_func,
18
        flash_attn_with_kvcache,
Tri Dao's avatar
Tri Dao committed
19
    )
20
except ImportError:
Tri Dao's avatar
Tri Dao committed
21
    flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
22
    flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
23
    flash_attn_with_kvcache = None
24
25

try:
Tri Dao's avatar
Tri Dao committed
26
    from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
27
except ImportError:
Tri Dao's avatar
Tri Dao committed
28
    FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
29
30
31
32
33
34
35

try:
    from flash_attn.layers.rotary import RotaryEmbedding
except ImportError:
    RotaryEmbedding = None


36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
def get_alibi_slopes(nheads):
    def get_slopes_power_of_2(nheads):
        start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))
        ratio = start
        return [start * ratio**i for i in range(nheads)]

    if math.log2(nheads).is_integer():
        return get_slopes_power_of_2(nheads)
    else:
        closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
        return (
            get_slopes_power_of_2(closest_power_of_2)
            + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
        )


53
54
55
56
57
58
59
60
61
62
class FlashSelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
63

64
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False):
65
        super().__init__()
Tri Dao's avatar
Tri Dao committed
66
67
        assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
        assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
68
69
        self.causal = causal
        self.softmax_scale = softmax_scale
70
        self.drop = nn.Dropout(attention_dropout)
71
        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
72
        self.deterministic = deterministic
73

Tri Dao's avatar
Tri Dao committed
74
    def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
75
76
77
        """Implements the multihead softmax attention.
        Arguments
        ---------
Tri Dao's avatar
Tri Dao committed
78
79
80
81
            qkv: The tensor containing the query, key, and value.
                If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
                If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
                (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
Tri Dao's avatar
Tri Dao committed
82
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
83
84
85
86
87
88
89
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into qkv.
            max_seqlen: int. Maximum sequence length in the batch.
        Returns:
        --------
            out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
                else (B, S, H, D).
90
91
92
        """
        assert qkv.dtype in [torch.float16, torch.bfloat16]
        assert qkv.is_cuda
Tri Dao's avatar
Tri Dao committed
93
        causal = self.causal if causal is None else causal
Tri Dao's avatar
Tri Dao committed
94
95
96
97
98
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
Tri Dao's avatar
Tri Dao committed
99
            return flash_attn_varlen_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
100
101
102
103
104
105
                qkv,
                cu_seqlens,
                max_seqlen,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
106
                alibi_slopes=self.alibi_slopes,
107
                deterministic=self.deterministic,
108
            )
Tri Dao's avatar
Tri Dao committed
109
        else:
Tri Dao's avatar
Tri Dao committed
110
111
112
113
114
            return flash_attn_qkvpacked_func(
                qkv,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
115
                alibi_slopes=self.alibi_slopes,
116
                deterministic=self.deterministic,
Tri Dao's avatar
Tri Dao committed
117
            )
118
119
120
121
122
123
124
125
126
127
128
129


class FlashCrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
130

131
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False):
132
        super().__init__()
Tri Dao's avatar
Tri Dao committed
133
134
        assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
        assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
135
136
        self.causal = causal
        self.softmax_scale = softmax_scale
137
        self.drop = nn.Dropout(attention_dropout)
138
        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
139
        self.deterministic = deterministic
140

Tri Dao's avatar
Tri Dao committed
141
142
143
144
145
146
147
148
149
150
    def forward(
        self,
        q,
        kv,
        causal=None,
        cu_seqlens=None,
        max_seqlen=None,
        cu_seqlens_k=None,
        max_seqlen_k=None,
    ):
151
152
153
154
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
Tri Dao's avatar
Tri Dao committed
155
            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
Tri Dao's avatar
Tri Dao committed
156
            causal: if passed, will override self.causal
157
158
159
160
161
162
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into q.
            max_seqlen: int. Maximum sequence length in the batch of q.
            cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into kv.
            max_seqlen_k: int. Maximum sequence length in the batch of k and v.
163
164
165
        """
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.is_cuda and kv.is_cuda
Tri Dao's avatar
Tri Dao committed
166
        causal = self.causal if causal is None else causal
167
168
169
170
171
172
173
174
175
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
            assert cu_seqlens_k is not None
            assert cu_seqlens_k.dtype == torch.int32
            assert max_seqlen_k is not None
            assert isinstance(max_seqlen, int)
Tri Dao's avatar
Tri Dao committed
176
            return flash_attn_varlen_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
177
178
179
180
181
182
                q,
                kv,
                cu_seqlens,
                cu_seqlens_k,
                max_seqlen,
                max_seqlen_k,
183
                self.drop.p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
184
185
                softmax_scale=self.softmax_scale,
                causal=causal,
186
                alibi_slopes=self.alibi_slopes,
187
                deterministic=self.deterministic,
188
            )
189
190
191
        else:
            batch_size, seqlen_q = q.shape[0], q.shape[1]
            seqlen_k = kv.shape[1]
Tri Dao's avatar
Tri Dao committed
192
            assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
Tri Dao's avatar
Tri Dao committed
193
194
195
196
197
198
            return flash_attn_kvpacked_func(
                q,
                kv,
                self.drop.p if self.training else 0.0,
                causal=causal,
                softmax_scale=self.softmax_scale,
199
                alibi_slopes=self.alibi_slopes,
200
                deterministic=self.deterministic,
Tri Dao's avatar
Tri Dao committed
201
            )
202
203
204
205
206
207
208
209
210
211
212
213


class SelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
214

Tri Dao's avatar
Tri Dao committed
215
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
216
217
218
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
219
        self.drop = nn.Dropout(attention_dropout)
220

Tri Dao's avatar
Tri Dao committed
221
    def forward(self, qkv, causal=None, key_padding_mask=None):
222
223
224
225
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
Tri Dao's avatar
Tri Dao committed
226
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
227
228
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, S)
229
230
        """
        batch_size, seqlen = qkv.shape[0], qkv.shape[1]
Tri Dao's avatar
Tri Dao committed
231
        causal = self.causal if causal is None else causal
232
233
        q, k, v = qkv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
Tri Dao's avatar
Tri Dao committed
234
        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
Tri Dao's avatar
Tri Dao committed
235
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
236
237
238
            padding_mask = torch.full(
                (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
            )
Tri Dao's avatar
Tri Dao committed
239
240
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
Tri Dao's avatar
Tri Dao committed
241
            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
Tri Dao's avatar
Tri Dao committed
242
        if causal:
243
244
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
Tri Dao's avatar
Tri Dao committed
245
246
247
            causal_mask = torch.triu(
                torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
            )
248
249
250
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + causal_mask.to(dtype=scores.dtype)
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
251
        attention_drop = self.drop(attention)
Tri Dao's avatar
Tri Dao committed
252
        output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
253
254
255
256
257
258
259
260
261
262
263
264
265
        return output


class CrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
266

Tri Dao's avatar
Tri Dao committed
267
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
268
269
270
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
271
        self.drop = nn.Dropout(attention_dropout)
272

Tri Dao's avatar
Tri Dao committed
273
    def forward(self, q, kv, causal=None, key_padding_mask=None):
274
275
276
277
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
Tri Dao's avatar
Tri Dao committed
278
            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
Tri Dao's avatar
Tri Dao committed
279
            causal: if passed, will override self.causal
280
281
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, Sk)
282
283
        """
        batch_size, seqlen_q = q.shape[0], q.shape[1]
Tri Dao's avatar
Tri Dao committed
284
        causal = self.causal if causal is None else causal
285
        seqlen_k = kv.shape[1]
Tri Dao's avatar
Tri Dao committed
286
287
288
        assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
        if kv.shape[3] != q.shape[2]:  # MQA/GQA
            kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
289
290
        k, v = kv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
Tri Dao's avatar
Tri Dao committed
291
        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
292
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
293
294
295
            padding_mask = torch.full(
                (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
            )
296
297
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
Tri Dao's avatar
Tri Dao committed
298
            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
Tri Dao's avatar
Tri Dao committed
299
        if causal:
300
301
302
            # causal mask needs to take into account the difference between seqlen_q and seqlen_k
            row_idx = rearrange(
                torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
Tri Dao's avatar
Tri Dao committed
303
            )
304
305
306
307
308
309
310
311
            col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
            sk = (
                seqlen_k
                if key_padding_mask is None
                else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
            )
            causal_mask = col_idx > row_idx + sk - seqlen_q
            scores = scores.masked_fill(causal_mask, -10000.0)
312
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
313
        attention_drop = self.drop(attention)
Tri Dao's avatar
Tri Dao committed
314
        output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
315
316
317
318
        return output


class LinearResidual(nn.Linear):
Tri Dao's avatar
Tri Dao committed
319
    """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
320
321
322
323
324

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return super().forward(input), input


325
def _update_kv_cache(kv, inference_params, layer_idx):
Tri Dao's avatar
Tri Dao committed
326
    """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
327
328
329
330
    # Pre-allocate memory for key-values for inference.
    num_heads, head_dim = kv.shape[-2:]
    if layer_idx not in inference_params.key_value_memory_dict:
        kv_cache = torch.empty(
Tri Dao's avatar
Tri Dao committed
331
            inference_params.max_batch_size,
332
            inference_params.max_seqlen,
Tri Dao's avatar
Tri Dao committed
333
334
335
336
337
            2,
            num_heads,
            head_dim,
            dtype=kv.dtype,
            device=kv.device,
338
339
340
        )
        inference_params.key_value_memory_dict[layer_idx] = kv_cache
    else:
341
        kv_cache = inference_params.key_value_memory_dict[layer_idx]
342
343
344
    # Adjust key and value for inference
    batch_start = inference_params.batch_size_offset
    batch_end = batch_start + kv.shape[0]
345
    sequence_start = inference_params.seqlen_offset
346
    sequence_end = sequence_start + kv.shape[1]
347
348
    assert batch_end <= kv_cache.shape[0]
    assert sequence_end <= kv_cache.shape[1]
349
350
351
    assert kv_cache is not None
    kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
    return kv_cache[batch_start:batch_end, :sequence_end, ...]
Tri Dao's avatar
Tri Dao committed
352
353


354
class MHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    """Multi-head self-attention and cross-attention"""

    def __init__(
        self,
        embed_dim,
        num_heads,
        num_heads_kv=None,
        cross_attn=False,
        qkv_proj_bias=True,
        out_proj_bias=True,
        dropout=0.0,
        softmax_scale=None,
        causal=False,
        layer_idx=None,
        dwconv=False,
        rotary_emb_dim=0,
        rotary_emb_base=10000.0,
        rotary_emb_scale_base=None,
        rotary_emb_interleaved=False,
374
        use_alibi=False,
Tri Dao's avatar
Tri Dao committed
375
376
377
378
379
380
381
        fused_bias_fc=False,
        use_flash_attn=False,
        return_residual=False,
        checkpointing=False,
        device=None,
        dtype=None,
    ) -> None:
382
        """
Tri Dao's avatar
Tri Dao committed
383
384
385
386
        num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
        return_residual: whether to return the input x along with the output. This is for
            performance reason: for post-norm architecture, returning the input allows us
            to fuse the backward of nn.Linear with the residual connection.
387
        """
Tri Dao's avatar
Tri Dao committed
388
        factory_kwargs = {"device": device, "dtype": dtype}
389
390
391
392
        super().__init__()
        self.embed_dim = embed_dim
        self.cross_attn = cross_attn
        self.causal = causal
Tri Dao's avatar
Tri Dao committed
393
        self.layer_idx = layer_idx
394
395
        self.dwconv = dwconv
        self.rotary_emb_dim = rotary_emb_dim
Tri Dao's avatar
Tri Dao committed
396
        self.use_flash_attn = use_flash_attn
397
398
        self.return_residual = return_residual
        self.checkpointing = checkpointing
399
400
401
402
403
        if use_alibi:
            assert use_flash_attn, "ALiBi code path requires flash_attn"
            alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
        else:
            alibi_slopes = None
404
405

        self.num_heads = num_heads
Tri Dao's avatar
Tri Dao committed
406
        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
Tri Dao's avatar
Tri Dao committed
407
408
409
        assert (
            self.num_heads % self.num_heads_kv == 0
        ), "num_heads must be divisible by num_heads_kv"
Tri Dao's avatar
Tri Dao committed
410
        assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
411
        self.head_dim = self.embed_dim // num_heads
Tri Dao's avatar
Tri Dao committed
412
413
        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
        kv_dim = 2 * self.head_dim * self.num_heads_kv
414
415

        if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
416
417
418
419
420
421
422
423
424
            assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
            assert RotaryEmbedding is not None, "rotary_emb is not installed"
            self.rotary_emb = RotaryEmbedding(
                self.rotary_emb_dim,
                base=rotary_emb_base,
                scale_base=rotary_emb_scale_base,
                interleaved=rotary_emb_interleaved,
                device=device,
            )
425

Tri Dao's avatar
Tri Dao committed
426
        if fused_bias_fc and FusedDense is None:
Tri Dao's avatar
Tri Dao committed
427
            raise ImportError("fused_dense is not installed")
Tri Dao's avatar
Tri Dao committed
428
        linear_cls = nn.Linear if not fused_bias_fc else FusedDense
Tri Dao's avatar
Tri Dao committed
429
430
431
        linear_resid_cls = (
            LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
        )
Tri Dao's avatar
Tri Dao committed
432
        wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
433
434
435
436
437
438
439
440
441
442
        inner_attn_cls = (
            partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
            if use_flash_attn
            else SelfAttention
        )
        inner_cross_attn_cls = (
            partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
            if use_flash_attn
            else CrossAttention
        )
443
        if not self.cross_attn:
Tri Dao's avatar
Tri Dao committed
444
            self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
445
        else:
Tri Dao's avatar
Tri Dao committed
446
            self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
447
448
449
            self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
        if self.dwconv:
            if self.num_heads_kv == self.num_heads:
Tri Dao's avatar
Tri Dao committed
450
451
452
                self.dwconv_qkv = nn.Conv1d(
                    qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
                )
453
            else:
Tri Dao's avatar
Tri Dao committed
454
455
456
457
458
                self.dwconv_q = nn.Conv1d(
                    embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
                )
                self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
        self.inner_attn = inner_attn_cls(
459
460
461
            causal=causal,
            softmax_scale=softmax_scale,
            attention_dropout=dropout,
Tri Dao's avatar
Tri Dao committed
462
463
464
465
        )
        self.inner_cross_attn = inner_cross_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
Tri Dao's avatar
Tri Dao committed
466
        self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
467

468
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
469
470
        dtype = self.out_proj.weight.dtype if dtype is None else dtype
        device = self.out_proj.weight.device
471
472
473
474
475
476
477
478
479
        return torch.empty(
            batch_size,
            max_seqlen,
            2,
            self.num_heads_kv,
            self.head_dim,
            dtype=dtype,
            device=device,
        )
480

Tri Dao's avatar
Tri Dao committed
481
    def _update_kv_cache(self, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
482
483
484
        """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
        assert not self.dwconv, "Generation does not support dwconv yet"
        assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
485
        return _update_kv_cache(kv, inference_params, self.layer_idx)
Tri Dao's avatar
Tri Dao committed
486

487
    def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
488
        """
489
490
491
        Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
        q: (batch_size, seqlen_q, nheads, head_dim)
        kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
Tri Dao's avatar
Tri Dao committed
492
        """
493
        assert inference_params is not None and inference_params.seqlen_offset > 0
494
495
496
497
        assert self.use_flash_attn
        if self.rotary_emb_dim > 0:
            assert self.rotary_emb.scale is None, "This code path does not support xPos"
            self.rotary_emb._update_cos_sin_cache(
498
                inference_params.max_seqlen, device=q.device, dtype=q.dtype
499
500
501
502
503
504
505
506
507
            )
            rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
        else:
            rotary_cos, rotary_sin = None, None
        batch = q.shape[0]
        kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
        cache_seqlens = (
            inference_params.lengths_per_sample[:batch]
            if inference_params.lengths_per_sample is not None
508
            else inference_params.seqlen_offset
Tri Dao's avatar
Tri Dao committed
509
        )
510
        alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
511
512
513
514
515
516
517
518
519
520
521
522
        context = flash_attn_with_kvcache(
            q,
            kv_cache[:, :, 0],
            kv_cache[:, :, 1],
            kv[:, :, 0],
            kv[:, :, 1],
            rotary_cos=rotary_cos,
            rotary_sin=rotary_sin,
            cache_seqlens=cache_seqlens,
            softmax_scale=self.inner_cross_attn.softmax_scale,
            causal=self.inner_cross_attn.causal,
            rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
523
            alibi_slopes=alibi_slopes,
524
525
        )
        return context
Tri Dao's avatar
Tri Dao committed
526

527
    def _update_kvcache_attention(self, q, kv, inference_params):
528
        """Write kv to inference_params, then do attention"""
529
        if (
530
            inference_params.seqlen_offset == 0
531
532
533
            or flash_attn_with_kvcache is None
            or not self.use_flash_attn
        ):
534
            # TODO: this only uses seqlen_offset and not lengths_per_sample.
535
536
537
538
539
540
541
542
            kv = self._update_kv_cache(kv, inference_params)
            return self.inner_cross_attn(q, kv)
        else:
            batch = q.shape[0]
            kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
            cache_seqlens = (
                inference_params.lengths_per_sample[:batch]
                if inference_params.lengths_per_sample is not None
543
                else inference_params.seqlen_offset
544
            )
545
            alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
546
547
548
549
550
551
552
553
554
            return flash_attn_with_kvcache(
                q,
                kv_cache[:, :, 0],
                kv_cache[:, :, 1],
                kv[:, :, 0],
                kv[:, :, 1],
                cache_seqlens=cache_seqlens,
                softmax_scale=self.inner_cross_attn.softmax_scale,
                causal=self.inner_cross_attn.causal,
555
                alibi_slopes=alibi_slopes,
556
557
            )

Tri Dao's avatar
Tri Dao committed
558
559
560
561
562
563
564
565
566
567
568
    def forward(
        self,
        x,
        x_kv=None,
        key_padding_mask=None,
        cu_seqlens=None,
        max_seqlen=None,
        mixer_subset=None,
        inference_params=None,
        **kwargs,
    ):
569
570
        """
        Arguments:
Tri Dao's avatar
Tri Dao committed
571
572
573
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
                cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
                is the is the sum of the sequence lengths in the batch.
574
            x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
Tri Dao's avatar
Tri Dao committed
575
576
577
578
579
580
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into x. Only applicable when using
                FlashAttention.
            max_seqlen: int. Maximum sequence length in the batch.
            key_padding_mask: boolean mask, True means to keep, False means to mask out.
                (batch, seqlen). Only applicable when not using FlashAttention.
581
582
583
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
Tri Dao's avatar
Tri Dao committed
584
585
            inference_params: for generation. Adapted from Megatron-LM (and Apex)
            https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
586
        """
Tri Dao's avatar
Tri Dao committed
587
588
589
590
591
592
593
594
595
596
        if cu_seqlens is not None:
            assert max_seqlen is not None
            assert key_padding_mask is None
            assert self.use_flash_attn
            assert not self.dwconv
            assert self.rotary_emb_dim == 0
        if key_padding_mask is not None:
            assert cu_seqlens is None
            assert max_seqlen is None
            assert not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
597
598
599
600
        if inference_params is not None:
            assert key_padding_mask is None
            assert cu_seqlens is None and max_seqlen is None
            assert not self.dwconv
Tri Dao's avatar
Tri Dao committed
601

Tri Dao's avatar
Tri Dao committed
602
603
604
605
606
        kwargs = (
            {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
            if self.use_flash_attn
            else {"key_padding_mask": key_padding_mask, **kwargs}
        )
607
608
609
610
611
612
        seqlen_offset = (
            0
            if inference_params is None
            else (
                inference_params.lengths_per_sample
                if inference_params.lengths_per_sample is not None
613
                else inference_params.seqlen_offset
614
615
            )
        )
616
        rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
617
        batch, seqlen = x.shape[:2]
Tri Dao's avatar
Tri Dao committed
618
        if not self.cross_attn and self.num_heads_kv == self.num_heads:
619
            assert x_kv is None and mixer_subset is None
620
621
622
623
624
            if not self.return_residual:
                qkv = self.Wqkv(x)
            else:
                qkv, x = self.Wqkv(x)
            if self.dwconv:
Tri Dao's avatar
Tri Dao committed
625
626
627
                qkv = rearrange(
                    self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
628
            qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
Tri Dao's avatar
Tri Dao committed
629
630
            if (
                inference_params is None
631
                or inference_params.seqlen_offset == 0
632
633
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
634
            ):
Tri Dao's avatar
Tri Dao committed
635
                if self.rotary_emb_dim > 0:
636
637
638
                    qkv = self.rotary_emb(
                        qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
639
640
641
642
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_attn(qkv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
643
                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
Tri Dao's avatar
Tri Dao committed
644
                else:
645
646
647
                    context = self._update_kvcache_attention(
                        qkv[:, :, 0], qkv[:, :, 1:], inference_params
                    )
Tri Dao's avatar
Tri Dao committed
648
            else:
649
650
651
                context = self._apply_rotary_update_kvcache_attention(
                    qkv[:, :, 0], qkv[:, :, 1:], inference_params
                )
652
        else:
Tri Dao's avatar
Tri Dao committed
653
654
655
656
657
658
659
660
661
662
            if self.cross_attn:
                if not self.return_residual:
                    q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
                    kv = self.Wkv(x_kv if x_kv is not None else x)
                else:
                    if x_kv is not None:
                        kv, x_kv = self.Wkv(x_kv)
                    else:
                        kv, x = self.Wkv(x)
                    q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
663
            else:
Tri Dao's avatar
Tri Dao committed
664
665
666
                assert self.num_heads_kv != self.num_heads
                if not self.return_residual:
                    qkv = self.Wqkv(x)
667
                else:
Tri Dao's avatar
Tri Dao committed
668
                    qkv, x = self.Wqkv(x)
Tri Dao's avatar
Tri Dao committed
669
670
                q = qkv[..., : self.num_heads * self.head_dim]
                kv = qkv[..., self.num_heads * self.head_dim :]
671
672
            q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
            kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
673
            if self.dwconv:
Tri Dao's avatar
Tri Dao committed
674
675
676
677
678
679
680
681
                q = rearrange(
                    self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
                kv = rearrange(
                    self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
            if (
                inference_params is None
682
                or inference_params.seqlen_offset == 0
683
684
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
685
            ):
Tri Dao's avatar
Tri Dao committed
686
                if self.rotary_emb_dim > 0:
687
688
689
                    q, kv = self.rotary_emb(
                        q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
690
691
692
693
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_cross_attn(q, kv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
694
695
696
                        context = torch.utils.checkpoint.checkpoint(
                            self.inner_cross_attn, q, kv, **kwargs
                        )
Tri Dao's avatar
Tri Dao committed
697
                else:
698
                    context = self._update_kvcache_attention(q, kv, inference_params)
699
            else:
700
                context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
701
        out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
702
        return out if not self.return_residual else (out, x)
Tri Dao's avatar
Tri Dao committed
703
704
705


class ParallelMHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
    """Multi-head self-attention and cross-attention"""

    def __init__(
        self,
        embed_dim,
        num_heads,
        process_group,
        num_heads_kv=None,
        qkv_proj_bias=True,
        out_proj_bias=True,
        dropout=0.0,
        softmax_scale=None,
        causal=False,
        layer_idx=None,
        rotary_emb_dim=0,
        rotary_emb_base=10000.0,
        rotary_emb_scale_base=None,
        rotary_emb_interleaved=False,
724
        use_alibi=False,
Tri Dao's avatar
Tri Dao committed
725
726
727
728
729
730
731
        use_flash_attn=False,
        checkpointing=False,
        sequence_parallel=True,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
Tri Dao's avatar
Tri Dao committed
732
733
734
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal
735
        self.layer_idx = layer_idx
Tri Dao's avatar
Tri Dao committed
736
737
738
        self.rotary_emb_dim = rotary_emb_dim
        self.use_flash_attn = use_flash_attn
        self.checkpointing = checkpointing
Tri Dao's avatar
Tri Dao committed
739
        self.process_group = process_group
740
741
        self.world_size = process_group.size()
        self.local_rank = torch.distributed.get_rank(process_group)
Tri Dao's avatar
Tri Dao committed
742
743

        self.num_heads = num_heads
744
745
        assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"

Tri Dao's avatar
Tri Dao committed
746
        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
Tri Dao's avatar
Tri Dao committed
747
748
749
        assert (
            self.num_heads % self.num_heads_kv == 0
        ), "num_heads must be divisible by num_heads_kv"
750

Tri Dao's avatar
Tri Dao committed
751
752
753
754
        self.num_heads_per_rank = get_dim_for_local_rank(
            self.num_heads, self.world_size, self.local_rank
        )
        self.num_heads_kv_per_rank = get_dim_for_local_rank(
755
            self.num_heads_kv, self.world_size, self.local_rank
Tri Dao's avatar
Tri Dao committed
756
        )
Tri Dao's avatar
Tri Dao committed
757
        self.head_dim = self.embed_dim // num_heads
Tri Dao's avatar
Tri Dao committed
758
        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
Tri Dao's avatar
Tri Dao committed
759

760
761
762
763
764
765
766
767
768
769
770
771
        if use_alibi:
            assert use_flash_attn, "ALiBi code path requires flash_attn"
            num_heads_local = math.ceil(self.num_heads / self.world_size)
            alibi_slopes = torch.tensor(
                get_alibi_slopes(num_heads)[
                    self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
                ],
                device=device,
            )
        else:
            alibi_slopes = None

Tri Dao's avatar
Tri Dao committed
772
        if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
773
774
775
776
777
778
779
780
            assert RotaryEmbedding is not None, "rotary_emb is not installed"
            self.rotary_emb = RotaryEmbedding(
                self.rotary_emb_dim,
                base=rotary_emb_base,
                scale_base=rotary_emb_scale_base,
                interleaved=rotary_emb_interleaved,
                device=device,
            )
Tri Dao's avatar
Tri Dao committed
781
782

        if ColumnParallelLinear is None or RowParallelLinear is None:
Tri Dao's avatar
Tri Dao committed
783
784
785
786
787
788
789
            raise ImportError("fused_dense is not installed")
        self.Wqkv = ColumnParallelLinear(
            embed_dim,
            qkv_dim,
            process_group,
            bias=qkv_proj_bias,
            sequence_parallel=sequence_parallel,
790
            multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
Tri Dao's avatar
Tri Dao committed
791
792
            **factory_kwargs,
        )
793
794
795
796
797
798
799
800
801
802
        inner_attn_cls = (
            partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
            if use_flash_attn
            else SelfAttention
        )
        inner_cross_attn_cls = (
            partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
            if use_flash_attn
            else CrossAttention
        )
Tri Dao's avatar
Tri Dao committed
803
804
805
806
807
808
809
810
811
812
813
814
        self.inner_attn = inner_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.inner_cross_attn = inner_cross_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            process_group,
            bias=out_proj_bias,
            sequence_parallel=sequence_parallel,
815
            multiple_of=self.head_dim,
Tri Dao's avatar
Tri Dao committed
816
817
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
818

819
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
Tri Dao's avatar
Tri Dao committed
820
821
        dtype = self.out_proj.weight.dtype if dtype is None else dtype
        device = self.out_proj.weight.device
822
823
824
825
826
827
828
829
830
        return torch.empty(
            batch_size,
            max_seqlen,
            2,
            self.num_heads_kv_per_rank,
            self.head_dim,
            dtype=dtype,
            device=device,
        )
Tri Dao's avatar
Tri Dao committed
831
832

    def _update_kv_cache(self, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
833
834
        """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
        assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
Tri Dao's avatar
Tri Dao committed
835
836
        return _update_kv_cache(kv, inference_params, self.layer_idx)

837
    def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
838
        """
839
840
841
        Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
        q: (batch_size, seqlen_q, nheads, head_dim)
        kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
Tri Dao's avatar
Tri Dao committed
842
        """
843
        assert inference_params is not None and inference_params.seqlen_offset > 0
844
845
846
847
        assert self.use_flash_attn
        if self.rotary_emb_dim > 0:
            assert self.rotary_emb.scale is None, "This code path does not support xPos"
            self.rotary_emb._update_cos_sin_cache(
848
                inference_params.max_seqlen, device=q.device, dtype=q.dtype
849
850
851
852
853
854
855
856
857
            )
            rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
        else:
            rotary_cos, rotary_sin = None, None
        batch = q.shape[0]
        kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
        cache_seqlens = (
            inference_params.lengths_per_sample[:batch]
            if inference_params.lengths_per_sample is not None
858
            else inference_params.seqlen_offset
859
        )
860
        alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
861
862
863
864
865
866
867
868
869
870
871
872
        context = flash_attn_with_kvcache(
            q,
            kv_cache[:, :, 0],
            kv_cache[:, :, 1],
            kv[:, :, 0],
            kv[:, :, 1],
            rotary_cos=rotary_cos,
            rotary_sin=rotary_sin,
            cache_seqlens=cache_seqlens,
            softmax_scale=self.inner_cross_attn.softmax_scale,
            causal=self.inner_cross_attn.causal,
            rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
873
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
874
        )
875
        return context
Tri Dao's avatar
Tri Dao committed
876

877
    def _update_kvcache_attention(self, q, kv, inference_params):
878
        """Write kv to inference_params, then do attention"""
879
880
        if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
            # TODO: this only uses seqlen_offset and not lengths_per_sample.
881
882
883
884
885
886
887
888
            kv = self._update_kv_cache(kv, inference_params)
            return self.inner_cross_attn(q, kv)
        else:
            batch = q.shape[0]
            kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
            cache_seqlens = (
                inference_params.lengths_per_sample[:batch]
                if inference_params.lengths_per_sample is not None
889
                else inference_params.seqlen_offset
890
            )
891
            alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
892
893
894
895
896
897
898
899
900
            context = flash_attn_with_kvcache(
                q,
                kv_cache[:, :, 0],
                kv_cache[:, :, 1],
                kv[:, :, 0],
                kv[:, :, 1],
                cache_seqlens=cache_seqlens,
                softmax_scale=self.inner_cross_attn.softmax_scale,
                causal=self.inner_cross_attn.causal,
901
                alibi_slopes=alibi_slopes,
902
903
904
            )
            return context

905
    def forward(self, x, seqlen=None, inference_params=None, **kwargs):
Tri Dao's avatar
Tri Dao committed
906
907
908
909
910
911
912
913
        """
        Arguments:
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
                If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
                split x during sequence parallel, we split the batch * seqlen dimension
                (in case batch is small).
        """
        qkv = self.Wqkv(x)
Tri Dao's avatar
Tri Dao committed
914
915
        if seqlen is not None:
            qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
916
917
918
919
920
921
        seqlen_offset = (
            0
            if inference_params is None
            else (
                inference_params.lengths_per_sample
                if inference_params.lengths_per_sample is not None
922
                else inference_params.seqlen_offset
923
924
            )
        )
925
        rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
Tri Dao's avatar
Tri Dao committed
926
        if self.num_heads_kv == self.num_heads:
Tri Dao's avatar
Tri Dao committed
927
928
929
            qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
            if (
                inference_params is None
930
                or inference_params.seqlen_offset == 0
931
932
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
933
            ):
Tri Dao's avatar
Tri Dao committed
934
                if self.rotary_emb_dim > 0:
935
936
937
                    qkv = self.rotary_emb(
                        qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
938
939
940
941
942
943
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_attn(qkv, **kwargs)
                    else:
                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
                else:
944
945
946
                    context = self._update_kvcache_attention(
                        qkv[:, :, 0], qkv[:, :, 1:], inference_params
                    )
947
            else:
948
949
950
                context = self._apply_rotary_update_kvcache_attention(
                    qkv[:, :, 0], qkv[:, :, 1:], inference_params
                )
Tri Dao's avatar
Tri Dao committed
951
        else:
Tri Dao's avatar
Tri Dao committed
952
953
954
955
956
957
958
959
960
961
962
963
964
            q = rearrange(
                qkv[..., : self.num_heads_per_rank * self.head_dim],
                "... (h d) -> ... h d",
                d=self.head_dim,
            )
            kv = rearrange(
                qkv[..., self.num_heads_per_rank * self.head_dim :],
                "... (two hkv d) -> ... two hkv d",
                two=2,
                d=self.head_dim,
            )
            if (
                inference_params is None
965
                or inference_params.seqlen_offset == 0
966
967
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
968
            ):
969
                if self.rotary_emb_dim > 0:
970
971
972
                    q, kv = self.rotary_emb(
                        q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
973
974
975
976
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_cross_attn(q, kv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
977
978
979
                        context = torch.utils.checkpoint.checkpoint(
                            self.inner_cross_attn, q, kv, **kwargs
                        )
Tri Dao's avatar
Tri Dao committed
980
                else:
981
                    context = self._update_kvcache_attention(q, kv, inference_params)
982
            else:
983
                context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
Tri Dao's avatar
Tri Dao committed
984
        context = rearrange(context, "b s h d -> b s (h d)")
Tri Dao's avatar
Tri Dao committed
985
        if seqlen is not None:
Tri Dao's avatar
Tri Dao committed
986
            context = rearrange(context, "b s d -> (b s) d")
Tri Dao's avatar
Tri Dao committed
987
988
        out = self.out_proj(context)
        return out