Commit 7c19b3a8 authored by wangsen's avatar wangsen
Browse files

Initial commit

parents
Pipeline #1721 failed with stages
in 0 seconds
import logging
from typing import Optional, Tuple
import torch
from torch import Tensor
from megatron.core import parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint
class LanguageModule(MegatronModule):
"""Base language module that has common helper functions used across GPT, BERT etc.
Args:
config (TransformerConfig): Input transformer config for the model
"""
def __init__(self, config: TransformerConfig) -> None:
super().__init__(config=config)
def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor:
"""Computes the language model loss (Cross entropy across vocabulary)
Args:
labels (Tensor): The labels of dimension [batch size, seq length]
logits (Tensor): The final logits returned by the output layer of the transformer model
Returns:
Tensor: Loss tensor of dimensions [batch size, sequence_length]
"""
# [b s] => [s b]
labels = labels.transpose(0, 1).contiguous()
if self.config.cross_entropy_loss_fusion:
loss = fused_vocab_parallel_cross_entropy(logits, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(logits, labels)
# [s b] => [b, s]
loss = loss.transpose(0, 1).contiguous()
return loss
def setup_embeddings_and_output_layer(self) -> None:
"""Sets up embedding layer in first stage and output layer in last stage.
This function initalizes word embeddings in the final stage when we are
using pipeline parallelism and sharing word embeddings, and sets up param
attributes on the embedding and output layers.
"""
# Set `is_embedding_or_output_parameter` attribute.
if self.pre_process:
self.embedding.word_embeddings.weight.is_embedding_or_output_parameter = True
if self.post_process and self.output_layer.weight is not None:
self.output_layer.weight.is_embedding_or_output_parameter = True
if not self.share_embeddings_and_output_weights:
return
if self.pre_process and self.post_process:
# Zero out wgrad if sharing embeddings between two layers on same
# pipeline stage to make sure grad accumulation into main_grad is
# correct and does not include garbage values (e.g., from torch.empty).
self.shared_embedding_or_output_weight().zero_out_wgrad = True
return
if self.pre_process and not self.post_process:
assert parallel_state.is_pipeline_first_stage()
self.shared_embedding_or_output_weight().shared_embedding = True
if self.post_process and not self.pre_process:
assert not parallel_state.is_pipeline_first_stage()
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.output_layer.weight.data.fill_(0)
self.output_layer.weight.shared = True
self.output_layer.weight.shared_embedding = True
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if parallel_state.is_rank_in_embedding_group():
weight = self.shared_embedding_or_output_weight()
weight.data = weight.data.cuda()
torch.distributed.all_reduce(
weight.data, group=parallel_state.get_embedding_group()
)
elif not getattr(LanguageModule, "embedding_warning_printed", False):
logging.getLogger(__name__).warning(
"Distributed processes aren't initialized, so the output layer "
"is not initialized with weights from the word embeddings. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
LanguageModule.embedding_warning_printed = True
def shared_embedding_or_output_weight(self) -> Tensor:
"""Gets the emedding weight or output logit weights when share embedding and output weights set to True.
Returns:
Tensor: During pre processing it returns the input embeddings weight while during post processing it returns the final output layers weight
"""
if self.pre_process:
return self.embedding.word_embeddings.weight
elif self.post_process:
return self.output_layer.weight
return None
def sharded_state_dict(
self,
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
) -> ShardedStateDict:
""" Sharded state dict implementation that handles the output layer weights tying.
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the LanguageModel
"""
assert not sharded_offsets, "Unexpected sharded offsets"
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
first_stage_word_emb_key = f'{prefix}embedding.word_embeddings.weight'
output_layer_weight_key = f'{prefix}output_layer.weight'
output_layer_bias_key = f'{prefix}output_layer.bias'
if self.share_embeddings_and_output_weights:
self.tie_embeddings_and_output_weights_state_dict(
sharded_state_dict, output_layer_weight_key, first_stage_word_emb_key
)
elif self.post_process:
# Make sure the output layer follows the embeddings padding logic
sharded_state_dict[output_layer_weight_key].allow_shape_mismatch = True
# Regardless of sharing the output weights with embeddings, we must handle the bias padding
if self.post_process and output_layer_bias_key in sharded_state_dict:
sharded_state_dict[output_layer_bias_key].allow_shape_mismatch = True
return sharded_state_dict
def tie_embeddings_and_output_weights_state_dict(
self,
sharded_state_dict: ShardedStateDict,
output_layer_weight_key: str,
first_stage_word_emb_key: str,
) -> None:
"""Ties the embedding and output weights in a given sharded state dict.
Args:
sharded_state_dict (ShardedStateDict): state dict with the weight to tie
output_layer_weight_key (str): key of the output layer weight in the state dict.
This entry will be replaced with a tied version
first_stage_word_emb_key (str): this must be the same as the
ShardedTensor.key of the first stage word embeddings.
Returns: None, acts in-place
"""
if not self.post_process:
# No output layer
assert output_layer_weight_key not in sharded_state_dict, sharded_state_dict.keys()
return
if self.pre_process:
# Output layer is equivalent to the embedding already
return
# Replace the default output layer with a one sharing the weights with the embedding
del sharded_state_dict[output_layer_weight_key]
tensor = self.shared_embedding_or_output_weight()
last_stage_word_emb_replica_id = (
1, # copy of first stage embedding
0,
parallel_state.get_data_parallel_rank(with_context_parallel=True),
)
sharded_state_dict[output_layer_weight_key] = make_tp_sharded_tensor_for_checkpoint(
tensor=tensor,
key=first_stage_word_emb_key,
replica_id=last_stage_word_emb_replica_id,
allow_shape_mismatch=True,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Megatron Vision Module."""
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
# Note: This is only a stub at the moment. This will be expanded in follow-up changes.
class VisionModule(MegatronModule):
"""Base vision module that has common helper functions used across CLIP, ViT, etc.
Args:
config (TransformerConfig): Input transformer config for the model
"""
def __init__(self, config: TransformerConfig) -> None:
super().__init__(config=config)
from .gpt_model import GPTModel
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
try:
from megatron.core.transformer.custom_layers.transformer_engine import (
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelLinear,
)
except ImportError:
TEDotProductAttention = None
TELayerNormColumnParallelLinear = None
TENorm = None
TERowParallelLinear = None
#print("Do not support transformer_engine")
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
# Use this spec to use lower level Transformer Engine modules (required for fp8 training)
def get_gpt_layer_with_transformer_engine_spec(
num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
) -> ModuleSpec:
mlp = _get_mlp_module_spec(
use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=TENorm if qk_layernorm else IdentityOp,
k_layernorm=TENorm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
# Use this spec for an implementation using only modules in megatron core
def get_gpt_layer_local_spec(
num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
) -> ModuleSpec:
mlp = _get_mlp_module_spec(
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=FusedLayerNorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=FusedLayerNorm if qk_layernorm else IdentityOp,
k_layernorm=FusedLayerNorm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=FusedLayerNorm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
# Helper function to get module spec for MLP/MoE
def _get_mlp_module_spec(
use_te: bool = True, num_experts: int = None, moe_grouped_gemm: bool = False
) -> ModuleSpec:
if num_experts is None:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
)
else:
# Mixture of experts with modules in megatron core.
return ModuleSpec(
module=MoELayer,
submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear,)
if not moe_grouped_gemm
else None,
)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import logging
from typing import Dict, Literal, Optional, Tuple, Union
import torch
from torch import Tensor
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import AttnMaskType, ModelType
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint
class GPTModel(LanguageModule):
"""GPT Transformer language model.
Args:
config (TransformerConfig): Transformer config
transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers
vocab_size (int): Vocabulary size
max_sequence_length (int): maximum size of sequence. This is used for positional embedding
pre_process (bool, optional): Include embedding layer (used with pipeline parallelism). Defaults to True.
post_process (bool, optional): Include an output layer (used with pipeline parallelism). Defaults to True.
fp16_lm_cross_entropy (bool, optional): Defaults to False.
parallel_output (bool, optional): Do not gather the outputs, keep them split across tensor parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional): When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope], optional): Position embedding type.. Defaults to 'learned_absolute'.
rotary_percent (float, optional): Percent of rotary dimension to use for rotary position embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
rotary_base (int, optional): Base period for rotary position embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 10000.
seq_len_interpolation_factor (Optional[float], optional): scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None.
"""
def __init__(
self,
config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
seq_len_interpolation_factor: Optional[float] = None,
) -> None:
super().__init__(config=config)
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
# These 2 attributes are needed for TensorRT-LLM export.
self.max_position_embeddings = max_sequence_length
self.rotary_percent = rotary_percent
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
)
if self.position_embedding_type == 'rope':
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
)
# Transformer.
self.decoder = TransformerBlock(
config=self.config,
spec=transformer_layer_spec,
pre_process=self.pre_process,
post_process=self.post_process,
)
# Output
if post_process:
if self.config.defer_embedding_wgrad_compute:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs stored
# in gradient buffer to calculate the weight gradients for the embedding final linear layer.
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
def set_input_tensor(self, input_tensor: Tensor) -> None:
"""Sets input tensor to the model.
See megatron.model.transformer.set_input_tensor()
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
self.decoder.set_input_tensor(input_tensor[0])
def forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if decoder_input is not None:
pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
**(extra_block_kwargs or {}),
)
if not self.post_process:
return hidden_states
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
logits, _ = self.output_layer(hidden_states, weight=output_weight)
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
return loss
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None
) -> ShardedStateDict:
""" Sharded state dict implementation for GPTModel backward-compatibility (removing extra state).
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the GPTModel
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
output_layer_extra_state_key = f'{prefix}output_layer._extra_state'
# Old GPT checkpoints only stored the output layer weight key. So we remove the _extra_state key
# but check that it doesn't contain any data anyway
output_extra_state = sharded_state_dict.pop(output_layer_extra_state_key, None)
assert not (
output_extra_state and output_extra_state.data
), f'Expected output layer extra state to be empty, got: {output_extra_state}'
return sharded_state_dict
from .mamba_model import MambaModel
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules
from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules
from megatron.core.ssm.mamba_mixer import Mamba
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.custom_layers.transformer_engine import (
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelLinear,
)
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
mamba_stack_spec = ModuleSpec(
module=MambaStack,
submodules=MambaStackSubmodules(
mamba_layer=ModuleSpec(
module=MambaLayer, submodules=MambaLayerSubmodules(norm=TENorm, mixer=Mamba,),
),
# Started with spec from gpt_layer_specs.py (with MLP removed)
# Using the TE spec because we had problems getting the non-TE spec
# working
attention_layer=ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
),
),
# Started with spec from gpt_layer_specs.py
# Using the TE spec because we had problems getting the non-TE spec
# working
mlp_layer=ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
),
),
),
)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from typing import Literal, Optional
from torch import Tensor
from megatron.core import InferenceParams, tensor_parallel
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
class MambaModel(LanguageModule):
"""Mamba language model.
Args:
config (TransformerConfig): Transformer config
mamba_stack_spec (ModuleSpec): Specifies the modules to use for the various layer types
vocab_size (int): Vocabulary size
max_sequence_length (int): maximum size of sequence. This is used for positional embedding
pre_process (bool, optional): Include embedding layer (used with pipeline parallelism). Defaults to True.
hybrid_attention_ratio (float, optional): The target ratio of attention layers to total layers
hybrid_mlp_ratio (float, optional): The target ratio of mlp layers to total layers
hybrid_override_pattern (str, optional): The hybrid layer pattern to override with
post_process (bool, optional): Include an output layer (used with pipeline parallelism). Defaults to True.
fp16_lm_cross_entropy (bool, optional): Defaults to False.
parallel_output (bool, optional): Do not gather the outputs, keep them split across tensor parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional): When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope,none], optional): Position embedding type. Defaults to 'none'.
rotary_percent (float, optional): Percent of rotary dimension to use for rotary position embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
rotary_base (int, optional): Base period for rotary position embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 10000.
seq_len_interpolation_factor (Optional[float], optional): scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None.
"""
def __init__(
self,
config: TransformerConfig,
mamba_stack_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
hybrid_attention_ratio: float = 0.0,
hybrid_mlp_ratio: float = 0.0,
hybrid_override_pattern: str = None,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
# Mamba with no attention has no need for position embeddings, so none is default
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'none',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
seq_len_interpolation_factor: Optional[float] = None,
) -> None:
super().__init__(config=config)
self.mamba_stack_spec: ModuleSpec = mamba_stack_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.hybrid_attention_ratio = hybrid_attention_ratio
self.hybrid_mlp_ratio = hybrid_mlp_ratio
self.hybrid_override_pattern = hybrid_override_pattern
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
)
if self.position_embedding_type == 'rope':
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
)
self.decoder = build_module(
mamba_stack_spec,
self.config,
pre_process=self.pre_process,
hybrid_attention_ratio=self.hybrid_attention_ratio,
hybrid_mlp_ratio=self.hybrid_mlp_ratio,
hybrid_override_pattern=self.hybrid_override_pattern,
post_process=self.post_process,
dtype=config.params_dtype,
)
# Output
if post_process:
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
def set_input_tensor(self, input_tensor: Tensor) -> None:
"""Sets input tensor to the model.
See megatron.model.transformer.set_input_tensor()
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
self.decoder.set_input_tensor(input_tensor[0])
def forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
) -> Tensor:
"""Forward function of the Mamba model. This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if decoder_input is not None:
pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# The following assert will currently fail when running inference.
# Commented out for now.
# TODO (duncan/rwaleffe): (1) confirm that the externally-generated
# attention mask is not needed and is ignored by the model in
# inference mode, (2) reduce the size of the externally-generated
# attention mask to prevent CPU OOM (as we did for training), (3)
# force the attention mask passed to the model in inference mode to
# be None, so this assert will succeed.
# assert attention_mask is None, "The attention mask is ignored and should be set to None"
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
)
if not self.post_process:
return hidden_states
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
logits, _ = self.output_layer(hidden_states, weight=output_weight)
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
return loss
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
from collections import namedtuple
from functools import partial
from typing import List
import torch
from megatron.core import InferenceParams, parallel_state
from megatron.core.models.gpt import GPTModel
from megatron.core.models.vision.clip_vit_model import CLIPViTModel
from megatron.core.models.vision.multimodal_projector import MultimodalProjector
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
# Note: This is under development and may be missing features.
class LLaVAModel(MegatronModule):
"""LLaVA multi-modal model.
Args:
language_transformer_config (TransformerConfig): Transformer config for the language model.
language_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the language model.
language_vocab_size (int): Language model vocabulary size.
language_max_sequence_length (int): Language model maximum sequence length. This is used for positional embedding.
vision_transformer_config (TransformerConfig): Transformer config for the vision model.
vision_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the vision model.
drop_vision_class_token (bool): Drop vision class token(s) before input to the language model.
vision_projection_config (TransformerConfig): Config for the projection from vision model outputs to language model inputs.
vision_projection_layer_spec (ModuleSpec): Specifies the module to use for the vision projection.
vision_projection_type (str): Type of the vision projection to use. Default is a 2-layer MLP.
allow_missing_vision_projection_checkpoint (bool): Allow vision projection weights to be missing when loading a checkpoint. Default False.
parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks. This is typically True for training and False for inference.
language_position_embedding_type (str): Position embedding type to use in the language model. Default learned absolute.
language_rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings in the language model. Defaults to 1.0.
"""
def __init__(
self,
language_transformer_config: TransformerConfig,
language_transformer_layer_spec: ModuleSpec,
language_vocab_size: int,
language_max_sequence_length: int,
vision_transformer_config: TransformerConfig,
vision_transformer_layer_spec: ModuleSpec,
drop_vision_class_token: bool,
vision_projection_config: TransformerConfig,
vision_projection_layer_spec: ModuleSpec,
vision_projection_type: str = "mlp",
allow_missing_vision_projection_checkpoint: bool = False,
parallel_output: bool = True,
language_position_embedding_type: str = 'learned_absolute',
language_rotary_percent: float = 1.0,
) -> None:
super().__init__(config=language_transformer_config)
logging.getLogger(__name__).warning(
"LLaVA model is under development and may be missing features."
)
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
raise NotImplementedError("pipeline parallelism is not supported in this model yet.")
self.language_model = GPTModel(
config=language_transformer_config,
transformer_layer_spec=language_transformer_layer_spec,
vocab_size=language_vocab_size,
max_sequence_length=language_max_sequence_length,
parallel_output=parallel_output,
position_embedding_type=language_position_embedding_type,
rotary_percent=language_rotary_percent,
)
self.vision_model = CLIPViTModel(vision_transformer_config, vision_transformer_layer_spec)
self._drop_vision_class_token = drop_vision_class_token
# Map (intermediate) vision model outputs to the language model input dimension.
self.vision_projection = MultimodalProjector(
vision_projection_config,
vision_projection_layer_spec,
vision_projection_type,
vision_transformer_config.hidden_size, # input size to the projection.
)
# This allows ignoring missing weights for the vision projection during checkpoint loading.
# This should be disabled by default but can be enabled if your checkpoint contains pretrained
# vision and language models but not the projection from vision model outputs to language model inputs.
if allow_missing_vision_projection_checkpoint:
vision_projection_param_names = [
f"vision_projection.{name}" for name in self.vision_projection.state_dict().keys()
]
self.vision_projection.register_load_state_dict_post_hook(
partial(_load_state_dict_hook_ignore_param_names, vision_projection_param_names)
)
def set_input_tensor(self, input_tensor: torch.Tensor) -> None:
"""Sets input tensor to the model.
NOTE: Pipeline parallelism is not supported in this model yet. This is just a placeholder implementation.
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
self.vision_model.set_input_tensor(input_tensor)
def freeze(
self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool
):
"""Freeze model modules.
Make specific modules non-trainable by setting requires_grad to False for the module's parameters.
Args:
freeze_language_model (bool): Freeze the language model module.
freeze_vision_model (bool): Freeze the vision model module.
freeze_vision_projection (bool): Freeze the vision projection module.
"""
modules = []
if freeze_language_model:
modules.append(self.language_model)
if freeze_vision_model:
modules.append(self.vision_model)
if freeze_vision_projection:
modules.append(self.vision_projection)
for module in modules:
for param in module.parameters():
param.requires_grad = False
def forward(
self,
images: torch.Tensor,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: torch.Tensor = None,
inference_params: InferenceParams = None,
) -> torch.Tensor:
"""Forward function of the LLaVA model.
Args:
images (torch.Tensor): input image of shape [batch, img_h, img_w].
input_ids (torch.Tensor): input text ids [batch, text_seq_len].
position_ids (torch.Tensor): input text position ids [batch, text_seq_len].
attention_mask (torch.Tensor): attention mask for the language model [batch, 1, combined_seq_len, combined_seq_len].
labels (torch.Tensor): Optional target text labels [batch, combined_seq_len].
inference_params (InferenceParams): Inference-time parameters including KV cache.
Returns:
output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size].
"""
language_embeddings = self.language_model.embedding(
input_ids=input_ids, position_ids=position_ids
) # [text_seq_len, b, h_language]
# If running inference, we can skip image token computation if they were computed already earlier for this sample.
if (
inference_params is not None
and "image_tokens_count" in inference_params.key_value_memory_dict
):
combined_embeddings = language_embeddings
else:
image_embeddings = self.vision_model(images) # [b, img_seq_len, h_vision]
if self._drop_vision_class_token:
image_embeddings = image_embeddings[:, self.vision_model.class_token_len :, :]
image_embeddings = image_embeddings.permute(1, 0, 2) # [img_seq_len, b, h_vision]
# map vision model output size to language model input size.
image_embeddings = self.vision_projection(
image_embeddings
) # [img_seq_len, b, h_vision]
# If running inference, the language model KV cache will be updated for image token positions.
# Here we store the image tokens sequence length, which can be used as an offset to the KV cache later.
if inference_params is not None:
inference_params.key_value_memory_dict[
"image_tokens_count"
] = image_embeddings.shape[0]
combined_embeddings = torch.cat(
[image_embeddings, language_embeddings], dim=0
) # [combined_seq_len, b, h_language]
# Embedding is computed above so we can discard input and position ids.
input_ids = None
position_ids = None
# Note: This returns loss if labels are provided, otherwise logits.
output = self.language_model(
input_ids,
position_ids,
attention_mask,
decoder_input=combined_embeddings,
labels=labels,
inference_params=inference_params,
)
return output
def _load_state_dict_hook_ignore_param_names(
param_names: List[str], module: torch.nn.Module, incompatible_keys: namedtuple
):
"""Hook to ignore missing keys during checkpoint loading.
By default, this should not be used to avoid accidentally missing weights in checkpoint loading.
Example use case: Use this for the vision projection if you want to load a checkpoint that contains vision and language model weights
but not the vision projection weights.
Args:
param_names (list of str): Parameter names allowed to be missing when calling load_state_dict.
module (torch.nn.Module): The torch module this hook applies to. Unused here but required by the torch API.
incompatible_keys (namedtuple): Namedtuple with fields missing_keys and unexpected_keys, which collect the missing and unexpected
keys when calling load_state_dict on this torch module, respectively.
"""
for param_name in param_names:
incompatible_keys.missing_keys.remove(param_name)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
Exports:
- RetroConfig: configuration dataclass for RetroModel.
- RetroModel: The Retro model.
- get_retro_decoder_block_spec: Get spec for Retro decoder transformer block.
"""
from .config import RetroConfig
from .decoder_spec import get_retro_decoder_block_spec
from .model import RetroModel
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment