Commit d444a97a authored by yangzhong's avatar yangzhong
Browse files

首次上传

parents
Pipeline #3020 canceled with stages
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from typing import Literal
import torch
from torch import Tensor
from megatron.core import tensor_parallel
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
class LanguageModelEmbedding(MegatronModule):
"""Language model embeddings.
Args:
config (TransformerConfig): config object with all necessary configs for TransformerBlock
vocab_size (int): vocabulary size
max_sequence_length (int): maximum size of sequence. This
is used for positional embedding
add_position_embedding (bool): Add a position embedding.
embedding_dropout_prob (float): dropout probability for embeddings
num_tokentypes (int): Set to 0 without binary head, and 2 with a binary head. Defaults to 0.
scatter_to_sequence_parallel (bool): Set to False to disable scatter of embedding
across sequence parallel region. Defaults to True.
"""
def __init__(
self,
config: TransformerConfig,
vocab_size: int,
max_sequence_length: int,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
num_tokentypes: int = 0,
scatter_to_sequence_parallel: bool = True,
):
super().__init__(config=config)
self.config: TransformerConfig = config
self.vocab_size: int = vocab_size
self.max_sequence_length: int = max_sequence_length
self.add_position_embedding: bool = position_embedding_type == 'learned_absolute'
self.num_tokentypes = num_tokentypes
self.scatter_to_sequence_parallel = scatter_to_sequence_parallel
self.reduce_scatter_embeddings = (
(not self.add_position_embedding)
and self.num_tokentypes <= 0
and self.config.sequence_parallel
and self.scatter_to_sequence_parallel
)
# Word embeddings (parallel).
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
num_embeddings=self.vocab_size,
embedding_dim=self.config.hidden_size,
init_method=self.config.init_method,
reduce_scatter_embeddings=self.reduce_scatter_embeddings,
config=self.config,
)
# Position embedding (serial).
if self.add_position_embedding:
self.position_embeddings = torch.nn.Embedding(
self.max_sequence_length, self.config.hidden_size
)
# Initialize the position embeddings.
if self.config.perform_initialization:
self.config.init_method(self.position_embeddings.weight)
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(
self.num_tokentypes, self.config.hidden_size
)
# Initialize the token-type embeddings.
if self.config.perform_initialization:
self.config.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(self.config.hidden_dropout)
def zero_parameters(self):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
self.position_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.shared = True
if self.num_tokentypes > 0:
self.tokentype_embeddings.weight.data.fill_(0)
self.tokentype_embeddings.weight.shared = True
def forward(self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: int = None) -> Tensor:
"""Forward pass of the embedding module.
Args:
input_ids (Tensor): The input tokens
position_ids (Tensor): The position id's used to calculate position embeddings
tokentype_ids (int): The token type ids. Used when args.bert_binary_head is
set to True. Defaults to None
Returns:
Tensor: The output embeddings
"""
word_embeddings = self.word_embeddings(input_ids)
if self.add_position_embedding:
position_embeddings = self.position_embeddings(position_ids)
embeddings = word_embeddings + position_embeddings
else:
embeddings = word_embeddings
if not self.reduce_scatter_embeddings:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
# [b s h] -> [s b h] (So that it can be added with embeddings)
tokentype_embedding = self.tokentype_embeddings(tokentype_ids).permute(1, 0, 2)
embeddings = embeddings + tokentype_embedding
else:
assert self.tokentype_embeddings is None
# If the input flag for fp32 residual connection is set, convert for float.
if self.config.fp32_residual_connection:
embeddings = embeddings.float()
# Dropout.
if self.config.sequence_parallel:
if not self.reduce_scatter_embeddings and self.scatter_to_sequence_parallel:
embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
# `scatter_to_sequence_parallel_region` returns a view, which prevents
# the original tensor from being garbage collected. Clone to facilitate GC.
# Has a small runtime cost (~0.5%).
if self.config.clone_scatter_output_in_embedding and self.scatter_to_sequence_parallel:
embeddings = embeddings.clone()
with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings)
return embeddings
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from megatron.core.transformer.transformer_config import TransformerConfig
import logging
import torch
from torch import Tensor
from megatron.core import parallel_state
from megatron.core.utils import is_te_min_version
logger = logging.getLogger(__name__)
# Prefer fused RoPE from Apex as we need the `transpose_output_memory` argument for the bshd trick.
# See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2469.
try:
from apex.transformer.functional import fused_apply_rotary_pos_emb
except ImportError:
try:
from megatron.core.extensions.transformer_engine import fused_apply_rotary_pos_emb
except:
fused_apply_rotary_pos_emb = None
try:
from megatron.core.extensions.transformer_engine import fused_apply_rotary_pos_emb_thd
except ImportError:
try:
from apex.transformer.functional import fused_apply_rotary_pos_emb_thd
except ImportError:
fused_apply_rotary_pos_emb_thd = None
try:
from flash_attn.layers.rotary import apply_rotary_emb as apply_rotary_emb_flash
except ImportError:
apply_rotary_emb_flash = None
__all__ = ['apply_rotary_emb_flash']
def get_pos_emb_on_this_cp_rank(pos_emb: Tensor, seq_dim: int) -> Tensor:
"""Get the position embedding on the current context parallel rank.
Args:
pos_emb (Tensor): Positional embedding tensor
seq_dim (int): Sequence dimension
"""
cp_size = parallel_state.get_context_parallel_world_size()
cp_rank = parallel_state.get_context_parallel_rank()
cp_idx = torch.tensor(
[cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True
).cuda(non_blocking=True)
pos_emb = pos_emb.view(
*pos_emb.shape[:seq_dim], 2 * cp_size, -1, *pos_emb.shape[(seq_dim + 1) :]
)
pos_emb = pos_emb.index_select(seq_dim, cp_idx)
pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :])
return pos_emb
def _rotate_half(x: Tensor, rotary_interleaved: bool) -> Tensor:
"""Change sign so the last dimension becomes [-odd, +even]
Args:
x (Tensor): Input tensor
Returns:
Tensor: Tensor rotated half
"""
if not rotary_interleaved:
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x_new = torch.stack((-x2, x1), dim=-1)
return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1)
def _apply_rotary_pos_emb_bshd(
t: Tensor,
freqs: Tensor,
rotary_interleaved: bool = False,
multi_latent_attention: bool = False,
mscale: float = 1.0,
) -> Tensor:
"""Apply rotary positional embedding to input tensor T.
check https://kexue.fm/archives/8265 for detailed formulas
Args:
t (Tensor): Input tensor T is of shape [seq_length, ... , dim]
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim]
Returns:
Tensor: The input tensor after applying RoPE
"""
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
if multi_latent_attention:
x1 = t[..., 0::2]
x2 = t[..., 1::2]
t = torch.cat((x1, x2), dim=-1)
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
cos_ = (torch.cos(freqs) * mscale).to(t.dtype)
sin_ = (torch.sin(freqs) * mscale).to(t.dtype)
t = (t * cos_) + (_rotate_half(t, rotary_interleaved) * sin_)
return torch.cat((t, t_pass), dim=-1)
def _get_thd_freqs_on_this_cp_rank(cp_rank: int, cp_size: int, x: Tensor, freqs: Tensor) -> Tensor:
if cp_size > 1:
cp_seg = x.size(0) // 2
full_seqlen = cp_size * x.size(0)
return torch.cat(
[
freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg],
freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg],
]
)
else:
return freqs[: x.size(0)]
def _apply_rotary_pos_emb_thd(
t: Tensor,
cu_seqlens: Tensor,
freqs: Tensor,
rotary_interleaved: bool = False,
multi_latent_attention: bool = False,
mscale: float = 1.0,
) -> Tensor:
"""A baseline implementation of applying RoPE for `thd` format.
Args:
t (Tensor): Input tensor T is of shape [t, h, d]
cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`,
with shape [b + 1] and dtype torch.int32.
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]
Returns:
Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
"""
cp_size = parallel_state.get_context_parallel_world_size()
cp_rank = parallel_state.get_context_parallel_rank()
cu_seqlens = cu_seqlens // cp_size
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat(
[
_apply_rotary_pos_emb_bshd(
x.unsqueeze(1),
_get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs),
rotary_interleaved=rotary_interleaved,
multi_latent_attention=multi_latent_attention,
mscale=mscale,
)
for x in torch.split(t, seqlens)
]
).squeeze(1)
def apply_rotary_pos_emb(
t: Tensor,
freqs: Tensor,
config: TransformerConfig,
cu_seqlens: Optional[Tensor] = None,
mscale: float = 1.0,
):
"""
Reroute to the appropriate apply_rotary_pos_emb function depending on
fused/unfused kernels, or bshd (conventional) / thd (packed seq) format
"""
if config.apply_rope_fusion:
if cu_seqlens is None:
assert fused_apply_rotary_pos_emb is not None, "apply_rope_fusion is not available."
return fused_apply_rotary_pos_emb(t, freqs, transpose_output_memory=True)
else:
assert fused_apply_rotary_pos_emb_thd is not None, "apply_rope_fusion is not available."
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
if not is_te_min_version("1.11.0", check_equality=False):
raise ValueError("Only TE >= 1.12 supports RoPE fusion for THD format with CP.")
return fused_apply_rotary_pos_emb_thd(
t,
cu_seqlens,
freqs,
cp_size=cp_size,
cp_rank=parallel_state.get_context_parallel_rank(),
)
else:
return fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs)
else:
if cu_seqlens is None:
return _apply_rotary_pos_emb_bshd(
t,
freqs,
rotary_interleaved=config.rotary_interleaved,
multi_latent_attention=config.multi_latent_attention,
mscale=mscale,
)
else:
return _apply_rotary_pos_emb_thd(
t,
cu_seqlens,
freqs,
rotary_interleaved=config.rotary_interleaved,
multi_latent_attention=config.multi_latent_attention,
mscale=mscale,
)
def apply_rotary_pos_emb_with_cos_sin(
t: Tensor, cos: Tensor, sin: Tensor, rotary_interleaved: bool = False
) -> Tensor:
"""
This function applies rotary positional embedding to the target tensor t
using precomputed cos and sin of size (seq_len, d_rot / 2)
"""
cos = cos.to(t.dtype)
sin = sin.to(t.dtype)
if apply_rotary_emb_flash is None:
# Combine cos and sin into freqs
freqs = torch.stack([cos, sin], dim=-1).flatten(start_dim=-2)
# Expand freqs to match t's shape
while freqs.dim() < t.dim():
freqs = freqs.unsqueeze(1)
freqs = freqs.expand(t.shape[:-1] + (-1,))
y = _apply_rotary_pos_emb_bshd(
t,
freqs,
rotary_interleaved=rotary_interleaved,
multi_latent_attention=False,
mscale=1.0,
)
else:
# Use Flash Attention's optimized kernel for rotary embedding
t = t.permute(1, 0, 2, 3)
y = apply_rotary_emb_flash(t, cos, sin, rotary_interleaved)
y = y.permute(1, 0, 2, 3)
return y
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.inference_params import InferenceParams
from megatron.core.packed_seq_params import PackedSeqParams
import logging
import math
from functools import lru_cache
import torch
from torch import Tensor, nn
from megatron.core import parallel_state
from megatron.core.models.common.embeddings.rope_utils import ( # for backward compatibility; pylint: disable=unused-import
_apply_rotary_pos_emb_bshd,
_apply_rotary_pos_emb_thd,
_rotate_half,
apply_rotary_pos_emb,
get_pos_emb_on_this_cp_rank,
)
logger = logging.getLogger(__name__)
__all__ = ['RotaryEmbedding']
class RotaryEmbedding(nn.Module):
"""Rotary Embedding for language model.
Args:
kv_channels (int): Projection weights dimension in multi-head attention. Obtained
from transformer config
rotary_percent (float): Percent of rotary dimension to use for rotary position
embeddings.
rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings.
Defaults to False.
seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE
for longer sequences. The value must be a float larger than 1.0. Defaults to None
rotary_base (int, optional): Base period for rotary position embeddings. Defaults to
10000.
rope_scaling (bool, optional): Apply rope scaling as used in llama 3.1
use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly
on the GPU. Defaults to False
"""
def __init__(
self,
kv_channels: int,
rotary_percent: float,
rotary_interleaved: bool = False,
seq_len_interpolation_factor: float = None,
rotary_base: int = 10000,
rope_scaling: bool = False,
use_cpu_initialization: bool = False,
) -> None:
super().__init__()
dim = kv_channels
if rotary_percent < 1.0:
dim = int(dim * rotary_percent)
self.rotary_interleaved = rotary_interleaved
self.seq_len_interpolation_factor = seq_len_interpolation_factor
device = 'cpu' if use_cpu_initialization else torch.cuda.current_device()
self.inv_freq = 1.0 / (
rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
)
if rope_scaling:
self.inv_freq = self._apply_scaling(self.inv_freq)
def _apply_scaling(
self,
freqs,
factor=8,
low_freq_factor=1,
high_freq_factor=4,
original_max_position_embeddings=8192,
):
# This implementation is adapted from:
# https://github.com/huggingface/transformers/blob/2a5a6ad18aa22e98429bb5ecb880660328030ea0/src/transformers/modeling_rope_utils.py#L303-L343
factor = factor # `8` in the original implementation
low_freq_factor = low_freq_factor # `1` in the original implementation
high_freq_factor = high_freq_factor # `4` in the original implementation
old_context_len = original_max_position_embeddings # `8192` in the original implementation
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
wavelen = 2 * math.pi / freqs
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, freqs / factor, freqs)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
smoothed_inv_freq = (
1 - smooth_factor
) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
return inv_freq_llama
def get_freqs_non_repeated(self, max_seq_len: int, offset: int = 0) -> Tensor:
"""Generates matrix of frequencies based on positions in the sequence,
used to create positional encodings"""
seq = (
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ offset
)
if self.seq_len_interpolation_factor is not None:
seq *= 1 / self.seq_len_interpolation_factor
freqs = torch.outer(seq, self.inv_freq) # [seq len, dim]
return freqs
def get_cos_sin(self, max_seq_len: int, offset: int = 0) -> (Tensor, Tensor):
"""Cosine and sine values for RoPE are precomputed for all positions up to the maximum
sequence length"""
freqs = self.get_freqs_non_repeated(max_seq_len, offset)
cos = torch.cos(freqs)
sin = torch.sin(freqs)
return cos, sin
@lru_cache(maxsize=32)
def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor:
"""Forward pass of RoPE embedding.
Args:
max_seq_len (int): Maximum size of sequence
offset (int, optional): RoPE offset. Defaults to 0.
packed_seq (bool, optional): Whether to use packed sequence. Defaults to False.
Returns:
Tensor: Embeddings after applying RoPE.
"""
if self.inv_freq.device.type == 'cpu':
# move `inv_freq` to GPU once at the first micro-batch forward pass
self.inv_freq = self.inv_freq.to(device=torch.cuda.current_device())
freqs = self.get_freqs_non_repeated(max_seq_len, offset)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
if not self.rotary_interleaved:
emb = torch.cat((freqs, freqs), dim=-1)
else:
emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view(
freqs.shape[0], -1
)
# emb [seq_length, .., dim]
emb = emb[:, None, None, :]
if parallel_state.get_context_parallel_world_size() > 1 and not packed_seq:
# slice rotary_pos_emb along sequence dimension and select the parition of the current
# CP rank
emb = get_pos_emb_on_this_cp_rank(emb, 0)
return emb
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
state_dict.pop(f'{prefix}inv_freq', None)
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def get_rotary_seq_len(
self,
inference_params: InferenceParams,
transformer: TransformerBlock,
transformer_input: Tensor,
transformer_config: TransformerConfig,
packed_seq_params: PackedSeqParams,
) -> float:
"""Function to get the rotary sequence length.
Args:
inference_params : Used during Inference time
transformer (TransformerBlock): The transformer block (decoder/encoder) used
by the model
transformer_input (Tensor): Input tensor to the transformer
transformer_config (TransformerConfig): Transformer config used by the model
packed_seq_params (PackedSeqParams): Packed sequence params
Returns:
float: The rotary sequence length
"""
if packed_seq_params is not None:
# max_seqlen are the max sequence length in the packed sequence before being divived
# by the tp and cp size.
return max(packed_seq_params.max_seqlen_q, packed_seq_params.max_seqlen_kv)
elif inference_params is not None:
rotary_seq_len = inference_params.max_sequence_length
else:
if transformer.input_tensor is not None:
rotary_seq_len = transformer.input_tensor.size(0)
else:
rotary_seq_len = transformer_input.size(0)
if transformer_config.sequence_parallel:
rotary_seq_len *= transformer_config.tensor_model_parallel_size
rotary_seq_len *= transformer_config.context_parallel_size
return rotary_seq_len
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from __future__ import annotations
import logging
import math
from functools import lru_cache
import torch
from torch import Tensor
from megatron.core import parallel_state
from megatron.core.models.common.embeddings.rope_utils import get_pos_emb_on_this_cp_rank
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
logger = logging.getLogger(__name__)
class YarnRotaryEmbedding(RotaryEmbedding):
"""Yarn Rotary Embedding for language model.
Args:
kv_channels (int): Projection weights dimension in multi-head attention. Obtained from
transformer config
rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings.
Defaults to False.
seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE for
longer sequences. The value must be a float larger than 1.0. Defaults to None
rotary_base (float, optional): Base period for rotary position embeddings. Defaults to
10000.
use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly on
the GPU. Defaults to False
scaling_factor (float, optional): Scaling factor for Yarn RoPE. Defaults to 1.0.
original_max_position_embeddings (int, optional): Original maximum position embeddings
length. Defaults to 4096.
beta_fast (float, optional): Fast beta value for Yarn RoPE. Defaults to 32.
beta_slow (float, optional): Slow beta value for Yarn RoPE. Defaults to 1.
mscale (float, optional): Mscale value for Yarn RoPE. Defaults to 1.
mscale_all_dim (float, optional): Mscale all dim value for Yarn RoPE. Defaults to 0.
"""
def __init__(
self,
kv_channels: int,
rotary_percent: float = 1.0,
rotary_interleaved: bool = False,
seq_len_interpolation_factor: float = None,
rotary_base: float = 10000.0,
use_cpu_initialization: bool = False,
scaling_factor: float = 1.0,
original_max_position_embeddings: int = 4096,
beta_fast: float = 32.0,
beta_slow: float = 1.0,
mscale: float = 1.0,
mscale_all_dim: float = 0.0,
):
self.dim = kv_channels
self.rotary_base = rotary_base
self.scaling_factor = scaling_factor
self.original_max_position_embeddings = original_max_position_embeddings
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.mscale = mscale
self.mscale_all_dim = mscale_all_dim
device = 'cpu' if use_cpu_initialization else torch.cuda.current_device()
self.inv_freq_extra = 1.0 / (
self.rotary_base
** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)
)
self.inv_freq_inter = 1.0 / (
self.scaling_factor
* self.rotary_base
** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)
)
super().__init__(
kv_channels,
rotary_percent,
rotary_interleaved,
seq_len_interpolation_factor,
rotary_base,
use_cpu_initialization,
)
@lru_cache(maxsize=32)
def forward(self, max_seq_len: int, offset: int = 0) -> Tensor:
"""Forward pass of Yarn Rotary Embedding.
Args:
max_seq_len (int): Maximum size of sequence
offset (int, optional): RoPE offset. Defaults to 0.
Returns:
Tensor: Embeddings after applying Yarn RoPE.
"""
assert (
not self.rotary_interleaved
), "Yarn RoPE does not support interleaved rotary embeddings"
if self.inv_freq_extra.device.type == 'cpu':
# move `inv_freq_extra` to GPU once at the first micro-batch forward pass
self.inv_freq_extra = self.inv_freq_extra.to(device=torch.cuda.current_device())
if self.inv_freq_inter.device.type == 'cpu':
# move `inv_freq_inter` to GPU once at the first micro-batch forward pass
self.inv_freq_inter = self.inv_freq_inter.to(device=torch.cuda.current_device())
low, high = _yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
self.dim,
self.rotary_base,
self.original_max_position_embeddings,
)
inv_freq_mask = 1.0 - _yarn_linear_ramp_mask(low, high, self.dim // 2).to(
device=self.inv_freq_extra.device, dtype=torch.float32
)
inv_freq = self.inv_freq_inter * (1 - inv_freq_mask) + self.inv_freq_extra * inv_freq_mask
seq = (
torch.arange(
max_seq_len, device=self.inv_freq_extra.device, dtype=self.inv_freq_extra.dtype
)
+ offset
)
freqs = torch.outer(seq, inv_freq)
_mscale = float(
_yarn_get_mscale(self.scaling_factor, self.mscale)
/ _yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
)
emb = torch.cat((freqs, freqs), dim=-1)
# emb [seq_length, .., dim]
emb = emb[:, None, None, :]
if parallel_state.get_context_parallel_world_size() > 1:
# slice rotary_pos_emb along sequence dimension
# and select the parition of the current CP rank
emb = get_pos_emb_on_this_cp_rank(emb, 0)
return emb, _mscale
# Inverse dim formula to find dim based on number of rotations
def _yarn_find_correction_dim(
num_rotations: float, dim: int, rotary_base: float = 10000, max_position_embeddings: int = 2048
) -> float:
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(rotary_base)
)
# Find dim range bounds based on rotations
def _yarn_find_correction_range(
low_rot: float,
high_rot: float,
dim: int,
rotary_base: float = 10000,
max_position_embeddings: int = 2048,
) -> tuple[int, int]:
low = math.floor(_yarn_find_correction_dim(low_rot, dim, rotary_base, max_position_embeddings))
high = math.ceil(_yarn_find_correction_dim(high_rot, dim, rotary_base, max_position_embeddings))
return max(low, 0), min(high, dim - 1) # Clamp values just in case
def _yarn_linear_ramp_mask(min: float, max: float, dim: int) -> Tensor:
if min == max:
max += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def _yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
import os
from typing import Optional, Tuple
import torch
from torch import Tensor
from megatron.core import parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
from megatron.core.transformer.enums import AttnBackend
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint
class LanguageModule(MegatronModule):
"""Base language module that has common helper functions used across GPT, BERT etc.
Args:
config (TransformerConfig): Input transformer config for the model
"""
def __init__(self, config: TransformerConfig) -> None:
super().__init__(config=config)
self._set_attention_backend()
# pylint: disable=line-too-long
def _set_attention_backend(self):
"""Set attention backend
Transformer engine works based on optout. By default all three attention backend flags are set to 1. So if the user choses a particular attention backend we set the other two to 0. If the user choses local, we set all 3 TE env variables to 0.
"""
def check_and_set_env_variable(
env_variable_name: str, expected_value: int, attn_type: AttnBackend
) -> None:
current_value = os.getenv(env_variable_name)
assert current_value is None or current_value == str(
expected_value
), f'{env_variable_name} set to {current_value}, but expected {expected_value} for attention backend type {attn_type.name}. unset NVTE_FLASH_ATTN, NVTE_FUSED_ATTN and NVTE_UNFUSED_ATTN. Use the --attention-backend argument if you want to choose between (flash/fused/unfused/auto/local). Default is auto.'
os.environ[env_variable_name] = str(expected_value)
if self.config.attention_backend == AttnBackend.local:
check_and_set_env_variable("NVTE_FLASH_ATTN", 0, AttnBackend.flash)
check_and_set_env_variable("NVTE_FUSED_ATTN", 0, AttnBackend.flash)
check_and_set_env_variable("NVTE_UNFUSED_ATTN", 0, AttnBackend.flash)
elif self.config.attention_backend == AttnBackend.flash:
check_and_set_env_variable("NVTE_FLASH_ATTN", 1, AttnBackend.flash)
check_and_set_env_variable("NVTE_FUSED_ATTN", 0, AttnBackend.flash)
check_and_set_env_variable("NVTE_UNFUSED_ATTN", 0, AttnBackend.flash)
elif self.config.attention_backend == AttnBackend.fused:
check_and_set_env_variable("NVTE_FLASH_ATTN", 0, AttnBackend.fused)
check_and_set_env_variable("NVTE_FUSED_ATTN", 1, AttnBackend.fused)
check_and_set_env_variable("NVTE_UNFUSED_ATTN", 0, AttnBackend.fused)
elif self.config.attention_backend == AttnBackend.unfused:
check_and_set_env_variable("NVTE_FLASH_ATTN", 0, AttnBackend.unfused)
check_and_set_env_variable("NVTE_FUSED_ATTN", 0, AttnBackend.unfused)
check_and_set_env_variable("NVTE_UNFUSED_ATTN", 1, AttnBackend.unfused)
elif self.config.attention_backend == AttnBackend.auto:
check_and_set_env_variable("NVTE_FLASH_ATTN", 1, AttnBackend.auto)
check_and_set_env_variable("NVTE_FUSED_ATTN", 1, AttnBackend.auto)
check_and_set_env_variable("NVTE_UNFUSED_ATTN", 1, AttnBackend.auto)
def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor:
"""Computes the language model loss (Cross entropy across vocabulary)
Args:
labels (Tensor): The labels of dimension [batch size, seq length]
logits (Tensor): The final logits returned by the output layer of the transformer model
Returns:
Tensor: Loss tensor of dimensions [batch size, sequence_length]
"""
# [b s] => [s b]
labels = labels.transpose(0, 1).contiguous()
if self.config.cross_entropy_loss_fusion:
loss = fused_vocab_parallel_cross_entropy(logits, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(logits, labels)
# [s b] => [b, s]
loss = loss.transpose(0, 1).contiguous()
return loss
def setup_embeddings_and_output_layer(self) -> None:
"""Sets up embedding layer in first stage and output layer in last stage.
This function initalizes word embeddings in the final stage when we are
using pipeline parallelism and sharing word embeddings, and sets up param
attributes on the embedding and output layers.
"""
# Set `is_embedding_or_output_parameter` attribute.
if self.pre_process:
self.embedding.word_embeddings.weight.is_embedding_or_output_parameter = True
if self.post_process and self.output_layer.weight is not None:
self.output_layer.weight.is_embedding_or_output_parameter = True
if not self.share_embeddings_and_output_weights:
return
if parallel_state.get_pipeline_model_parallel_world_size() == 1:
# Zero out wgrad if sharing embeddings between two layers on same
# pipeline stage to make sure grad accumulation into main_grad is
# correct and does not include garbage values (e.g., from torch.empty).
self.shared_embedding_or_output_weight().zero_out_wgrad = True
return
if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process:
self.shared_embedding_or_output_weight().shared_embedding = True
if self.post_process and not self.pre_process:
assert not parallel_state.is_pipeline_first_stage()
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.output_layer.weight.data.fill_(0)
self.output_layer.weight.shared = True
self.output_layer.weight.shared_embedding = True
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if parallel_state.is_rank_in_embedding_group():
weight = self.shared_embedding_or_output_weight()
weight.data = weight.data.cuda()
torch.distributed.all_reduce(
weight.data, group=parallel_state.get_embedding_group()
)
elif not getattr(LanguageModule, "embedding_warning_printed", False):
logging.getLogger(__name__).warning(
"Distributed processes aren't initialized, so the output layer "
"is not initialized with weights from the word embeddings. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
LanguageModule.embedding_warning_printed = True
def shared_embedding_or_output_weight(self) -> Tensor:
"""Gets the emedding weight or output logit weights when share embedding and output weights set to True.
Returns:
Tensor: During pre processing it returns the input embeddings weight while during post processing it returns the final output layers weight
"""
if self.pre_process:
return self.embedding.word_embeddings.weight
elif self.post_process:
return self.output_layer.weight
return None
def sharded_state_dict(
self,
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
) -> ShardedStateDict:
"""Sharded state dict implementation that handles the output layer weights tying.
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the LanguageModel
"""
assert not sharded_offsets, "Unexpected sharded offsets"
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
first_stage_word_emb_key = f'{prefix}embedding.word_embeddings.weight'
output_layer_weight_key = f'{prefix}output_layer.weight'
output_layer_bias_key = f'{prefix}output_layer.bias'
if self.share_embeddings_and_output_weights:
self.tie_embeddings_and_output_weights_state_dict(
sharded_state_dict, output_layer_weight_key, first_stage_word_emb_key
)
elif self.post_process:
# Make sure the output layer follows the embeddings padding logic
sharded_state_dict[output_layer_weight_key].allow_shape_mismatch = True
# Regardless of sharing the output weights with embeddings, we must handle the bias padding
if self.post_process and output_layer_bias_key in sharded_state_dict:
sharded_state_dict[output_layer_bias_key].allow_shape_mismatch = True
return sharded_state_dict
def tie_embeddings_and_output_weights_state_dict(
self,
sharded_state_dict: ShardedStateDict,
output_layer_weight_key: str,
first_stage_word_emb_key: str,
) -> None:
"""Ties the embedding and output weights in a given sharded state dict.
Args:
sharded_state_dict (ShardedStateDict): state dict with the weight to tie
output_layer_weight_key (str): key of the output layer weight in the state dict.
This entry will be replaced with a tied version
first_stage_word_emb_key (str): this must be the same as the
ShardedTensor.key of the first stage word embeddings.
Returns: None, acts in-place
"""
if not self.post_process:
# No output layer
assert output_layer_weight_key not in sharded_state_dict, sharded_state_dict.keys()
return
if self.pre_process:
# Output layer is equivalent to the embedding already
return
# Replace the default output layer with a one sharing the weights with the embedding
del sharded_state_dict[output_layer_weight_key]
tensor = self.shared_embedding_or_output_weight()
last_stage_word_emb_replica_id = (
1, # copy of first stage embedding
0,
parallel_state.get_data_parallel_rank(with_context_parallel=True),
)
sharded_state_dict[output_layer_weight_key] = make_tp_sharded_tensor_for_checkpoint(
tensor=tensor,
key=first_stage_word_emb_key,
replica_id=last_stage_word_emb_replica_id,
allow_shape_mismatch=True,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Megatron Vision Module."""
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
# Note: This is only a stub at the moment. This will be expanded in follow-up changes.
class VisionModule(MegatronModule):
"""Base vision module that has common helper functions used across CLIP, ViT, etc.
Args:
config (TransformerConfig): Input transformer config for the model
"""
def __init__(self, config: TransformerConfig) -> None:
super().__init__(config=config)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from .gpt_model import GPTModel
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import warnings
from typing import Optional
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.multi_latent_attention import (
MLASelfAttention,
MLASelfAttentionSubmodules,
)
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import (
TransformerBlockSubmodules,
get_num_layers_to_build,
)
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from megatron.core.utils import is_te_min_version
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelLinear,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
def get_gpt_layer_with_transformer_engine_spec(
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with TE modules
"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
mlp = _get_mlp_module_spec(
use_te=True,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
if multi_latent_attention:
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=MLASelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=MLASelfAttentionSubmodules(
linear_q_proj=TEColumnParallelLinear,
linear_q_down_proj=TEColumnParallelLinear,
linear_q_up_proj=TEColumnParallelLinear,
linear_kv_down_proj=TEColumnParallelLinear,
linear_kv_up_proj=TEColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=TENorm if qk_layernorm else IdentityOp,
kv_layernorm=TENorm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
else:
# TENorm significantly harms convergence when used
# for QKLayerNorm if TE Version < 1.9;
# we instead use the Apex implementation.
qk_norm = TENorm if is_te_min_version("1.9.0") else FusedLayerNorm
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=qk_norm if qk_layernorm else IdentityOp,
k_layernorm=qk_norm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
def get_gpt_layer_local_spec(
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Use this spec for an implementation using only modules in Megatron-Core.
Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with Megatron-Core modules
"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "get_gpt_layer_local_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
mlp = _get_mlp_module_spec(
use_te=False,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
if multi_latent_attention:
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=MLASelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=MLASelfAttentionSubmodules(
linear_q_proj=ColumnParallelLinear,
linear_q_down_proj=ColumnParallelLinear,
linear_q_up_proj=ColumnParallelLinear,
linear_kv_down_proj=ColumnParallelLinear,
linear_kv_up_proj=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=LNImpl if qk_layernorm else IdentityOp,
kv_layernorm=LNImpl if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
else:
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=LNImpl if qk_layernorm else IdentityOp,
k_layernorm=LNImpl if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
def _get_mlp_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Helper function to get module spec for MLP/MoE"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "_get_mlp_module_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
if num_experts is None:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
)
else:
# Mixture of experts with modules in megatron core.
return get_moe_module_spec(
use_te=use_te,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
def get_gpt_decoder_block_spec(
config: TransformerConfig, use_transformer_engine: bool
) -> TransformerBlockSubmodules:
"""GPT block spec."""
if use_transformer_engine:
layer_norm_impl = TENorm
else:
layer_norm_impl = LNImpl
# Layer specs.
dense_layer_spec = (
get_gpt_layer_with_transformer_engine_spec(
num_experts=None,
moe_grouped_gemm=False,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
if use_transformer_engine
else get_gpt_layer_local_spec(
num_experts=None,
moe_grouped_gemm=False,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
)
moe_layer_spec = (
get_gpt_layer_with_transformer_engine_spec(
num_experts=config.num_moe_experts,
moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
if use_transformer_engine
else get_gpt_layer_local_spec(
num_experts=config.num_moe_experts,
moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
)
# Parse config.moe_layer_freq to determine the pattern of expert/dense layers.
# 0 stands for dense layers, 1 stands for expert layers.
# For integer N: Creates a pattern with one expert layer every N layers.
# For string pattern: Evaluates the str directly (e.g. "[1,0,1]" for alternating expert/dense).
if isinstance(config.moe_layer_freq, int):
moe_layer_pattern = [
1 if (i % config.moe_layer_freq == 0) else 0 for i in range(config.num_layers)
]
elif isinstance(config.moe_layer_freq, list):
moe_layer_pattern = config.moe_layer_freq
assert len(moe_layer_pattern) == config.num_layers, (
f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, "
f"expected {config.num_layers}, "
f"current moe layer pattern: {config.moe_layer_freq}"
)
else:
raise ValueError(
f"Invalid moe_layer_freq: {type(config.moe_layer_freq)}, {config.moe_layer_freq}"
)
# Create the layer specs for the model.
layer_specs = []
for layer_number in range(config.num_layers):
if moe_layer_pattern[layer_number] == 1:
layer_specs.append(moe_layer_spec)
elif moe_layer_pattern[layer_number] == 0:
layer_specs.append(dense_layer_spec)
else:
raise ValueError(f"Invalid layer pattern: {moe_layer_pattern}")
# Slice the layer specs to only include the layers that are built in this pipeline stage.
# Note: MCore layer_number starts at 1
offset = TransformerLayer._get_layer_offset(config)
num_layers_to_build = get_num_layers_to_build(config)
layer_specs = layer_specs[offset : offset + num_layers_to_build]
# Block spec.
block_spec = TransformerBlockSubmodules(layer_specs=layer_specs, layer_norm=layer_norm_impl)
return block_spec
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from collections import OrderedDict
from typing import Dict, Literal, Optional
from torch import Tensor
from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
class GPTModel(LanguageModule):
"""GPT Transformer language model.
Args:
config (TransformerConfig):
Transformer config
transformer_layer_spec (ModuleSpec):
Specifies module to use for transformer layers
vocab_size (int):
Vocabulary size
max_sequence_length (int):
maximum size of sequence. This is used for positional embedding
pre_process (bool, optional):
Include embedding layer (used with pipeline parallelism). Defaults to True.
post_process (bool, optional):
Include an output layer (used with pipeline parallelism). Defaults to True.
fp16_lm_cross_entropy (bool, optional):
Defaults to False.
parallel_output (bool, optional):
Do not gather the outputs, keep them split across tensor
parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional):
When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope], optional):
Position embedding type.. Defaults to 'learned_absolute'.
rotary_percent (float, optional):
Percent of rotary dimension to use for rotary position embeddings.
Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
rotary_base (int, optional):
Base period for rotary position embeddings. Ignored unless
position_embedding_type is 'rope'.
Defaults to 10000.
scatter_embedding_sequence_parallel (bool, optional):
Whether embeddings should be scattered across sequence parallel
region or not. Defaults to True.
seq_len_interpolation_factor (Optional[float], optional):
scale of linearly interpolating RoPE for longer sequences.
The value must be a float larger than 1.0. Defaults to None.
"""
def __init__(
self,
config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
rope_scaling: bool = False,
scatter_embedding_sequence_parallel: bool = True,
seq_len_interpolation_factor: Optional[float] = None,
) -> None:
super().__init__(config=config)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self.max_position_embeddings = max_sequence_length
self.rotary_percent = rotary_percent
self.rotary_base = rotary_base
self.rotary_scaling = rope_scaling
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
scatter_to_sequence_parallel=scatter_embedding_sequence_parallel,
)
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
rope_scaling=rope_scaling,
use_cpu_initialization=self.config.use_cpu_initialization,
)
# Transformer.
self.decoder = TransformerBlock(
config=self.config,
spec=transformer_layer_spec,
pre_process=self.pre_process,
post_process=self.post_process,
)
# Output
if post_process:
if self.config.defer_embedding_wgrad_compute:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
if has_config_logger_enabled(self.config):
log_config_to_disk(
self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt'
)
def set_input_tensor(self, input_tensor: Tensor) -> None:
"""Sets input tensor to the model.
See megatron.model.transformer.set_input_tensor()
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
self.decoder.set_input_tensor(input_tensor[0])
def forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if decoder_input is not None:
pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
rotary_pos_cos = None
rotary_pos_sin = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if not self.training and self.config.flash_decode:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb.get_cos_sin(
inference_params.max_sequence_length
)
else:
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config, packed_seq_params
)
rotary_pos_emb = self.rotary_pos_emb(
rotary_seq_len,
packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd',
)
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
**(extra_block_kwargs or {}),
)
if not self.post_process:
return hidden_states
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
)
if has_config_logger_enabled(self.config):
payload = OrderedDict(
{
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attention_mask,
'decoder_input': decoder_input,
'logits': logits,
}
)
log_config_to_disk(self.config, payload, prefix='input_and_logits')
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
return loss
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None
) -> ShardedStateDict:
"""Sharded state dict implementation for GPTModel backward-compatibility
(removing extra state).
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the GPTModel
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
output_layer_extra_state_key = f'{prefix}output_layer._extra_state'
# Old GPT checkpoints only stored the output layer weight key. So we remove the
# _extra_state key but check that it doesn't contain any data anyway
output_extra_state = sharded_state_dict.pop(output_layer_extra_state_key, None)
assert not (
output_extra_state and output_extra_state.data
), f'Expected output layer extra state to be empty, got: {output_extra_state}'
return sharded_state_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment