Commit 160bf237 authored by wangxj's avatar wangxj
Browse files

更新0.12

parent b01809dd
Pipeline #2448 failed with stages
File mode changed from 100755 to 100644
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
import logging
import math
from typing import Callable
import torch
from torch import Tensor, nn
from megatron.core.inference_params import InferenceParams
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
logger = logging.getLogger(__name__)
__all__ = ['RelativePositionEmbedding']
class RelativePositionEmbedding(nn.Module):
"""Relative Position Embedding for language model.
Args:
"""
def __init__(
self,
bidirectional: bool,
init_method: Callable,
num_attention_heads: int,
relative_attention_num_buckets: int = 32,
relative_attention_max_distance: int = 128,
) -> None:
super().__init__()
self.bidirectional = bidirectional
self.relative_attention_num_buckets = relative_attention_num_buckets
self.relative_attention_max_distance = relative_attention_max_distance
self.relative_attention_bias = torch.nn.Embedding(
self.relative_attention_num_buckets, num_attention_heads
)
init_method(self.relative_attention_bias.weight)
def _relative_position_bucket(
self, relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
"""
Adapted from HuggingFace T5 Model:
https://github.com/huggingface/transformers/blob/329f5dbf97a5cb2473914c88c05aa3dcb242e19a/
src/transformers/models/t5/modeling_t5.py#L397
Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e. the
distance in tokens from the attending position to the attended-to position.
If bidirectional=False, then positive relative positions are invalid. We use
smaller buckets for small absolute relative_position and larger buckets for
larger absolute relative_positions. All relative positions >=max_distance map
to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the
model has been trained on.
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position,
containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger
# bins in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def _compute_bias(self, query_length, key_length):
"""
Adapted from HuggingFace T5 Model
https://github.com/huggingface/transformers/blob/329f5dbf97a5cb2473914c88c05aa3dcb242e19a/
src/transformers/models/t5/modeling_t5.py#L444C9-L444C21
Compute binned relative position bias
Args:
query_length (int): The length of the query sequence
(e.g., the input sequence in attention).
key_length (int): The length of the key sequence
(e.g., the sequence to compare against in attention).
Returns:
torch.Tensor: A tensor representing the relative position bias, with shape
(1, num_heads, query_length, key_length).
"""
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape(query_length,key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=self.bidirectional,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(
relative_position_bucket
) # shape(query_length,key_length,num_heads)
values = values.permute([2, 0, 1]).unsqueeze(
0
) # shape(1, num_heads,query_length,key_length)
return values
@staticmethod
def get_relative_seq_len(
inference_params: InferenceParams,
transformer: TransformerBlock,
transformer_input: Tensor,
transformer_config: TransformerConfig,
) -> float:
"""Function to get the rotary sequence length.
Args:
inference_params : Used during Inference time
transformer (TransformerBlock): The transformer block (decoder/encoder) used
by the model
transformer_input (Tensor): Input tensor to the transformer
transformer_config (TransformerConfig): Transformer config used by the model
Returns:
float: The rotary sequence length
"""
if inference_params is not None:
relative_seq_len = inference_params.max_sequence_length
else:
if transformer.input_tensor is not None:
relative_seq_len = transformer.input_tensor.size(0)
else:
relative_seq_len = transformer_input.size(0)
if transformer_config.sequence_parallel:
relative_seq_len *= transformer_config.tensor_model_parallel_size
return relative_seq_len
def forward(self, query_seq_length, key_seq_length):
"""
Args:
Returns:
"""
return self._compute_bias(query_seq_length, key_seq_length)
......@@ -17,23 +17,24 @@ from megatron.core.utils import is_te_min_version
logger = logging.getLogger(__name__)
# Prefer fused RoPE from Apex as we need the `transpose_output_memory` argument for the bshd trick.
# See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2469.
try:
from megatron.core.extensions.transformer_engine import (
fused_apply_rotary_pos_emb,
fused_apply_rotary_pos_emb_thd,
)
HAVE_APPLY_ROPE_FUSION = True
from apex.transformer.functional import fused_apply_rotary_pos_emb
except ImportError:
try:
from apex.transformer.functional import (
fused_apply_rotary_pos_emb,
fused_apply_rotary_pos_emb_thd,
)
from megatron.core.extensions.transformer_engine import fused_apply_rotary_pos_emb
except:
fused_apply_rotary_pos_emb = None
HAVE_APPLY_ROPE_FUSION = True
try:
from megatron.core.extensions.transformer_engine import fused_apply_rotary_pos_emb_thd
except ImportError:
try:
from apex.transformer.functional import fused_apply_rotary_pos_emb_thd
except ImportError:
HAVE_APPLY_ROPE_FUSION = False
fused_apply_rotary_pos_emb_thd = None
try:
......@@ -188,8 +189,10 @@ def apply_rotary_pos_emb(
if config.apply_rope_fusion:
if cu_seqlens is None:
return fused_apply_rotary_pos_emb(t, freqs)
assert fused_apply_rotary_pos_emb is not None, "apply_rope_fusion is not available."
return fused_apply_rotary_pos_emb(t, freqs, transpose_output_memory=True)
else:
assert fused_apply_rotary_pos_emb_thd is not None, "apply_rope_fusion is not available."
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
if not is_te_min_version("1.11.0", check_equality=False):
......
......@@ -46,7 +46,8 @@ class RotaryEmbedding(nn.Module):
for longer sequences. The value must be a float larger than 1.0. Defaults to None
rotary_base (int, optional): Base period for rotary position embeddings. Defaults to
10000.
rope_scaling (bool, optional): Apply rope scaling as used in llama 3.1
rope_scaling (bool, optional): Apply rope scaling as used in llama 3.x.
rope_scaling_factor (float, optional): rope scaling factor in llama 3.x. Defaults to 8.
use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly
on the GPU. Defaults to False
"""
......@@ -59,6 +60,7 @@ class RotaryEmbedding(nn.Module):
seq_len_interpolation_factor: float = None,
rotary_base: int = 10000,
rope_scaling: bool = False,
rope_scaling_factor: float = 8.0,
use_cpu_initialization: bool = False,
) -> None:
super().__init__()
......@@ -75,7 +77,7 @@ class RotaryEmbedding(nn.Module):
)
if rope_scaling:
self.inv_freq = self._apply_scaling(self.inv_freq)
self.inv_freq = self._apply_scaling(self.inv_freq, factor=rope_scaling_factor)
def _apply_scaling(
self,
......@@ -200,7 +202,7 @@ class RotaryEmbedding(nn.Module):
elif inference_params is not None:
rotary_seq_len = inference_params.max_sequence_length
else:
if transformer.input_tensor is not None:
if transformer is not None and transformer.input_tensor is not None:
rotary_seq_len = transformer.input_tensor.size(0)
else:
rotary_seq_len = transformer_input.size(0)
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import warnings
from typing import Optional
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
from megatron.core.transformer.multi_latent_attention import (
MLASelfAttention,
MLASelfAttentionSubmodules,
......@@ -21,17 +21,19 @@ from megatron.core.transformer.transformer_block import (
get_num_layers_to_build,
)
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from megatron.core.transformer.transformer_layer import (
TransformerLayer,
TransformerLayerSubmodules,
get_transformer_layer_offset,
)
from megatron.core.utils import is_te_min_version
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelGroupedLinear,
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelGroupedLinear,
TERowParallelLinear,
)
......@@ -47,8 +49,6 @@ try:
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn('Apex is not installed. Falling back to Torch Norm')
......@@ -60,7 +60,8 @@ def get_gpt_layer_with_transformer_engine_spec(
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
......@@ -69,13 +70,24 @@ def get_gpt_layer_with_transformer_engine_spec(
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Flag to decide the linear layer spec for MoE. Defaults to None.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with TE modules
"""
mlp = _get_mlp_module_spec(
use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=fp8
if fp8 is not None:
warnings.warn(
'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
mlp = get_mlp_module_spec(
use_te=True,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
if multi_latent_attention:
......@@ -89,13 +101,21 @@ def get_gpt_layer_with_transformer_engine_spec(
submodules=MLASelfAttentionSubmodules(
linear_q_proj=TEColumnParallelLinear,
linear_q_down_proj=TEColumnParallelLinear,
linear_q_up_proj=TEColumnParallelLinear,
linear_q_up_proj=(
TELayerNormColumnParallelLinear
if qk_layernorm
else TEColumnParallelLinear
),
linear_kv_down_proj=TEColumnParallelLinear,
linear_kv_up_proj=TEColumnParallelLinear,
linear_kv_up_proj=(
TELayerNormColumnParallelLinear
if qk_layernorm
else TEColumnParallelLinear
),
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=TENorm if qk_layernorm else IdentityOp,
kv_layernorm=TENorm if qk_layernorm else IdentityOp,
q_layernorm=IdentityOp,
kv_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
......@@ -138,6 +158,8 @@ def get_gpt_layer_local_spec(
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Use this spec for an implementation using only modules in Megatron-Core.
......@@ -146,13 +168,24 @@ def get_gpt_layer_local_spec(
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with Megatron-Core modules
"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "get_gpt_layer_local_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
mlp = _get_mlp_module_spec(
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
mlp = get_mlp_module_spec(
use_te=False,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
if multi_latent_attention:
......@@ -213,63 +246,54 @@ def _get_mlp_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None,
) -> ModuleSpec:
"""Helper function to get module spec for MLP"""
if num_experts is not None:
moe_spec = _get_moe_module_spec(
use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=fp8
)
return moe_spec
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
):
warnings.warn(
"""This private function is on a deprecation track. Please switch to `get_mlp_module_spec`
since it will be removed in a future release."""
)
return get_mlp_module_spec(
use_te=use_te,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
fp8=fp8,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
def _get_moe_module_spec(
def get_mlp_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Helper function to get module spec for MoE"""
"""Helper function to get module spec for MLP/MoE"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "_get_mlp_module_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
if num_experts is None:
return None
if use_te and moe_grouped_gemm:
linear_fc1 = TEColumnParallelGroupedLinear
linear_fc2 = TERowParallelGroupedLinear
elif use_te and fp8:
linear_fc1 = TEColumnParallelLinear
linear_fc2 = TERowParallelLinear
else:
linear_fc1 = ColumnParallelLinear
linear_fc2 = RowParallelLinear
use_te_grouped_gemm = use_te and TEColumnParallelGroupedLinear is not None
return ModuleSpec(
module=MoELayer,
submodules=MoESubmodules(
experts=(
MLPSubmodules(linear_fc1=linear_fc1, linear_fc2=linear_fc2)
if not moe_grouped_gemm or use_te_grouped_gemm
else None
),
shared_experts=ModuleSpec(
module=SharedExpertMLP,
params={"gate": False},
submodules=MLPSubmodules(
linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
),
)
)
else:
# Mixture of experts with modules in megatron core.
return get_moe_module_spec(
use_te=use_te,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
def get_gpt_decoder_block_spec(
......@@ -288,7 +312,7 @@ def get_gpt_decoder_block_spec(
moe_grouped_gemm=False,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
fp8=config.fp8,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
if use_transformer_engine
else get_gpt_layer_local_spec(
......@@ -296,6 +320,7 @@ def get_gpt_decoder_block_spec(
moe_grouped_gemm=False,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
)
moe_layer_spec = (
......@@ -304,7 +329,7 @@ def get_gpt_decoder_block_spec(
moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
fp8=config.fp8,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
if use_transformer_engine
else get_gpt_layer_local_spec(
......@@ -312,6 +337,7 @@ def get_gpt_decoder_block_spec(
moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
)
......@@ -347,7 +373,7 @@ def get_gpt_decoder_block_spec(
# Slice the layer specs to only include the layers that are built in this pipeline stage.
# Note: MCore layer_number starts at 1
offset = TransformerLayer._get_layer_offset(config)
offset = get_transformer_layer_offset(config)
num_layers_to_build = get_num_layers_to_build(config)
layer_specs = layer_specs[offset : offset + num_layers_to_build]
......
......@@ -3,6 +3,7 @@
from collections import OrderedDict
from typing import Dict, Literal, Optional
import torch
from torch import Tensor
from megatron.core import InferenceParams, tensor_parallel
......@@ -50,6 +51,8 @@ class GPTModel(LanguageModule):
Base period for rotary position embeddings. Ignored unless
position_embedding_type is 'rope'.
Defaults to 10000.
rope_scaling (bool, optional): Toggle RoPE scaling.
rope_scaling_factor (float): RoPE scaling factor. Default 8.
scatter_embedding_sequence_parallel (bool, optional):
Whether embeddings should be scattered across sequence parallel
region or not. Defaults to True.
......@@ -73,6 +76,7 @@ class GPTModel(LanguageModule):
rotary_percent: float = 1.0,
rotary_base: int = 10000,
rope_scaling: bool = False,
rope_scaling_factor: float = 8.0,
scatter_embedding_sequence_parallel: bool = True,
seq_len_interpolation_factor: Optional[float] = None,
) -> None:
......@@ -118,9 +122,13 @@ class GPTModel(LanguageModule):
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
rope_scaling=rope_scaling,
rope_scaling_factor=rope_scaling_factor,
use_cpu_initialization=self.config.use_cpu_initialization,
)
# Cache for RoPE tensors which do not change between iterations.
self.rotary_pos_emb_cache = {}
# Transformer.
self.decoder = TransformerBlock(
config=self.config,
......@@ -224,10 +232,11 @@ class GPTModel(LanguageModule):
rotary_pos_cos = None
rotary_pos_sin = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if not self.training and self.config.flash_decode:
if not self.training and self.config.flash_decode and inference_params:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb.get_cos_sin(
inference_params.max_sequence_length
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
inference_params.max_sequence_length,
self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length),
)
else:
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
......@@ -238,6 +247,18 @@ class GPTModel(LanguageModule):
packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd',
)
if (
(self.config.enable_cuda_graph or self.config.flash_decode)
and rotary_pos_cos is not None
and inference_params
):
sequence_len_offset = torch.tensor(
[inference_params.sequence_len_offset] * inference_params.current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None
# Run decoder.
hidden_states = self.decoder(
......@@ -248,6 +269,7 @@ class GPTModel(LanguageModule):
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**(extra_block_kwargs or {}),
)
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import warnings
from typing import Optional
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.mlp import MLPSubmodules
from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP
from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.utils import get_te_version, is_te_min_version
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelGroupedLinear,
TEColumnParallelLinear,
TERowParallelGroupedLinear,
TERowParallelLinear,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
def get_moe_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Helper function to get module spec for MoE"""
assert num_experts is not None
mlp = MLPSubmodules(
linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
)
# experts spec
if moe_grouped_gemm:
## use GroupedMLP
if use_te and TEColumnParallelGroupedLinear is not None and not moe_use_legacy_grouped_gemm:
## use TEGroupedLinear
expert_module = TEGroupedMLP
expert_submodule = MLPSubmodules(
linear_fc1=TEColumnParallelGroupedLinear, linear_fc2=TERowParallelGroupedLinear
)
else:
## use legacy GroupedMLP
expert_module = GroupedMLP
expert_submodule = None
warnings.warn(
'The legacy GroupedMLP will be deprecated in Megatron-Core v0.12.0. '
'Please update the TransformerEngine to version>=1.7.0 and use TEGroupedMLP.'
)
else:
## use SequentialMLP
expert_module = SequentialMLP
if use_te and not is_te_min_version("1.7.0.dev0"):
warnings.warn(
"Only transformer-engine>=1.7.0 supports MoE experts, "
f"but your version is {get_te_version()}. Use local linear implementation instead."
)
expert_submodule = MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear
)
else:
expert_submodule = mlp
experts = ModuleSpec(module=expert_module, submodules=expert_submodule)
# shared experts spec
shared_experts = ModuleSpec(module=SharedExpertMLP, params={"gate": False}, submodules=mlp)
# MoE module spec
moe_module_spec = ModuleSpec(
module=MoELayer, submodules=MoESubmodules(experts=experts, shared_experts=shared_experts)
)
return moe_module_spec
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from .module import HuggingFaceModule, build_hf_model
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from transformers import AutoModel
from megatron.core.models.huggingface import HuggingFaceModule
class ClipHuggingFaceModel(HuggingFaceModule):
"""
Wrapper for CLIP HuggingFace models
"""
def __init__(self, config):
super().__init__(config)
self.model = AutoModel.from_pretrained(config.huggingface_model_name_or_path)
def forward(self, *args, **kwargs):
"""Forward function"""
x = self.model(*args, **kwargs)
x = x['last_hidden_state']
return x
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from transformers import AutoConfig, AutoModel
from megatron.core.transformer.module import MegatronModule
class HuggingFaceModule(MegatronModule):
"""
Basic module for huggingface
"""
def __init__(self, config):
super().__init__(config=config)
def set_input_tensor(self, input_tensor):
"""Dummy function for set_input_tensor"""
self.input_tensor = input_tensor
class AutoHuggingFaceModel(HuggingFaceModule):
"""
Wrapper for HuggingFace AutoModel
"""
def __init__(self, config):
super().__init__(config)
self.model = AutoModel.from_pretrained(config.huggingface_model_name_or_path)
def forward(self, *args, **kwargs):
"""Forward function"""
return self.model(*args, **kwargs)
def build_hf_model(config):
"""Builds huggingface wrapper model given config"""
hf_config = AutoConfig.from_pretrained(config.huggingface_model_name_or_path)
if "qwen" in hf_config.model_type:
from megatron.core.models.huggingface.qwen_model import QwenHuggingFaceModel
model = QwenHuggingFaceModel(config)
elif "vit" in hf_config.model_type:
from megatron.core.models.huggingface.clip_model import ClipHuggingFaceModel
model = ClipHuggingFaceModel(config)
else:
raise NotImplementedError(f"Huggingface model type {hf_config.model_type} is not supported")
return model
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from transformers.models.qwen2 import Qwen2ForCausalLM
from megatron.core.models.huggingface import HuggingFaceModule
class QwenHuggingFaceModel(HuggingFaceModule):
"""
Wrapper for Qwen LM HuggingFace models
"""
def __init__(self, config):
super().__init__(config)
self.model = Qwen2ForCausalLM.from_pretrained(config.huggingface_model_name_or_path)
def forward(self, *args, **kwargs):
"""Forward function"""
combined_embeddings = kwargs['decoder_input'].permute(1, 0, 2)
x = self.model(
position_ids=None, # TODO: I guess we're just assuming no custom pos ids
attention_mask=kwargs['attention_mask'],
inputs_embeds=combined_embeddings,
labels=kwargs['labels'],
)
if kwargs['labels'] is not None:
x = x["loss"]
else:
x = x["logits"]
return x
def embedding(self, input_ids, position_ids=None):
"""Function to run process tokens with input embeddings"""
return self.model.get_input_embeddings()(input_ids).transpose(1, 0).contiguous()
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment