"research/neural_gpu/neural_gpu_trainer.py" did not exist on "574c981c140b43b4e66e7c43d6e1247b3acc842a"
attention.py 26.9 KB
Newer Older
liangjing's avatar
v1  
liangjing committed
1
2
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC, abstractmethod
xingjinliang's avatar
xingjinliang committed
3
4
from dataclasses import dataclass
from typing import Tuple, Union
liangjing's avatar
v1  
liangjing committed
5
6

import torch
xingjinliang's avatar
xingjinliang committed
7
from torch import Tensor
liangjing's avatar
v1  
liangjing committed
8

xingjinliang's avatar
xingjinliang committed
9
10
11
12
13
14
15
16
17
18
19
20
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.models.common.embeddings.rope_utils import (
    apply_rotary_pos_emb,
    apply_rotary_pos_emb_with_cos_sin,
)
from megatron.core.parallel_state import (
    get_data_parallel_group,
    get_data_parallel_rank,
    get_data_parallel_world_size,
    get_tensor_model_parallel_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
liangjing's avatar
v1  
liangjing committed
21
22
)
from megatron.core.transformer.module import MegatronModule
xingjinliang's avatar
xingjinliang committed
23
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
liangjing's avatar
v1  
liangjing committed
24
25
26
27
28
from megatron.core.utils import divide

from .enums import AttnMaskType
from .transformer_config import TransformerConfig

xingjinliang's avatar
xingjinliang committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
try:
    from flash_attn import flash_attn_with_kvcache
except:
    flash_attn_with_kvcache = None


try:
    import transformer_engine  # pylint: disable=unused-import

    HAVE_TE = True
    from megatron.core.extensions.transformer_engine import SplitAlongDim
except ImportError:
    HAVE_TE = False
    SplitAlongDim = None


@dataclass
class SelfAttentionSubmodules:
    """
    Configuration class for specifying the submodules of a self-attention.
    """

    linear_qkv: Union[ModuleSpec, type] = None
    core_attention: Union[ModuleSpec, type] = None
    linear_proj: Union[ModuleSpec, type] = None
    q_layernorm: Union[ModuleSpec, type] = None
    k_layernorm: Union[ModuleSpec, type] = None


@dataclass
class CrossAttentionSubmodules:
    """
    Configuration class for specifying the submodules of a cross-attention.
    """

    linear_q: Union[ModuleSpec, type] = None
    linear_kv: Union[ModuleSpec, type] = None
    core_attention: Union[ModuleSpec, type] = None
    linear_proj: Union[ModuleSpec, type] = None

liangjing's avatar
v1  
liangjing committed
69
70
71
72
73
74
75
76
77

class Attention(MegatronModule, ABC):
    """Attention layer abstract class.

    This layer only contains common modules required for the "self attn" and
    "cross attn" specializations.
    """

    def __init__(
xingjinliang's avatar
xingjinliang committed
78
79
80
81
82
83
84
        self,
        config: TransformerConfig,
        submodules: Union[SelfAttentionSubmodules, CrossAttentionSubmodules],
        layer_number: int,
        attn_mask_type: AttnMaskType,
        attention_type: str,
        cp_comm_type: str = None,
liangjing's avatar
v1  
liangjing committed
85
86
87
88
89
90
    ):
        super().__init__(config=config)

        self.config = config
        self.layer_number = layer_number
        self.attn_mask_type = attn_mask_type
xingjinliang's avatar
xingjinliang committed
91
        self.attention_type = attention_type
liangjing's avatar
v1  
liangjing committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105

        # For normal attention without groups, num_query_groups == num_attention_heads,
        # so these two will be the same
        self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads
        self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups

        # Per attention head and per partition values.
        world_size = parallel_state.get_tensor_model_parallel_world_size()
        self.hidden_size_per_attention_head = divide(
            self.query_projection_size, self.config.num_attention_heads
        )
        self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
        self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)

xingjinliang's avatar
xingjinliang committed
106
107
108
109
110
111
112
        self.core_attention = build_module(
            submodules.core_attention,
            config=self.config,
            layer_number=self.layer_number,
            attn_mask_type=self.attn_mask_type,
            attention_type=self.attention_type,
            cp_comm_type=cp_comm_type,
liangjing's avatar
v1  
liangjing committed
113
114
        )

xingjinliang's avatar
xingjinliang committed
115
        self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
liangjing's avatar
v1  
liangjing committed
116
117

        # Output.
xingjinliang's avatar
xingjinliang committed
118
119
        self.linear_proj = build_module(
            submodules.linear_proj,
liangjing's avatar
v1  
liangjing committed
120
121
122
123
124
            self.query_projection_size,
            self.config.hidden_size,
            config=self.config,
            init_method=self.config.output_layer_init_method,
            bias=self.config.add_bias_linear,
xingjinliang's avatar
xingjinliang committed
125
            input_is_parallel=True,
liangjing's avatar
v1  
liangjing committed
126
            skip_bias_add=True,
xingjinliang's avatar
xingjinliang committed
127
128
            is_expert=False,
            tp_comm_buffer_name='proj',
liangjing's avatar
v1  
liangjing committed
129
130
131
        )

    def _checkpointed_attention_forward(
xingjinliang's avatar
xingjinliang committed
132
133
134
135
136
137
138
139
140
        self,
        query,
        key,
        value,
        attention_mask,
        rotary_pos_emb=None,
        attn_mask_type=None,
        attention_bias=None,
        packed_seq_params=None,
liangjing's avatar
v1  
liangjing committed
141
142
143
144
145
146
147
148
    ):
        """Forward method with selective activation checkpointing."""

        def custom_forward(*inputs):
            query = inputs[0]
            key = inputs[1]
            value = inputs[2]
            attention_mask = inputs[3]
xingjinliang's avatar
xingjinliang committed
149
150
151
152
153
154
155
156
157
158
159
            attn_mask_type = inputs[5]
            attn_mask_type = AttnMaskType(attn_mask_type.item())
            output_ = self.core_attention(
                query,
                key,
                value,
                attention_mask,
                attn_mask_type=attn_mask_type,
                attention_bias=attention_bias,
                packed_seq_params=packed_seq_params,
            )
liangjing's avatar
v1  
liangjing committed
160
161
            return output_

xingjinliang's avatar
xingjinliang committed
162
163
164
        if attn_mask_type is None:
            attn_mask_type = self.attn_mask_type
        attn_mask_type = torch.tensor([attn_mask_type.value], dtype=torch.int)
liangjing's avatar
v1  
liangjing committed
165
        hidden_states = tensor_parallel.checkpoint(
xingjinliang's avatar
xingjinliang committed
166
            custom_forward, False, query, key, value, attention_mask, rotary_pos_emb, attn_mask_type
liangjing's avatar
v1  
liangjing committed
167
168
169
170
        )

        return hidden_states

xingjinliang's avatar
xingjinliang committed
171
    def _allocate_memory(self, inference_max_sequence_length, batch_size, dim, dtype):
liangjing's avatar
v1  
liangjing committed
172
173
174
175
176
177
        """Allocate memory to store kv cache during inference."""

        return torch.empty(
            inference_max_sequence_length,
            batch_size,
            self.num_query_groups_per_partition,
xingjinliang's avatar
xingjinliang committed
178
            dim,
liangjing's avatar
v1  
liangjing committed
179
180
181
182
            dtype=dtype,
            device=torch.cuda.current_device(),
        )

xingjinliang's avatar
xingjinliang committed
183
184
185
186
187
188
189
190
191
192
    def _adjust_key_value_for_inference(
        self,
        inference_params: InferenceParams,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        rotary_pos_emb: Tensor,
        rotary_pos_cos: Tensor = None,
        rotary_pos_sin: Tensor = None,
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
liangjing's avatar
v1  
liangjing committed
193
194
195
196
197
198
199
200
        """
        Saves the generated key and value tensors to the end of the buffers in inference_params.
        Returns the full size keys and values from the provided inference_params, as well as
        adjusted rotary_pos_emb.

        Returns a tuple: (key, value, rotary_pos_emb)

        """
xingjinliang's avatar
xingjinliang committed
201
        attn_mask_type = self.attn_mask_type
liangjing's avatar
v1  
liangjing committed
202
        if inference_params is None:
xingjinliang's avatar
xingjinliang committed
203
            return query, key, value, rotary_pos_emb, attn_mask_type
liangjing's avatar
v1  
liangjing committed
204
205
206
207
208
209
210
211

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
        if self.layer_number not in inference_params.key_value_memory_dict:
            inf_max_seq_length = inference_params.max_sequence_length
            inf_max_batch_size = inference_params.max_batch_size
            inference_key_memory = self._allocate_memory(
xingjinliang's avatar
xingjinliang committed
212
                inf_max_seq_length, inf_max_batch_size, key.shape[-1], key.dtype
liangjing's avatar
v1  
liangjing committed
213
214
            )
            inference_value_memory = self._allocate_memory(
xingjinliang's avatar
xingjinliang committed
215
                inf_max_seq_length, inf_max_batch_size, value.shape[-1], value.dtype
liangjing's avatar
v1  
liangjing committed
216
217
218
219
220
221
222
223
224
225
226
            )
            inference_params.key_value_memory_dict[self.layer_number] = (
                inference_key_memory,
                inference_value_memory,
            )
        else:
            # Get the pre-allocated buffers for this layer
            inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[
                self.layer_number
            ]

xingjinliang's avatar
xingjinliang committed
227
228
229
230
231
        if inference_params.sequence_len_offset > 0:
            # This should mean that we are past the prompt forward_step
            # and so we need to turn off masking
            attn_mask_type = AttnMaskType.no_mask

liangjing's avatar
v1  
liangjing committed
232
233
234
235
236
237
        batch_start = inference_params.batch_size_offset
        batch_end = batch_start + key.size(1)
        assert batch_end <= inference_key_memory.size(1)
        sequence_start = inference_params.sequence_len_offset
        sequence_end = sequence_start + key.size(0)
        assert sequence_end <= inference_key_memory.size(0)
xingjinliang's avatar
xingjinliang committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261

        if self.config.flash_decode:
            assert (
                rotary_pos_cos is not None and rotary_pos_sin is not None
            ), "Flash decoding requires precomputed cos and sin tensors"
            if inference_params.sequence_len_offset > 0:  # Decode phase, not prefill
                rotary_pos_cos_q = rotary_pos_cos[sequence_end - 1 : sequence_end]
                rotary_pos_sin_q = rotary_pos_sin[sequence_end - 1 : sequence_end]
                rotary_pos_cos_k = rotary_pos_cos[sequence_end - 1 : sequence_end]
                rotary_pos_sin_k = rotary_pos_sin[sequence_end - 1 : sequence_end]
            else:
                rotary_pos_cos_q = rotary_pos_cos[:sequence_end]
                rotary_pos_sin_q = rotary_pos_sin[:sequence_end]
                rotary_pos_cos_k = rotary_pos_cos[:sequence_end]
                rotary_pos_sin_k = rotary_pos_sin[:sequence_end]

            # Flash Decoding assumes that the keys stored in the KV Cache already have RoPE applied.
            # Apply RoPE before we store the keys to make it compatible with flash decoding kernel.
            key = apply_rotary_pos_emb_with_cos_sin(key, rotary_pos_cos_k, rotary_pos_sin_k)
            query = apply_rotary_pos_emb_with_cos_sin(query, rotary_pos_cos_q, rotary_pos_sin_q)
        else:
            rotary_pos_cos_q = None
            rotary_pos_sin_q = None

liangjing's avatar
v1  
liangjing committed
262
263
264
265
266
267
268
        # Copy key and values.
        inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key
        inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value
        key = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
        value = inference_value_memory[:sequence_end, batch_start:batch_end, ...]

        # adjust the key rotary positional embedding
xingjinliang's avatar
xingjinliang committed
269
270
271
272
273
274
275
        if rotary_pos_emb is None:
            return query, key, value, rotary_pos_emb, attn_mask_type

        q_pos_emb, k_pos_emb = rotary_pos_emb
        q_pos_emb = q_pos_emb[sequence_start:sequence_end, :, :, :]
        k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
        rotary_pos_emb = (q_pos_emb, k_pos_emb)
liangjing's avatar
v1  
liangjing committed
276

xingjinliang's avatar
xingjinliang committed
277
        return query, key, value, rotary_pos_emb, attn_mask_type
liangjing's avatar
v1  
liangjing committed
278
279
280
281
282
283
284
285

    @abstractmethod
    def get_query_key_value_tensors(self, hidden_states, key_value_states):
        """
        This method needs to be implemented based on whether the derived class
        is "self-attn" or "cross-attn".
        """

xingjinliang's avatar
xingjinliang committed
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
    def flash_decoding(
        self,
        sequence_len_offset: Tensor,
        query_layer: Tensor,
        key_layer: Tensor,
        value_layer: Tensor,
        inference_key_memory: Tensor,
        inference_value_memory: Tensor,
        rotary_cos: Tensor,
        rotary_sin: Tensor,
    ) -> (Tensor, Tensor):
        """
        The flash decoding kernel will do the following in a single execution:
        1. Compute RoPE embedding with precomputed cos & sin tensors
        2. Update the KV Cache
        3. Performs the flash attention operation
        """
        assert flash_attn_with_kvcache is not None, (
            "Flash Decoding requires the flash_attn_with_kvcache kernel, "
            "available in the flash-attn package."
        )
        cache_seqlens = sequence_len_offset - 1
        q = query_layer.permute(1, 0, 2, 3)
        k = key_layer.permute(1, 0, 2, 3)
        v = value_layer.permute(1, 0, 2, 3)
        k_cache = inference_key_memory.permute(1, 0, 2, 3)
        v_cache = inference_value_memory.permute(1, 0, 2, 3)

        if rotary_cos is not None:
            rotary_cos = rotary_cos.to(query_layer.dtype)
        if rotary_sin is not None:
            rotary_sin = rotary_sin.to(query_layer.dtype)

        out = flash_attn_with_kvcache(
            q=q,
            k_cache=k_cache,
            v_cache=v_cache,
            k=k,
            v=v,
            rotary_cos=rotary_cos,
            rotary_sin=rotary_sin,
            cache_seqlens=cache_seqlens,
            rotary_interleaved=False,
        )
        return out

liangjing's avatar
v1  
liangjing committed
332
333
334
335
336
337
338
    def forward(
        self,
        hidden_states,
        attention_mask,
        key_value_states=None,
        inference_params=None,
        rotary_pos_emb=None,
xingjinliang's avatar
xingjinliang committed
339
340
341
342
        rotary_pos_cos=None,
        rotary_pos_sin=None,
        attention_bias=None,
        packed_seq_params=None,
liangjing's avatar
v1  
liangjing committed
343
    ):
xingjinliang's avatar
xingjinliang committed
344
345
346
347
        """
        Perform a forward pass through the attention module.
        """

liangjing's avatar
v1  
liangjing committed
348
        # hidden_states: [sq, b, h]
xingjinliang's avatar
xingjinliang committed
349
350
351
352
        if self.config.flash_decode:
            rotary_pos_emb = None
        else:
            assert rotary_pos_cos is None and rotary_pos_sin is None
liangjing's avatar
v1  
liangjing committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367

        # For self attention we just duplicate the rotary_pos_emb if it isn't already
        if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
            rotary_pos_emb = (rotary_pos_emb,) * 2

        # =====================
        # Query, Key, and Value
        # =====================
        # Get the query, key and value tensors based on the type of attention -
        # self or cross attn.
        query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)

        # ===================================================
        # Adjust key, value, and rotary_pos_emb for inference
        # ===================================================
xingjinliang's avatar
xingjinliang committed
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397

        # This branch only runs in the decode phase of flash decoding and returns after the linear
        # projection. This conditional is not used in the prefill phase or non-flash-decoding cases.
        if (
            self.config.flash_decode
            and inference_params is not None
            and self.layer_number
            in inference_params.key_value_memory_dict  # Decode phase if key already exists
        ):
            assert inference_params.sequence_len_offset is not None
            inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[
                self.layer_number
            ]
            output = self.flash_decoding(
                sequence_len_offset=inference_params.sequence_len_offset,
                query_layer=query,
                key_layer=key,
                value_layer=value,
                inference_key_memory=inference_key_memory,
                inference_value_memory=inference_value_memory,
                rotary_cos=rotary_pos_cos,
                rotary_sin=rotary_pos_sin,
            )
            out = output.transpose(0, 1).contiguous()
            context_layer = out.view(out.size(0), out.size(1), -1)
            output, bias = self.linear_proj(context_layer)
            return output, bias

        query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
            inference_params, query, key, value, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin
liangjing's avatar
v1  
liangjing committed
398
399
        )

xingjinliang's avatar
xingjinliang committed
400
401
402
403
404
        if packed_seq_params is not None:
            query = query.squeeze(1)
            key = key.squeeze(1)
            value = value.squeeze(1)

liangjing's avatar
v1  
liangjing committed
405
406
407
        # ================================================
        # relative positional embedding (rotary embedding)
        # ================================================
xingjinliang's avatar
xingjinliang committed
408
        if rotary_pos_emb is not None and not self.config.flash_decode:
liangjing's avatar
v1  
liangjing committed
409
            q_pos_emb, k_pos_emb = rotary_pos_emb
xingjinliang's avatar
xingjinliang committed
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426

            if packed_seq_params is not None:
                if packed_seq_params.cu_seqlens_q_padded is not None:
                    cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded
                else:
                    cu_seqlens_q = packed_seq_params.cu_seqlens_q
                if packed_seq_params.cu_seqlens_kv_padded is not None:
                    cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded
                else:
                    cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
            else:
                cu_seqlens_q = cu_seqlens_kv = None
            query = apply_rotary_pos_emb(
                query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q
            )
            key = apply_rotary_pos_emb(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv)

liangjing's avatar
v1  
liangjing committed
427
428
429
430
431
432
433
434
435
            # TODO, can apply positional embedding to value_layer so it has
            # absolute positional embedding.
            # otherwise, only relative positional embedding takes effect
            # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)

        # ==================================
        # core attention computation
        # ==================================

xingjinliang's avatar
xingjinliang committed
436
437
438
439
440
441
442
443
444
        if self.checkpoint_core_attention and self.training:
            core_attn_out = self._checkpointed_attention_forward(
                query,
                key,
                value,
                attention_mask,
                attn_mask_type=attn_mask_type,
                attention_bias=attention_bias,
                packed_seq_params=packed_seq_params,
liangjing's avatar
v1  
liangjing committed
445
            )
xingjinliang's avatar
xingjinliang committed
446
447
448
449
450
451
452
453
454
        else:
            core_attn_out = self.core_attention(
                query,
                key,
                value,
                attention_mask,
                attn_mask_type=attn_mask_type,
                attention_bias=attention_bias,
                packed_seq_params=packed_seq_params,
liangjing's avatar
v1  
liangjing committed
455
456
            )

xingjinliang's avatar
xingjinliang committed
457
458
459
460
461
462
        if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd':
            # reshape to same output shape as unpacked case
            # (t, np, hn) -> (t, b=1, h=np*hn)
            # t is the pack size = sum (sq_i)
            # note that batch is a dummy dimension in the packed case
            core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)
liangjing's avatar
v1  
liangjing committed
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480

        # =================
        # Output. [sq, b, h]
        # =================

        output, bias = self.linear_proj(core_attn_out)

        return output, bias


class SelfAttention(Attention):
    """Self-attention layer class

    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
    """

    def __init__(
xingjinliang's avatar
xingjinliang committed
481
482
483
484
485
486
        self,
        config: TransformerConfig,
        submodules: SelfAttentionSubmodules,
        layer_number: int,
        attn_mask_type=AttnMaskType.padding,
        cp_comm_type: str = None,
liangjing's avatar
v1  
liangjing committed
487
    ):
xingjinliang's avatar
xingjinliang committed
488
489
490
491
492
493
494
495
        super().__init__(
            config=config,
            submodules=submodules,
            layer_number=layer_number,
            attn_mask_type=attn_mask_type,
            attention_type="self",
            cp_comm_type=cp_comm_type,
        )
liangjing's avatar
v1  
liangjing committed
496

xingjinliang's avatar
xingjinliang committed
497
498
        self.linear_qkv = build_module(
            submodules.linear_qkv,
liangjing's avatar
v1  
liangjing committed
499
500
501
502
            self.config.hidden_size,
            self.query_projection_size + 2 * self.kv_projection_size,
            config=self.config,
            init_method=self.config.init_method,
xingjinliang's avatar
xingjinliang committed
503
504
            gather_output=False,
            bias=self.config.add_bias_linear or self.config.add_qkv_bias,
liangjing's avatar
v1  
liangjing committed
505
            skip_bias_add=False,
xingjinliang's avatar
xingjinliang committed
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
            is_expert=False,
            tp_comm_buffer_name='qkv',
        )

        if submodules.q_layernorm is not None:
            self.q_layernorm = build_module(
                submodules.q_layernorm,
                hidden_size=self.hidden_size_per_attention_head,
                config=self.config,
                eps=self.config.layernorm_epsilon,
            )
        else:
            self.q_layernorm = None

        if submodules.k_layernorm is not None:
            self.k_layernorm = build_module(
                submodules.k_layernorm,
                hidden_size=self.hidden_size_per_attention_head,
                config=self.config,
                eps=self.config.layernorm_epsilon,
            )
        else:
            self.k_layernorm = None

    def run_realtime_tests(self):
        """Performs a consistency check.

        This function makes sure that tensors across devices are the same during an experiment.
        This is often not guaranteed to be so because of silent hardware failures (eg, memory
        corruption loading a checkpoint, network traffic corruption encountered during
        data transmission).

        (TODO) In the future, more tensors should be checked across the training run and
        checked every X iterations. This is left for future work. Equality of tensors is probably
        not required; transmitting hashes is sufficient."""

        if not self.config.qk_layernorm:
            return

        # check that all tensor parallel and data parallel ranks have the same
        # Q & K layernorm parameters.
        rank = get_data_parallel_rank()
        inputs = torch.stack(
            [
                self.q_layernorm.weight.data,
                self.q_layernorm.bias.data,
                self.k_layernorm.weight.data,
                self.k_layernorm.bias.data,
            ]
liangjing's avatar
v1  
liangjing committed
555
        )
xingjinliang's avatar
xingjinliang committed
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
        dp_list = [torch.empty_like(inputs) for _ in range(get_data_parallel_world_size())]
        dp_list[rank] = inputs
        torch.distributed.all_gather(dp_list, inputs, group=get_data_parallel_group())

        def _compare(srcs, tgts, names, parallelism):
            assert len(srcs) == len(tgts) == len(names)
            for src, tgt, name in zip(srcs, tgts, names):
                assert torch.all(src == tgt), (
                    f"Discrepancy between {name} in {parallelism} ranks {i} and {rank}. "
                    f"Diff: {torch.norm(src - tgt)}"
                )

        for i, dp in enumerate(dp_list):
            q_w, q_b, k_w, k_b = torch.unbind(dp)
            _compare(
                [q_w, q_b, k_w, k_b],
                [
                    self.q_layernorm.weight.data,
                    self.q_layernorm.bias.data,
                    self.k_layernorm.weight.data,
                    self.k_layernorm.bias.data,
                ],
                ["q_w", "q_b", "k_w", "k_b"],
                "DP",
            )

        rank = get_tensor_model_parallel_rank()
        tp_list = [torch.empty_like(inputs) for _ in range(get_tensor_model_parallel_world_size())]
        tp_list[rank] = inputs
        torch.distributed.all_gather(tp_list, inputs, group=get_tensor_model_parallel_group())

        for i, tp in enumerate(tp_list):
            q_w, q_b, k_w, k_b = torch.unbind(tp)
            _compare(
                [q_w, q_b, k_w, k_b],
                [
                    self.q_layernorm.weight.data,
                    self.q_layernorm.bias.data,
                    self.k_layernorm.weight.data,
                    self.k_layernorm.bias.data,
                ],
                ["q_w", "q_b", "k_w", "k_b"],
                "TP",
            )
liangjing's avatar
v1  
liangjing committed
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617

    def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
        """
        Derives `query`, `key` and `value` tensors from `hidden_states`.
        """
        # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
        mixed_qkv, _ = self.linear_qkv(hidden_states)

        # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
        new_tensor_shape = mixed_qkv.size()[:-1] + (
            self.num_query_groups_per_partition,
            (
                (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
                * self.hidden_size_per_attention_head
            ),
        )
        mixed_qkv = mixed_qkv.view(*new_tensor_shape)

xingjinliang's avatar
xingjinliang committed
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
        split_arg_list = [
            (
                self.num_attention_heads_per_partition
                // self.num_query_groups_per_partition
                * self.hidden_size_per_attention_head
            ),
            self.hidden_size_per_attention_head,
            self.hidden_size_per_attention_head,
        ]

        if SplitAlongDim is not None:

            # [sq, b, ng, (np/ng + 2) * hn]
            # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
            (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list)
        else:

            # [sq, b, ng, (np/ng + 2) * hn]
            # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
            (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3)

liangjing's avatar
v1  
liangjing committed
639
640
641
        # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
        query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)

xingjinliang's avatar
xingjinliang committed
642
643
644
645
646
647
648
649
650
        if self.q_layernorm is not None:
            query = self.q_layernorm(query)

        if self.k_layernorm is not None:
            key = self.k_layernorm(key)

        if self.config.test_mode:
            self.run_realtime_tests()

liangjing's avatar
v1  
liangjing committed
651
652
653
654
655
656
657
658
659
660
661
        return query, key, value


class CrossAttention(Attention):
    """Cross-attention layer class

    Cross-attention layer takes input with size [s, b, h] and context with size
    [s, b, h] and returns output of the same size.
    """

    def __init__(
xingjinliang's avatar
xingjinliang committed
662
663
664
665
666
667
        self,
        config: TransformerConfig,
        submodules: CrossAttentionSubmodules,
        layer_number: int,
        attn_mask_type=AttnMaskType.padding,
        cp_comm_type: str = None,
liangjing's avatar
v1  
liangjing committed
668
    ):
xingjinliang's avatar
xingjinliang committed
669
670
671
672
673
674
675
676
        super().__init__(
            config=config,
            submodules=submodules,
            layer_number=layer_number,
            attn_mask_type=attn_mask_type,
            attention_type="cross",
            cp_comm_type=cp_comm_type,
        )
liangjing's avatar
v1  
liangjing committed
677
678

        if self.config.num_query_groups != self.config.num_attention_heads:
xingjinliang's avatar
xingjinliang committed
679
            raise ValueError("Group query attention is not currently supported in cross attention.")
liangjing's avatar
v1  
liangjing committed
680
681
        assert self.query_projection_size == self.kv_projection_size

xingjinliang's avatar
xingjinliang committed
682
683
        self.linear_q = build_module(
            submodules.linear_q,
liangjing's avatar
v1  
liangjing committed
684
685
686
687
            self.config.hidden_size,
            self.query_projection_size,
            config=self.config,
            init_method=self.config.init_method,
xingjinliang's avatar
xingjinliang committed
688
            gather_output=False,
liangjing's avatar
v1  
liangjing committed
689
690
            bias=self.config.add_bias_linear,
            skip_bias_add=False,
xingjinliang's avatar
xingjinliang committed
691
            is_expert=False,
liangjing's avatar
v1  
liangjing committed
692
693
        )

xingjinliang's avatar
xingjinliang committed
694
695
        self.linear_kv = build_module(
            submodules.linear_kv,
liangjing's avatar
v1  
liangjing committed
696
697
698
699
            self.config.hidden_size,
            2 * self.kv_projection_size,
            config=self.config,
            init_method=self.config.init_method,
xingjinliang's avatar
xingjinliang committed
700
            gather_output=False,
liangjing's avatar
v1  
liangjing committed
701
702
            bias=self.config.add_bias_linear,
            skip_bias_add=False,
xingjinliang's avatar
xingjinliang committed
703
            is_expert=False,
liangjing's avatar
v1  
liangjing committed
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
        )

    def get_query_key_value_tensors(self, hidden_states, key_value_states):
        """
        Derives `query` tensor from `hidden_states`, and `key`/`value` tensors
        from `key_value_states`.
        """
        # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
        mixed_kv, _ = self.linear_kv(key_value_states)

        # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
        new_tensor_shape = mixed_kv.size()[:-1] + (
            self.num_attention_heads_per_partition,
            2 * self.hidden_size_per_attention_head,
        )
        mixed_kv = mixed_kv.view(*new_tensor_shape)

        # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
        (key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2)

        # Attention head [sq, b, h] --> [sq, b, hp]
        query, _ = self.linear_q(hidden_states)

        # [sq, b, hp] --> [sq, b, np, hn]
        new_tensor_shape = query.size()[:-1] + (
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
        query = query.view(*new_tensor_shape)

        return query, key, value