mha.py 42.3 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2023, Tri Dao.
2
3
4
5
6
7

import math
from functools import partial

import torch
import torch.nn as nn
Tri Dao's avatar
Tri Dao committed
8
from einops import rearrange, repeat
9

10
11
from flash_attn.utils.distributed import get_dim_for_local_rank

12
try:
Tri Dao's avatar
Tri Dao committed
13
14
15
16
17
    from flash_attn import (
        flash_attn_kvpacked_func,
        flash_attn_qkvpacked_func,
        flash_attn_varlen_kvpacked_func,
        flash_attn_varlen_qkvpacked_func,
18
        flash_attn_with_kvcache,
Tri Dao's avatar
Tri Dao committed
19
    )
20
except ImportError:
Tri Dao's avatar
Tri Dao committed
21
    flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
22
    flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
23
    flash_attn_with_kvcache = None
24
25

try:
Tri Dao's avatar
Tri Dao committed
26
    from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
27
except ImportError:
Tri Dao's avatar
Tri Dao committed
28
    FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
29
30
31
32
33
34
35

try:
    from flash_attn.layers.rotary import RotaryEmbedding
except ImportError:
    RotaryEmbedding = None


36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
def get_alibi_slopes(nheads):
    def get_slopes_power_of_2(nheads):
        start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))
        ratio = start
        return [start * ratio**i for i in range(nheads)]

    if math.log2(nheads).is_integer():
        return get_slopes_power_of_2(nheads)
    else:
        closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
        return (
            get_slopes_power_of_2(closest_power_of_2)
            + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
        )


53
54
55
56
57
58
59
60
61
62
class FlashSelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
63

64
65
66
67
68
69
70
71
72
    def __init__(
        self,
        causal=False,
        softmax_scale=None,
        attention_dropout=0.0,
        window_size=(-1, -1),
        alibi_slopes=None,
        deterministic=False,
    ):
73
        super().__init__()
Tri Dao's avatar
Tri Dao committed
74
75
        assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
        assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
76
77
        self.causal = causal
        self.softmax_scale = softmax_scale
78
        self.drop = nn.Dropout(attention_dropout)
79
        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
80
        self.window_size = window_size
81
        self.deterministic = deterministic
82

Tri Dao's avatar
Tri Dao committed
83
    def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
84
85
86
        """Implements the multihead softmax attention.
        Arguments
        ---------
Tri Dao's avatar
Tri Dao committed
87
88
89
90
            qkv: The tensor containing the query, key, and value.
                If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
                If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
                (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
Tri Dao's avatar
Tri Dao committed
91
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
92
93
94
95
96
97
98
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into qkv.
            max_seqlen: int. Maximum sequence length in the batch.
        Returns:
        --------
            out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
                else (B, S, H, D).
99
100
101
        """
        assert qkv.dtype in [torch.float16, torch.bfloat16]
        assert qkv.is_cuda
Tri Dao's avatar
Tri Dao committed
102
        causal = self.causal if causal is None else causal
Tri Dao's avatar
Tri Dao committed
103
        unpadded = cu_seqlens is not None
104
105
        if self.alibi_slopes is not None:
            self.alibi_slopes = self.alibi_slopes.to(torch.float32)
Tri Dao's avatar
Tri Dao committed
106
107
108
109
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
Tri Dao's avatar
Tri Dao committed
110
            return flash_attn_varlen_qkvpacked_func(
Tri Dao's avatar
Tri Dao committed
111
112
113
114
115
116
                qkv,
                cu_seqlens,
                max_seqlen,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
117
                alibi_slopes=self.alibi_slopes,
118
                window_size=self.window_size,
119
                deterministic=self.deterministic,
120
            )
Tri Dao's avatar
Tri Dao committed
121
        else:
Tri Dao's avatar
Tri Dao committed
122
123
124
125
126
            return flash_attn_qkvpacked_func(
                qkv,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
127
                alibi_slopes=self.alibi_slopes,
128
                window_size=self.window_size,
129
                deterministic=self.deterministic,
Tri Dao's avatar
Tri Dao committed
130
            )
131
132
133
134
135
136
137
138
139
140
141
142


class FlashCrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
143

144
145
146
147
148
149
150
151
152
    def __init__(
        self,
        causal=False,
        softmax_scale=None,
        attention_dropout=0.0,
        alibi_slopes=None,
        window_size=(-1, -1),
        deterministic=False,
    ):
153
        super().__init__()
Tri Dao's avatar
Tri Dao committed
154
155
        assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
        assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
156
157
        self.causal = causal
        self.softmax_scale = softmax_scale
158
        self.drop = nn.Dropout(attention_dropout)
159
        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
160
        self.window_size = window_size
161
        self.deterministic = deterministic
162

Tri Dao's avatar
Tri Dao committed
163
164
165
166
167
168
169
170
171
172
    def forward(
        self,
        q,
        kv,
        causal=None,
        cu_seqlens=None,
        max_seqlen=None,
        cu_seqlens_k=None,
        max_seqlen_k=None,
    ):
173
174
175
176
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
Tri Dao's avatar
Tri Dao committed
177
            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
Tri Dao's avatar
Tri Dao committed
178
            causal: if passed, will override self.causal
179
180
181
182
183
184
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into q.
            max_seqlen: int. Maximum sequence length in the batch of q.
            cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into kv.
            max_seqlen_k: int. Maximum sequence length in the batch of k and v.
185
186
187
        """
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.is_cuda and kv.is_cuda
Tri Dao's avatar
Tri Dao committed
188
        causal = self.causal if causal is None else causal
189
        unpadded = cu_seqlens is not None
190
191
        if self.alibi_slopes is not None:
            self.alibi_slopes = self.alibi_slopes.to(torch.float32)
192
193
194
195
196
197
198
199
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
            assert cu_seqlens_k is not None
            assert cu_seqlens_k.dtype == torch.int32
            assert max_seqlen_k is not None
            assert isinstance(max_seqlen, int)
Tri Dao's avatar
Tri Dao committed
200
            return flash_attn_varlen_kvpacked_func(
Tri Dao's avatar
Tri Dao committed
201
202
203
204
205
206
                q,
                kv,
                cu_seqlens,
                cu_seqlens_k,
                max_seqlen,
                max_seqlen_k,
207
                self.drop.p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
208
209
                softmax_scale=self.softmax_scale,
                causal=causal,
210
                alibi_slopes=self.alibi_slopes,
211
                window_size=self.window_size,
212
                deterministic=self.deterministic,
213
            )
214
215
216
        else:
            batch_size, seqlen_q = q.shape[0], q.shape[1]
            seqlen_k = kv.shape[1]
Tri Dao's avatar
Tri Dao committed
217
            assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
Tri Dao's avatar
Tri Dao committed
218
219
220
221
222
223
            return flash_attn_kvpacked_func(
                q,
                kv,
                self.drop.p if self.training else 0.0,
                causal=causal,
                softmax_scale=self.softmax_scale,
224
                alibi_slopes=self.alibi_slopes,
225
                window_size=self.window_size,
226
                deterministic=self.deterministic,
Tri Dao's avatar
Tri Dao committed
227
            )
228
229
230
231
232
233
234
235
236
237
238
239


class SelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
240

Tri Dao's avatar
Tri Dao committed
241
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
242
243
244
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
245
        self.drop = nn.Dropout(attention_dropout)
246

Tri Dao's avatar
Tri Dao committed
247
    def forward(self, qkv, causal=None, key_padding_mask=None):
248
249
250
251
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
Tri Dao's avatar
Tri Dao committed
252
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
253
254
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, S)
255
256
        """
        batch_size, seqlen = qkv.shape[0], qkv.shape[1]
Tri Dao's avatar
Tri Dao committed
257
        causal = self.causal if causal is None else causal
258
259
        q, k, v = qkv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
Tri Dao's avatar
Tri Dao committed
260
        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
Tri Dao's avatar
Tri Dao committed
261
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
262
263
264
            padding_mask = torch.full(
                (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
            )
Tri Dao's avatar
Tri Dao committed
265
266
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
Tri Dao's avatar
Tri Dao committed
267
            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
Tri Dao's avatar
Tri Dao committed
268
        if causal:
269
270
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
Tri Dao's avatar
Tri Dao committed
271
272
273
            causal_mask = torch.triu(
                torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
            )
274
275
276
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + causal_mask.to(dtype=scores.dtype)
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
277
        attention_drop = self.drop(attention)
Tri Dao's avatar
Tri Dao committed
278
        output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
279
280
281
282
283
284
285
286
287
288
289
290
291
        return output


class CrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
292

Tri Dao's avatar
Tri Dao committed
293
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
294
295
296
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
297
        self.drop = nn.Dropout(attention_dropout)
298

Tri Dao's avatar
Tri Dao committed
299
    def forward(self, q, kv, causal=None, key_padding_mask=None):
300
301
302
303
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
Tri Dao's avatar
Tri Dao committed
304
            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
Tri Dao's avatar
Tri Dao committed
305
            causal: if passed, will override self.causal
306
307
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, Sk)
308
309
        """
        batch_size, seqlen_q = q.shape[0], q.shape[1]
Tri Dao's avatar
Tri Dao committed
310
        causal = self.causal if causal is None else causal
311
        seqlen_k = kv.shape[1]
Tri Dao's avatar
Tri Dao committed
312
313
314
        assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
        if kv.shape[3] != q.shape[2]:  # MQA/GQA
            kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
315
316
        k, v = kv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
Tri Dao's avatar
Tri Dao committed
317
        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
318
        if key_padding_mask is not None:
Tri Dao's avatar
Tri Dao committed
319
320
321
            padding_mask = torch.full(
                (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
            )
322
323
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
Tri Dao's avatar
Tri Dao committed
324
            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
Tri Dao's avatar
Tri Dao committed
325
        if causal:
326
327
328
            # causal mask needs to take into account the difference between seqlen_q and seqlen_k
            row_idx = rearrange(
                torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
Tri Dao's avatar
Tri Dao committed
329
            )
330
331
332
333
334
335
336
337
            col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
            sk = (
                seqlen_k
                if key_padding_mask is None
                else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
            )
            causal_mask = col_idx > row_idx + sk - seqlen_q
            scores = scores.masked_fill(causal_mask, -10000.0)
338
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
339
        attention_drop = self.drop(attention)
Tri Dao's avatar
Tri Dao committed
340
        output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
341
342
343
344
        return output


class LinearResidual(nn.Linear):
Tri Dao's avatar
Tri Dao committed
345
    """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
346
347
348
349
350

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return super().forward(input), input


351
def _update_kv_cache(kv, inference_params, layer_idx):
Tri Dao's avatar
Tri Dao committed
352
    """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
353
354
355
356
    # Pre-allocate memory for key-values for inference.
    num_heads, head_dim = kv.shape[-2:]
    if layer_idx not in inference_params.key_value_memory_dict:
        kv_cache = torch.empty(
Tri Dao's avatar
Tri Dao committed
357
            inference_params.max_batch_size,
358
            inference_params.max_seqlen,
Tri Dao's avatar
Tri Dao committed
359
360
361
362
363
            2,
            num_heads,
            head_dim,
            dtype=kv.dtype,
            device=kv.device,
364
365
366
        )
        inference_params.key_value_memory_dict[layer_idx] = kv_cache
    else:
367
        kv_cache = inference_params.key_value_memory_dict[layer_idx]
368
369
370
    # Adjust key and value for inference
    batch_start = inference_params.batch_size_offset
    batch_end = batch_start + kv.shape[0]
371
    sequence_start = inference_params.seqlen_offset
372
    sequence_end = sequence_start + kv.shape[1]
373
374
    assert batch_end <= kv_cache.shape[0]
    assert sequence_end <= kv_cache.shape[1]
375
376
377
    assert kv_cache is not None
    kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
    return kv_cache[batch_start:batch_end, :sequence_end, ...]
Tri Dao's avatar
Tri Dao committed
378
379


380
class MHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
    """Multi-head self-attention and cross-attention"""

    def __init__(
        self,
        embed_dim,
        num_heads,
        num_heads_kv=None,
        cross_attn=False,
        qkv_proj_bias=True,
        out_proj_bias=True,
        dropout=0.0,
        softmax_scale=None,
        causal=False,
        layer_idx=None,
        dwconv=False,
        rotary_emb_dim=0,
        rotary_emb_base=10000.0,
        rotary_emb_scale_base=None,
        rotary_emb_interleaved=False,
400
        use_alibi=False,
401
        window_size=(-1, -1),
Tri Dao's avatar
Tri Dao committed
402
403
404
405
406
407
408
        fused_bias_fc=False,
        use_flash_attn=False,
        return_residual=False,
        checkpointing=False,
        device=None,
        dtype=None,
    ) -> None:
409
        """
Tri Dao's avatar
Tri Dao committed
410
411
412
413
        num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
        return_residual: whether to return the input x along with the output. This is for
            performance reason: for post-norm architecture, returning the input allows us
            to fuse the backward of nn.Linear with the residual connection.
414
        """
Tri Dao's avatar
Tri Dao committed
415
        factory_kwargs = {"device": device, "dtype": dtype}
416
417
418
419
        super().__init__()
        self.embed_dim = embed_dim
        self.cross_attn = cross_attn
        self.causal = causal
Tri Dao's avatar
Tri Dao committed
420
        self.layer_idx = layer_idx
421
422
        self.dwconv = dwconv
        self.rotary_emb_dim = rotary_emb_dim
Tri Dao's avatar
Tri Dao committed
423
        self.use_flash_attn = use_flash_attn
424
425
        self.return_residual = return_residual
        self.checkpointing = checkpointing
426
427
428
429
430
        if use_alibi:
            assert use_flash_attn, "ALiBi code path requires flash_attn"
            alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
        else:
            alibi_slopes = None
431
432
        if window_size != (-1, -1):
            assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
433
434

        self.num_heads = num_heads
Tri Dao's avatar
Tri Dao committed
435
        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
Tri Dao's avatar
Tri Dao committed
436
437
438
        assert (
            self.num_heads % self.num_heads_kv == 0
        ), "num_heads must be divisible by num_heads_kv"
Tri Dao's avatar
Tri Dao committed
439
        assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
440
        self.head_dim = self.embed_dim // num_heads
Tri Dao's avatar
Tri Dao committed
441
442
        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
        kv_dim = 2 * self.head_dim * self.num_heads_kv
443
444

        if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
445
446
447
448
449
450
451
452
453
            assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
            assert RotaryEmbedding is not None, "rotary_emb is not installed"
            self.rotary_emb = RotaryEmbedding(
                self.rotary_emb_dim,
                base=rotary_emb_base,
                scale_base=rotary_emb_scale_base,
                interleaved=rotary_emb_interleaved,
                device=device,
            )
454

Tri Dao's avatar
Tri Dao committed
455
        if fused_bias_fc and FusedDense is None:
Tri Dao's avatar
Tri Dao committed
456
            raise ImportError("fused_dense is not installed")
Tri Dao's avatar
Tri Dao committed
457
        linear_cls = nn.Linear if not fused_bias_fc else FusedDense
Tri Dao's avatar
Tri Dao committed
458
459
460
        linear_resid_cls = (
            LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
        )
Tri Dao's avatar
Tri Dao committed
461
        wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
462
        inner_attn_cls = (
463
            partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
464
465
466
467
            if use_flash_attn
            else SelfAttention
        )
        inner_cross_attn_cls = (
468
            partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
469
470
471
            if use_flash_attn
            else CrossAttention
        )
472
        if not self.cross_attn:
Tri Dao's avatar
Tri Dao committed
473
            self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
474
        else:
Tri Dao's avatar
Tri Dao committed
475
            self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
476
477
478
            self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
        if self.dwconv:
            if self.num_heads_kv == self.num_heads:
Tri Dao's avatar
Tri Dao committed
479
480
481
                self.dwconv_qkv = nn.Conv1d(
                    qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
                )
482
            else:
Tri Dao's avatar
Tri Dao committed
483
484
485
486
487
                self.dwconv_q = nn.Conv1d(
                    embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
                )
                self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
        self.inner_attn = inner_attn_cls(
488
489
490
            causal=causal,
            softmax_scale=softmax_scale,
            attention_dropout=dropout,
Tri Dao's avatar
Tri Dao committed
491
492
493
494
        )
        self.inner_cross_attn = inner_cross_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
Tri Dao's avatar
Tri Dao committed
495
        self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
496

497
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
498
499
        dtype = self.out_proj.weight.dtype if dtype is None else dtype
        device = self.out_proj.weight.device
500
501
502
503
504
505
506
507
508
        return torch.empty(
            batch_size,
            max_seqlen,
            2,
            self.num_heads_kv,
            self.head_dim,
            dtype=dtype,
            device=device,
        )
509

Tri Dao's avatar
Tri Dao committed
510
    def _update_kv_cache(self, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
511
512
513
        """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
        assert not self.dwconv, "Generation does not support dwconv yet"
        assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
514
        return _update_kv_cache(kv, inference_params, self.layer_idx)
Tri Dao's avatar
Tri Dao committed
515

516
    def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
517
        """
518
519
520
        Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
        q: (batch_size, seqlen_q, nheads, head_dim)
        kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
Tri Dao's avatar
Tri Dao committed
521
        """
522
        assert inference_params is not None and inference_params.seqlen_offset > 0
523
524
525
526
        assert self.use_flash_attn
        if self.rotary_emb_dim > 0:
            assert self.rotary_emb.scale is None, "This code path does not support xPos"
            self.rotary_emb._update_cos_sin_cache(
527
                inference_params.max_seqlen, device=q.device, dtype=q.dtype
528
529
530
531
532
533
534
535
536
            )
            rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
        else:
            rotary_cos, rotary_sin = None, None
        batch = q.shape[0]
        kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
        cache_seqlens = (
            inference_params.lengths_per_sample[:batch]
            if inference_params.lengths_per_sample is not None
537
            else inference_params.seqlen_offset
Tri Dao's avatar
Tri Dao committed
538
        )
539
        alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
540
541
542
543
544
545
546
547
548
549
550
551
        context = flash_attn_with_kvcache(
            q,
            kv_cache[:, :, 0],
            kv_cache[:, :, 1],
            kv[:, :, 0],
            kv[:, :, 1],
            rotary_cos=rotary_cos,
            rotary_sin=rotary_sin,
            cache_seqlens=cache_seqlens,
            softmax_scale=self.inner_cross_attn.softmax_scale,
            causal=self.inner_cross_attn.causal,
            rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
552
            alibi_slopes=alibi_slopes,
553
554
        )
        return context
Tri Dao's avatar
Tri Dao committed
555

556
    def _update_kvcache_attention(self, q, kv, inference_params):
557
        """Write kv to inference_params, then do attention"""
558
        if (
559
            inference_params.seqlen_offset == 0
560
561
562
            or flash_attn_with_kvcache is None
            or not self.use_flash_attn
        ):
563
            # TODO: this only uses seqlen_offset and not lengths_per_sample.
564
565
566
567
568
569
570
571
            kv = self._update_kv_cache(kv, inference_params)
            return self.inner_cross_attn(q, kv)
        else:
            batch = q.shape[0]
            kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
            cache_seqlens = (
                inference_params.lengths_per_sample[:batch]
                if inference_params.lengths_per_sample is not None
572
                else inference_params.seqlen_offset
573
            )
574
            alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
575
576
577
578
579
580
581
582
583
            return flash_attn_with_kvcache(
                q,
                kv_cache[:, :, 0],
                kv_cache[:, :, 1],
                kv[:, :, 0],
                kv[:, :, 1],
                cache_seqlens=cache_seqlens,
                softmax_scale=self.inner_cross_attn.softmax_scale,
                causal=self.inner_cross_attn.causal,
584
                alibi_slopes=alibi_slopes,
585
586
            )

Tri Dao's avatar
Tri Dao committed
587
588
589
590
591
592
593
594
595
596
597
    def forward(
        self,
        x,
        x_kv=None,
        key_padding_mask=None,
        cu_seqlens=None,
        max_seqlen=None,
        mixer_subset=None,
        inference_params=None,
        **kwargs,
    ):
598
599
        """
        Arguments:
Tri Dao's avatar
Tri Dao committed
600
601
602
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
                cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
                is the is the sum of the sequence lengths in the batch.
603
            x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
Tri Dao's avatar
Tri Dao committed
604
605
606
607
608
609
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into x. Only applicable when using
                FlashAttention.
            max_seqlen: int. Maximum sequence length in the batch.
            key_padding_mask: boolean mask, True means to keep, False means to mask out.
                (batch, seqlen). Only applicable when not using FlashAttention.
610
611
612
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
Tri Dao's avatar
Tri Dao committed
613
614
            inference_params: for generation. Adapted from Megatron-LM (and Apex)
            https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
615
        """
Tri Dao's avatar
Tri Dao committed
616
617
618
619
620
621
622
623
624
625
        if cu_seqlens is not None:
            assert max_seqlen is not None
            assert key_padding_mask is None
            assert self.use_flash_attn
            assert not self.dwconv
            assert self.rotary_emb_dim == 0
        if key_padding_mask is not None:
            assert cu_seqlens is None
            assert max_seqlen is None
            assert not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
626
627
628
629
        if inference_params is not None:
            assert key_padding_mask is None
            assert cu_seqlens is None and max_seqlen is None
            assert not self.dwconv
Tri Dao's avatar
Tri Dao committed
630

Tri Dao's avatar
Tri Dao committed
631
632
633
634
635
        kwargs = (
            {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
            if self.use_flash_attn
            else {"key_padding_mask": key_padding_mask, **kwargs}
        )
636
637
638
639
640
641
        seqlen_offset = (
            0
            if inference_params is None
            else (
                inference_params.lengths_per_sample
                if inference_params.lengths_per_sample is not None
642
                else inference_params.seqlen_offset
643
644
            )
        )
645
        rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
646
        batch, seqlen = x.shape[:2]
Tri Dao's avatar
Tri Dao committed
647
        if not self.cross_attn and self.num_heads_kv == self.num_heads:
648
            assert x_kv is None and mixer_subset is None
649
650
651
652
653
            if not self.return_residual:
                qkv = self.Wqkv(x)
            else:
                qkv, x = self.Wqkv(x)
            if self.dwconv:
Tri Dao's avatar
Tri Dao committed
654
655
656
                qkv = rearrange(
                    self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
657
            qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
Tri Dao's avatar
Tri Dao committed
658
659
            if (
                inference_params is None
660
                or inference_params.seqlen_offset == 0
661
662
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
663
            ):
Tri Dao's avatar
Tri Dao committed
664
                if self.rotary_emb_dim > 0:
665
666
667
                    qkv = self.rotary_emb(
                        qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
668
669
670
671
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_attn(qkv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
672
                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
Tri Dao's avatar
Tri Dao committed
673
                else:
674
675
676
                    context = self._update_kvcache_attention(
                        qkv[:, :, 0], qkv[:, :, 1:], inference_params
                    )
Tri Dao's avatar
Tri Dao committed
677
            else:
678
679
680
                context = self._apply_rotary_update_kvcache_attention(
                    qkv[:, :, 0], qkv[:, :, 1:], inference_params
                )
681
        else:
Tri Dao's avatar
Tri Dao committed
682
683
684
685
686
687
688
689
690
691
            if self.cross_attn:
                if not self.return_residual:
                    q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
                    kv = self.Wkv(x_kv if x_kv is not None else x)
                else:
                    if x_kv is not None:
                        kv, x_kv = self.Wkv(x_kv)
                    else:
                        kv, x = self.Wkv(x)
                    q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
692
            else:
Tri Dao's avatar
Tri Dao committed
693
694
695
                assert self.num_heads_kv != self.num_heads
                if not self.return_residual:
                    qkv = self.Wqkv(x)
696
                else:
Tri Dao's avatar
Tri Dao committed
697
                    qkv, x = self.Wqkv(x)
Tri Dao's avatar
Tri Dao committed
698
699
                q = qkv[..., : self.num_heads * self.head_dim]
                kv = qkv[..., self.num_heads * self.head_dim :]
700
701
            q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
            kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
702
            if self.dwconv:
Tri Dao's avatar
Tri Dao committed
703
704
705
706
707
708
709
710
                q = rearrange(
                    self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
                kv = rearrange(
                    self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
                ).contiguous()
            if (
                inference_params is None
711
                or inference_params.seqlen_offset == 0
712
713
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
714
            ):
Tri Dao's avatar
Tri Dao committed
715
                if self.rotary_emb_dim > 0:
716
717
718
                    q, kv = self.rotary_emb(
                        q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
719
720
721
722
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_cross_attn(q, kv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
723
724
725
                        context = torch.utils.checkpoint.checkpoint(
                            self.inner_cross_attn, q, kv, **kwargs
                        )
Tri Dao's avatar
Tri Dao committed
726
                else:
727
                    context = self._update_kvcache_attention(q, kv, inference_params)
728
            else:
729
                context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
730
        out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
731
        return out if not self.return_residual else (out, x)
Tri Dao's avatar
Tri Dao committed
732
733
734


class ParallelMHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
    """Multi-head self-attention and cross-attention"""

    def __init__(
        self,
        embed_dim,
        num_heads,
        process_group,
        num_heads_kv=None,
        qkv_proj_bias=True,
        out_proj_bias=True,
        dropout=0.0,
        softmax_scale=None,
        causal=False,
        layer_idx=None,
        rotary_emb_dim=0,
        rotary_emb_base=10000.0,
        rotary_emb_scale_base=None,
        rotary_emb_interleaved=False,
753
        use_alibi=False,
754
        window_size=(-1, -1),
Tri Dao's avatar
Tri Dao committed
755
756
757
758
759
760
761
        use_flash_attn=False,
        checkpointing=False,
        sequence_parallel=True,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
Tri Dao's avatar
Tri Dao committed
762
763
764
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal
765
        self.layer_idx = layer_idx
Tri Dao's avatar
Tri Dao committed
766
767
768
        self.rotary_emb_dim = rotary_emb_dim
        self.use_flash_attn = use_flash_attn
        self.checkpointing = checkpointing
Tri Dao's avatar
Tri Dao committed
769
        self.process_group = process_group
770
771
        self.world_size = process_group.size()
        self.local_rank = torch.distributed.get_rank(process_group)
Tri Dao's avatar
Tri Dao committed
772
773

        self.num_heads = num_heads
774
775
        assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"

Tri Dao's avatar
Tri Dao committed
776
        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
Tri Dao's avatar
Tri Dao committed
777
778
779
        assert (
            self.num_heads % self.num_heads_kv == 0
        ), "num_heads must be divisible by num_heads_kv"
780

Tri Dao's avatar
Tri Dao committed
781
782
783
784
        self.num_heads_per_rank = get_dim_for_local_rank(
            self.num_heads, self.world_size, self.local_rank
        )
        self.num_heads_kv_per_rank = get_dim_for_local_rank(
785
            self.num_heads_kv, self.world_size, self.local_rank
Tri Dao's avatar
Tri Dao committed
786
        )
Tri Dao's avatar
Tri Dao committed
787
        self.head_dim = self.embed_dim // num_heads
Tri Dao's avatar
Tri Dao committed
788
        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
Tri Dao's avatar
Tri Dao committed
789

790
791
792
793
794
795
796
797
798
799
800
        if use_alibi:
            assert use_flash_attn, "ALiBi code path requires flash_attn"
            num_heads_local = math.ceil(self.num_heads / self.world_size)
            alibi_slopes = torch.tensor(
                get_alibi_slopes(num_heads)[
                    self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
                ],
                device=device,
            )
        else:
            alibi_slopes = None
801
802
        if window_size != (-1, -1):
            assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
803

Tri Dao's avatar
Tri Dao committed
804
        if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
805
806
807
808
809
810
811
812
            assert RotaryEmbedding is not None, "rotary_emb is not installed"
            self.rotary_emb = RotaryEmbedding(
                self.rotary_emb_dim,
                base=rotary_emb_base,
                scale_base=rotary_emb_scale_base,
                interleaved=rotary_emb_interleaved,
                device=device,
            )
Tri Dao's avatar
Tri Dao committed
813
814

        if ColumnParallelLinear is None or RowParallelLinear is None:
Tri Dao's avatar
Tri Dao committed
815
816
817
818
819
820
821
            raise ImportError("fused_dense is not installed")
        self.Wqkv = ColumnParallelLinear(
            embed_dim,
            qkv_dim,
            process_group,
            bias=qkv_proj_bias,
            sequence_parallel=sequence_parallel,
822
            multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
Tri Dao's avatar
Tri Dao committed
823
824
            **factory_kwargs,
        )
825
        inner_attn_cls = (
826
            partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
827
828
829
830
            if use_flash_attn
            else SelfAttention
        )
        inner_cross_attn_cls = (
831
            partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
832
833
834
            if use_flash_attn
            else CrossAttention
        )
Tri Dao's avatar
Tri Dao committed
835
836
837
838
839
840
841
842
843
844
845
846
        self.inner_attn = inner_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.inner_cross_attn = inner_cross_attn_cls(
            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            process_group,
            bias=out_proj_bias,
            sequence_parallel=sequence_parallel,
847
            multiple_of=self.head_dim,
Tri Dao's avatar
Tri Dao committed
848
849
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
850

851
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
Tri Dao's avatar
Tri Dao committed
852
853
        dtype = self.out_proj.weight.dtype if dtype is None else dtype
        device = self.out_proj.weight.device
854
855
856
857
858
859
860
861
862
        return torch.empty(
            batch_size,
            max_seqlen,
            2,
            self.num_heads_kv_per_rank,
            self.head_dim,
            dtype=dtype,
            device=device,
        )
Tri Dao's avatar
Tri Dao committed
863
864

    def _update_kv_cache(self, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
865
866
        """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
        assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
Tri Dao's avatar
Tri Dao committed
867
868
        return _update_kv_cache(kv, inference_params, self.layer_idx)

869
    def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
Tri Dao's avatar
Tri Dao committed
870
        """
871
872
873
        Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
        q: (batch_size, seqlen_q, nheads, head_dim)
        kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
Tri Dao's avatar
Tri Dao committed
874
        """
875
        assert inference_params is not None and inference_params.seqlen_offset > 0
876
877
878
879
        assert self.use_flash_attn
        if self.rotary_emb_dim > 0:
            assert self.rotary_emb.scale is None, "This code path does not support xPos"
            self.rotary_emb._update_cos_sin_cache(
880
                inference_params.max_seqlen, device=q.device, dtype=q.dtype
881
882
883
884
885
886
887
888
889
            )
            rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
        else:
            rotary_cos, rotary_sin = None, None
        batch = q.shape[0]
        kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
        cache_seqlens = (
            inference_params.lengths_per_sample[:batch]
            if inference_params.lengths_per_sample is not None
890
            else inference_params.seqlen_offset
891
        )
892
        alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
893
894
895
896
897
898
899
900
901
902
903
904
        context = flash_attn_with_kvcache(
            q,
            kv_cache[:, :, 0],
            kv_cache[:, :, 1],
            kv[:, :, 0],
            kv[:, :, 1],
            rotary_cos=rotary_cos,
            rotary_sin=rotary_sin,
            cache_seqlens=cache_seqlens,
            softmax_scale=self.inner_cross_attn.softmax_scale,
            causal=self.inner_cross_attn.causal,
            rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
905
            alibi_slopes=alibi_slopes,
Tri Dao's avatar
Tri Dao committed
906
        )
907
        return context
Tri Dao's avatar
Tri Dao committed
908

909
    def _update_kvcache_attention(self, q, kv, inference_params):
910
        """Write kv to inference_params, then do attention"""
911
912
        if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
            # TODO: this only uses seqlen_offset and not lengths_per_sample.
913
914
915
916
917
918
919
920
            kv = self._update_kv_cache(kv, inference_params)
            return self.inner_cross_attn(q, kv)
        else:
            batch = q.shape[0]
            kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
            cache_seqlens = (
                inference_params.lengths_per_sample[:batch]
                if inference_params.lengths_per_sample is not None
921
                else inference_params.seqlen_offset
922
            )
923
            alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
924
925
926
927
928
929
930
931
932
            context = flash_attn_with_kvcache(
                q,
                kv_cache[:, :, 0],
                kv_cache[:, :, 1],
                kv[:, :, 0],
                kv[:, :, 1],
                cache_seqlens=cache_seqlens,
                softmax_scale=self.inner_cross_attn.softmax_scale,
                causal=self.inner_cross_attn.causal,
933
                alibi_slopes=alibi_slopes,
934
935
936
            )
            return context

937
    def forward(self, x, seqlen=None, inference_params=None, **kwargs):
Tri Dao's avatar
Tri Dao committed
938
939
940
941
942
943
944
945
        """
        Arguments:
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
                If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
                split x during sequence parallel, we split the batch * seqlen dimension
                (in case batch is small).
        """
        qkv = self.Wqkv(x)
Tri Dao's avatar
Tri Dao committed
946
947
        if seqlen is not None:
            qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
948
949
950
951
952
953
        seqlen_offset = (
            0
            if inference_params is None
            else (
                inference_params.lengths_per_sample
                if inference_params.lengths_per_sample is not None
954
                else inference_params.seqlen_offset
955
956
            )
        )
957
        rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
Tri Dao's avatar
Tri Dao committed
958
        if self.num_heads_kv == self.num_heads:
Tri Dao's avatar
Tri Dao committed
959
960
961
            qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
            if (
                inference_params is None
962
                or inference_params.seqlen_offset == 0
963
964
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
965
            ):
Tri Dao's avatar
Tri Dao committed
966
                if self.rotary_emb_dim > 0:
967
968
969
                    qkv = self.rotary_emb(
                        qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
970
971
972
973
974
975
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_attn(qkv, **kwargs)
                    else:
                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
                else:
976
977
978
                    context = self._update_kvcache_attention(
                        qkv[:, :, 0], qkv[:, :, 1:], inference_params
                    )
979
            else:
980
981
982
                context = self._apply_rotary_update_kvcache_attention(
                    qkv[:, :, 0], qkv[:, :, 1:], inference_params
                )
Tri Dao's avatar
Tri Dao committed
983
        else:
Tri Dao's avatar
Tri Dao committed
984
985
986
987
988
989
990
991
992
993
994
995
996
            q = rearrange(
                qkv[..., : self.num_heads_per_rank * self.head_dim],
                "... (h d) -> ... h d",
                d=self.head_dim,
            )
            kv = rearrange(
                qkv[..., self.num_heads_per_rank * self.head_dim :],
                "... (two hkv d) -> ... two hkv d",
                two=2,
                d=self.head_dim,
            )
            if (
                inference_params is None
997
                or inference_params.seqlen_offset == 0
998
999
                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
                or not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
1000
            ):
1001
                if self.rotary_emb_dim > 0:
1002
1003
1004
                    q, kv = self.rotary_emb(
                        q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
                    )
Tri Dao's avatar
Tri Dao committed
1005
1006
1007
1008
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_cross_attn(q, kv, **kwargs)
                    else:
Tri Dao's avatar
Tri Dao committed
1009
1010
1011
                        context = torch.utils.checkpoint.checkpoint(
                            self.inner_cross_attn, q, kv, **kwargs
                        )
Tri Dao's avatar
Tri Dao committed
1012
                else:
1013
                    context = self._update_kvcache_attention(q, kv, inference_params)
1014
            else:
1015
                context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
Tri Dao's avatar
Tri Dao committed
1016
        context = rearrange(context, "b s h d -> b s (h d)")
Tri Dao's avatar
Tri Dao committed
1017
        if seqlen is not None:
Tri Dao's avatar
Tri Dao committed
1018
            context = rearrange(context, "b s d -> (b s) d")
Tri Dao's avatar
Tri Dao committed
1019
1020
        out = self.out_proj(context)
        return out