mha.py 42.1 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2023, Tri Dao.
2
3
4
5
6
7

import math
from functools import partial

import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
8
from einops import rearrange, repeat
9

10
11
from flash_attn.utils.distributed import get_dim_for_local_rank

12
try:
Tri Dao's avatar
Tri Dao committed
13
14
15
16
17
    from flash_attn import (
        flash_attn_kvpacked_func,
        flash_attn_qkvpacked_func,
        flash_attn_varlen_kvpacked_func,
        flash_attn_varlen_qkvpacked_func,
18
        flash_attn_with_kvcache,
Tri Dao's avatar
Tri Dao committed
19
    )
20
except ImportError:
Tri Dao's avatar
Tri Dao committed
21
    flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
22
    flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
23
    flash_attn_with_kvcache = None
24
25

try:
Tri Dao's avatar
Tri Dao committed
26
    from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
27
except ImportError:
Tri Dao's avatar
Tri Dao committed
28
    FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
29
30
31
32
33
34
35

try:
    from flash_attn.layers.rotary import RotaryEmbedding
except ImportError:
    RotaryEmbedding = None


36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
def get_alibi_slopes(nheads):
    def get_slopes_power_of_2(nheads):
        start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))
        ratio = start
        return [start * ratio**i for i in range(nheads)]

    if math.log2(nheads).is_integer():
        return get_slopes_power_of_2(nheads)
    else:
        closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
        return (
            get_slopes_power_of_2(closest_power_of_2)
            + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
        )


53
54
55
56
57
58
59
60
61
62
class FlashSelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
63

64
65
66
67
68
69
70
71
72
    def __init__(
        self,
        causal=False,
        softmax_scale=None,
        attention_dropout=0.0,
        window_size=(-1, -1),
        alibi_slopes=None,
        deterministic=False,
    ):
73
        super().__init__()
Tri Dao's avatar
Tri Dao committed
74
75
        assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
        assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
76
77
        self.causal = causal
        self.softmax_scale = softmax_scale
78
        self.drop = nn.Dropout(attention_dropout)
79
        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
80
        self.window_size = window_size
81
        self.deterministic = deterministic
82

Tri Dao's avatar
Tri Dao committed
83
    def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
84
85
86
        """Implements the multihead softmax attention.
        Arguments
        ---------
Tri Dao's avatar
Tri Dao committed
87
88
89
90
            qkv: The tensor containing the query, key, and value.
                If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
                If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
                (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
Tri Dao's avatar
Tri Dao committed
91
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
92
93
94
95
96
97
98
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into qkv.
            max_seqlen: int. Maximum sequence length in the batch.
        Returns:
        --------
            out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
                else (B, S, H, D).
99
100
101
        """
        assert qkv.dtype in [torch.float16, torch.bfloat16]
        assert qkv.is_cuda
Tri Dao's avatar
Tri Dao committed
102
        causal = self.causal if causal is None else causal
Tri Dao's avatar
Tri Dao committed
103
104
105
106
107
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
Tri Dao's avatar
Tri Dao committed
108
            return flash_attn_varlen_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
109
110
111
112
113
114
                qkv,
                cu_seqlens,
                max_seqlen,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
115
                alibi_slopes=self.alibi_slopes,
116
                window_size=self.window_size,
117
                deterministic=self.deterministic,
118
            )
Tri Dao's avatar
Tri Dao committed
119
        else:
Tri Dao's avatar
Tri Dao committed
120
121
122
123
124
            return flash_attn_qkvpacked_func(
                qkv,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
125
                alibi_slopes=self.alibi_slopes,
126
                window_size=self.window_size,
127
                deterministic=self.deterministic,
Tri Dao's avatar
Tri Dao committed
128
            )
129
130
131
132
133
134
135
136
137
138
139
140


class FlashCrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
141

142
143
144
145
146
147
148
149
150
    def __init__(
        self,
        causal=False,
        softmax_scale=None,
        attention_dropout=0.0,
        alibi_slopes=None,
        window_size=(-1, -1),
        deterministic=False,
    ):
151
        super().__init__()
Tri Dao's avatar
Tri Dao committed
152
153
        assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
        assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
154
155
        self.causal = causal
        self.softmax_scale = softmax_scale
156
        self.drop = nn.Dropout(attention_dropout)
157
        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
158
        self.window_size = window_size
159
        self.deterministic = deterministic
160

Tri Dao's avatar
Tri Dao committed
161
162
163
164
165
166
167
168
169
170
    def forward(
        self,
        q,
        kv,
        causal=None,
        cu_seqlens=None,
        max_seqlen=None,
        cu_seqlens_k=None,
        max_seqlen_k=None,
    ):
171
172
173
174
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
Tri Dao's avatar
Tri Dao committed
175
            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
Tri Dao's avatar
Tri Dao committed
176
            causal: if passed, will override self.causal
177
178
179
180
181
182
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into q.
            max_seqlen: int. Maximum sequence length in the batch of q.
            cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into kv.
            max_seqlen_k: int. Maximum sequence length in the batch of k and v.
183
184
185
        """
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.is_cuda and kv.is_cuda
Tri Dao's avatar
Tri Dao committed
186
        causal = self.causal if causal is None else causal
187
188
189
190
191
192
193
194
195
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
            assert cu_seqlens_k is not None
            assert cu_seqlens_k.dtype == torch.int32
            assert max_seqlen_k is not None
            assert isinstance(max_seqlen, int)
Tri Dao's avatar
Tri Dao committed
196
            return flash_attn_varlen_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
197
198
199
200
201
202
                q,
                kv,
                cu_seqlens,
                cu_seqlens_k,
                max_seqlen,
                max_seqlen_k,
203
                self.drop.p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
204
205
                softmax_scale=self.softmax_scale,
                causal=causal,
206
                alibi_slopes=self.alibi_slopes,
207
                window_size=self.window_size,
208
                deterministic=self.deterministic,
209
            )
210
211
212
        else:
            batch_size, seqlen_q = q.shape[0], q.shape[1]
            seqlen_k = kv.shape[1]
Tri Dao's avatar
Tri Dao committed
213
            assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
Tri Dao's avatar
Tri Dao committed
214
215
216
217
218
219
            return flash_attn_kvpacked_func(
                q,
                kv,
                self.drop.p if self.training else 0.0,
                causal=causal,
                softmax_scale=self.softmax_scale,
220
                alibi_slopes=self.alibi_slopes,
221
                window_size=self.window_size,
222
                deterministic=self.deterministic,
Tri Dao's avatar
Tri Dao committed
223
            )
224
225
226
227
228
229
230
231
232
233
234
235


class SelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
236

Tri Dao's avatar
Tri Dao committed
237
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
238
239
240
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
241
        self.drop = nn.Dropout(attention_dropout)
242

Tri Dao's avatar
Tri Dao committed
243
    def forward(self, qkv, causal=None, key_padding_mask=None):
244
245
246
247
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
Tri Dao's avatar
Tri Dao committed
248
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
249
250
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, S)
251
252
        """
        batch_size, seqlen = qkv.shape[0], qkv.shape[1]
Tri Dao's avatar
Tri Dao committed
253
        causal = self.causal if causal is None else causal
254
255
        q, k, v = qkv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
Tri Dao's avatar
Tri Dao committed
256
        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
Tri Dao's avatar
Tri Dao committed
257
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
258
259
260
            padding_mask = torch.full(
                (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
            )
Tri Dao's avatar
Tri Dao committed
261
262
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
Tri Dao's avatar
Tri Dao committed
263
            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
Tri Dao's avatar
Tri Dao committed
264
        if causal:
265
266
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
Tri Dao's avatar
Tri Dao committed
267
268
269
            causal_mask = torch.triu(
                torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
            )
270
271
272
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + causal_mask.to(dtype=scores.dtype)
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
273
        attention_drop = self.drop(attention)
Tri Dao's avatar
Tri Dao committed
274
        output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
275
276
277
278
279
280
281
282
283
284
285
286
287
        return output


class CrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
288

Tri Dao's avatar
Tri Dao committed
289
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
290
291
292
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
293
        self.drop = nn.Dropout(attention_dropout)
294

Tri Dao's avatar
Tri Dao committed
295
    def forward(self, q, kv, causal=None, key_padding_mask=None):
296
297
298
299
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
Tri Dao's avatar
Tri Dao committed
300
            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
Tri Dao's avatar
Tri Dao committed
301
            causal: if passed, will override self.causal
302
303
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, Sk)
304
305
        """
        batch_size, seqlen_q = q.shape[0], q.shape[1]
Tri Dao's avatar
Tri Dao committed
306
        causal = self.causal if causal is None else causal
307
        seqlen_k = kv.shape[1]
Tri Dao's avatar
Tri Dao committed
308
309
310
        assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
        if kv.shape[3] != q.shape[2]:  # MQA/GQA
            kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
311
312
        k, v = kv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
Tri Dao's avatar
Tri Dao committed
313
        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
314
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
315
316
317
            padding_mask = torch.full(
                (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
            )
318
319
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
Tri Dao's avatar
Tri Dao committed
320
            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
Tri Dao's avatar
Tri Dao committed
321
        if causal:
322
323
324
            # causal mask needs to take into account the difference between seqlen_q and seqlen_k
            row_idx = rearrange(
                torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
Tri Dao's avatar
Tri Dao committed
325
            )
326
327
328
329
330
331
332
333
            col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
            sk = (
                seqlen_k
                if key_padding_mask is None
                else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
            )
            causal_mask = col_idx > row_idx + sk - seqlen_q
            scores = scores.masked_fill(causal_mask, -10000.0)
334
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
335
        attention_drop = self.drop(attention)
Tri Dao's avatar
Tri Dao committed
336
        output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
337
338
339
340
        return output


class LinearResidual(nn.Linear):
Tri Dao's avatar
Tri Dao committed
341
    """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
342
343
344
345
346

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return super().forward(input), input


347
def _update_kv_cache(kv, inference_params, layer_idx):
Tri Dao's avatar
Tri Dao committed
348
    """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
349
350
351
352
    # Pre-allocate memory for key-values for inference.
    num_heads, head_dim = kv.shape[-2:]
    if layer_idx not in inference_params.key_value_memory_dict:
        kv_cache = torch.empty(
Tri Dao's avatar
Tri Dao committed
353
            inference_params.max_batch_size,
354
            inference_params.max_seqlen,
Tri Dao's avatar
Tri Dao committed
355
356
357
358
359
            2,
            num_heads,
            head_dim,
            dtype=kv.dtype,
            device=kv.device,
360
361
362
        )
        inference_params.key_value_memory_dict[layer_idx] = kv_cache
    else:
363
        kv_cache = inference_params.key_value_memory_dict[layer_idx]
364
365
366
    # Adjust key and value for inference
    batch_start = inference_params.batch_size_offset
    batch_end = batch_start + kv.shape[0]
367
    sequence_start = inference_params.seqlen_offset
368
    sequence_end = sequence_start + kv.shape[1]
369
370
    assert batch_end <= kv_cache.shape[0]
    assert sequence_end <= kv_cache.shape[1]
371
372
373
    assert kv_cache is not None
    kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
    return kv_cache[batch_start:batch_end, :sequence_end, ...]
Tri Dao's avatar
Tri Dao committed
374
375


376
class MHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    """Multi-head self-attention and cross-attention"""

    def __init__(
        self,
        embed_dim,
        num_heads,
        num_heads_kv=None,
        cross_attn=False,
        qkv_proj_bias=True,
        out_proj_bias=True,
        dropout=0.0,
        softmax_scale=None,
        causal=False,
        layer_idx=None,
        dwconv=False,
        rotary_emb_dim=0,
        rotary_emb_base=10000.0,
        rotary_emb_scale_base=None,
        rotary_emb_interleaved=False,
396
        use_alibi=False,
397
        window_size=(-1, -1),
Tri Dao's avatar
Tri Dao committed
398
399
400
401
402
403
404
        fused_bias_fc=False,
        use_flash_attn=False,
        return_residual=False,
        checkpointing=False,
        device=None,
        dtype=None,
    ) -> None:
405
        """
Tri Dao's avatar
Tri Dao committed
406
407
408
409
        num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
        return_residual: whether to return the input x along with the output. This is for
            performance reason: for post-norm architecture, returning the input allows us
            to fuse the backward of nn.Linear with the residual connection.
410
        """
Tri Dao's avatar
Tri Dao committed
411
        factory_kwargs = {"device": device, "dtype": dtype}
412
413
414
415
        super().__init__()
        self.embed_dim = embed_dim
        self.cross_attn = cross_attn
        self.causal = causal
Tri Dao's avatar
Tri Dao committed
416
        self.layer_idx = layer_idx
417
418
        self.dwconv = dwconv
        self.rotary_emb_dim = rotary_emb_dim
Tri Dao's avatar
Tri Dao committed
419
        self.use_flash_attn = use_flash_attn
420
421
        self.return_residual = return_residual
        self.checkpointing = checkpointing
422
423
424
425
426
        if use_alibi:
            assert use_flash_attn, "ALiBi code path requires flash_attn"
            alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
        else:
            alibi_slopes = None
427
428
        if window_size != (-1, -1):
            assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
429
430

        self.num_heads = num_heads
Tri Dao's avatar
Tri Dao committed
431
        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
Tri Dao's avatar
Tri Dao committed
432
433
434
        assert (
            self.num_heads % self.num_heads_kv == 0
        ), "num_heads must be divisible by num_heads_kv"
Tri Dao's avatar
Tri Dao committed
435
        assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
436
        self.head_dim = self.embed_dim // num_heads
Tri Dao's avatar
Tri Dao committed
437
438
        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
        kv_dim = 2 * self.head_dim * self.num_heads_kv
439
440

        if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
441
442
443
444
445
446
447
448
449
            assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
            assert RotaryEmbedding is not None, "rotary_emb is not installed"
            self.rotary_emb = RotaryEmbedding(
                self.rotary_emb_dim,
                base=rotary_emb_base,
                scale_base=rotary_emb_scale_base,
                interleaved=rotary_emb_interleaved,
                device=device,
            )
450

Tri Dao's avatar
Tri Dao committed
451
        if fused_bias_fc and FusedDense is None:
Tri Dao's avatar
Tri Dao committed
452
            raise ImportError("fused_dense is not installed")
Tri Dao's avatar
Tri Dao committed
453
        linear_cls = nn.Linear if not fused_bias_fc else FusedDense
Tri Dao's avatar
Tri Dao committed
454
455
456
        linear_resid_cls = (
            LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
        )
Tri Dao's avatar
Tri Dao committed
457
        wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
458
        inner_attn_cls = (
459
            partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
460
461
462
463
            if use_flash_attn
            else SelfAttention
        )
        inner_cross_attn_cls = (
464
            partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
465
466
467
            if use_flash_attn
            else CrossAttention
        )
468
        if not self.cross_attn:
Tri Dao's avatar
Tri Dao committed
469
            self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
470
        else:
Tri Dao's avatar
Tri Dao committed
471
            self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
472
473
474
            self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
        if self.dwconv:
            if self.num_heads_kv == self.num_heads:
Tri Dao's avatar
Tri Dao committed
475
476
477
                self.dwconv_qkv = nn.Conv1d(
                    qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
                )
478
            else:
Tri Dao's avatar
Tri Dao committed
479
480
481
482
483
                self.dwconv_q = nn.Conv1d(
                    embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
                )
                self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
        self.inner_attn = inner_attn_cls(
484
485
486
            causal=causal,
            softmax_scale=softmax_scale,
            attention_dropout=dropout,
Tri Dao's avatar
Tri Dao committed
487
488
489
490
        )
        self.inner_cross_attn = inner_cross_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
Tri Dao's avatar
Tri Dao committed
491
        self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
492

493
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
494
495
        dtype = self.out_proj.weight.dtype if dtype is None else dtype
        device = self.out_proj.weight.device
496
497
498
499
500
501
502
503
504
        return torch.empty(
            batch_size,
            max_seqlen,
            2,
            self.num_heads_kv,
            self.head_dim,
            dtype=dtype,
            device=device,
        )
505

Tri Dao's avatar
Tri Dao committed
506
    def _update_kv_cache(self, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
507
508
509
        """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
        assert not self.dwconv, "Generation does not support dwconv yet"
        assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
510
        return _update_kv_cache(kv, inference_params, self.layer_idx)
Tri Dao's avatar
Tri Dao committed
511

512
    def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
513
        """
514
515
516
        Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
        q: (batch_size, seqlen_q, nheads, head_dim)
        kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
Tri Dao's avatar
Tri Dao committed
517
        """
518
        assert inference_params is not None and inference_params.seqlen_offset > 0
519
520
521
522
        assert self.use_flash_attn
        if self.rotary_emb_dim > 0:
            assert self.rotary_emb.scale is None, "This code path does not support xPos"
            self.rotary_emb._update_cos_sin_cache(
523
                inference_params.max_seqlen, device=q.device, dtype=q.dtype
524
525
526
527
528
529
530
531
532
            )
            rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
        else:
            rotary_cos, rotary_sin = None, None
        batch = q.shape[0]
        kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
        cache_seqlens = (
            inference_params.lengths_per_sample[:batch]
            if inference_params.lengths_per_sample is not None
533
            else inference_params.seqlen_offset
Tri Dao's avatar
Tri Dao committed
534
        )
535
        alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
536
537
538
539
540
541
542
543
544
545
546
547
        context = flash_attn_with_kvcache(
            q,
            kv_cache[:, :, 0],
            kv_cache[:, :, 1],
            kv[:, :, 0],
            kv[:, :, 1],
            rotary_cos=rotary_cos,
            rotary_sin=rotary_sin,
            cache_seqlens=cache_seqlens,
            softmax_scale=self.inner_cross_attn.softmax_scale,
            causal=self.inner_cross_attn.causal,
            rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
548
            alibi_slopes=alibi_slopes,
549
550
        )
        return context
Tri Dao's avatar
Tri Dao committed
551

552
    def _update_kvcache_attention(self, q, kv, inference_params):
553
        """Write kv to inference_params, then do attention"""
554
        if (
555
            inference_params.seqlen_offset == 0
556
557
558
            or flash_attn_with_kvcache is None
            or not self.use_flash_attn
        ):
559
            # TODO: this only uses seqlen_offset and not lengths_per_sample.
560
561
562
563
564
565
566
567
            kv = self._update_kv_cache(kv, inference_params)
            return self.inner_cross_attn(q, kv)
        else:
            batch = q.shape[0]
            kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
            cache_seqlens = (
                inference_params.lengths_per_sample[:batch]
                if inference_params.lengths_per_sample is not None
568
                else inference_params.seqlen_offset
569
            )
570
            alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
571
572
573
574
575
576
577
578
579
            return flash_attn_with_kvcache(
                q,
                kv_cache[:, :, 0],
                kv_cache[:, :, 1],
                kv[:, :, 0],
                kv[:, :, 1],
                cache_seqlens=cache_seqlens,
                softmax_scale=self.inner_cross_attn.softmax_scale,
                causal=self.inner_cross_attn.causal,
580
                alibi_slopes=alibi_slopes,
581
582
            )

Tri Dao's avatar
Tri Dao committed
583
584
585
586
587
588
589
590
591
592
593
    def forward(
        self,
        x,
        x_kv=None,
        key_padding_mask=None,
        cu_seqlens=None,
        max_seqlen=None,
        mixer_subset=None,
        inference_params=None,
        **kwargs,
    ):
594
595
        """
        Arguments:
Tri Dao's avatar
Tri Dao committed
596
597
598
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
                cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
                is the is the sum of the sequence lengths in the batch.
599
            x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
Tri Dao's avatar
Tri Dao committed
600
601
602
603
604
605
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into x. Only applicable when using
                FlashAttention.
            max_seqlen: int. Maximum sequence length in the batch.
            key_padding_mask: boolean mask, True means to keep, False means to mask out.
                (batch, seqlen). Only applicable when not using FlashAttention.
606
607
608
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
Tri Dao's avatar
Tri Dao committed
609
610
            inference_params: for generation. Adapted from Megatron-LM (and Apex)
            https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
611
        """
Tri Dao's avatar
Tri Dao committed
612
613
614
615
616
617
618
619
620
621
        if cu_seqlens is not None:
            assert max_seqlen is not None
            assert key_padding_mask is None
            assert self.use_flash_attn
            assert not self.dwconv
            assert self.rotary_emb_dim == 0
        if key_padding_mask is not None:
            assert cu_seqlens is None
            assert max_seqlen is None
            assert not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
622
623
624
625
        if inference_params is not None:
            assert key_padding_mask is None
            assert cu_seqlens is None and max_seqlen is None
            assert not self.dwconv
Tri Dao's avatar
Tri Dao committed
626

Tri Dao's avatar
Tri Dao committed
627
628
629
630
631
        kwargs = (
            {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
            if self.use_flash_attn
            else {"key_padding_mask": key_padding_mask, **kwargs}
        )
632
633
634
635
636
637
        seqlen_offset = (
            0
            if inference_params is None
            else (
                inference_params.lengths_per_sample
                if inference_params.lengths_per_sample is not None
638
                else inference_params.seqlen_offset
639
640
            )
        )
641
        rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
642
        batch, seqlen = x.shape[:2]
Tri Dao's avatar
Tri Dao committed
643
        if not self.cross_attn and self.num_heads_kv == self.num_heads:
644
            assert x_kv is None and mixer_subset is None
645
646
647
648
649
            if not self.return_residual:
                qkv = self.Wqkv(x)
            else:
                qkv, x = self.Wqkv(x)
            if self.dwconv:
Tri Dao's avatar
Tri Dao committed
650
651
652
                qkv = rearrange(
                    self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
653
            qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
Tri Dao's avatar
Tri Dao committed
654
655
            if (
                inference_params is None
656
                or inference_params.seqlen_offset == 0
657
658
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
659
            ):
Tri Dao's avatar
Tri Dao committed
660
                if self.rotary_emb_dim > 0:
661
662
663
                    qkv = self.rotary_emb(
                        qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
664
665
666
667
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_attn(qkv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
668
                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
Tri Dao's avatar
Tri Dao committed
669
                else:
670
671
672
                    context = self._update_kvcache_attention(
                        qkv[:, :, 0], qkv[:, :, 1:], inference_params
                    )
Tri Dao's avatar
Tri Dao committed
673
            else:
674
675
676
                context = self._apply_rotary_update_kvcache_attention(
                    qkv[:, :, 0], qkv[:, :, 1:], inference_params
                )
677
        else:
Tri Dao's avatar
Tri Dao committed
678
679
680
681
682
683
684
685
686
687
            if self.cross_attn:
                if not self.return_residual:
                    q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
                    kv = self.Wkv(x_kv if x_kv is not None else x)
                else:
                    if x_kv is not None:
                        kv, x_kv = self.Wkv(x_kv)
                    else:
                        kv, x = self.Wkv(x)
                    q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
688
            else:
Tri Dao's avatar
Tri Dao committed
689
690
691
                assert self.num_heads_kv != self.num_heads
                if not self.return_residual:
                    qkv = self.Wqkv(x)
692
                else:
Tri Dao's avatar
Tri Dao committed
693
                    qkv, x = self.Wqkv(x)
Tri Dao's avatar
Tri Dao committed
694
695
                q = qkv[..., : self.num_heads * self.head_dim]
                kv = qkv[..., self.num_heads * self.head_dim :]
696
697
            q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
            kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
698
            if self.dwconv:
Tri Dao's avatar
Tri Dao committed
699
700
701
702
703
704
705
706
                q = rearrange(
                    self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
                kv = rearrange(
                    self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
            if (
                inference_params is None
707
                or inference_params.seqlen_offset == 0
708
709
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
710
            ):
Tri Dao's avatar
Tri Dao committed
711
                if self.rotary_emb_dim > 0:
712
713
714
                    q, kv = self.rotary_emb(
                        q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
715
716
717
718
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_cross_attn(q, kv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
719
720
721
                        context = torch.utils.checkpoint.checkpoint(
                            self.inner_cross_attn, q, kv, **kwargs
                        )
Tri Dao's avatar
Tri Dao committed
722
                else:
723
                    context = self._update_kvcache_attention(q, kv, inference_params)
724
            else:
725
                context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
726
        out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
727
        return out if not self.return_residual else (out, x)
Tri Dao's avatar
Tri Dao committed
728
729
730


class ParallelMHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
    """Multi-head self-attention and cross-attention"""

    def __init__(
        self,
        embed_dim,
        num_heads,
        process_group,
        num_heads_kv=None,
        qkv_proj_bias=True,
        out_proj_bias=True,
        dropout=0.0,
        softmax_scale=None,
        causal=False,
        layer_idx=None,
        rotary_emb_dim=0,
        rotary_emb_base=10000.0,
        rotary_emb_scale_base=None,
        rotary_emb_interleaved=False,
749
        use_alibi=False,
750
        window_size=(-1, -1),
Tri Dao's avatar
Tri Dao committed
751
752
753
754
755
756
757
        use_flash_attn=False,
        checkpointing=False,
        sequence_parallel=True,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
Tri Dao's avatar
Tri Dao committed
758
759
760
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal
761
        self.layer_idx = layer_idx
Tri Dao's avatar
Tri Dao committed
762
763
764
        self.rotary_emb_dim = rotary_emb_dim
        self.use_flash_attn = use_flash_attn
        self.checkpointing = checkpointing
Tri Dao's avatar
Tri Dao committed
765
        self.process_group = process_group
766
767
        self.world_size = process_group.size()
        self.local_rank = torch.distributed.get_rank(process_group)
Tri Dao's avatar
Tri Dao committed
768
769

        self.num_heads = num_heads
770
771
        assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"

Tri Dao's avatar
Tri Dao committed
772
        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
Tri Dao's avatar
Tri Dao committed
773
774
775
        assert (
            self.num_heads % self.num_heads_kv == 0
        ), "num_heads must be divisible by num_heads_kv"
776

Tri Dao's avatar
Tri Dao committed
777
778
779
780
        self.num_heads_per_rank = get_dim_for_local_rank(
            self.num_heads, self.world_size, self.local_rank
        )
        self.num_heads_kv_per_rank = get_dim_for_local_rank(
781
            self.num_heads_kv, self.world_size, self.local_rank
Tri Dao's avatar
Tri Dao committed
782
        )
Tri Dao's avatar
Tri Dao committed
783
        self.head_dim = self.embed_dim // num_heads
Tri Dao's avatar
Tri Dao committed
784
        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
Tri Dao's avatar
Tri Dao committed
785

786
787
788
789
790
791
792
793
794
795
796
        if use_alibi:
            assert use_flash_attn, "ALiBi code path requires flash_attn"
            num_heads_local = math.ceil(self.num_heads / self.world_size)
            alibi_slopes = torch.tensor(
                get_alibi_slopes(num_heads)[
                    self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
                ],
                device=device,
            )
        else:
            alibi_slopes = None
797
798
        if window_size != (-1, -1):
            assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
799

Tri Dao's avatar
Tri Dao committed
800
        if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
801
802
803
804
805
806
807
808
            assert RotaryEmbedding is not None, "rotary_emb is not installed"
            self.rotary_emb = RotaryEmbedding(
                self.rotary_emb_dim,
                base=rotary_emb_base,
                scale_base=rotary_emb_scale_base,
                interleaved=rotary_emb_interleaved,
                device=device,
            )
Tri Dao's avatar
Tri Dao committed
809
810

        if ColumnParallelLinear is None or RowParallelLinear is None:
Tri Dao's avatar
Tri Dao committed
811
812
813
814
815
816
817
            raise ImportError("fused_dense is not installed")
        self.Wqkv = ColumnParallelLinear(
            embed_dim,
            qkv_dim,
            process_group,
            bias=qkv_proj_bias,
            sequence_parallel=sequence_parallel,
818
            multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
Tri Dao's avatar
Tri Dao committed
819
820
            **factory_kwargs,
        )
821
        inner_attn_cls = (
822
            partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
823
824
825
826
            if use_flash_attn
            else SelfAttention
        )
        inner_cross_attn_cls = (
827
            partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
828
829
830
            if use_flash_attn
            else CrossAttention
        )
Tri Dao's avatar
Tri Dao committed
831
832
833
834
835
836
837
838
839
840
841
842
        self.inner_attn = inner_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.inner_cross_attn = inner_cross_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            process_group,
            bias=out_proj_bias,
            sequence_parallel=sequence_parallel,
843
            multiple_of=self.head_dim,
Tri Dao's avatar
Tri Dao committed
844
845
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
846

847
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
Tri Dao's avatar
Tri Dao committed
848
849
        dtype = self.out_proj.weight.dtype if dtype is None else dtype
        device = self.out_proj.weight.device
850
851
852
853
854
855
856
857
858
        return torch.empty(
            batch_size,
            max_seqlen,
            2,
            self.num_heads_kv_per_rank,
            self.head_dim,
            dtype=dtype,
            device=device,
        )
Tri Dao's avatar
Tri Dao committed
859
860

    def _update_kv_cache(self, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
861
862
        """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
        assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
Tri Dao's avatar
Tri Dao committed
863
864
        return _update_kv_cache(kv, inference_params, self.layer_idx)

865
    def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
866
        """
867
868
869
        Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
        q: (batch_size, seqlen_q, nheads, head_dim)
        kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
Tri Dao's avatar
Tri Dao committed
870
        """
871
        assert inference_params is not None and inference_params.seqlen_offset > 0
872
873
874
875
        assert self.use_flash_attn
        if self.rotary_emb_dim > 0:
            assert self.rotary_emb.scale is None, "This code path does not support xPos"
            self.rotary_emb._update_cos_sin_cache(
876
                inference_params.max_seqlen, device=q.device, dtype=q.dtype
877
878
879
880
881
882
883
884
885
            )
            rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
        else:
            rotary_cos, rotary_sin = None, None
        batch = q.shape[0]
        kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
        cache_seqlens = (
            inference_params.lengths_per_sample[:batch]
            if inference_params.lengths_per_sample is not None
886
            else inference_params.seqlen_offset
887
        )
888
        alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
889
890
891
892
893
894
895
896
897
898
899
900
        context = flash_attn_with_kvcache(
            q,
            kv_cache[:, :, 0],
            kv_cache[:, :, 1],
            kv[:, :, 0],
            kv[:, :, 1],
            rotary_cos=rotary_cos,
            rotary_sin=rotary_sin,
            cache_seqlens=cache_seqlens,
            softmax_scale=self.inner_cross_attn.softmax_scale,
            causal=self.inner_cross_attn.causal,
            rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
901
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
902
        )
903
        return context
Tri Dao's avatar
Tri Dao committed
904

905
    def _update_kvcache_attention(self, q, kv, inference_params):
906
        """Write kv to inference_params, then do attention"""
907
908
        if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
            # TODO: this only uses seqlen_offset and not lengths_per_sample.
909
910
911
912
913
914
915
916
            kv = self._update_kv_cache(kv, inference_params)
            return self.inner_cross_attn(q, kv)
        else:
            batch = q.shape[0]
            kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
            cache_seqlens = (
                inference_params.lengths_per_sample[:batch]
                if inference_params.lengths_per_sample is not None
917
                else inference_params.seqlen_offset
918
            )
919
            alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
920
921
922
923
924
925
926
927
928
            context = flash_attn_with_kvcache(
                q,
                kv_cache[:, :, 0],
                kv_cache[:, :, 1],
                kv[:, :, 0],
                kv[:, :, 1],
                cache_seqlens=cache_seqlens,
                softmax_scale=self.inner_cross_attn.softmax_scale,
                causal=self.inner_cross_attn.causal,
929
                alibi_slopes=alibi_slopes,
930
931
932
            )
            return context

933
    def forward(self, x, seqlen=None, inference_params=None, **kwargs):
Tri Dao's avatar
Tri Dao committed
934
935
936
937
938
939
940
941
        """
        Arguments:
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
                If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
                split x during sequence parallel, we split the batch * seqlen dimension
                (in case batch is small).
        """
        qkv = self.Wqkv(x)
Tri Dao's avatar
Tri Dao committed
942
943
        if seqlen is not None:
            qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
944
945
946
947
948
949
        seqlen_offset = (
            0
            if inference_params is None
            else (
                inference_params.lengths_per_sample
                if inference_params.lengths_per_sample is not None
950
                else inference_params.seqlen_offset
951
952
            )
        )
953
        rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
Tri Dao's avatar
Tri Dao committed
954
        if self.num_heads_kv == self.num_heads:
Tri Dao's avatar
Tri Dao committed
955
956
957
            qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
            if (
                inference_params is None
958
                or inference_params.seqlen_offset == 0
959
960
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
961
            ):
Tri Dao's avatar
Tri Dao committed
962
                if self.rotary_emb_dim > 0:
963
964
965
                    qkv = self.rotary_emb(
                        qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
966
967
968
969
970
971
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_attn(qkv, **kwargs)
                    else:
                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
                else:
972
973
974
                    context = self._update_kvcache_attention(
                        qkv[:, :, 0], qkv[:, :, 1:], inference_params
                    )
975
            else:
976
977
978
                context = self._apply_rotary_update_kvcache_attention(
                    qkv[:, :, 0], qkv[:, :, 1:], inference_params
                )
Tri Dao's avatar
Tri Dao committed
979
        else:
Tri Dao's avatar
Tri Dao committed
980
981
982
983
984
985
986
987
988
989
990
991
992
            q = rearrange(
                qkv[..., : self.num_heads_per_rank * self.head_dim],
                "... (h d) -> ... h d",
                d=self.head_dim,
            )
            kv = rearrange(
                qkv[..., self.num_heads_per_rank * self.head_dim :],
                "... (two hkv d) -> ... two hkv d",
                two=2,
                d=self.head_dim,
            )
            if (
                inference_params is None
993
                or inference_params.seqlen_offset == 0
994
995
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
996
            ):
997
                if self.rotary_emb_dim > 0:
998
999
1000
                    q, kv = self.rotary_emb(
                        q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
1001
1002
1003
1004
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_cross_attn(q, kv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
1005
1006
1007
                        context = torch.utils.checkpoint.checkpoint(
                            self.inner_cross_attn, q, kv, **kwargs
                        )
Tri Dao's avatar
Tri Dao committed
1008
                else:
1009
                    context = self._update_kvcache_attention(q, kv, inference_params)
1010
            else:
1011
                context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
Tri Dao's avatar
Tri Dao committed
1012
        context = rearrange(context, "b s h d -> b s (h d)")
Tri Dao's avatar
Tri Dao committed
1013
        if seqlen is not None:
Tri Dao's avatar
Tri Dao committed
1014
            context = rearrange(context, "b s d -> (b s) d")
Tri Dao's avatar
Tri Dao committed
1015
1016
        out = self.out_proj(context)
        return out