mha.py 40.4 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2023, Tri Dao.
2
3
4
5
6
7

import math
from functools import partial

import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
8
from einops import rearrange, repeat
9

10
11
from flash_attn.utils.distributed import get_dim_for_local_rank

12
try:
Tri Dao's avatar
Tri Dao committed
13
14
15
16
17
    from flash_attn import (
        flash_attn_kvpacked_func,
        flash_attn_qkvpacked_func,
        flash_attn_varlen_kvpacked_func,
        flash_attn_varlen_qkvpacked_func,
18
        flash_attn_with_kvcache,
Tri Dao's avatar
Tri Dao committed
19
    )
20
except ImportError:
Tri Dao's avatar
Tri Dao committed
21
    flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
22
    flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
23
    flash_attn_with_kvcache = None
24
25

try:
Tri Dao's avatar
Tri Dao committed
26
    from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
27
except ImportError:
Tri Dao's avatar
Tri Dao committed
28
    FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
29
30
31
32
33
34
35

try:
    from flash_attn.layers.rotary import RotaryEmbedding
except ImportError:
    RotaryEmbedding = None


36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
def get_alibi_slopes(nheads):
    def get_slopes_power_of_2(nheads):
        start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))
        ratio = start
        return [start * ratio**i for i in range(nheads)]

    if math.log2(nheads).is_integer():
        return get_slopes_power_of_2(nheads)
    else:
        closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
        return (
            get_slopes_power_of_2(closest_power_of_2)
            + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
        )


53
54
55
56
57
58
59
60
61
62
class FlashSelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
63

64
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None):
65
        super().__init__()
Tri Dao's avatar
Tri Dao committed
66
67
        assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
        assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
68
69
        self.causal = causal
        self.softmax_scale = softmax_scale
70
        self.drop = nn.Dropout(attention_dropout)
71
        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
72

Tri Dao's avatar
Tri Dao committed
73
    def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
74
75
76
        """Implements the multihead softmax attention.
        Arguments
        ---------
Tri Dao's avatar
Tri Dao committed
77
78
79
80
            qkv: The tensor containing the query, key, and value.
                If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
                If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
                (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
Tri Dao's avatar
Tri Dao committed
81
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
82
83
84
85
86
87
88
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into qkv.
            max_seqlen: int. Maximum sequence length in the batch.
        Returns:
        --------
            out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
                else (B, S, H, D).
89
90
91
        """
        assert qkv.dtype in [torch.float16, torch.bfloat16]
        assert qkv.is_cuda
Tri Dao's avatar
Tri Dao committed
92
        causal = self.causal if causal is None else causal
Tri Dao's avatar
Tri Dao committed
93
94
95
96
97
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
Tri Dao's avatar
Tri Dao committed
98
            return flash_attn_varlen_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
99
100
101
102
103
104
                qkv,
                cu_seqlens,
                max_seqlen,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
105
                alibi_slopes=self.alibi_slopes,
106
            )
Tri Dao's avatar
Tri Dao committed
107
        else:
Tri Dao's avatar
Tri Dao committed
108
109
110
111
112
            return flash_attn_qkvpacked_func(
                qkv,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
113
                alibi_slopes=self.alibi_slopes,
Tri Dao's avatar
Tri Dao committed
114
            )
115
116
117
118
119
120
121
122
123
124
125
126


class FlashCrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
127

128
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None):
129
        super().__init__()
Tri Dao's avatar
Tri Dao committed
130
131
        assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
        assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
132
133
        self.causal = causal
        self.softmax_scale = softmax_scale
134
        self.drop = nn.Dropout(attention_dropout)
135
        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
136

Tri Dao's avatar
Tri Dao committed
137
138
139
140
141
142
143
144
145
146
    def forward(
        self,
        q,
        kv,
        causal=None,
        cu_seqlens=None,
        max_seqlen=None,
        cu_seqlens_k=None,
        max_seqlen_k=None,
    ):
147
148
149
150
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
Tri Dao's avatar
Tri Dao committed
151
            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
Tri Dao's avatar
Tri Dao committed
152
            causal: if passed, will override self.causal
153
154
155
156
157
158
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into q.
            max_seqlen: int. Maximum sequence length in the batch of q.
            cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into kv.
            max_seqlen_k: int. Maximum sequence length in the batch of k and v.
159
160
161
        """
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.is_cuda and kv.is_cuda
Tri Dao's avatar
Tri Dao committed
162
        causal = self.causal if causal is None else causal
163
164
165
166
167
168
169
170
171
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
            assert cu_seqlens_k is not None
            assert cu_seqlens_k.dtype == torch.int32
            assert max_seqlen_k is not None
            assert isinstance(max_seqlen, int)
Tri Dao's avatar
Tri Dao committed
172
            return flash_attn_varlen_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
173
174
175
176
177
178
                q,
                kv,
                cu_seqlens,
                cu_seqlens_k,
                max_seqlen,
                max_seqlen_k,
179
                self.drop.p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
180
181
                softmax_scale=self.softmax_scale,
                causal=causal,
182
                alibi_slopes=self.alibi_slopes,
183
            )
184
185
186
        else:
            batch_size, seqlen_q = q.shape[0], q.shape[1]
            seqlen_k = kv.shape[1]
Tri Dao's avatar
Tri Dao committed
187
            assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
Tri Dao's avatar
Tri Dao committed
188
189
190
191
192
193
            return flash_attn_kvpacked_func(
                q,
                kv,
                self.drop.p if self.training else 0.0,
                causal=causal,
                softmax_scale=self.softmax_scale,
194
                alibi_slopes=self.alibi_slopes,
Tri Dao's avatar
Tri Dao committed
195
            )
196
197
198
199
200
201
202
203
204
205
206
207


class SelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
208

Tri Dao's avatar
Tri Dao committed
209
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
210
211
212
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
213
        self.drop = nn.Dropout(attention_dropout)
214

Tri Dao's avatar
Tri Dao committed
215
    def forward(self, qkv, causal=None, key_padding_mask=None):
216
217
218
219
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
Tri Dao's avatar
Tri Dao committed
220
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
221
222
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, S)
223
224
        """
        batch_size, seqlen = qkv.shape[0], qkv.shape[1]
Tri Dao's avatar
Tri Dao committed
225
        causal = self.causal if causal is None else causal
226
227
        q, k, v = qkv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
Tri Dao's avatar
Tri Dao committed
228
        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
Tri Dao's avatar
Tri Dao committed
229
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
230
231
232
            padding_mask = torch.full(
                (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
            )
Tri Dao's avatar
Tri Dao committed
233
234
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
Tri Dao's avatar
Tri Dao committed
235
            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
Tri Dao's avatar
Tri Dao committed
236
        if causal:
237
238
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
Tri Dao's avatar
Tri Dao committed
239
240
241
            causal_mask = torch.triu(
                torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
            )
242
243
244
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + causal_mask.to(dtype=scores.dtype)
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
245
        attention_drop = self.drop(attention)
Tri Dao's avatar
Tri Dao committed
246
        output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
247
248
249
250
251
252
253
254
255
256
257
258
259
        return output


class CrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
260

Tri Dao's avatar
Tri Dao committed
261
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
262
263
264
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
265
        self.drop = nn.Dropout(attention_dropout)
266

Tri Dao's avatar
Tri Dao committed
267
    def forward(self, q, kv, causal=None, key_padding_mask=None):
268
269
270
271
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
Tri Dao's avatar
Tri Dao committed
272
            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
Tri Dao's avatar
Tri Dao committed
273
            causal: if passed, will override self.causal
274
275
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, Sk)
276
277
        """
        batch_size, seqlen_q = q.shape[0], q.shape[1]
Tri Dao's avatar
Tri Dao committed
278
        causal = self.causal if causal is None else causal
279
        seqlen_k = kv.shape[1]
Tri Dao's avatar
Tri Dao committed
280
281
282
        assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
        if kv.shape[3] != q.shape[2]:  # MQA/GQA
            kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
283
284
        k, v = kv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
Tri Dao's avatar
Tri Dao committed
285
        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
286
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
287
288
289
            padding_mask = torch.full(
                (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
            )
290
291
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
Tri Dao's avatar
Tri Dao committed
292
            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
Tri Dao's avatar
Tri Dao committed
293
        if causal:
294
295
296
            # causal mask needs to take into account the difference between seqlen_q and seqlen_k
            row_idx = rearrange(
                torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
Tri Dao's avatar
Tri Dao committed
297
            )
298
299
300
301
302
303
304
305
            col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
            sk = (
                seqlen_k
                if key_padding_mask is None
                else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
            )
            causal_mask = col_idx > row_idx + sk - seqlen_q
            scores = scores.masked_fill(causal_mask, -10000.0)
306
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
307
        attention_drop = self.drop(attention)
Tri Dao's avatar
Tri Dao committed
308
        output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
309
310
311
312
        return output


class LinearResidual(nn.Linear):
Tri Dao's avatar
Tri Dao committed
313
    """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
314
315
316
317
318

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return super().forward(input), input


319
def _update_kv_cache(kv, inference_params, layer_idx):
Tri Dao's avatar
Tri Dao committed
320
    """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
321
322
323
324
    # Pre-allocate memory for key-values for inference.
    num_heads, head_dim = kv.shape[-2:]
    if layer_idx not in inference_params.key_value_memory_dict:
        kv_cache = torch.empty(
Tri Dao's avatar
Tri Dao committed
325
            inference_params.max_batch_size,
326
            inference_params.max_seqlen,
Tri Dao's avatar
Tri Dao committed
327
328
329
330
331
            2,
            num_heads,
            head_dim,
            dtype=kv.dtype,
            device=kv.device,
332
333
334
        )
        inference_params.key_value_memory_dict[layer_idx] = kv_cache
    else:
335
        kv_cache = inference_params.key_value_memory_dict[layer_idx]
336
337
338
    # Adjust key and value for inference
    batch_start = inference_params.batch_size_offset
    batch_end = batch_start + kv.shape[0]
339
    sequence_start = inference_params.seqlen_offset
340
    sequence_end = sequence_start + kv.shape[1]
341
342
    assert batch_end <= kv_cache.shape[0]
    assert sequence_end <= kv_cache.shape[1]
343
344
345
    assert kv_cache is not None
    kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
    return kv_cache[batch_start:batch_end, :sequence_end, ...]
Tri Dao's avatar
Tri Dao committed
346
347


348
class MHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    """Multi-head self-attention and cross-attention"""

    def __init__(
        self,
        embed_dim,
        num_heads,
        num_heads_kv=None,
        cross_attn=False,
        qkv_proj_bias=True,
        out_proj_bias=True,
        dropout=0.0,
        softmax_scale=None,
        causal=False,
        layer_idx=None,
        dwconv=False,
        rotary_emb_dim=0,
        rotary_emb_base=10000.0,
        rotary_emb_scale_base=None,
        rotary_emb_interleaved=False,
368
        use_alibi=False,
Tri Dao's avatar
Tri Dao committed
369
370
371
372
373
374
375
        fused_bias_fc=False,
        use_flash_attn=False,
        return_residual=False,
        checkpointing=False,
        device=None,
        dtype=None,
    ) -> None:
376
        """
Tri Dao's avatar
Tri Dao committed
377
378
379
380
        num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
        return_residual: whether to return the input x along with the output. This is for
            performance reason: for post-norm architecture, returning the input allows us
            to fuse the backward of nn.Linear with the residual connection.
381
        """
Tri Dao's avatar
Tri Dao committed
382
        factory_kwargs = {"device": device, "dtype": dtype}
383
384
385
386
        super().__init__()
        self.embed_dim = embed_dim
        self.cross_attn = cross_attn
        self.causal = causal
Tri Dao's avatar
Tri Dao committed
387
        self.layer_idx = layer_idx
388
389
        self.dwconv = dwconv
        self.rotary_emb_dim = rotary_emb_dim
Tri Dao's avatar
Tri Dao committed
390
        self.use_flash_attn = use_flash_attn
391
392
        self.return_residual = return_residual
        self.checkpointing = checkpointing
393
394
395
396
397
        if use_alibi:
            assert use_flash_attn, "ALiBi code path requires flash_attn"
            alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
        else:
            alibi_slopes = None
398
399

        self.num_heads = num_heads
Tri Dao's avatar
Tri Dao committed
400
        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
Tri Dao's avatar
Tri Dao committed
401
402
403
        assert (
            self.num_heads % self.num_heads_kv == 0
        ), "num_heads must be divisible by num_heads_kv"
Tri Dao's avatar
Tri Dao committed
404
        assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
405
        self.head_dim = self.embed_dim // num_heads
Tri Dao's avatar
Tri Dao committed
406
407
        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
        kv_dim = 2 * self.head_dim * self.num_heads_kv
408
409

        if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
410
411
412
413
414
415
416
417
418
            assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
            assert RotaryEmbedding is not None, "rotary_emb is not installed"
            self.rotary_emb = RotaryEmbedding(
                self.rotary_emb_dim,
                base=rotary_emb_base,
                scale_base=rotary_emb_scale_base,
                interleaved=rotary_emb_interleaved,
                device=device,
            )
419

Tri Dao's avatar
Tri Dao committed
420
        if fused_bias_fc and FusedDense is None:
Tri Dao's avatar
Tri Dao committed
421
            raise ImportError("fused_dense is not installed")
Tri Dao's avatar
Tri Dao committed
422
        linear_cls = nn.Linear if not fused_bias_fc else FusedDense
Tri Dao's avatar
Tri Dao committed
423
424
425
        linear_resid_cls = (
            LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
        )
Tri Dao's avatar
Tri Dao committed
426
        wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
427
428
429
430
431
432
433
434
435
436
        inner_attn_cls = (
            partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
            if use_flash_attn
            else SelfAttention
        )
        inner_cross_attn_cls = (
            partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
            if use_flash_attn
            else CrossAttention
        )
437
        if not self.cross_attn:
Tri Dao's avatar
Tri Dao committed
438
            self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
439
        else:
Tri Dao's avatar
Tri Dao committed
440
            self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
441
442
443
            self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
        if self.dwconv:
            if self.num_heads_kv == self.num_heads:
Tri Dao's avatar
Tri Dao committed
444
445
446
                self.dwconv_qkv = nn.Conv1d(
                    qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
                )
447
            else:
Tri Dao's avatar
Tri Dao committed
448
449
450
451
452
                self.dwconv_q = nn.Conv1d(
                    embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
                )
                self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
        self.inner_attn = inner_attn_cls(
453
454
455
            causal=causal,
            softmax_scale=softmax_scale,
            attention_dropout=dropout,
Tri Dao's avatar
Tri Dao committed
456
457
458
459
        )
        self.inner_cross_attn = inner_cross_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
Tri Dao's avatar
Tri Dao committed
460
        self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
461

462
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
463
464
        dtype = self.out_proj.weight.dtype if dtype is None else dtype
        device = self.out_proj.weight.device
465
466
467
468
469
470
471
472
473
        return torch.empty(
            batch_size,
            max_seqlen,
            2,
            self.num_heads_kv,
            self.head_dim,
            dtype=dtype,
            device=device,
        )
474

Tri Dao's avatar
Tri Dao committed
475
    def _update_kv_cache(self, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
476
477
478
        """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
        assert not self.dwconv, "Generation does not support dwconv yet"
        assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
479
        return _update_kv_cache(kv, inference_params, self.layer_idx)
Tri Dao's avatar
Tri Dao committed
480

481
    def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
482
        """
483
484
485
        Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
        q: (batch_size, seqlen_q, nheads, head_dim)
        kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
Tri Dao's avatar
Tri Dao committed
486
        """
487
        assert inference_params is not None and inference_params.seqlen_offset > 0
488
489
490
491
        assert self.use_flash_attn
        if self.rotary_emb_dim > 0:
            assert self.rotary_emb.scale is None, "This code path does not support xPos"
            self.rotary_emb._update_cos_sin_cache(
492
                inference_params.max_seqlen, device=q.device, dtype=q.dtype
493
494
495
496
497
498
499
500
501
            )
            rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
        else:
            rotary_cos, rotary_sin = None, None
        batch = q.shape[0]
        kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
        cache_seqlens = (
            inference_params.lengths_per_sample[:batch]
            if inference_params.lengths_per_sample is not None
502
            else inference_params.seqlen_offset
Tri Dao's avatar
Tri Dao committed
503
        )
504
505
506
507
508
509
510
511
512
513
514
515
516
517
        context = flash_attn_with_kvcache(
            q,
            kv_cache[:, :, 0],
            kv_cache[:, :, 1],
            kv[:, :, 0],
            kv[:, :, 1],
            rotary_cos=rotary_cos,
            rotary_sin=rotary_sin,
            cache_seqlens=cache_seqlens,
            softmax_scale=self.inner_cross_attn.softmax_scale,
            causal=self.inner_cross_attn.causal,
            rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
        )
        return context
Tri Dao's avatar
Tri Dao committed
518

519
    def _update_kvcache_attention(self, q, kv, inference_params):
520
        """Write kv to inference_params, then do attention"""
521
        if (
522
            inference_params.seqlen_offset == 0
523
524
525
            or flash_attn_with_kvcache is None
            or not self.use_flash_attn
        ):
526
            # TODO: this only uses seqlen_offset and not lengths_per_sample.
527
528
529
530
531
532
533
534
            kv = self._update_kv_cache(kv, inference_params)
            return self.inner_cross_attn(q, kv)
        else:
            batch = q.shape[0]
            kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
            cache_seqlens = (
                inference_params.lengths_per_sample[:batch]
                if inference_params.lengths_per_sample is not None
535
                else inference_params.seqlen_offset
536
537
538
539
540
541
542
543
544
545
546
547
            )
            return flash_attn_with_kvcache(
                q,
                kv_cache[:, :, 0],
                kv_cache[:, :, 1],
                kv[:, :, 0],
                kv[:, :, 1],
                cache_seqlens=cache_seqlens,
                softmax_scale=self.inner_cross_attn.softmax_scale,
                causal=self.inner_cross_attn.causal,
            )

Tri Dao's avatar
Tri Dao committed
548
549
550
551
552
553
554
555
556
557
558
    def forward(
        self,
        x,
        x_kv=None,
        key_padding_mask=None,
        cu_seqlens=None,
        max_seqlen=None,
        mixer_subset=None,
        inference_params=None,
        **kwargs,
    ):
559
560
        """
        Arguments:
Tri Dao's avatar
Tri Dao committed
561
562
563
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
                cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
                is the is the sum of the sequence lengths in the batch.
564
            x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
Tri Dao's avatar
Tri Dao committed
565
566
567
568
569
570
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into x. Only applicable when using
                FlashAttention.
            max_seqlen: int. Maximum sequence length in the batch.
            key_padding_mask: boolean mask, True means to keep, False means to mask out.
                (batch, seqlen). Only applicable when not using FlashAttention.
571
572
573
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
Tri Dao's avatar
Tri Dao committed
574
575
            inference_params: for generation. Adapted from Megatron-LM (and Apex)
            https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
576
        """
Tri Dao's avatar
Tri Dao committed
577
578
579
580
581
582
583
584
585
586
        if cu_seqlens is not None:
            assert max_seqlen is not None
            assert key_padding_mask is None
            assert self.use_flash_attn
            assert not self.dwconv
            assert self.rotary_emb_dim == 0
        if key_padding_mask is not None:
            assert cu_seqlens is None
            assert max_seqlen is None
            assert not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
587
588
589
590
        if inference_params is not None:
            assert key_padding_mask is None
            assert cu_seqlens is None and max_seqlen is None
            assert not self.dwconv
Tri Dao's avatar
Tri Dao committed
591

Tri Dao's avatar
Tri Dao committed
592
593
594
595
596
        kwargs = (
            {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
            if self.use_flash_attn
            else {"key_padding_mask": key_padding_mask, **kwargs}
        )
597
598
599
600
601
602
        seqlen_offset = (
            0
            if inference_params is None
            else (
                inference_params.lengths_per_sample
                if inference_params.lengths_per_sample is not None
603
                else inference_params.seqlen_offset
604
605
            )
        )
606
        rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
607
        batch, seqlen = x.shape[:2]
Tri Dao's avatar
Tri Dao committed
608
        if not self.cross_attn and self.num_heads_kv == self.num_heads:
609
            assert x_kv is None and mixer_subset is None
610
611
612
613
614
            if not self.return_residual:
                qkv = self.Wqkv(x)
            else:
                qkv, x = self.Wqkv(x)
            if self.dwconv:
Tri Dao's avatar
Tri Dao committed
615
616
617
                qkv = rearrange(
                    self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
618
            qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
Tri Dao's avatar
Tri Dao committed
619
620
            if (
                inference_params is None
621
                or inference_params.seqlen_offset == 0
622
623
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
624
            ):
Tri Dao's avatar
Tri Dao committed
625
                if self.rotary_emb_dim > 0:
626
627
628
                    qkv = self.rotary_emb(
                        qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
629
630
631
632
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_attn(qkv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
633
                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
Tri Dao's avatar
Tri Dao committed
634
                else:
635
636
637
                    context = self._update_kvcache_attention(
                        qkv[:, :, 0], qkv[:, :, 1:], inference_params
                    )
Tri Dao's avatar
Tri Dao committed
638
            else:
639
640
641
                context = self._apply_rotary_update_kvcache_attention(
                    qkv[:, :, 0], qkv[:, :, 1:], inference_params
                )
642
        else:
Tri Dao's avatar
Tri Dao committed
643
644
645
646
647
648
649
650
651
652
            if self.cross_attn:
                if not self.return_residual:
                    q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
                    kv = self.Wkv(x_kv if x_kv is not None else x)
                else:
                    if x_kv is not None:
                        kv, x_kv = self.Wkv(x_kv)
                    else:
                        kv, x = self.Wkv(x)
                    q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
653
            else:
Tri Dao's avatar
Tri Dao committed
654
655
656
                assert self.num_heads_kv != self.num_heads
                if not self.return_residual:
                    qkv = self.Wqkv(x)
657
                else:
Tri Dao's avatar
Tri Dao committed
658
                    qkv, x = self.Wqkv(x)
Tri Dao's avatar
Tri Dao committed
659
660
                q = qkv[..., : self.num_heads * self.head_dim]
                kv = qkv[..., self.num_heads * self.head_dim :]
661
662
            q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
            kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
663
            if self.dwconv:
Tri Dao's avatar
Tri Dao committed
664
665
666
667
668
669
670
671
                q = rearrange(
                    self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
                kv = rearrange(
                    self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
            if (
                inference_params is None
672
                or inference_params.seqlen_offset == 0
673
674
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
675
            ):
Tri Dao's avatar
Tri Dao committed
676
                if self.rotary_emb_dim > 0:
677
678
679
                    q, kv = self.rotary_emb(
                        q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
680
681
682
683
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_cross_attn(q, kv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
684
685
686
                        context = torch.utils.checkpoint.checkpoint(
                            self.inner_cross_attn, q, kv, **kwargs
                        )
Tri Dao's avatar
Tri Dao committed
687
                else:
688
                    context = self._update_kvcache_attention(q, kv, inference_params)
689
            else:
690
                context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
691
        out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
692
        return out if not self.return_residual else (out, x)
Tri Dao's avatar
Tri Dao committed
693
694
695


class ParallelMHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
    """Multi-head self-attention and cross-attention"""

    def __init__(
        self,
        embed_dim,
        num_heads,
        process_group,
        num_heads_kv=None,
        qkv_proj_bias=True,
        out_proj_bias=True,
        dropout=0.0,
        softmax_scale=None,
        causal=False,
        layer_idx=None,
        rotary_emb_dim=0,
        rotary_emb_base=10000.0,
        rotary_emb_scale_base=None,
        rotary_emb_interleaved=False,
714
        use_alibi=False,
Tri Dao's avatar
Tri Dao committed
715
716
717
718
719
720
721
        use_flash_attn=False,
        checkpointing=False,
        sequence_parallel=True,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
Tri Dao's avatar
Tri Dao committed
722
723
724
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal
725
        self.layer_idx = layer_idx
Tri Dao's avatar
Tri Dao committed
726
727
728
        self.rotary_emb_dim = rotary_emb_dim
        self.use_flash_attn = use_flash_attn
        self.checkpointing = checkpointing
Tri Dao's avatar
Tri Dao committed
729
        self.process_group = process_group
730
731
        self.world_size = process_group.size()
        self.local_rank = torch.distributed.get_rank(process_group)
Tri Dao's avatar
Tri Dao committed
732
733

        self.num_heads = num_heads
734
735
        assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"

Tri Dao's avatar
Tri Dao committed
736
        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
Tri Dao's avatar
Tri Dao committed
737
738
739
        assert (
            self.num_heads % self.num_heads_kv == 0
        ), "num_heads must be divisible by num_heads_kv"
740

Tri Dao's avatar
Tri Dao committed
741
742
743
744
        self.num_heads_per_rank = get_dim_for_local_rank(
            self.num_heads, self.world_size, self.local_rank
        )
        self.num_heads_kv_per_rank = get_dim_for_local_rank(
745
            self.num_heads_kv, self.world_size, self.local_rank
Tri Dao's avatar
Tri Dao committed
746
        )
Tri Dao's avatar
Tri Dao committed
747
        self.head_dim = self.embed_dim // num_heads
Tri Dao's avatar
Tri Dao committed
748
        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
Tri Dao's avatar
Tri Dao committed
749

750
751
752
753
754
755
756
757
758
759
760
761
        if use_alibi:
            assert use_flash_attn, "ALiBi code path requires flash_attn"
            num_heads_local = math.ceil(self.num_heads / self.world_size)
            alibi_slopes = torch.tensor(
                get_alibi_slopes(num_heads)[
                    self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
                ],
                device=device,
            )
        else:
            alibi_slopes = None

Tri Dao's avatar
Tri Dao committed
762
        if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
763
764
765
766
767
768
769
770
            assert RotaryEmbedding is not None, "rotary_emb is not installed"
            self.rotary_emb = RotaryEmbedding(
                self.rotary_emb_dim,
                base=rotary_emb_base,
                scale_base=rotary_emb_scale_base,
                interleaved=rotary_emb_interleaved,
                device=device,
            )
Tri Dao's avatar
Tri Dao committed
771
772

        if ColumnParallelLinear is None or RowParallelLinear is None:
Tri Dao's avatar
Tri Dao committed
773
774
775
776
777
778
779
            raise ImportError("fused_dense is not installed")
        self.Wqkv = ColumnParallelLinear(
            embed_dim,
            qkv_dim,
            process_group,
            bias=qkv_proj_bias,
            sequence_parallel=sequence_parallel,
780
            multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
Tri Dao's avatar
Tri Dao committed
781
782
            **factory_kwargs,
        )
783
784
785
786
787
788
789
790
791
792
        inner_attn_cls = (
            partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
            if use_flash_attn
            else SelfAttention
        )
        inner_cross_attn_cls = (
            partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
            if use_flash_attn
            else CrossAttention
        )
Tri Dao's avatar
Tri Dao committed
793
794
795
796
797
798
799
800
801
802
803
804
        self.inner_attn = inner_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.inner_cross_attn = inner_cross_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            process_group,
            bias=out_proj_bias,
            sequence_parallel=sequence_parallel,
805
            multiple_of=self.head_dim,
Tri Dao's avatar
Tri Dao committed
806
807
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
808

809
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
Tri Dao's avatar
Tri Dao committed
810
811
        dtype = self.out_proj.weight.dtype if dtype is None else dtype
        device = self.out_proj.weight.device
812
813
814
815
816
817
818
819
820
        return torch.empty(
            batch_size,
            max_seqlen,
            2,
            self.num_heads_kv_per_rank,
            self.head_dim,
            dtype=dtype,
            device=device,
        )
Tri Dao's avatar
Tri Dao committed
821
822

    def _update_kv_cache(self, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
823
824
        """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
        assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
Tri Dao's avatar
Tri Dao committed
825
826
        return _update_kv_cache(kv, inference_params, self.layer_idx)

827
    def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
828
        """
829
830
831
        Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
        q: (batch_size, seqlen_q, nheads, head_dim)
        kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
Tri Dao's avatar
Tri Dao committed
832
        """
833
        assert inference_params is not None and inference_params.seqlen_offset > 0
834
835
836
837
        assert self.use_flash_attn
        if self.rotary_emb_dim > 0:
            assert self.rotary_emb.scale is None, "This code path does not support xPos"
            self.rotary_emb._update_cos_sin_cache(
838
                inference_params.max_seqlen, device=q.device, dtype=q.dtype
839
840
841
842
843
844
845
846
847
            )
            rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
        else:
            rotary_cos, rotary_sin = None, None
        batch = q.shape[0]
        kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
        cache_seqlens = (
            inference_params.lengths_per_sample[:batch]
            if inference_params.lengths_per_sample is not None
848
            else inference_params.seqlen_offset
849
850
851
852
853
854
855
856
857
858
859
860
861
        )
        context = flash_attn_with_kvcache(
            q,
            kv_cache[:, :, 0],
            kv_cache[:, :, 1],
            kv[:, :, 0],
            kv[:, :, 1],
            rotary_cos=rotary_cos,
            rotary_sin=rotary_sin,
            cache_seqlens=cache_seqlens,
            softmax_scale=self.inner_cross_attn.softmax_scale,
            causal=self.inner_cross_attn.causal,
            rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
Tri Dao's avatar
Tri Dao committed
862
        )
863
        return context
Tri Dao's avatar
Tri Dao committed
864

865
    def _update_kvcache_attention(self, q, kv, inference_params):
866
        """Write kv to inference_params, then do attention"""
867
868
        if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
            # TODO: this only uses seqlen_offset and not lengths_per_sample.
869
870
871
872
873
874
875
876
            kv = self._update_kv_cache(kv, inference_params)
            return self.inner_cross_attn(q, kv)
        else:
            batch = q.shape[0]
            kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
            cache_seqlens = (
                inference_params.lengths_per_sample[:batch]
                if inference_params.lengths_per_sample is not None
877
                else inference_params.seqlen_offset
878
879
880
881
882
883
884
885
886
887
888
889
890
            )
            context = flash_attn_with_kvcache(
                q,
                kv_cache[:, :, 0],
                kv_cache[:, :, 1],
                kv[:, :, 0],
                kv[:, :, 1],
                cache_seqlens=cache_seqlens,
                softmax_scale=self.inner_cross_attn.softmax_scale,
                causal=self.inner_cross_attn.causal,
            )
            return context

891
    def forward(self, x, seqlen=None, inference_params=None, **kwargs):
Tri Dao's avatar
Tri Dao committed
892
893
894
895
896
897
898
899
        """
        Arguments:
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
                If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
                split x during sequence parallel, we split the batch * seqlen dimension
                (in case batch is small).
        """
        qkv = self.Wqkv(x)
Tri Dao's avatar
Tri Dao committed
900
901
        if seqlen is not None:
            qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
902
903
904
905
906
907
        seqlen_offset = (
            0
            if inference_params is None
            else (
                inference_params.lengths_per_sample
                if inference_params.lengths_per_sample is not None
908
                else inference_params.seqlen_offset
909
910
            )
        )
911
        rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
Tri Dao's avatar
Tri Dao committed
912
        if self.num_heads_kv == self.num_heads:
Tri Dao's avatar
Tri Dao committed
913
914
915
            qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
            if (
                inference_params is None
916
                or inference_params.seqlen_offset == 0
917
918
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
919
            ):
Tri Dao's avatar
Tri Dao committed
920
                if self.rotary_emb_dim > 0:
921
922
923
                    qkv = self.rotary_emb(
                        qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
924
925
926
927
928
929
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_attn(qkv, **kwargs)
                    else:
                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
                else:
930
931
932
                    context = self._update_kvcache_attention(
                        qkv[:, :, 0], qkv[:, :, 1:], inference_params
                    )
933
            else:
934
935
936
                context = self._apply_rotary_update_kvcache_attention(
                    qkv[:, :, 0], qkv[:, :, 1:], inference_params
                )
Tri Dao's avatar
Tri Dao committed
937
        else:
Tri Dao's avatar
Tri Dao committed
938
939
940
941
942
943
944
945
946
947
948
949
950
            q = rearrange(
                qkv[..., : self.num_heads_per_rank * self.head_dim],
                "... (h d) -> ... h d",
                d=self.head_dim,
            )
            kv = rearrange(
                qkv[..., self.num_heads_per_rank * self.head_dim :],
                "... (two hkv d) -> ... two hkv d",
                two=2,
                d=self.head_dim,
            )
            if (
                inference_params is None
951
                or inference_params.seqlen_offset == 0
952
953
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
954
            ):
955
                if self.rotary_emb_dim > 0:
956
957
958
                    q, kv = self.rotary_emb(
                        q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
959
960
961
962
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_cross_attn(q, kv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
963
964
965
                        context = torch.utils.checkpoint.checkpoint(
                            self.inner_cross_attn, q, kv, **kwargs
                        )
Tri Dao's avatar
Tri Dao committed
966
                else:
967
                    context = self._update_kvcache_attention(q, kv, inference_params)
968
            else:
969
                context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
Tri Dao's avatar
Tri Dao committed
970
        context = rearrange(context, "b s h d -> b s (h d)")
Tri Dao's avatar
Tri Dao committed
971
        if seqlen is not None:
Tri Dao's avatar
Tri Dao committed
972
            context = rearrange(context, "b s d -> (b s) d")
Tri Dao's avatar
Tri Dao committed
973
974
        out = self.out_proj(context)
        return out