mha.py 23 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright (c) 2022, Tri Dao.

import math
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange

try:
    from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
    from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
except ImportError:
    flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func = None, None

try:
    from flash_attn.ops.flash_attn_triton import flash_attn_qkvpacked_func, flash_attn_kvpacked_func
except ImportError:
    flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None

try:
Tri Dao's avatar
Tri Dao committed
24
    from flash_attn.ops.fused_dense import FusedDense, ColumnParallelLinear, RowParallelLinear
25
except ImportError:
Tri Dao's avatar
Tri Dao committed
26
    FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

try:
    from flash_attn.layers.rotary import RotaryEmbedding
except ImportError:
    RotaryEmbedding = None


class FlashSelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
Tri Dao's avatar
Tri Dao committed
45
                 triton=False):
46
47
48
49
50
51
52
53
54
55
        super().__init__()
        if attention_dropout != 0.0 or not triton:
            assert flash_attn_unpadded_qkvpacked_func is not None, 'FlashAttention is not installed'
        if attention_dropout == 0.0 and triton:
            assert flash_attn_qkvpacked_func is not None, 'FlashAttention Triton is not installed'
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout
        self.triton = triton

Tri Dao's avatar
Tri Dao committed
56
    def forward(self, qkv, cu_seqlens=None, max_seqlen=None):
57
58
59
        """Implements the multihead softmax attention.
        Arguments
        ---------
Tri Dao's avatar
Tri Dao committed
60
61
62
63
64
65
66
67
68
69
70
            qkv: The tensor containing the query, key, and value.
                If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
                If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
                (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into qkv.
            max_seqlen: int. Maximum sequence length in the batch.
        Returns:
        --------
            out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
                else (B, S, H, D).
71
72
73
        """
        assert qkv.dtype in [torch.float16, torch.bfloat16]
        assert qkv.is_cuda
Tri Dao's avatar
Tri Dao committed
74
75
76
77
78
79
80
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
            return flash_attn_unpadded_qkvpacked_func(
                qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
81
82
                softmax_scale=self.softmax_scale, causal=self.causal
            )
Tri Dao's avatar
Tri Dao committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        else:
            batch_size, seqlen = qkv.shape[0], qkv.shape[1]
            # Triton version doesn't support dropout
            if self.triton and (self.dropout_p == 0 or not self.training):
                output = flash_attn_qkvpacked_func(qkv, None, self.causal, self.softmax_scale)
            else:
                qkv = rearrange(qkv, 'b s ... -> (b s) ...')
                max_seqlen = seqlen
                cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
                                        device=qkv.device)
                output = flash_attn_unpadded_qkvpacked_func(
                    qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
                    softmax_scale=self.softmax_scale, causal=self.causal
                )
                output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
            return output
99
100
101
102
103
104
105
106
107
108
109
110
111


class FlashCrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
Tri Dao's avatar
Tri Dao committed
112
                 triton=False):
113
114
115
116
117
118
119
120
121
122
        super().__init__()
        if attention_dropout != 0.0 or not triton:
            assert flash_attn_unpadded_kvpacked_func is not None, 'FlashAttention is not installed'
        if attention_dropout == 0.0 and triton:
            assert flash_attn_kvpacked_func is not None, 'FlashAttention Triton is not installed'
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout
        self.triton = triton

123
    def forward(self, q, kv, cu_seqlens=None, max_seqlen=None, cu_seqlens_k=None, max_seqlen_k=None):
124
125
126
127
128
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
            kv: The tensor containing the key and value. (B, Sk, 2, H, D)
129
130
131
132
133
134
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into q.
            max_seqlen: int. Maximum sequence length in the batch of q.
            cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into kv.
            max_seqlen_k: int. Maximum sequence length in the batch of k and v.
135
136
137
        """
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.is_cuda and kv.is_cuda
138
139
140
141
142
143
144
145
146
147
148
        unpadded = cu_seqlens is not None
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
            assert cu_seqlens_k is not None
            assert cu_seqlens_k.dtype == torch.int32
            assert max_seqlen_k is not None
            assert isinstance(max_seqlen, int)
            return flash_attn_unpadded_kvpacked_func(
                q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k,
149
150
151
                self.dropout_p if self.training else 0.0,
                softmax_scale=self.softmax_scale, causal=self.causal
            )
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        else:
            batch_size, seqlen_q = q.shape[0], q.shape[1]
            seqlen_k = kv.shape[1]
            assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
            if self.triton and (self.dropout_p == 0.0 or not self.training):  # Triton version doesn't support dropout
                output = flash_attn_kvpacked_func(q, kv, None, self.causal, self.softmax_scale)
            else:
                q = rearrange(q, 'b s ... -> (b s) ...')
                kv = rearrange(kv, 'b s ... -> (b s) ...')
                cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q,
                                            dtype=torch.int32, device=q.device)
                cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k,
                                            dtype=torch.int32, device=kv.device)
                output = flash_attn_unpadded_kvpacked_func(
                    q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
                    self.dropout_p if self.training else 0.0,
                    softmax_scale=self.softmax_scale, causal=self.causal
                )
                output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
            return output
172
173
174
175
176
177
178
179
180
181
182
183


class SelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
184
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
185
186
187
188
189
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout

Tri Dao's avatar
Tri Dao committed
190
    def forward(self, qkv, key_padding_mask=None):
191
192
193
194
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
Tri Dao's avatar
Tri Dao committed
195
196
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, S)
197
198
199
200
201
        """
        batch_size, seqlen = qkv.shape[0], qkv.shape[1]
        q, k, v = qkv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
        scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
Tri Dao's avatar
Tri Dao committed
202
203
204
205
206
207
        if key_padding_mask is not None:
            padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype,
                                      device=scores.device)
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        if self.causal:
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
            causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + causal_mask.to(dtype=scores.dtype)
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
        attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0)
        output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
        return output


class CrossAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
Tri Dao's avatar
Tri Dao committed
230
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
231
232
233
234
235
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout

236
    def forward(self, q, kv, key_padding_mask=None):
237
238
239
240
241
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q: The tensor containing the query. (B, Sq, H, D)
            kv: The tensor containing the key and value. (B, Sk, 2, H, D)
242
243
            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
                False means to mask out. (B, Sk)
244
245
246
247
248
249
250
        """
        batch_size, seqlen_q = q.shape[0], q.shape[1]
        seqlen_k = kv.shape[1]
        assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
        k, v = kv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
        scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
251
252
253
254
255
256
        if key_padding_mask is not None:
            padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype,
                                      device=scores.device)
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        if self.causal:
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
            causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0,
                                                device=scores.device), 1)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + causal_mask.to(dtype=scores.dtype)
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
        attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0)
        output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
        return output


class LinearResidual(nn.Linear):
Tri Dao's avatar
Tri Dao committed
271
    """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.
272
273
274
275
276
277
278
279
280
281
282
283
    """

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return super().forward(input), input


class MHA(nn.Module):
    """Multi-head self-attention and cross-attention
    """

    def __init__(self, embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.0,
                 softmax_scale=None, causal=False, dwconv=False, rotary_emb_dim=0,
Tri Dao's avatar
Tri Dao committed
284
                 rotary_emb_scale_base=0,
285
286
287
288
289
290
291
292
293
294
295
296
297
298
                 fused_bias_fc=False, use_flash_attn=False, return_residual=False,
                 checkpointing=False, device=None, dtype=None) -> None:
        """
            return_residual: whether to return the input x along with the output. This is for
                performance reason: for post-norm architecture, returning the input allows us
                to fuse the backward of nn.Linear with the residual connection.
        """
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.embed_dim = embed_dim
        self.cross_attn = cross_attn
        self.causal = causal
        self.dwconv = dwconv
        self.rotary_emb_dim = rotary_emb_dim
Tri Dao's avatar
Tri Dao committed
299
        self.use_flash_attn = use_flash_attn
300
301
302
303
304
305
306
307
308
309
        self.return_residual = return_residual
        self.checkpointing = checkpointing

        self.num_heads = num_heads
        assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
        self.head_dim = self.embed_dim // num_heads

        if self.rotary_emb_dim > 0:
            assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet'
            assert RotaryEmbedding is not None, 'rotary_emb is not installed'
Tri Dao's avatar
Tri Dao committed
310
311
            self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base,
                                              device=device)
312

Tri Dao's avatar
Tri Dao committed
313
        if fused_bias_fc and FusedDense is None:
314
            raise ImportError('fused_dense is not installed')
Tri Dao's avatar
Tri Dao committed
315
316
317
        linear_cls = nn.Linear if not fused_bias_fc else FusedDense
        linear_resid_cls = (LinearResidual if not fused_bias_fc
                            else partial(FusedDense, return_residual=True))
318
319
320
321
322
323
324
325
326
327
328
        if not self.cross_attn:
            if not self.return_residual:
                self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
            else:
                self.Wqkv = linear_resid_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
            if self.dwconv:
                self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
                                            groups=3 * embed_dim)
            inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
        else:
            self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
329
330
331
332
            if not self.return_residual:
                self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
            else:
                self.Wkv = linear_resid_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
333
334
335
336
337
338
339
            if self.dwconv:
                self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2,
                                        groups=embed_dim)
                self.dwconv_kv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, kernel_size=3, padding=2,
                                        groups=2 * embed_dim)
            inner_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
        self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
Tri Dao's avatar
Tri Dao committed
340
                                         attention_dropout=dropout)
341
342
343
        # output projection always have the bias (for now)
        self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)

344
345
    def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
                **kwargs):
346
347
        """
        Arguments:
Tri Dao's avatar
Tri Dao committed
348
349
350
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
                cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
                is the is the sum of the sequence lengths in the batch.
351
            x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
Tri Dao's avatar
Tri Dao committed
352
353
354
355
356
357
            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
                of the sequences in the batch, used to index into x. Only applicable when using
                FlashAttention.
            max_seqlen: int. Maximum sequence length in the batch.
            key_padding_mask: boolean mask, True means to keep, False means to mask out.
                (batch, seqlen). Only applicable when not using FlashAttention.
358
        """
Tri Dao's avatar
Tri Dao committed
359
360
361
362
363
364
365
366
367
368
369
        if cu_seqlens is not None:
            assert max_seqlen is not None
            assert key_padding_mask is None
            assert self.use_flash_attn
            assert not self.dwconv
            assert self.rotary_emb_dim == 0
        if key_padding_mask is not None:
            assert cu_seqlens is None
            assert max_seqlen is None
            assert not self.use_flash_attn

370
371
        kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs}
                  if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs})
372
373
374
375
376
377
378
379
        if not self.cross_attn:
            if not self.return_residual:
                qkv = self.Wqkv(x)
            else:
                qkv, x = self.Wqkv(x)
            if self.dwconv:
                qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2],
                                'b d s -> b s d').contiguous()
Tri Dao's avatar
Tri Dao committed
380
            qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim)
381
382
383
            if self.rotary_emb_dim > 0:
                qkv = self.rotary_emb(qkv)
            if not self.checkpointing:
384
                context = self.inner_attn(qkv, **kwargs)
385
            else:
386
                context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
387
        else:
388
389
390
391
392
393
394
395
396
            if not self.return_residual:
                q = self.Wq(x)
                kv = self.Wkv(x_kv if x_kv is not None else x)
            else:
                if x_kv is not None:
                    kv, x_kv = self.Wkv(x_kv)
                else:
                    kv, x = self.Wkv(x)
                q = self.Wq(x)
Tri Dao's avatar
Tri Dao committed
397
398
            q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
            kv = rearrange(kv, '... (two h d) -> ... two h d', two=2, d=self.head_dim)
399
400
401
402
403
404
            if self.dwconv:
                q = rearrange(self.dwconv_q(rearrange(q, 'b s d -> b d s'))[..., :-2],
                              'b d s -> b s d').contiguous()
                kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2],
                               'b d s -> b s d').contiguous()
            if not self.checkpointing:
405
                context = self.inner_attn(q, kv, **kwargs)
406
            else:
407
                context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv, **kwargs)
Tri Dao's avatar
Tri Dao committed
408
        out = self.out_proj(rearrange(context, '... h d -> ... (h d)'))
409
        return out if not self.return_residual else (out, x)
Tri Dao's avatar
Tri Dao committed
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472


class ParallelMHA(nn.Module):
    """Multi-head self-attention and cross-attention
    """

    def __init__(self, embed_dim, num_heads, process_group, bias=True, dropout=0.0,
                 softmax_scale=None, causal=False, rotary_emb_dim=0, rotary_emb_scale_base=0,
                 use_flash_attn=False, checkpointing=False, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.process_group = process_group
        self.embed_dim = embed_dim
        self.causal = causal
        self.rotary_emb_dim = rotary_emb_dim
        self.use_flash_attn = use_flash_attn
        self.checkpointing = checkpointing

        self.num_heads = num_heads
        assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
        self.head_dim = self.embed_dim // num_heads

        if self.rotary_emb_dim > 0:
            assert RotaryEmbedding is not None, 'rotary_emb is not installed'
            self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base,
                                              device=device)

        if ColumnParallelLinear is None or RowParallelLinear is None:
            raise ImportError('fused_dense is not installed')
        self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group, bias=bias,
                                         **factory_kwargs)
        inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
        self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
                                         attention_dropout=dropout)
        # output projection always have the bias (for now)
        self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group, **factory_kwargs)

    def forward(self, x, seqlen=None, **kwargs):
        """
        Arguments:
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
                If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
                split x during sequence parallel, we split the batch * seqlen dimension
                (in case batch is small).
        """
        qkv = self.Wqkv(x)
        if seqlen is None:
            qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, d=self.head_dim)
        else:
            qkv = rearrange(qkv, '(b s) (three h d) -> b s three h d', s=seqlen, three=3,
                            d=self.head_dim)
        if self.rotary_emb_dim > 0:
            qkv = self.rotary_emb(qkv)
        if not self.checkpointing:
            context = self.inner_attn(qkv, **kwargs)
        else:
            context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
        if seqlen is None:
            context = rearrange(context, 'b s h d -> b s (h d)')
        else:
            context = rearrange(context, 'b s h d -> (b s) (h d)')
        out = self.out_proj(context)
        return out