mha.py 37 KB
Newer Older
1
2
3
4
5
6
7
8
9
# Copyright (c) 2022, Tri Dao.

import math
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

Tri Dao's avatar
Tri Dao committed
10
from einops import rearrange, repeat
11
12

try:
Tri Dao's avatar
Tri Dao committed
13
14
    from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func
    from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func
15
except ImportError:
Tri Dao's avatar
Tri Dao committed
16
    flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
17
18
19
    flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None

try:
Tri Dao's avatar
Tri Dao committed
20
    from flash_attn.ops.fused_dense import FusedDense, ColumnParallelLinear, RowParallelLinear
21
except ImportError:
Tri Dao's avatar
Tri Dao committed
22
    FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
23
24
25
26
27
28

try:
    from flash_attn.layers.rotary import RotaryEmbedding
except ImportError:
    RotaryEmbedding = None

29
30
31
32
33
try:
    import ft_attention
except ImportError:
    ft_attention = None

34
35
36
37
38
39
40
41
42
43
44

class FlashSelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
45
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
46
        super().__init__()
Tri Dao's avatar
Tri Dao committed
47
48
        assert flash_attn_varlen_qkvpacked_func is not None, 'FlashAttention is not installed'
        assert flash_attn_qkvpacked_func is not None, 'FlashAttention is not installed'
49
50
        self.causal = causal
        self.softmax_scale = softmax_scale
51
        self.drop = nn.Dropout(attention_dropout)
52

Tri Dao's avatar
Tri Dao committed
53
    def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
54
55
56
        """Implements the multihead softmax attention.
        Arguments
        ---------
Tri Dao's avatar
Tri Dao committed
57
58
59
60
            qkv: The tensor containing the query, key, and value.
                If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
                If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
                (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
Tri Dao's avatar
Tri Dao committed
61
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
62
63
64
65
66
67
68
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into qkv.
            max_seqlen: int. Maximum sequence length in the batch.
        Returns:
        --------
            out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
                else (B, S, H, D).
69
70
71
        """
        assert qkv.dtype in [torch.float16, torch.bfloat16]
        assert qkv.is_cuda
Tri Dao's avatar
Tri Dao committed
72
        causal = self.causal if causal is None else causal
Tri Dao's avatar
Tri Dao committed
73
74
75
76
77
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
Tri Dao's avatar
Tri Dao committed
78
            return flash_attn_varlen_qkvpacked_func(
79
                qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
80
                softmax_scale=self.softmax_scale, causal=causal
81
            )
Tri Dao's avatar
Tri Dao committed
82
        else:
Tri Dao's avatar
Tri Dao committed
83
84
            return flash_attn_qkvpacked_func(qkv, self.drop.p if self.training else 0.0,
                                             softmax_scale=self.softmax_scale, causal=causal)
85
86
87
88
89
90
91
92
93
94
95
96


class FlashCrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
97
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
98
        super().__init__()
Tri Dao's avatar
Tri Dao committed
99
100
        assert flash_attn_varlen_kvpacked_func is not None, 'FlashAttention is not installed'
        assert flash_attn_kvpacked_func is not None, 'FlashAttention is not installed'
101
102
        self.causal = causal
        self.softmax_scale = softmax_scale
103
        self.drop = nn.Dropout(attention_dropout)
104

Tri Dao's avatar
Tri Dao committed
105
106
    def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None,
                cu_seqlens_k=None, max_seqlen_k=None):
107
108
109
110
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
Tri Dao's avatar
Tri Dao committed
111
            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
Tri Dao's avatar
Tri Dao committed
112
            causal: if passed, will override self.causal
113
114
115
116
117
118
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into q.
            max_seqlen: int. Maximum sequence length in the batch of q.
            cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into kv.
            max_seqlen_k: int. Maximum sequence length in the batch of k and v.
119
120
121
        """
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.is_cuda and kv.is_cuda
Tri Dao's avatar
Tri Dao committed
122
        causal = self.causal if causal is None else causal
123
124
125
126
127
128
129
130
131
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
            assert cu_seqlens_k is not None
            assert cu_seqlens_k.dtype == torch.int32
            assert max_seqlen_k is not None
            assert isinstance(max_seqlen, int)
Tri Dao's avatar
Tri Dao committed
132
            return flash_attn_varlen_kvpacked_func(
133
                q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k,
134
                self.drop.p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
135
                softmax_scale=self.softmax_scale, causal=causal
136
            )
137
138
139
        else:
            batch_size, seqlen_q = q.shape[0], q.shape[1]
            seqlen_k = kv.shape[1]
Tri Dao's avatar
Tri Dao committed
140
141
142
            assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
            return flash_attn_kvpacked_func(q, kv, self.drop.p if self.training else 0.0,
                                            causal=causal, softmax_scale=self.softmax_scale)
143
144
145
146
147
148
149
150
151
152
153
154


class SelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
155
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
156
157
158
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
159
        self.drop = nn.Dropout(attention_dropout)
160

Tri Dao's avatar
Tri Dao committed
161
    def forward(self, qkv, causal=None, key_padding_mask=None):
162
163
164
165
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
Tri Dao's avatar
Tri Dao committed
166
            causal: if passed, will override self.causal
Tri Dao's avatar
Tri Dao committed
167
168
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, S)
169
170
        """
        batch_size, seqlen = qkv.shape[0], qkv.shape[1]
Tri Dao's avatar
Tri Dao committed
171
        causal = self.causal if causal is None else causal
172
173
174
        q, k, v = qkv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
        scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
Tri Dao's avatar
Tri Dao committed
175
176
177
178
179
180
        if key_padding_mask is not None:
            padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype,
                                      device=scores.device)
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
Tri Dao's avatar
Tri Dao committed
181
        if causal:
182
183
184
185
186
187
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
            causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + causal_mask.to(dtype=scores.dtype)
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
188
        attention_drop = self.drop(attention)
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
        return output


class CrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
203
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
204
205
206
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
207
        self.drop = nn.Dropout(attention_dropout)
208

Tri Dao's avatar
Tri Dao committed
209
    def forward(self, q, kv, causal=None, key_padding_mask=None):
210
211
212
213
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
Tri Dao's avatar
Tri Dao committed
214
            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
Tri Dao's avatar
Tri Dao committed
215
            causal: if passed, will override self.causal
216
217
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, Sk)
218
219
        """
        batch_size, seqlen_q = q.shape[0], q.shape[1]
Tri Dao's avatar
Tri Dao committed
220
        causal = self.causal if causal is None else causal
221
        seqlen_k = kv.shape[1]
Tri Dao's avatar
Tri Dao committed
222
223
224
        assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
        if kv.shape[3] != q.shape[2]:  # MQA/GQA
            kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
225
226
227
        k, v = kv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
        scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
228
229
230
231
232
233
        if key_padding_mask is not None:
            padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype,
                                      device=scores.device)
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
Tri Dao's avatar
Tri Dao committed
234
        if causal:
235
236
237
238
239
240
241
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
            causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0,
                                                device=scores.device), 1)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + causal_mask.to(dtype=scores.dtype)
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
242
        attention_drop = self.drop(attention)
243
244
245
246
247
        output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
        return output


class LinearResidual(nn.Linear):
Tri Dao's avatar
Tri Dao committed
248
    """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.
249
250
251
252
253
254
    """

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return super().forward(input), input


255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
def _update_kv_cache(kv, inference_params, layer_idx):
    """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
    """
    # Pre-allocate memory for key-values for inference.
    num_heads, head_dim = kv.shape[-2:]
    if layer_idx not in inference_params.key_value_memory_dict:
        kv_cache = torch.empty(
            inference_params.max_batch_size, inference_params.max_sequence_len, 2,
            num_heads, head_dim, dtype=kv.dtype, device=kv.device
        )
        inference_params.key_value_memory_dict[layer_idx] = kv_cache
    else:
        if not inference_params.fused_ft_kernel:
            kv_cache = inference_params.key_value_memory_dict[layer_idx]
        else:
            # For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
            # where packsize = 4 if fp32, 8 if fp16 or bf16.
            # v_cache has shape (b, h, s, headdim)
            k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
            kv_cache = None
    # Adjust key and value for inference
    batch_start = inference_params.batch_size_offset
    batch_end = batch_start + kv.shape[0]
    sequence_start = inference_params.sequence_len_offset
    sequence_end = sequence_start + kv.shape[1]
    assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
    assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
    # Copy key and values.
    if not inference_params.fused_ft_kernel:
        assert kv_cache is not None
        kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
        kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
        return kv
    else:
        assert inference_params.sequence_len_offset == 0
        # FT kernel requires different layouts for the k_cache and v_cache.
        assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
        packsize = 4 if kv.dtype == torch.float32 else 8
        if kv_cache is not None:
            kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
            k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize',
                                packsize=packsize).contiguous()
            v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous()
            inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
        else:
            k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
                kv[:, :, 0], 'b s h (d packsize) -> b h d s packsize', packsize=packsize
            )
            v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(
                kv[:, :, 1], 'b s h d -> b h s d'
            )
        return kv


Tri Dao's avatar
Tri Dao committed
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
def _apply_rotary_single_query_attention(qkv, inference_params, layer_idx, rotary_emb_dim,
                                         rotary_emb_base, kv=None, rotary_emb_interleaved=False):
    """
    qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
            q of shape (batch_size, 1, nheads, head_dim)
    kv: (batch_size, 1, 2, nheads_kv, head_dim)
    """
    assert inference_params.fused_ft_kernel
    assert ft_attention is not None
    if kv is None:
        q, k, v = rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1)
    else:
        q = rearrange(qkv, 'b 1 h d -> b h d')
        k, v = rearrange(kv, 'b 1 two h d -> b two h d').unbind(dim=1)
    batch_start = inference_params.batch_size_offset
    batch_end = batch_start + q.shape[0]
    k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
    lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end]
                            if inference_params.lengths_per_sample is not None else None)
    context = ft_attention.single_query_attention(
        q, k, v,
        k_cache[batch_start:batch_end],
        v_cache[batch_start:batch_end],
        lengths_per_sample,
        None,  # rotary_cos_
        None,  # rotary_sin_
        None,  # nnz_head_idx
        inference_params.sequence_len_offset,
        rotary_emb_dim, rotary_emb_base,
        not rotary_emb_interleaved  # neox_rotary_style
    )
    return rearrange(context, 'b h d -> b 1 h d')


343
344
345
346
class MHA(nn.Module):
    """Multi-head self-attention and cross-attention
    """

Tri Dao's avatar
Tri Dao committed
347
    def __init__(self, embed_dim, num_heads, num_heads_kv=None, cross_attn=False,
Tri Dao's avatar
Tri Dao committed
348
349
                 qkv_proj_bias=True, out_proj_bias=True,
                 dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False,
350
351
352
                 rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None,
                 rotary_emb_interleaved=False, fused_bias_fc=False, use_flash_attn=False,
                 return_residual=False, checkpointing=False, device=None, dtype=None) -> None:
353
        """
Tri Dao's avatar
Tri Dao committed
354
            num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
355
356
357
358
359
360
361
362
363
            return_residual: whether to return the input x along with the output. This is for
                performance reason: for post-norm architecture, returning the input allows us
                to fuse the backward of nn.Linear with the residual connection.
        """
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.embed_dim = embed_dim
        self.cross_attn = cross_attn
        self.causal = causal
Tri Dao's avatar
Tri Dao committed
364
        self.layer_idx = layer_idx
365
366
        self.dwconv = dwconv
        self.rotary_emb_dim = rotary_emb_dim
Tri Dao's avatar
Tri Dao committed
367
        self.use_flash_attn = use_flash_attn
368
369
370
371
        self.return_residual = return_residual
        self.checkpointing = checkpointing

        self.num_heads = num_heads
Tri Dao's avatar
Tri Dao committed
372
373
374
        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
        assert self.num_heads % self.num_heads_kv == 0, "num_heads must be divisible by num_heads_kv"
        assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
375
        self.head_dim = self.embed_dim // num_heads
Tri Dao's avatar
Tri Dao committed
376
377
        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
        kv_dim = 2 * self.head_dim * self.num_heads_kv
378
379
380
381

        if self.rotary_emb_dim > 0:
            assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet'
            assert RotaryEmbedding is not None, 'rotary_emb is not installed'
382
383
            self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, base=rotary_emb_base,
                                              scale_base=rotary_emb_scale_base,
Tri Dao's avatar
Tri Dao committed
384
                                              interleaved=rotary_emb_interleaved, device=device)
385

Tri Dao's avatar
Tri Dao committed
386
        if fused_bias_fc and FusedDense is None:
387
            raise ImportError('fused_dense is not installed')
Tri Dao's avatar
Tri Dao committed
388
389
390
        linear_cls = nn.Linear if not fused_bias_fc else FusedDense
        linear_resid_cls = (LinearResidual if not fused_bias_fc
                            else partial(FusedDense, return_residual=True))
Tri Dao's avatar
Tri Dao committed
391
        wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
Tri Dao's avatar
Tri Dao committed
392
393
        inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
        inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
394
        if not self.cross_attn:
Tri Dao's avatar
Tri Dao committed
395
            self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
396
        else:
Tri Dao's avatar
Tri Dao committed
397
            self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
398
399
400
401
402
            self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
        if self.dwconv:
            if self.num_heads_kv == self.num_heads:
                self.dwconv_qkv = nn.Conv1d(qkv_dim, qkv_dim, kernel_size=3, padding=2,
                                            groups=qkv_dim)
403
            else:
404
                self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2,
Tri Dao's avatar
Tri Dao committed
405
                                          groups=embed_dim)
Tri Dao's avatar
Tri Dao committed
406
407
                self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2,
                                           groups=kv_dim)
408
        self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
Tri Dao's avatar
Tri Dao committed
409
                                         attention_dropout=dropout)
Tri Dao's avatar
Tri Dao committed
410
411
        self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
                                                     attention_dropout=dropout)
Tri Dao's avatar
Tri Dao committed
412
        self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
413

414
415
416
417
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
        dtype = self.out_proj.weight.dtype if dtype is None else dtype
        device = self.out_proj.weight.device
        if not fused_ft_kernel:
Tri Dao's avatar
Tri Dao committed
418
            return torch.empty(batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim,
419
420
421
422
423
                               dtype=dtype, device=device)
        else:
            assert dtype in [torch.float16, torch.bfloat16, torch.float32]
            packsize = 4 if dtype == torch.float32 else 8
            assert self.head_dim % packsize == 0
Tri Dao's avatar
Tri Dao committed
424
425
426
            k_cache = torch.empty(batch_size, self.num_heads_kv, self.head_dim // packsize,
                                  max_seqlen, packsize, dtype=dtype, device=device)
            v_cache = torch.empty(batch_size, self.num_heads_kv, max_seqlen, self.head_dim,
427
428
429
                                  dtype=dtype, device=device)
            return k_cache, v_cache

Tri Dao's avatar
Tri Dao committed
430
    def _update_kv_cache(self, kv, inference_params):
431
        """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
Tri Dao's avatar
Tri Dao committed
432
433
434
        """
        assert not self.dwconv, 'Generation does not support dwconv yet'
        assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
435
        return _update_kv_cache(kv, inference_params, self.layer_idx)
Tri Dao's avatar
Tri Dao committed
436

Tri Dao's avatar
Tri Dao committed
437
438
439
440
441
442
443
444
445
446
447
448
    def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
        """
        qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
              q of shape (batch_size, 1, nheads, head_dim)
        kv: (batch_size, 1, 2, nheads_kv, head_dim)
        """
        rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
        return _apply_rotary_single_query_attention(
            qkv, inference_params, self.layer_idx, self.rotary_emb_dim, rotary_emb_base, kv=kv,
            rotary_emb_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
        )

449
    def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
450
                mixer_subset=None, inference_params=None, **kwargs):
451
452
        """
        Arguments:
Tri Dao's avatar
Tri Dao committed
453
454
455
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
                cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
                is the is the sum of the sequence lengths in the batch.
456
            x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
Tri Dao's avatar
Tri Dao committed
457
458
459
460
461
462
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into x. Only applicable when using
                FlashAttention.
            max_seqlen: int. Maximum sequence length in the batch.
            key_padding_mask: boolean mask, True means to keep, False means to mask out.
                (batch, seqlen). Only applicable when not using FlashAttention.
463
464
465
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
Tri Dao's avatar
Tri Dao committed
466
467
            inference_params: for generation. Adapted from Megatron-LM (and Apex)
            https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
468
        """
Tri Dao's avatar
Tri Dao committed
469
470
471
472
473
474
475
476
477
478
        if cu_seqlens is not None:
            assert max_seqlen is not None
            assert key_padding_mask is None
            assert self.use_flash_attn
            assert not self.dwconv
            assert self.rotary_emb_dim == 0
        if key_padding_mask is not None:
            assert cu_seqlens is None
            assert max_seqlen is None
            assert not self.use_flash_attn
Tri Dao's avatar
Tri Dao committed
479
480
481
482
        if inference_params is not None:
            assert key_padding_mask is None
            assert cu_seqlens is None and max_seqlen is None
            assert not self.dwconv
Tri Dao's avatar
Tri Dao committed
483

484
485
        kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs}
                  if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs})
Tri Dao's avatar
Tri Dao committed
486
487
        seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
        if not self.cross_attn and self.num_heads_kv == self.num_heads:
488
            assert x_kv is None and mixer_subset is None
489
490
491
492
493
494
495
            if not self.return_residual:
                qkv = self.Wqkv(x)
            else:
                qkv, x = self.Wqkv(x)
            if self.dwconv:
                qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2],
                                'b d s -> b s d').contiguous()
Tri Dao's avatar
Tri Dao committed
496
            qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim)
Tri Dao's avatar
Tri Dao committed
497
498
            if (inference_params is None or inference_params.sequence_len_offset == 0
                or not inference_params.fused_ft_kernel):
Tri Dao's avatar
Tri Dao committed
499
                if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
500
501
502
503
504
505
506
                    qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_attn(qkv, **kwargs)
                    else:
                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv,
                                                                    **kwargs)
Tri Dao's avatar
Tri Dao committed
507
                else:
508
509
510
511
512
513
                    q = qkv[:, :, 0]
                    kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
                    # If we're processing the prompt, causal=None (use self.causal).
                    # If we're decoding, then causal=False.
                    causal = None if inference_params.sequence_len_offset == 0 else False
                    context = self.inner_cross_attn(q, kv, causal=causal)
Tri Dao's avatar
Tri Dao committed
514
515
            else:
                context = self._apply_rotary_single_query_attention(qkv, inference_params)
516
        else:
Tri Dao's avatar
Tri Dao committed
517
518
519
520
521
522
523
524
525
526
            if self.cross_attn:
                if not self.return_residual:
                    q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
                    kv = self.Wkv(x_kv if x_kv is not None else x)
                else:
                    if x_kv is not None:
                        kv, x_kv = self.Wkv(x_kv)
                    else:
                        kv, x = self.Wkv(x)
                    q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
527
            else:
Tri Dao's avatar
Tri Dao committed
528
529
530
                assert self.num_heads_kv != self.num_heads
                if not self.return_residual:
                    qkv = self.Wqkv(x)
531
                else:
Tri Dao's avatar
Tri Dao committed
532
533
534
                    qkv, x = self.Wqkv(x)
                q = qkv[..., :self.num_heads * self.head_dim]
                kv = qkv[..., self.num_heads * self.head_dim:]
Tri Dao's avatar
Tri Dao committed
535
            q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
Tri Dao's avatar
Tri Dao committed
536
            kv = rearrange(kv, '... (two hkv d) -> ... two hkv d', two=2, d=self.head_dim)
537
538
539
540
541
            if self.dwconv:
                q = rearrange(self.dwconv_q(rearrange(q, 'b s d -> b d s'))[..., :-2],
                              'b d s -> b s d').contiguous()
                kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2],
                               'b d s -> b s d').contiguous()
Tri Dao's avatar
Tri Dao committed
542
543
544
545
546
547
548
549
550
551
            if (inference_params is None or inference_params.sequence_len_offset == 0
                or not inference_params.fused_ft_kernel):
                if self.rotary_emb_dim > 0:
                    q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_cross_attn(q, kv, **kwargs)
                    else:
                        context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv,
                                                                    **kwargs)
Tri Dao's avatar
Tri Dao committed
552
                else:
Tri Dao's avatar
Tri Dao committed
553
554
555
556
557
                    kv = self._update_kv_cache(kv, inference_params)
                    # If we're processing the prompt, causal=None (use self.causal).
                    # If we're decoding, then causal=False.
                    causal = None if inference_params.sequence_len_offset == 0 else False
                    context = self.inner_cross_attn(q, kv, causal=causal)
558
            else:
Tri Dao's avatar
Tri Dao committed
559
                context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
Tri Dao's avatar
Tri Dao committed
560
        out = self.out_proj(rearrange(context, '... h d -> ... (h d)'))
561
        return out if not self.return_residual else (out, x)
Tri Dao's avatar
Tri Dao committed
562
563
564
565
566
567


class ParallelMHA(nn.Module):
    """Multi-head self-attention and cross-attention
    """

Tri Dao's avatar
Tri Dao committed
568
569
    def __init__(self, embed_dim, num_heads, process_group, num_heads_kv=None,
                 qkv_proj_bias=True, out_proj_bias=True,
Tri Dao's avatar
Tri Dao committed
570
                 dropout=0.0, softmax_scale=None, causal=False, layer_idx=None,
571
572
                 rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None,
                 rotary_emb_interleaved=False, use_flash_attn=False, checkpointing=False,
573
                 sequence_parallel=True, device=None, dtype=None) -> None:
Tri Dao's avatar
Tri Dao committed
574
575
576
577
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal
578
        self.layer_idx = layer_idx
Tri Dao's avatar
Tri Dao committed
579
580
581
        self.rotary_emb_dim = rotary_emb_dim
        self.use_flash_attn = use_flash_attn
        self.checkpointing = checkpointing
Tri Dao's avatar
Tri Dao committed
582
583
        self.process_group = process_group
        self.world_size = process_group.size() if process_group is not None else 1
Tri Dao's avatar
Tri Dao committed
584
585

        self.num_heads = num_heads
Tri Dao's avatar
Tri Dao committed
586
587
588
589
590
591
        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
        self.num_heads_per_rank = num_heads // self.world_size
        self.num_heads_kv_per_rank = self.num_heads_kv // self.world_size
        assert self.num_heads % self.num_heads_kv == 0, "num_heads must be divisible by num_heads_kv"
        assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        assert self.num_heads_kv % self.world_size == 0, "num_heads_kv must be divisible by world_size"
Tri Dao's avatar
Tri Dao committed
592
        self.head_dim = self.embed_dim // num_heads
Tri Dao's avatar
Tri Dao committed
593
594
        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
        kv_dim = 2 * self.head_dim * self.num_heads_kv
Tri Dao's avatar
Tri Dao committed
595
596
597

        if self.rotary_emb_dim > 0:
            assert RotaryEmbedding is not None, 'rotary_emb is not installed'
598
599
            self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, base=rotary_emb_base,
                                              scale_base=rotary_emb_scale_base,
Tri Dao's avatar
Tri Dao committed
600
                                              interleaved=rotary_emb_interleaved, device=device)
Tri Dao's avatar
Tri Dao committed
601
602
603

        if ColumnParallelLinear is None or RowParallelLinear is None:
            raise ImportError('fused_dense is not installed')
Tri Dao's avatar
Tri Dao committed
604
        self.Wqkv = ColumnParallelLinear(embed_dim, qkv_dim, process_group,
Tri Dao's avatar
Tri Dao committed
605
                                         bias=qkv_proj_bias,
606
                                         sequence_parallel=sequence_parallel, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
607
        inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
608
        inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
Tri Dao's avatar
Tri Dao committed
609
610
        self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
                                         attention_dropout=dropout)
611
612
        self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
                                                     attention_dropout=dropout)
613
        self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group,
Tri Dao's avatar
Tri Dao committed
614
                                          bias=out_proj_bias,
615
                                          sequence_parallel=sequence_parallel, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
616

Tri Dao's avatar
Tri Dao committed
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
        dtype = self.out_proj.weight.dtype if dtype is None else dtype
        device = self.out_proj.weight.device
        if not fused_ft_kernel:
            return torch.empty(batch_size, max_seqlen, 2, self.num_heads_kv_per_rank,
                               self.head_dim, dtype=dtype, device=device)
        else:
            assert dtype in [torch.float16, torch.bfloat16, torch.float32]
            packsize = 4 if dtype == torch.float32 else 8
            assert self.head_dim % packsize == 0
            k_cache = torch.empty(batch_size, self.num_heads_kv_per_rank,
                                  self.head_dim // packsize,
                                  max_seqlen, packsize, dtype=dtype, device=device)
            v_cache = torch.empty(batch_size, self.num_heads_kv_per_rank, max_seqlen,
                                  self.head_dim, dtype=dtype, device=device)
            return k_cache, v_cache

    def _update_kv_cache(self, kv, inference_params):
        """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
        """
        assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
        return _update_kv_cache(kv, inference_params, self.layer_idx)

    def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
        """
        qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
              q of shape (batch_size, 1, nheads, head_dim)
        kv: (batch_size, 1, 2, nheads_kv, head_dim)
        """
        rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
        return _apply_rotary_single_query_attention(
            qkv, inference_params, self.layer_idx, self.rotary_emb_dim, rotary_emb_base, kv=kv,
            rotary_emb_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
        )

652
    def forward(self, x, seqlen=None, inference_params=None, **kwargs):
Tri Dao's avatar
Tri Dao committed
653
654
655
656
657
658
659
660
        """
        Arguments:
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
                If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
                split x during sequence parallel, we split the batch * seqlen dimension
                (in case batch is small).
        """
        qkv = self.Wqkv(x)
Tri Dao's avatar
Tri Dao committed
661
662
663
664
        if seqlen is not None:
            qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
        seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
        if self.num_heads_kv == self.num_heads:
Tri Dao's avatar
Tri Dao committed
665
            qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, d=self.head_dim)
Tri Dao's avatar
Tri Dao committed
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
            if (inference_params is None or inference_params.sequence_len_offset == 0
                or not inference_params.fused_ft_kernel):
                if self.rotary_emb_dim > 0:
                    qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_attn(qkv, **kwargs)
                    else:
                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
                else:
                    q = qkv[:, :, 0]
                    kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
                    # If we're processing the prompt, causal=None (use self.causal).
                    # If we're decoding, then causal=False.
                    causal = None if inference_params.sequence_len_offset == 0 else False
                    context = self.inner_cross_attn(q, kv, causal=causal)
682
            else:
Tri Dao's avatar
Tri Dao committed
683
                context = self._apply_rotary_single_query_attention(qkv, inference_params)
Tri Dao's avatar
Tri Dao committed
684
        else:
Tri Dao's avatar
Tri Dao committed
685
686
687
688
689
690
            q = rearrange(qkv[..., :self.num_heads_per_rank * self.head_dim],
                          "... (h d) -> ... h d", d=self.head_dim)
            kv = rearrange(qkv[..., self.num_heads_per_rank * self.head_dim:],
                           "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
            if (inference_params is None or inference_params.sequence_len_offset == 0
                or not inference_params.fused_ft_kernel):
691
                if self.rotary_emb_dim > 0:
Tri Dao's avatar
Tri Dao committed
692
693
694
695
696
697
698
699
700
701
702
703
704
                    q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
                if inference_params is None:
                    if not self.checkpointing:
                        context = self.inner_cross_attn(q, kv, **kwargs)
                    else:
                        context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv,
                                                                    **kwargs)
                else:
                    kv = self._update_kv_cache(kv, inference_params)
                    # If we're processing the prompt, causal=None (use self.causal).
                    # If we're decoding, then causal=False.
                    causal = None if inference_params.sequence_len_offset == 0 else False
                    context = self.inner_cross_attn(q, kv, causal=causal)
705
            else:
Tri Dao's avatar
Tri Dao committed
706
707
708
709
                context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
        context = rearrange(context, 'b s h d -> b s (h d)')
        if seqlen is not None:
            context = rearrange(context, 'b s d -> (b s) d')
Tri Dao's avatar
Tri Dao committed
710
711
        out = self.out_proj(context)
        return out