Commit 7c19b3a8 authored by wangsen's avatar wangsen
Browse files

Initial commit

parents
Pipeline #1721 failed with stages
in 0 seconds
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from .module import MegatronModule
from .spec_utils import ModuleSpec, build_module
from .transformer_config import TransformerConfig
from .transformer_layer import TransformerLayer, TransformerLayerSubmodules
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from importlib.metadata import version
from typing import Union
import torch
from pkg_resources import packaging
from megatron.core import parallel_state, tensor_parallel
from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.core.parallel_state import (
get_data_parallel_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
try:
from megatron.core.transformer.custom_layers.transformer_engine import SplitAlongDim
except ImportError:
#print("Do not support transformer_engine")
SplitAlongDim = None
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import divide
from .enums import AttnMaskType
from .transformer_config import TransformerConfig
@dataclass
class SelfAttentionSubmodules:
linear_qkv: Union[ModuleSpec, type] = None
core_attention: Union[ModuleSpec, type] = None
linear_proj: Union[ModuleSpec, type] = None
q_layernorm: Union[ModuleSpec, type] = None
k_layernorm: Union[ModuleSpec, type] = None
@dataclass
class CrossAttentionSubmodules:
linear_q: Union[ModuleSpec, type] = None
linear_kv: Union[ModuleSpec, type] = None
core_attention: Union[ModuleSpec, type] = None
linear_proj: Union[ModuleSpec, type] = None
class Attention(MegatronModule, ABC):
"""Attention layer abstract class.
This layer only contains common modules required for the "self attn" and
"cross attn" specializations.
"""
def __init__(
self,
config: TransformerConfig,
submodules: Union[SelfAttentionSubmodules, CrossAttentionSubmodules],
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
):
super().__init__(config=config)
self.config = config
self.layer_number = layer_number
self.attn_mask_type = attn_mask_type
self.attention_type = attention_type
# For normal attention without groups, num_query_groups == num_attention_heads,
# so these two will be the same
self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads
self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = divide(
self.query_projection_size, self.config.num_attention_heads
)
self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)
self.core_attention = build_module(
submodules.core_attention,
config=self.config,
layer_number=self.layer_number,
attn_mask_type=self.attn_mask_type,
attention_type=self.attention_type,
)
self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
# Output.
self.linear_proj = build_module(
submodules.linear_proj,
self.query_projection_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name='proj',
)
def _checkpointed_attention_forward(
self,
query,
key,
value,
attention_mask,
rotary_pos_emb=None,
attn_mask_type=None,
packed_seq_params=None,
):
"""Forward method with selective activation checkpointing."""
def custom_forward(*inputs):
query = inputs[0]
key = inputs[1]
value = inputs[2]
attention_mask = inputs[3]
attn_mask_type = inputs[5]
attn_mask_type = AttnMaskType(attn_mask_type.item())
output_ = self.core_attention(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)
return output_
if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
attn_mask_type = torch.tensor([attn_mask_type.value], dtype=torch.int)
hidden_states = tensor_parallel.checkpoint(
custom_forward,
False,
query,
key,
value,
attention_mask,
rotary_pos_emb,
attn_mask_type,
)
return hidden_states
def _allocate_memory(self, inference_max_sequence_length, batch_size, dtype):
"""Allocate memory to store kv cache during inference."""
return torch.empty(
inference_max_sequence_length,
batch_size,
self.num_query_groups_per_partition,
self.hidden_size_per_attention_head,
dtype=dtype,
device=torch.cuda.current_device(),
)
def _adjust_key_value_for_inference(self, inference_params, key, value, rotary_pos_emb):
"""
Saves the generated key and value tensors to the end of the buffers in inference_params.
Returns the full size keys and values from the provided inference_params, as well as
adjusted rotary_pos_emb.
Returns a tuple: (key, value, rotary_pos_emb)
"""
attn_mask_type = self.attn_mask_type
if inference_params is None:
return key, value, rotary_pos_emb, attn_mask_type
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step = False
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_length = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, key.dtype
)
inference_value_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, value.dtype
)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory,
inference_value_memory,
)
is_first_step = True
else:
# Get the pre-allocated buffers for this layer
inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[
self.layer_number
]
attn_mask_type = AttnMaskType.no_mask
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key
inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value
key = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
# adjust the key rotary positional embedding
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
# need to cross check this condition during inference
# if not set_inference_key_value_memory:
if not is_first_step:
# In inference, we compute one token at a time.
# Select the correct positional embedding
# (only the last token in the sequence)
q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
else:
# In the first forward pass of inference,
# we use the entire provided prefix.
# q_pos_emb here has the rope embeddings of the entire
# prefix + to-be-generated output so
# we slice to just the prefix.
q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)
return key, value, rotary_pos_emb, attn_mask_type
@abstractmethod
def get_query_key_value_tensors(self, hidden_states, key_value_states):
"""
This method needs to be implemented based on whether the derived class
is "self-attn" or "cross-attn".
"""
def forward(
self,
hidden_states,
attention_mask,
key_value_states=None,
inference_params=None,
rotary_pos_emb=None,
packed_seq_params=None,
):
# hidden_states: [sq, b, h]
# For self attention we just duplicate the rotary_pos_emb if it isn't already
if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = (rotary_pos_emb,) * 2
# =====================
# Query, Key, and Value
# =====================
# Get the query, key and value tensors based on the type of attention -
# self or cross attn.
query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
# ===================================================
# Adjust key, value, and rotary_pos_emb for inference
# ===================================================
key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
inference_params, key, value, rotary_pos_emb
)
if packed_seq_params is not None:
query = query.squeeze(1)
key = key.squeeze(1)
value = value.squeeze(1)
# ================================================
# relative positional embedding (rotary embedding)
# ================================================
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
if packed_seq_params is not None:
cu_seqlens_q = packed_seq_params.cu_seqlens_q
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
else:
cu_seqlens_q = cu_seqlens_kv = None
query = apply_rotary_pos_emb(
query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q,
)
key = apply_rotary_pos_emb(
key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv,
)
# TODO, can apply positional embedding to value_layer so it has
# absolute positional embedding.
# otherwise, only relative positional embedding takes effect
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)
# ==================================
# core attention computation
# ==================================
if self.checkpoint_core_attention and self.training:
core_attn_out = self._checkpointed_attention_forward(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)
else:
core_attn_out = self.core_attention(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)
if packed_seq_params is not None:
# reshape to same output shape as unpacked case
# (t, np, hn) -> (t, b=1, h=np*hn)
# t is the pack size = sum (sq_i)
# note that batch is a dummy dimension in the packed case
core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.linear_proj(core_attn_out)
return output, bias
class SelfAttention(Attention):
"""Self-attention layer class
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(
self,
config: TransformerConfig,
submodules: SelfAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
):
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
attention_type="self",
)
self.linear_qkv = build_module(
submodules.linear_qkv,
self.config.hidden_size,
self.query_projection_size + 2 * self.kv_projection_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='qkv',
)
if submodules.q_layernorm is not None:
self.q_layernorm = build_module(
submodules.q_layernorm,
hidden_size=self.hidden_size_per_attention_head,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.q_layernorm = None
if submodules.k_layernorm is not None:
self.k_layernorm = build_module(
submodules.k_layernorm,
hidden_size=self.hidden_size_per_attention_head,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.k_layernorm = None
def run_realtime_tests(self):
"""Performs a consistency check.
This function makes sure that tensors across devices are the same during an experiment.
This is often not guaranteed to be so because of silent hardware failures (eg, memory
corruption loading a checkpoint, network traffic corruption encountered during data transmission).
(TODO) In the future, more tensors should be checked across the training run and
checked every X iterations. This is left for future work. Equality of tensors is probably not
required; transmitting hashes is sufficient."""
if not self.config.qk_layernorm:
return
# check that all tensor parallel and data parallel ranks have the same
# Q & K layernorm parameters.
rank = get_data_parallel_rank()
inputs = torch.stack(
[
self.q_layernorm.weight.data,
self.q_layernorm.bias.data,
self.k_layernorm.weight.data,
self.k_layernorm.bias.data,
]
)
dp_list = [torch.empty_like(inputs) for _ in range(get_data_parallel_world_size())]
dp_list[rank] = inputs
torch.distributed.all_gather(dp_list, inputs, group=get_data_parallel_group())
def _compare(srcs, tgts, names, parallelism):
assert len(srcs) == len(tgts) == len(names)
for src, tgt, name in zip(srcs, tgts, names):
assert torch.all(
src == tgt
), f"Discrepancy between {name} in {parallelism} ranks {i} and {rank}. Diff: {torch.norm(src - tgt)}"
for i, dp in enumerate(dp_list):
q_w, q_b, k_w, k_b = torch.unbind(dp)
_compare(
[q_w, q_b, k_w, k_b],
[
self.q_layernorm.weight.data,
self.q_layernorm.bias.data,
self.k_layernorm.weight.data,
self.k_layernorm.bias.data,
],
["q_w", "q_b", "k_w", "k_b"],
"DP",
)
rank = get_tensor_model_parallel_rank()
tp_list = [torch.empty_like(inputs) for _ in range(get_tensor_model_parallel_world_size())]
tp_list[rank] = inputs
torch.distributed.all_gather(tp_list, inputs, group=get_tensor_model_parallel_group())
for i, tp in enumerate(tp_list):
q_w, q_b, k_w, k_b = torch.unbind(tp)
_compare(
[q_w, q_b, k_w, k_b],
[
self.q_layernorm.weight.data,
self.q_layernorm.bias.data,
self.k_layernorm.weight.data,
self.k_layernorm.bias.data,
],
["q_w", "q_b", "k_w", "k_b"],
"TP",
)
def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
"""
Derives `query`, `key` and `value` tensors from `hidden_states`.
"""
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
mixed_qkv, _ = self.linear_qkv(hidden_states)
# [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
new_tensor_shape = mixed_qkv.size()[:-1] + (
self.num_query_groups_per_partition,
(
(self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
* self.hidden_size_per_attention_head
),
)
mixed_qkv = mixed_qkv.view(*new_tensor_shape)
split_arg_list = [
(
self.num_attention_heads_per_partition
// self.num_query_groups_per_partition
* self.hidden_size_per_attention_head
),
self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head,
]
if SplitAlongDim is not None:
# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list,)
else:
# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3,)
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
if self.q_layernorm is not None:
query = self.q_layernorm(query)
if self.k_layernorm is not None:
key = self.k_layernorm(key)
if self.config.test_mode:
self.run_realtime_tests()
return query, key, value
class CrossAttention(Attention):
"""Cross-attention layer class
Cross-attention layer takes input with size [s, b, h] and context with size
[s, b, h] and returns output of the same size.
"""
def __init__(
self,
config: TransformerConfig,
submodules: CrossAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
):
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
attention_type="cross",
)
if self.config.num_query_groups != self.config.num_attention_heads:
raise ValueError(
f"Group query attention is not currently supported in cross attention."
)
assert self.query_projection_size == self.kv_projection_size
self.linear_q = build_module(
submodules.linear_q,
self.config.hidden_size,
self.query_projection_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=False,
is_expert=False,
)
self.linear_kv = build_module(
submodules.linear_kv,
self.config.hidden_size,
2 * self.kv_projection_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=False,
is_expert=False,
)
def get_query_key_value_tensors(self, hidden_states, key_value_states):
"""
Derives `query` tensor from `hidden_states`, and `key`/`value` tensors
from `key_value_states`.
"""
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv, _ = self.linear_kv(key_value_states)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv.size()[:-1] + (
self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head,
)
mixed_kv = mixed_kv.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query, _ = self.linear_q(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query.size()[:-1] + (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
query = query.view(*new_tensor_shape)
return query, key, value
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import os
from importlib.metadata import version
from typing import Callable
import torch
import transformer_engine as te
from pkg_resources import packaging
from torch import Tensor
from megatron.core import ModelParallelConfig
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
get_context_parallel_group,
get_tensor_model_parallel_group,
)
from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
_te_version = packaging.version.Version(version("transformer-engine"))
def _get_extra_te_kwargs(config: TransformerConfig):
extra_transformer_engine_kwargs = {
"params_dtype": config.params_dtype,
}
if _te_version >= packaging.version.Version("0.12.0"):
if config.use_cpu_initialization:
extra_transformer_engine_kwargs["device"] = 'cpu'
else:
extra_transformer_engine_kwargs["device"] = torch.cuda.current_device()
return extra_transformer_engine_kwargs
def condition_init_method(config, init_method):
return init_method if config.perform_initialization else (lambda w: None)
class TENorm:
"""
A conditional wrapper to initialize an instance of Transformer-Engine's
`LayerNorm` or `RMSNorm` based on input
"""
# TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm?
def __new__(
cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5,
):
if config.normalization == "LayerNorm":
instance = te.pytorch.LayerNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
elif config.normalization == "RMSNorm":
assert hasattr(
te.pytorch, "RMSNorm"
), "Transformer-Engine >= v0.11 required to use this feature"
instance = te.pytorch.RMSNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
else:
raise Exception('Only LayerNorm and RMSNorm are curently supported')
return instance
class TELinear(te.pytorch.Linear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
parallel_mode: str,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
skip_weight_param_allocation: bool,
tp_comm_buffer_name: str = None,
):
self.config = config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
if skip_weight_param_allocation:
raise ValueError(
'Transformer Engine linear layers do not support skip_weight_param_allocation'
)
extra_kwargs = _get_extra_te_kwargs(config)
if _te_version >= packaging.version.Version("0.8.0"):
if self.config.tp_comm_overlap:
if _te_version > packaging.version.Version("1.5.0"):
# Use old overlap flags if they were supplied instead
extra_kwargs["ub_overlap_ag"] = (
self.config.tp_comm_overlap_ag
if hasattr(self.config, "tp_comm_overlap_ag")
else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag
)
extra_kwargs["ub_overlap_rs"] = (
self.config.tp_comm_overlap_rs
if hasattr(self.config, "tp_comm_overlap_rs")
else self.config.tp_comm_split_rs or self.config.tp_comm_atomic_rs
)
else:
extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag
extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag
extra_kwargs["ub_split_rs"] = self.config.tp_comm_split_rs
extra_kwargs["ub_atomic_gemm_rs"] = self.config.tp_comm_atomic_rs
if _te_version > packaging.version.Version("1.0.0"):
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name
super().__init__(
in_features=input_size,
out_features=output_size,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=get_cuda_rng_tracker
if get_cuda_rng_tracker().is_initialized()
else None,
init_method=condition_init_method(config, init_method),
bias=bias,
return_bias=self.te_return_bias,
parallel_mode=parallel_mode,
**extra_kwargs,
)
def forward(self, x):
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
self.is_first_microbatch = False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear):
"""
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
layernorm and linear layers
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: TransformerConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: str = None,
):
self.config = config
if gather_output:
raise ValueError('Transformer Engine linear layers do not support gather_output = True')
if is_expert:
raise ValueError('Transformer Engine linear layers do not yet support MoE')
if skip_weight_param_allocation:
raise ValueError(
'Transformer Engine linear layers do not support skip_weight_param_allocation'
)
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
extra_kwargs = _get_extra_te_kwargs(config)
# Only Transformer-Engine version >= 0.11.0 supports `RMSNorm`
if _te_version >= packaging.version.Version("0.11.0"):
extra_kwargs["normalization"] = self.config.normalization
elif self.config.normalization != "LayerNorm":
raise ValueError(
f"Transformer Engine v{_te_version} does not support {self.config.normalization}."
)
if _te_version >= packaging.version.Version("0.8.0"):
if self.config.tp_comm_overlap:
extra_kwargs["ub_bulk_wgrad"] = self.config.tp_comm_bulk_wgrad
extra_kwargs["ub_bulk_dgrad"] = self.config.tp_comm_bulk_dgrad
if _te_version > packaging.version.Version("1.5.0"):
# Use old overlap flags if they were supplied instead
extra_kwargs["ub_overlap_ag"] = (
self.config.tp_comm_overlap_ag
if hasattr(self.config, "tp_comm_overlap_ag")
else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag
)
if _te_version > packaging.version.Version("1.6.0.dev0"):
extra_kwargs["ub_overlap_rs_dgrad"] = (
self.config.tp_comm_overlap_rs_dgrad
if hasattr(self.config, "tp_comm_overlap_rs_dgrad")
else False
)
else:
extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag
extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag
if _te_version > packaging.version.Version("1.0.0"):
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name
super().__init__(
in_features=input_size,
out_features=output_size,
eps=self.config.layernorm_epsilon,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=get_cuda_rng_tracker
if get_cuda_rng_tracker().is_initialized()
else None,
init_method=condition_init_method(config, init_method),
bias=bias,
return_bias=self.te_return_bias,
parallel_mode="column",
return_layernorm_output=False,
zero_centered_gamma=self.config.layernorm_zero_centered_gamma,
**extra_kwargs,
)
def forward(self, x):
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
self.is_first_microbatch = False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
""" Sharding along axis 0, bias sharded """
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)
class TEColumnParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `ColumnParallelLinear` layer.
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: str = None,
):
if gather_output:
raise ValueError('Transformer Engine linear layers do not support gather_output = True')
if is_expert:
raise ValueError('Transformer Engine linear layers do not yet support MoE')
super().__init__(
input_size=input_size,
output_size=output_size,
parallel_mode="column",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name,
)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
""" Sharding along axis 0, bias sharded """
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)
class TERowParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `RowParallelLinear` layer.
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str = None,
):
if not input_is_parallel:
raise ValueError(
"Transformer Engine linear layers do not support input_is_parallel = False"
)
if is_expert:
raise ValueError('Transformer Engine linear layers do not yet support MoE')
super().__init__(
input_size=input_size,
output_size=output_size,
parallel_mode="row",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
skip_weight_param_allocation=False, # We don't currently use this for row parallel layers
tp_comm_buffer_name=tp_comm_buffer_name,
)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
""" Sharding along axis 1, bias not sharded """
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 1}, sharded_offsets
)
class TEDotProductAttention(te.pytorch.DotProductAttention):
"""
Wrapper for the Transformer-Engine's `DotProductAttention` layer that also
has "flash attention" enabled.
Note that if Megatron's parallel_state has not been initialized yet, the
tp_group and cp_group passed to TE will be None and must be set later
via set_tensor_parallel_group() and set_context_parallel_group().
"""
cp_stream: torch.cuda.Stream = None
def __init__(
self,
config: TransformerConfig,
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
attention_dropout: float = None,
):
self.config = config
self.te_forward_mask_type = False
self.qkv_format: str = 'sbhd'
if self.config.apply_query_key_layer_scaling != bool(
int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0'))
):
raise ValueError(
f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} "
f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support "
f"setting query key layer scaling via argument, so these two must match."
)
extra_kwargs = {}
if _te_version >= packaging.version.Version("0.11.0"):
extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
elif self.config.num_query_groups != self.config.num_attention_heads:
raise ValueError(
f"Transformer Engine v{_te_version} does not support Grouped Query Attention, "
f"use a newer version of Transformer Engine. "
f"(num_query_groups ({self.config.num_query_groups}) != "
f"num_attention_heads ({self.config.num_attention_heads}))"
)
if _te_version >= packaging.version.Version("0.10.0"):
extra_kwargs["attention_type"] = attention_type
# older version don't need attention_type
if _te_version > packaging.version.Version("0.12.0"):
self.te_forward_mask_type = True
# Only Transformer-Engine version >= 1.0.0 supports context parallelism
if _te_version >= packaging.version.Version("1.0.0"):
if getattr(TEDotProductAttention, "cp_stream") is None:
TEDotProductAttention.cp_stream = torch.cuda.Stream()
extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(
check_initialized=False
)
extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
else:
assert (
self.config.context_parallel_size == 1
), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if self.config.deterministic_mode:
if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0:
raise RuntimeError(
"deterministic_mode is on and we are using DotProductAttention from "
"Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. "
f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}."
)
if config.window_size is not None:
# Check version
assert _te_version >= packaging.version.Version(
"1.2.0"
), f"Transformer-Engine version ({str(_te_version)}) must be >= 1.2.0 to support sliding window attention."
extra_kwargs['window_size'] = config.window_size
super().__init__(
num_attention_heads=self.config.num_attention_heads,
kv_channels=self.config.kv_channels,
attention_dropout=self.config.attention_dropout
if attention_dropout is None
else attention_dropout,
attn_mask_type=attn_mask_type.name,
sequence_parallel=self.config.sequence_parallel,
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=get_cuda_rng_tracker
if get_cuda_rng_tracker().is_initialized()
else None,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
layer_number=layer_number,
**extra_kwargs,
)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: Tensor,
attn_mask_type: AttnMaskType,
packed_seq_params: PackedSeqParams = None,
):
packed_seq_kwargs = (
dataclasses.asdict(packed_seq_params) if packed_seq_params is not None else {}
)
# overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set after init
if self.config.apply_rope_fusion and _te_version > packaging.version.Version("0.13.0"):
self.qkv_format = 'bshd'
qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format)
if _te_version < packaging.version.Version("1.3.0"):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H copies (#555)
# These two arguments did not exist prior to 1.3.0
packed_seq_kwargs.pop("max_seqlen_q", None)
packed_seq_kwargs.pop("max_seqlen_kv", None)
if self.config.apply_rope_fusion and qkv_format == 'bshd':
query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
# In PyTorch, the following two tensors are in fact the same:
# Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
# Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
# Stride for a dimension that is 1 has no meaning, so tensors created two different ways
# can have same shape but different strides.
# We unify them to the first one to pass the stride check in TE
if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride():
value = value.as_strided(value.shape, key.stride())
if self.te_forward_mask_type:
core_attn_out = super().forward(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type.name,
**packed_seq_kwargs,
)
else:
core_attn_out = super().forward(query, key, value, attention_mask, **packed_seq_kwargs,)
if self.config.apply_rope_fusion and qkv_format == 'bshd':
return core_attn_out.transpose(0, 1)
else:
return core_attn_out
class TEDelayedScaling(te.common.recipe.DelayedScaling):
"""
Wrapper for the Transformer-Engine's `DelayedScaling` layer.
"""
def __init__(
self,
config: ModelParallelConfig,
fp8_format: int,
override_linear_precision: tuple = (False, False, False),
):
extra_kwargs = _get_extra_te_kwargs(config)
if _te_version >= packaging.version.Version("1.6.0.dev0"):
extra_kwargs["fp8_dpa"] = config.fp8_dot_product_attention
extra_kwargs["fp8_mha"] = config.fp8_multi_head_attention
super().__init__(
margin=config.fp8_margin,
interval=config.fp8_interval,
fp8_format=fp8_format,
amax_compute_algo=config.fp8_amax_compute_algo,
amax_history_len=config.fp8_amax_history_len,
override_linear_precision=override_linear_precision,
**extra_kwargs,
)
def te_checkpoint(
forward_func,
distribute_saved_activations,
get_rng_state_tracker,
tp_group,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
):
from transformer_engine.pytorch.distributed import checkpoint
if _te_version >= packaging.version.Version("1.5.0"):
return checkpoint(
forward_func,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
distribute_saved_activations=distribute_saved_activations,
get_rng_state_tracker=get_rng_state_tracker,
tp_group=tp_group,
)
else:
return checkpoint(
forward_func,
distribute_saved_activations,
get_rng_state_tracker,
tp_group,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
)
try:
from transformer_engine.pytorch.attention import _SplitAlongDim
SplitAlongDim = _SplitAlongDim.apply
except ImportError:
SplitAlongDim = None
try:
from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context
except ImportError:
get_cpu_offload_context = None
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import math
import torch
from torch import Tensor
from megatron.core import parallel_state, tensor_parallel
from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import attention_mask_func
from megatron.core.utils import divide
class DotProductAttention(MegatronModule):
"""
Region where selective activation recomputation is applied.
This region is memory intensive but less compute intensive which
makes activation checkpointing more efficient for LLMs (20B+).
See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
We use the following notation:
h: hidden size
n: number of attention heads
p: number of tensor model parallel partitions
b: batch size
s: sequence length
"""
def __init__(
self,
config: TransformerConfig,
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
attention_dropout: float = None,
):
super().__init__(config=config)
self.config: TransformerConfig = config
assert (
self.config.context_parallel_size == 1
), "Context parallelism is only supported by TEDotProductAttention!"
assert (
self.config.window_size is None
), "Sliding Window Attention is only supported by TEDotProductAttention!"
self.layer_number = max(1, layer_number)
self.attn_mask_type = attn_mask_type
self.attention_type = attention_type # unused for now
projection_size = self.config.kv_channels * self.config.num_attention_heads
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = divide(projection_size, world_size)
self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads)
self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.config.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
input_in_fp16=self.config.fp16,
input_in_bf16=self.config.bf16,
attn_mask_type=self.attn_mask_type,
scaled_masked_softmax_fusion=self.config.masked_softmax_fusion,
mask_func=attention_mask_func,
softmax_in_fp32=self.config.attention_softmax_in_fp32,
scale=coeff,
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(
self.config.attention_dropout if attention_dropout is None else attention_dropout
)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: Tensor,
attn_mask_type: AttnMaskType = None,
packed_seq_params: PackedSeqParams = None,
):
assert packed_seq_params is None, (
"Packed sequence is not supported by DotProductAttention."
"Please use TEDotProductAttention instead."
)
# ===================================
# Raw attention scores. [b, n/p, s, s]
# ===================================
# expand the key and value [sk, b, ng, hn] -> [sk, b, np, hn]
# This is a noop for normal attention where ng == np. When using group query attention this
# creates a view that has the keys and values virtually repeated along their dimension to
# match the number of queries.
# attn_mask_type is not used.
if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1:
key = key.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
)
value = value.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
)
# [b, np, sq, sk]
output_size = (
query.size(1),
query.size(2),
query.size(0),
key.size(0),
)
# [sq, b, np, hn] -> [sq, b * np, hn]
# This will be a simple view when doing normal attention, but in group query attention
# the key and value tensors are repeated to match the queries so you can't use simple strides
# to extract the queries.
query = query.reshape(output_size[2], output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key = key.view(output_size[3], output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
(output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu",
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query.transpose(0, 1), # [b * np, sq, hn]
key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if not self.config.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (
value.size(1),
value.size(2),
query.size(0),
value.size(3),
)
# change view [sk, b * np, hn]
value = value.view(value.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [b * np, sq, hn]
context = torch.bmm(attention_probs, value.transpose(0, 1))
# change view [b, np, sq, hn]
context = context.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context = context.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_shape = context.size()[:-2] + (self.hidden_size_per_partition,)
context = context.view(*new_context_shape)
return context
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import enum
# can we get rid of this?
# it's being used in pipeline schedules
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
# class LayerType(enum.Enum):
# encoder = 1
# decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
no_mask = 3 # only used for TE
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment