attention.py 21 KB
Newer Older
liangjing's avatar
v1  
liangjing committed
1
2
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC, abstractmethod
liangjing's avatar
liangjing committed
3
4
from dataclasses import dataclass
from typing import Union
liangjing's avatar
v1  
liangjing committed
5
6
7
8

import torch

from megatron.core import parallel_state, tensor_parallel
liangjing's avatar
liangjing committed
9
10
11
12
13
14
15
16
from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.core.parallel_state import (
    get_data_parallel_group,
    get_data_parallel_rank,
    get_data_parallel_world_size,
    get_tensor_model_parallel_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
liangjing's avatar
v1  
liangjing committed
17
18
)
from megatron.core.transformer.module import MegatronModule
liangjing's avatar
liangjing committed
19
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
liangjing's avatar
v1  
liangjing committed
20
21
22
23
24
from megatron.core.utils import divide

from .enums import AttnMaskType
from .transformer_config import TransformerConfig

liangjing's avatar
liangjing committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
try:
    import transformer_engine  # pylint: disable=unused-import

    HAVE_TE = True
    from megatron.core.transformer.custom_layers.transformer_engine import SplitAlongDim
except ImportError:
    HAVE_TE = False
    SplitAlongDim = None


@dataclass
class SelfAttentionSubmodules:
    linear_qkv: Union[ModuleSpec, type] = None
    core_attention: Union[ModuleSpec, type] = None
    linear_proj: Union[ModuleSpec, type] = None
    q_layernorm: Union[ModuleSpec, type] = None
    k_layernorm: Union[ModuleSpec, type] = None


@dataclass
class CrossAttentionSubmodules:
    linear_q: Union[ModuleSpec, type] = None
    linear_kv: Union[ModuleSpec, type] = None
    core_attention: Union[ModuleSpec, type] = None
    linear_proj: Union[ModuleSpec, type] = None

liangjing's avatar
v1  
liangjing committed
51
52
53
54
55
56
57
58
59

class Attention(MegatronModule, ABC):
    """Attention layer abstract class.

    This layer only contains common modules required for the "self attn" and
    "cross attn" specializations.
    """

    def __init__(
liangjing's avatar
liangjing committed
60
61
62
63
64
65
        self,
        config: TransformerConfig,
        submodules: Union[SelfAttentionSubmodules, CrossAttentionSubmodules],
        layer_number: int,
        attn_mask_type: AttnMaskType,
        attention_type: str,
liangjing's avatar
v1  
liangjing committed
66
67
68
69
70
71
    ):
        super().__init__(config=config)

        self.config = config
        self.layer_number = layer_number
        self.attn_mask_type = attn_mask_type
liangjing's avatar
liangjing committed
72
        self.attention_type = attention_type
liangjing's avatar
v1  
liangjing committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86

        # For normal attention without groups, num_query_groups == num_attention_heads,
        # so these two will be the same
        self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads
        self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups

        # Per attention head and per partition values.
        world_size = parallel_state.get_tensor_model_parallel_world_size()
        self.hidden_size_per_attention_head = divide(
            self.query_projection_size, self.config.num_attention_heads
        )
        self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
        self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)

liangjing's avatar
liangjing committed
87
88
89
90
91
92
        self.core_attention = build_module(
            submodules.core_attention,
            config=self.config,
            layer_number=self.layer_number,
            attn_mask_type=self.attn_mask_type,
            attention_type=self.attention_type,
liangjing's avatar
v1  
liangjing committed
93
94
        )

liangjing's avatar
liangjing committed
95
        self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
liangjing's avatar
v1  
liangjing committed
96
97

        # Output.
liangjing's avatar
liangjing committed
98
99
        self.linear_proj = build_module(
            submodules.linear_proj,
liangjing's avatar
v1  
liangjing committed
100
101
102
103
104
            self.query_projection_size,
            self.config.hidden_size,
            config=self.config,
            init_method=self.config.output_layer_init_method,
            bias=self.config.add_bias_linear,
liangjing's avatar
liangjing committed
105
            input_is_parallel=True,
liangjing's avatar
v1  
liangjing committed
106
            skip_bias_add=True,
liangjing's avatar
liangjing committed
107
108
            is_expert=False,
            tp_comm_buffer_name='proj',
liangjing's avatar
v1  
liangjing committed
109
110
111
        )

    def _checkpointed_attention_forward(
liangjing's avatar
liangjing committed
112
113
114
115
116
117
118
119
        self,
        query,
        key,
        value,
        attention_mask,
        rotary_pos_emb=None,
        attn_mask_type=None,
        packed_seq_params=None,
liangjing's avatar
v1  
liangjing committed
120
121
122
123
124
125
126
127
    ):
        """Forward method with selective activation checkpointing."""

        def custom_forward(*inputs):
            query = inputs[0]
            key = inputs[1]
            value = inputs[2]
            attention_mask = inputs[3]
liangjing's avatar
liangjing committed
128
129
130
131
132
133
134
135
136
137
            attn_mask_type = inputs[5]
            attn_mask_type = AttnMaskType(attn_mask_type.item())
            output_ = self.core_attention(
                query,
                key,
                value,
                attention_mask,
                attn_mask_type=attn_mask_type,
                packed_seq_params=packed_seq_params,
            )
liangjing's avatar
v1  
liangjing committed
138
139
            return output_

liangjing's avatar
liangjing committed
140
141
142
        if attn_mask_type is None:
            attn_mask_type = self.attn_mask_type
        attn_mask_type = torch.tensor([attn_mask_type.value], dtype=torch.int)
liangjing's avatar
v1  
liangjing committed
143
        hidden_states = tensor_parallel.checkpoint(
liangjing's avatar
liangjing committed
144
            custom_forward, False, query, key, value, attention_mask, rotary_pos_emb, attn_mask_type
liangjing's avatar
v1  
liangjing committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        )

        return hidden_states

    def _allocate_memory(self, inference_max_sequence_length, batch_size, dtype):
        """Allocate memory to store kv cache during inference."""

        return torch.empty(
            inference_max_sequence_length,
            batch_size,
            self.num_query_groups_per_partition,
            self.hidden_size_per_attention_head,
            dtype=dtype,
            device=torch.cuda.current_device(),
        )

    def _adjust_key_value_for_inference(self, inference_params, key, value, rotary_pos_emb):
        """
        Saves the generated key and value tensors to the end of the buffers in inference_params.
        Returns the full size keys and values from the provided inference_params, as well as
        adjusted rotary_pos_emb.

        Returns a tuple: (key, value, rotary_pos_emb)

        """
liangjing's avatar
liangjing committed
170
        attn_mask_type = self.attn_mask_type
liangjing's avatar
v1  
liangjing committed
171
        if inference_params is None:
liangjing's avatar
liangjing committed
172
            return key, value, rotary_pos_emb, attn_mask_type
liangjing's avatar
v1  
liangjing committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
        if self.layer_number not in inference_params.key_value_memory_dict:
            inf_max_seq_length = inference_params.max_sequence_length
            inf_max_batch_size = inference_params.max_batch_size
            inference_key_memory = self._allocate_memory(
                inf_max_seq_length, inf_max_batch_size, key.dtype
            )
            inference_value_memory = self._allocate_memory(
                inf_max_seq_length, inf_max_batch_size, value.dtype
            )
            inference_params.key_value_memory_dict[self.layer_number] = (
                inference_key_memory,
                inference_value_memory,
            )
        else:
            # Get the pre-allocated buffers for this layer
            inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[
                self.layer_number
            ]

liangjing's avatar
liangjing committed
196
197
198
199
200
        if inference_params.sequence_len_offset > 0:
            # This should mean that we are past the prompt forward_step
            # and so we need to turn off masking
            attn_mask_type = AttnMaskType.no_mask

liangjing's avatar
v1  
liangjing committed
201
202
203
204
205
206
207
208
209
210
211
212
213
        batch_start = inference_params.batch_size_offset
        batch_end = batch_start + key.size(1)
        assert batch_end <= inference_key_memory.size(1)
        sequence_start = inference_params.sequence_len_offset
        sequence_end = sequence_start + key.size(0)
        assert sequence_end <= inference_key_memory.size(0)
        # Copy key and values.
        inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key
        inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value
        key = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
        value = inference_value_memory[:sequence_end, batch_start:batch_end, ...]

        # adjust the key rotary positional embedding
liangjing's avatar
liangjing committed
214
215
216
217
218
219
220
        if rotary_pos_emb is None:
            return key, value, rotary_pos_emb, attn_mask_type

        q_pos_emb, k_pos_emb = rotary_pos_emb
        q_pos_emb = q_pos_emb[sequence_start:sequence_end, :, :, :]
        k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
        rotary_pos_emb = (q_pos_emb, k_pos_emb)
liangjing's avatar
v1  
liangjing committed
221

liangjing's avatar
liangjing committed
222
        return key, value, rotary_pos_emb, attn_mask_type
liangjing's avatar
v1  
liangjing committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237

    @abstractmethod
    def get_query_key_value_tensors(self, hidden_states, key_value_states):
        """
        This method needs to be implemented based on whether the derived class
        is "self-attn" or "cross-attn".
        """

    def forward(
        self,
        hidden_states,
        attention_mask,
        key_value_states=None,
        inference_params=None,
        rotary_pos_emb=None,
liangjing's avatar
liangjing committed
238
        packed_seq_params=None,
liangjing's avatar
v1  
liangjing committed
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    ):
        # hidden_states: [sq, b, h]

        # For self attention we just duplicate the rotary_pos_emb if it isn't already
        if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
            rotary_pos_emb = (rotary_pos_emb,) * 2

        # =====================
        # Query, Key, and Value
        # =====================
        # Get the query, key and value tensors based on the type of attention -
        # self or cross attn.
        query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)

        # ===================================================
        # Adjust key, value, and rotary_pos_emb for inference
        # ===================================================
liangjing's avatar
liangjing committed
256
        key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
liangjing's avatar
v1  
liangjing committed
257
258
259
            inference_params, key, value, rotary_pos_emb
        )

liangjing's avatar
liangjing committed
260
261
262
263
264
        if packed_seq_params is not None:
            query = query.squeeze(1)
            key = key.squeeze(1)
            value = value.squeeze(1)

liangjing's avatar
v1  
liangjing committed
265
266
267
268
269
        # ================================================
        # relative positional embedding (rotary embedding)
        # ================================================
        if rotary_pos_emb is not None:
            q_pos_emb, k_pos_emb = rotary_pos_emb
liangjing's avatar
liangjing committed
270
271
272
273
274
275
276
277
278
279
280

            if packed_seq_params is not None:
                cu_seqlens_q = packed_seq_params.cu_seqlens_q
                cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
            else:
                cu_seqlens_q = cu_seqlens_kv = None
            query = apply_rotary_pos_emb(
                query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q
            )
            key = apply_rotary_pos_emb(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv)

liangjing's avatar
v1  
liangjing committed
281
282
283
284
285
286
287
288
289
            # TODO, can apply positional embedding to value_layer so it has
            # absolute positional embedding.
            # otherwise, only relative positional embedding takes effect
            # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)

        # ==================================
        # core attention computation
        # ==================================

liangjing's avatar
liangjing committed
290
291
292
293
294
295
296
297
        if self.checkpoint_core_attention and self.training:
            core_attn_out = self._checkpointed_attention_forward(
                query,
                key,
                value,
                attention_mask,
                attn_mask_type=attn_mask_type,
                packed_seq_params=packed_seq_params,
liangjing's avatar
v1  
liangjing committed
298
            )
liangjing's avatar
liangjing committed
299
300
301
302
303
304
305
306
        else:
            core_attn_out = self.core_attention(
                query,
                key,
                value,
                attention_mask,
                attn_mask_type=attn_mask_type,
                packed_seq_params=packed_seq_params,
liangjing's avatar
v1  
liangjing committed
307
308
            )

liangjing's avatar
liangjing committed
309
310
311
312
313
314
        if packed_seq_params is not None:
            # reshape to same output shape as unpacked case
            # (t, np, hn) -> (t, b=1, h=np*hn)
            # t is the pack size = sum (sq_i)
            # note that batch is a dummy dimension in the packed case
            core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)
liangjing's avatar
v1  
liangjing committed
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332

        # =================
        # Output. [sq, b, h]
        # =================

        output, bias = self.linear_proj(core_attn_out)

        return output, bias


class SelfAttention(Attention):
    """Self-attention layer class

    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
    """

    def __init__(
liangjing's avatar
liangjing committed
333
334
335
336
337
        self,
        config: TransformerConfig,
        submodules: SelfAttentionSubmodules,
        layer_number: int,
        attn_mask_type=AttnMaskType.padding,
liangjing's avatar
v1  
liangjing committed
338
    ):
liangjing's avatar
liangjing committed
339
340
341
342
343
344
345
        super().__init__(
            config=config,
            submodules=submodules,
            layer_number=layer_number,
            attn_mask_type=attn_mask_type,
            attention_type="self",
        )
liangjing's avatar
v1  
liangjing committed
346

liangjing's avatar
liangjing committed
347
348
        self.linear_qkv = build_module(
            submodules.linear_qkv,
liangjing's avatar
v1  
liangjing committed
349
350
351
352
            self.config.hidden_size,
            self.query_projection_size + 2 * self.kv_projection_size,
            config=self.config,
            init_method=self.config.init_method,
liangjing's avatar
liangjing committed
353
354
            gather_output=False,
            bias=self.config.add_bias_linear or self.config.add_qkv_bias,
liangjing's avatar
v1  
liangjing committed
355
            skip_bias_add=False,
liangjing's avatar
liangjing committed
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
            is_expert=False,
            tp_comm_buffer_name='qkv',
        )

        if submodules.q_layernorm is not None:
            self.q_layernorm = build_module(
                submodules.q_layernorm,
                hidden_size=self.hidden_size_per_attention_head,
                config=self.config,
                eps=self.config.layernorm_epsilon,
            )
        else:
            self.q_layernorm = None

        if submodules.k_layernorm is not None:
            self.k_layernorm = build_module(
                submodules.k_layernorm,
                hidden_size=self.hidden_size_per_attention_head,
                config=self.config,
                eps=self.config.layernorm_epsilon,
            )
        else:
            self.k_layernorm = None

    def run_realtime_tests(self):
        """Performs a consistency check.

        This function makes sure that tensors across devices are the same during an experiment.
        This is often not guaranteed to be so because of silent hardware failures (eg, memory
        corruption loading a checkpoint, network traffic corruption encountered during
        data transmission).

        (TODO) In the future, more tensors should be checked across the training run and
        checked every X iterations. This is left for future work. Equality of tensors is probably
        not required; transmitting hashes is sufficient."""

        if not self.config.qk_layernorm:
            return

        # check that all tensor parallel and data parallel ranks have the same
        # Q & K layernorm parameters.
        rank = get_data_parallel_rank()
        inputs = torch.stack(
            [
                self.q_layernorm.weight.data,
                self.q_layernorm.bias.data,
                self.k_layernorm.weight.data,
                self.k_layernorm.bias.data,
            ]
liangjing's avatar
v1  
liangjing committed
405
        )
liangjing's avatar
liangjing committed
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
        dp_list = [torch.empty_like(inputs) for _ in range(get_data_parallel_world_size())]
        dp_list[rank] = inputs
        torch.distributed.all_gather(dp_list, inputs, group=get_data_parallel_group())

        def _compare(srcs, tgts, names, parallelism):
            assert len(srcs) == len(tgts) == len(names)
            for src, tgt, name in zip(srcs, tgts, names):
                assert torch.all(src == tgt), (
                    f"Discrepancy between {name} in {parallelism} ranks {i} and {rank}. "
                    f"Diff: {torch.norm(src - tgt)}"
                )

        for i, dp in enumerate(dp_list):
            q_w, q_b, k_w, k_b = torch.unbind(dp)
            _compare(
                [q_w, q_b, k_w, k_b],
                [
                    self.q_layernorm.weight.data,
                    self.q_layernorm.bias.data,
                    self.k_layernorm.weight.data,
                    self.k_layernorm.bias.data,
                ],
                ["q_w", "q_b", "k_w", "k_b"],
                "DP",
            )

        rank = get_tensor_model_parallel_rank()
        tp_list = [torch.empty_like(inputs) for _ in range(get_tensor_model_parallel_world_size())]
        tp_list[rank] = inputs
        torch.distributed.all_gather(tp_list, inputs, group=get_tensor_model_parallel_group())

        for i, tp in enumerate(tp_list):
            q_w, q_b, k_w, k_b = torch.unbind(tp)
            _compare(
                [q_w, q_b, k_w, k_b],
                [
                    self.q_layernorm.weight.data,
                    self.q_layernorm.bias.data,
                    self.k_layernorm.weight.data,
                    self.k_layernorm.bias.data,
                ],
                ["q_w", "q_b", "k_w", "k_b"],
                "TP",
            )
liangjing's avatar
v1  
liangjing committed
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467

    def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
        """
        Derives `query`, `key` and `value` tensors from `hidden_states`.
        """
        # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
        mixed_qkv, _ = self.linear_qkv(hidden_states)

        # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
        new_tensor_shape = mixed_qkv.size()[:-1] + (
            self.num_query_groups_per_partition,
            (
                (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
                * self.hidden_size_per_attention_head
            ),
        )
        mixed_qkv = mixed_qkv.view(*new_tensor_shape)

liangjing's avatar
liangjing committed
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
        split_arg_list = [
            (
                self.num_attention_heads_per_partition
                // self.num_query_groups_per_partition
                * self.hidden_size_per_attention_head
            ),
            self.hidden_size_per_attention_head,
            self.hidden_size_per_attention_head,
        ]

        if SplitAlongDim is not None:

            # [sq, b, ng, (np/ng + 2) * hn]
            # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
            (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list)
        else:

            # [sq, b, ng, (np/ng + 2) * hn]
            # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
            (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3)

liangjing's avatar
v1  
liangjing committed
489
490
491
        # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
        query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)

liangjing's avatar
liangjing committed
492
493
494
495
496
497
498
499
500
        if self.q_layernorm is not None:
            query = self.q_layernorm(query)

        if self.k_layernorm is not None:
            key = self.k_layernorm(key)

        if self.config.test_mode:
            self.run_realtime_tests()

liangjing's avatar
v1  
liangjing committed
501
502
503
504
505
506
507
508
509
510
511
        return query, key, value


class CrossAttention(Attention):
    """Cross-attention layer class

    Cross-attention layer takes input with size [s, b, h] and context with size
    [s, b, h] and returns output of the same size.
    """

    def __init__(
liangjing's avatar
liangjing committed
512
513
514
515
516
        self,
        config: TransformerConfig,
        submodules: CrossAttentionSubmodules,
        layer_number: int,
        attn_mask_type=AttnMaskType.padding,
liangjing's avatar
v1  
liangjing committed
517
    ):
liangjing's avatar
liangjing committed
518
519
520
521
522
523
524
        super().__init__(
            config=config,
            submodules=submodules,
            layer_number=layer_number,
            attn_mask_type=attn_mask_type,
            attention_type="cross",
        )
liangjing's avatar
v1  
liangjing committed
525
526

        if self.config.num_query_groups != self.config.num_attention_heads:
liangjing's avatar
liangjing committed
527
            raise ValueError("Group query attention is not currently supported in cross attention.")
liangjing's avatar
v1  
liangjing committed
528
529
        assert self.query_projection_size == self.kv_projection_size

liangjing's avatar
liangjing committed
530
531
        self.linear_q = build_module(
            submodules.linear_q,
liangjing's avatar
v1  
liangjing committed
532
533
534
535
            self.config.hidden_size,
            self.query_projection_size,
            config=self.config,
            init_method=self.config.init_method,
liangjing's avatar
liangjing committed
536
            gather_output=False,
liangjing's avatar
v1  
liangjing committed
537
538
            bias=self.config.add_bias_linear,
            skip_bias_add=False,
liangjing's avatar
liangjing committed
539
            is_expert=False,
liangjing's avatar
v1  
liangjing committed
540
541
        )

liangjing's avatar
liangjing committed
542
543
        self.linear_kv = build_module(
            submodules.linear_kv,
liangjing's avatar
v1  
liangjing committed
544
545
546
547
            self.config.hidden_size,
            2 * self.kv_projection_size,
            config=self.config,
            init_method=self.config.init_method,
liangjing's avatar
liangjing committed
548
            gather_output=False,
liangjing's avatar
v1  
liangjing committed
549
550
            bias=self.config.add_bias_linear,
            skip_bias_add=False,
liangjing's avatar
liangjing committed
551
            is_expert=False,
liangjing's avatar
v1  
liangjing committed
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        )

    def get_query_key_value_tensors(self, hidden_states, key_value_states):
        """
        Derives `query` tensor from `hidden_states`, and `key`/`value` tensors
        from `key_value_states`.
        """
        # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
        mixed_kv, _ = self.linear_kv(key_value_states)

        # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
        new_tensor_shape = mixed_kv.size()[:-1] + (
            self.num_attention_heads_per_partition,
            2 * self.hidden_size_per_attention_head,
        )
        mixed_kv = mixed_kv.view(*new_tensor_shape)

        # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
        (key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2)

        # Attention head [sq, b, h] --> [sq, b, hp]
        query, _ = self.linear_q(hidden_states)

        # [sq, b, hp] --> [sq, b, np, hn]
        new_tensor_shape = query.size()[:-1] + (
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
        query = query.view(*new_tensor_shape)

        return query, key, value