Commit 4b097dee authored by liangjing's avatar liangjing
Browse files

update to core_v0.9

parent 3aca1415
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from typing import Optional
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelGroupedLinear,
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelGroupedLinear,
TERowParallelLinear,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings
from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm
warnings.warn('Apex is not installed. Falling back to Torch LayerNorm')
LNImpl = WrappedTorchLayerNorm
def get_gpt_layer_with_transformer_engine_spec(
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
fp8: Optional[str] = None,
) -> ModuleSpec:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Flag to decide the linear layer spec for MoE. Defaults to None.
Returns:
ModuleSpec: Module specification with TE modules
"""
mlp = _get_mlp_module_spec(
use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=fp8
)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
# TENorm significantly harms convergence when used
# for QKLayerNorm; we instead use the Apex implementation.
q_layernorm=FusedLayerNorm if qk_layernorm else IdentityOp,
k_layernorm=FusedLayerNorm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
def get_gpt_layer_local_spec(
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
) -> ModuleSpec:
"""Use this spec for an implementation using only modules in Megatron-Core.
Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
Returns:
ModuleSpec: Module specification with Megatron-Core modules
"""
mlp = _get_mlp_module_spec(
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=LNImpl if qk_layernorm else IdentityOp,
k_layernorm=LNImpl if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
def _get_mlp_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None,
) -> ModuleSpec:
"""Helper function to get module spec for MLP/MoE"""
if num_experts is None:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
)
else:
# Mixture of experts with modules in megatron core.
if use_te and moe_grouped_gemm:
linear_fc1 = TEColumnParallelGroupedLinear
linear_fc2 = TERowParallelGroupedLinear
elif use_te and fp8:
linear_fc1 = TEColumnParallelLinear
linear_fc2 = TERowParallelLinear
else:
linear_fc1 = ColumnParallelLinear
linear_fc2 = RowParallelLinear
use_te_grouped_gemm = use_te and TEColumnParallelGroupedLinear is not None
return ModuleSpec(
module=MoELayer,
submodules=(
MLPSubmodules(linear_fc1=linear_fc1, linear_fc2=linear_fc2)
if not moe_grouped_gemm or use_te_grouped_gemm
else None
),
)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import logging
from typing import Literal, Optional
from collections import OrderedDict
from typing import Dict, Literal, Optional
import torch
from torch import Tensor
from megatron.core import parallel_state, tensor_parallel
from megatron.core.models.common.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.gpt.gpt_embedding import GPTEmbedding
from megatron.core.transformer.enums import AttnMaskType, ModelType
from megatron.core.transformer.module import MegatronModule
from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint
class GPTModel(MegatronModule):
"""Transformer language model.
Arguments:
config (TransformerConfig): transformer config
vocab_size (int): vocabulary size
max_sequence_length (int): maximum size of sequence. This is used for positional embedding
pre_process (bool): Include embedding layer (used with pipeline parallelism)
post_process (bool): Include an output layer (used with pipeline parallelism)
parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks
share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are
shared. Defaults to False.
position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope'].
Defaults is 'learned_absolute'.
rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.
seq_len_interpolation_factor (float): scale of linearly interpolating RoPE for longer sequences.
class GPTModel(LanguageModule):
"""GPT Transformer language model.
Args:
config (TransformerConfig):
Transformer config
transformer_layer_spec (ModuleSpec):
Specifies module to use for transformer layers
vocab_size (int):
Vocabulary size
max_sequence_length (int):
maximum size of sequence. This is used for positional embedding
pre_process (bool, optional):
Include embedding layer (used with pipeline parallelism). Defaults to True.
post_process (bool, optional):
Include an output layer (used with pipeline parallelism). Defaults to True.
fp16_lm_cross_entropy (bool, optional):
Defaults to False.
parallel_output (bool, optional):
Do not gather the outputs, keep them split across tensor
parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional):
When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope], optional):
Position embedding type.. Defaults to 'learned_absolute'.
rotary_percent (float, optional):
Percent of rotary dimension to use for rotary position embeddings.
Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
rotary_base (int, optional):
Base period for rotary position embeddings. Ignored unless
position_embedding_type is 'rope'.
Defaults to 10000.
seq_len_interpolation_factor (Optional[float], optional):
scale of linearly interpolating RoPE for longer sequences.
The value must be a float larger than 1.0. Defaults to None.
"""
def __init__(
self,
config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
......@@ -54,13 +66,17 @@ class GPTModel(MegatronModule):
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute',
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
seq_len_interpolation_factor: Optional[float] = None,
):
super(GPTModel, self).__init__(config=config)
) -> None:
super().__init__(config=config)
self.config: TransformerConfig = config
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
......@@ -74,35 +90,53 @@ class GPTModel(MegatronModule):
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
# Embeddings.
# These 2 attributes are needed for TensorRT-LLM export.
self.max_position_embeddings = max_sequence_length
self.rotary_percent = rotary_percent
if self.pre_process:
self.embedding = GPTEmbedding(
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
add_position_embedding=(self.position_embedding_type == 'learned_absolute'),
position_embedding_type=position_embedding_type,
)
# Rotary Position Embeddings
if self.position_embedding_type == 'rope':
rotary_dim = self.config.kv_channels
if rotary_percent < 1.0:
rotary_dim = int(rotary_dim * rotary_percent)
self.rotary_pos_emb = RotaryEmbedding(rotary_dim, seq_len_interpolation_factor)
else:
self.rotary_pos_emb = None
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
use_cpu_initialization=self.config.use_cpu_initialization,
)
# Transformer.
self.decoder = TransformerBlock(
config=self.config,
self_attn_mask_type=AttnMaskType.causal,
spec=transformer_layer_spec,
pre_process=self.pre_process,
post_process=self.post_process,
)
# Output
if post_process:
if self.config.defer_embedding_wgrad_compute:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
......@@ -113,20 +147,32 @@ class GPTModel(MegatronModule):
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
if has_config_logger_enabled(self.config):
log_config_to_disk(
self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt'
)
if self.share_embeddings_and_output_weights and (self.pre_process or self.post_process):
self.initialize_last_stage_with_word_embeddings()
def set_input_tensor(self, input_tensor: Tensor) -> None:
"""Sets input tensor to the model.
def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""
See megatron.model.transformer.set_input_tensor()
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt'
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
self.decoder.set_input_tensor(input_tensor[0])
def forward(
......@@ -136,8 +182,16 @@ class GPTModel(MegatronModule):
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params=None,
):
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
......@@ -151,21 +205,12 @@ class GPTModel(MegatronModule):
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
# Rotary positional embeddings
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
if self.rotary_pos_emb is not None:
if inference_params is not None:
rotary_seq_len = inference_params.max_sequence_length
else:
if self.decoder.input_tensor is not None:
rotary_seq_len = self.decoder.input_tensor.size(0)
else:
rotary_seq_len = decoder_input.size(0)
# Decoder input is split along sequence dimension, but RoPE is applied in tensor parallel region
if self.config.sequence_parallel:
rotary_seq_len *= self.config.tensor_model_parallel_size
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# Run decoder.
......@@ -174,6 +219,8 @@ class GPTModel(MegatronModule):
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
**(extra_block_kwargs or {}),
)
if not self.post_process:
......@@ -185,124 +232,48 @@ class GPTModel(MegatronModule):
output_weight = self.shared_embedding_or_output_weight()
logits, _ = self.output_layer(hidden_states, weight=output_weight)
if has_config_logger_enabled(self.config):
payload = OrderedDict(
{
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attention_mask,
'decoder_input': decoder_input,
'logits': logits,
}
)
log_config_to_disk(self.config, payload, prefix='input_and_logits')
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
# [b s] => [s b]
labels = labels.transpose(0, 1).contiguous()
loss = tensor_parallel.vocab_parallel_cross_entropy(logits.float(), labels)
loss = self.compute_language_model_loss(labels, logits)
# [s b] => [b, s]
loss = loss.transpose(0, 1).contiguous()
return loss
def shared_embedding_or_output_weight(self):
if self.pre_process:
return self.embedding.word_embeddings.weight
elif self.post_process:
return self.output_layer.weight
return None
def initialize_last_stage_with_word_embeddings(self):
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism and sharing word
# embeddings. Nothing to do if we aren't sharing weights or aren't using
# pipeline parallelism.
if not self.share_embeddings_and_output_weights or (self.pre_process and self.post_process):
return
if self.post_process and not self.pre_process:
assert not parallel_state.is_pipeline_first_stage()
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.output_layer.weight.data.fill_(0)
self.output_layer.weight.shared = True
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if parallel_state.is_rank_in_embedding_group():
weight = self.shared_embedding_or_output_weight()
torch.distributed.all_reduce(
weight.data, group=parallel_state.get_embedding_group()
)
elif not getattr(GPTModel, "embedding_warning_printed", False):
logging.getLogger(__name__).warning(
"Distributed processes aren't initialized, so the output layer "
"is not initialized with weights from the word embeddings. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
GPTModel.embedding_warning_printed = True
def sharded_state_dict(self, prefix=''):
sharded_state_dict = {}
if self.pre_process:
embedding_prefix = f'{prefix}embedding.'
embedding_sharded_state_dict = self.embedding.sharded_state_dict(
prefix=embedding_prefix
)
sharded_state_dict.update(embedding_sharded_state_dict)
decoder_prefix = f'{prefix}decoder.'
decoder_sharded_state_dict = self.decoder.sharded_state_dict(prefix=decoder_prefix)
sharded_state_dict.update(decoder_sharded_state_dict)
if self.post_process:
output_layer_prefix = f'{prefix}output_layer.'
output_layer_key = f'{output_layer_prefix}weight'
if self.share_embeddings_and_output_weights:
if not self.pre_process:
# when sharing embeddings with last stage, we need to use the weights from the first stage
# on pipeline first rank, word embeddings are saved to {prefix}embedding.word_embeddings.weight
tensor = self.shared_embedding_or_output_weight()
first_stage_word_emb_key = f'{prefix}embedding.word_embeddings.weight'
dp_rank = parallel_state.get_data_parallel_rank()
dp_size = parallel_state.get_data_parallel_world_size()
last_stage_word_emb_replica_id = (
dp_rank + dp_size
) # copy of first stage embedding
sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint(
tensor=tensor,
key=first_stage_word_emb_key,
replica_id=last_stage_word_emb_replica_id,
allow_shape_mismatch=True,
)
sharded_state_dict[output_layer_key] = sharded_output_layer_tensor
else:
output_layer_state_dict = self.output_layer.state_dict(
prefix=output_layer_prefix, keep_vars=True
)
output_layer_tensor = output_layer_state_dict[output_layer_key]
# independent output layer
sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint(
tensor=output_layer_tensor,
key=output_layer_key,
replica_id=parallel_state.get_data_parallel_rank(),
allow_shape_mismatch=True,
)
sharded_state_dict[output_layer_key] = sharded_output_layer_tensor
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None
) -> ShardedStateDict:
"""Sharded state dict implementation for GPTModel backward-compatibility
(removing extra state).
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the GPTModel
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
output_layer_extra_state_key = f'{prefix}output_layer._extra_state'
# Old GPT checkpoints only stored the output layer weight key. So we remove the
# _extra_state key but check that it doesn't contain any data anyway
output_extra_state = sharded_state_dict.pop(output_layer_extra_state_key, None)
assert not (
output_extra_state and output_extra_state.data
), f'Expected output layer extra state to be empty, got: {output_extra_state}'
return sharded_state_dict
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from .mamba_model import MambaModel
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from megatron.core.extensions.transformer_engine import (
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TERowParallelLinear,
)
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules
from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules
from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
mamba_stack_spec = ModuleSpec(
module=MambaStack,
submodules=MambaStackSubmodules(
mamba_layer=ModuleSpec(
module=MambaLayer,
submodules=MambaLayerSubmodules(
mixer=ModuleSpec(
module=MambaMixer,
submodules=MambaMixerSubmodules(
in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear
),
),
mamba_bda=get_bias_dropout_add,
),
),
# Started with spec from gpt_layer_specs.py (with MLP removed)
# Using the TE spec because we had problems getting the non-TE spec
# working
attention_layer=ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
),
),
# Started with spec from gpt_layer_specs.py
# Using the TE spec because we had problems getting the non-TE spec
# working
mlp_layer=ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear
),
),
mlp_bda=get_bias_dropout_add,
),
),
),
)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from typing import Literal, Optional
from torch import Tensor
from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
class MambaModel(LanguageModule):
"""Mamba language model.
Args:
config (TransformerConfig): Transformer config
mamba_stack_spec (ModuleSpec): Specifies the modules to use for the various layer types
vocab_size (int): Vocabulary size
max_sequence_length (int): maximum size of sequence.
This is used for positional embedding
pre_process (bool, optional): Include embedding layer
(used with pipeline parallelism). Defaults to True.
mamba_ssm_ngroups (int, optional): Specifies the number of groups to use.
The default value is 8, as in the NVIDIA Mamba2 (pure and hybrid) 8b.
However, in the original Mamba2 paper, the checkpoints use a setting of 1.
Defaults to 8.
hybrid_attention_ratio (float, optional): The target ratio of attention
layers to total layers
hybrid_mlp_ratio (float, optional): The target ratio of mlp layers to total layers
hybrid_override_pattern (str, optional): The hybrid layer pattern to override with
post_process (bool, optional): Include an output layer (used with pipeline parallelism).
Defaults to True.
fp16_lm_cross_entropy (bool, optional): Defaults to False.
parallel_output (bool, optional): Do not gather the outputs, keep them split across tensor
parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional): When True, input embeddings and
output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope,none], optional): Position
embedding type. Defaults to 'none'.
rotary_percent (float, optional): Percent of rotary dimension to use for rotary position
embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
rotary_base (int, optional): Base period for rotary position embeddings. Ignored unless
position_embedding_type is 'rope'. Defaults to 10000.
seq_len_interpolation_factor (Optional[float], optional): scale of linearly
interpolating RoPE for longer sequences. The value must be a float larger than 1.0.
Defaults to None.
"""
def __init__(
self,
config: TransformerConfig,
mamba_stack_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
mamba_ssm_ngroups: int = 8,
pre_process: bool = True,
hybrid_attention_ratio: float = 0.0,
hybrid_mlp_ratio: float = 0.0,
hybrid_override_pattern: str = None,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
# Mamba with no attention has no need for position embeddings, so none is default
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'none',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
seq_len_interpolation_factor: Optional[float] = None,
) -> None:
super().__init__(config=config)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.mamba_stack_spec: ModuleSpec = mamba_stack_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.mamba_ssm_ngroups = mamba_ssm_ngroups
self.pre_process = pre_process
self.hybrid_attention_ratio = hybrid_attention_ratio
self.hybrid_mlp_ratio = hybrid_mlp_ratio
self.hybrid_override_pattern = hybrid_override_pattern
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
)
if self.position_embedding_type == 'rope':
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
use_cpu_initialization=self.config.use_cpu_initialization,
)
self.decoder = build_module(
mamba_stack_spec,
self.config,
mamba_ssm_ngroups=self.mamba_ssm_ngroups,
pre_process=self.pre_process,
hybrid_attention_ratio=self.hybrid_attention_ratio,
hybrid_mlp_ratio=self.hybrid_mlp_ratio,
hybrid_override_pattern=self.hybrid_override_pattern,
post_process=self.post_process,
dtype=config.params_dtype,
)
# Output
if post_process:
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
def set_input_tensor(self, input_tensor: Tensor) -> None:
"""Sets input tensor to the model.
See megatron.model.transformer.set_input_tensor()
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
self.decoder.set_input_tensor(input_tensor[0])
def forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
) -> Tensor:
"""Forward function of the Mamba model. This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if decoder_input is not None:
pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# The following assert will currently fail when running inference.
# Commented out for now.
# TODO (duncan/rwaleffe): (1) confirm that the externally-generated
# attention mask is not needed and is ignored by the model in
# inference mode, (2) reduce the size of the externally-generated
# attention mask to prevent CPU OOM (as we did for training), (3)
# force the attention mask passed to the model in inference mode to
# be None, so this assert will succeed.
# assert attention_mask is None, "The attention mask is ignored and should be set to None"
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
)
if not self.post_process:
return hidden_states
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
logits, _ = self.output_layer(hidden_states, weight=output_weight)
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
return loss
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
from collections import namedtuple
from functools import partial
from typing import List, Optional
import torch
from megatron.core import InferenceParams
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.models.gpt import GPTModel
from megatron.core.models.vision.clip_vit_model import CLIPViTModel, get_image_sequence_length
from megatron.core.models.vision.multimodal_projector import MultimodalProjector
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
IMAGE_TOKEN_INDEX = -200 # ID for images in the input sequence.
IGNORE_INDEX = -100 # ID for labels that should be ignored.
# Note: This is under development and may be missing features.
class LLaVAModel(MegatronModule):
"""LLaVA multi-modal model.
Args:
language_transformer_config (TransformerConfig): Transformer config for the language model.
language_transformer_layer_spec (ModuleSpec): Language model spec.
language_vocab_size (int): Language model vocabulary size.
language_max_sequence_length (int): Language model maximum sequence length.
vision_transformer_config (TransformerConfig): Transformer config for the vision model.
vision_transformer_layer_spec (ModuleSpec): Vision model spec.
drop_vision_class_token (bool): Drop vision class token(s) before the language model.
vision_projection_config (TransformerConfig): Vision projection config.
vision_projection_layer_spec (ModuleSpec): Vision projection spec.
vision_projection_type (str): Type of the vision projection. Default: 2-layer MLP.
allow_missing_vision_projection_checkpoint (bool): Allow vision projection weights to be
missing when loading a checkpoint. Default False.
parallel_output (bool): Keep outputs split across tensor parallel ranks.
This is typically True for training and False for inference.
language_position_embedding_type (str): Language model position embedding type.
language_rotary_percent (float): RoPE percent. Defaults to 1.0.
pre_process (bool): Include embedding layer in the decoder (used with pipeline parallel).
post_process (bool): Include output layer in the decoder (used with pipeline parallel).
add_encoder (bool): Construct the encoder (used with pipeline parallel).
When we use pipelining, the encoder will live on only the first stage
add_decoder (bool): Construct the decoder (used with pipeline parallel).
When we use pipelining, the decoder will live on every stage after the first one.
img_h (int): Input image height.
img_w (int): Input image width.
patch_dim (int): The size of each image patch side.
language_rotary_base (int): RoPE base.
"""
def __init__(
self,
language_transformer_config: TransformerConfig,
language_transformer_layer_spec: ModuleSpec,
language_vocab_size: int,
language_max_sequence_length: int,
vision_transformer_config: TransformerConfig,
vision_transformer_layer_spec: ModuleSpec,
drop_vision_class_token: bool,
vision_projection_config: TransformerConfig,
vision_projection_layer_spec: ModuleSpec,
vision_projection_type: str = "mlp",
allow_missing_vision_projection_checkpoint: bool = False,
parallel_output: bool = True,
language_position_embedding_type: str = 'learned_absolute',
language_rotary_percent: float = 1.0,
pre_process: bool = True,
post_process: bool = True,
add_encoder: bool = True,
add_decoder: bool = True,
img_h: int = 336,
img_w: int = 336,
patch_dim: int = 14,
language_rotary_base: int = 10000,
) -> None:
super().__init__(config=language_transformer_config)
if has_config_logger_enabled(language_transformer_config):
log_config_to_disk(language_transformer_config, locals(), prefix=type(self).__name__)
logging.getLogger(__name__).warning(
"LLaVA model is under active development. "
"It may be missing features and its methods may change."
)
self.pre_process = pre_process
self.post_process = post_process
self.add_encoder = add_encoder
self.add_decoder = add_decoder
self.encoder_hidden_state = None
self.vision_model = None
self.vision_projection = None
self.language_model = None
# This attribute is needed to check if an all-reduce is required
# on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`.
self.share_embeddings_and_output_weights = False
if self.add_decoder:
self.language_model = GPTModel(
config=language_transformer_config,
transformer_layer_spec=language_transformer_layer_spec,
vocab_size=language_vocab_size,
max_sequence_length=language_max_sequence_length,
parallel_output=parallel_output,
position_embedding_type=language_position_embedding_type,
rotary_percent=language_rotary_percent,
pre_process=self.pre_process,
post_process=self.post_process,
rotary_base=language_rotary_base,
)
self.share_embeddings_and_output_weights = (
self.language_model.share_embeddings_and_output_weights
)
self._language_max_sequence_length = language_max_sequence_length
class_token_len = 1
if self.add_encoder:
self.vision_model = CLIPViTModel(
vision_transformer_config,
vision_transformer_layer_spec,
img_h=img_h,
img_w=img_w,
class_token_len=class_token_len,
patch_dim=patch_dim,
)
self._drop_vision_class_token = drop_vision_class_token
# Map (intermediate) vision model outputs to the language model input dimension.
self.vision_projection = MultimodalProjector(
vision_projection_config,
vision_projection_layer_spec,
vision_projection_type,
vision_transformer_config.hidden_size, # input size to the projection.
)
# Ignore missing weights for the vision projection during checkpoint loading.
# This should be disabled by default but can be enabled if your checkpoint contains
# pretrained vision and language models but not the projection from vision model
# outputs to language model inputs.
if allow_missing_vision_projection_checkpoint:
vision_projection_param_names = [
f"vision_projection.{name}"
for name in self.vision_projection.state_dict().keys()
]
self.vision_projection.register_load_state_dict_post_hook(
partial(_load_state_dict_hook_ignore_param_names, vision_projection_param_names)
)
self._img_seq_len = get_image_sequence_length(
img_h, img_w, patch_dim, not drop_vision_class_token, class_token_len
)
def shared_embedding_or_output_weight(self):
"""This is a convenience method to surface the language model's word embeddings, which is
necessary for `finalize_model_grads._allreduce_word_embedding_grads`."""
if self.add_decoder:
return self.language_model.shared_embedding_or_output_weight()
return None
def set_input_tensor(self, input_tensor) -> None:
"""Set model chunk input tensor."""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for llava'
if self.add_encoder and self.add_decoder:
self.vision_model.set_input_tensor(input_tensor[0])
elif self.add_encoder:
self.vision_model.set_input_tensor(input_tensor[0])
elif self.pre_process:
self.encoder_hidden_state = input_tensor[0]
else:
self.language_model.set_input_tensor(input_tensor[0])
def freeze(
self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool
):
"""Freeze model modules.
Make specific modules non-trainable by setting requires_grad to False.
Args:
freeze_language_model (bool): Freeze the language model module.
freeze_vision_model (bool): Freeze the vision model module.
freeze_vision_projection (bool): Freeze the vision projection module.
"""
modules = []
if freeze_language_model and self.language_model is not None:
modules.append(self.language_model)
if freeze_vision_model and self.vision_model is not None:
modules.append(self.vision_model)
if freeze_vision_projection and self.vision_projection is not None:
modules.append(self.vision_projection)
for module in modules:
for param in module.parameters():
param.requires_grad = False
def _preprocess_data(
self,
image_embeddings,
language_embeddings,
input_ids,
loss_mask,
labels,
use_inference_kv_cache,
image_token_index,
num_image_tiles,
):
"""Preprocess input data before input to language model.
This function is adopted from
https://github.com/huggingface/transformers/blob/85817d98fb60977c97e3014196a462b732d2ed1a/src/transformers/models/llava_next/modeling_llava_next.py#L409
for our input data conventions.
image_token_index = -200 indicates the image position in the input_ids = [0, 1, -200, 2, 3]
and labels = [1, -200, 2, 3, 4], for example.
We want to replace the image position (-200) with image_embeddings and return the following:
- final_embeddings = [0, 1, image_embeddings, 2, 3],
- final_labels = [1, -100, 2, 3, 4]
- final_loss_mask = [1, 0, 0, 1, 1]
This function handles samples without images (text-only sample). It also handles samples
with images that are split into multiples tiles.
If pipeline parallelism is not used, then self.pre_process and self.post_process
are both True and we update both input embeddings, labels and loss masks (if available).
If pipeline parallelism is used, then we do the following
- the first language model chunk has self.pre_process = True and
self.post_process = False. We update input embeddings.
- the middle language model chunk(s) has self.pre_process = False and
self.post_process = False. We don't need to update anything.
- the last language model chunk has self.pre_process = False and
self.post_process = True. We update labels and loss mask.
TODO: This function should adjust the attention mask too.
Currently, we assume the language model uses a causal mask.
Returns:
final_embedding (torch.Tensor): image and text embeddings [combined_seq_len, b, h].
final_labels (torch.Tensor): labels for image and text positions [b, combined_seq_len].
final_loss_mask (torch.Tensor): loss mask [b, combined_seq_len].
"""
assert self.add_decoder, "input text preprocessing is only needed for the language model"
# No pre- or postprocessing needed.
# With pipeline parallel > 2, this means a chunk in the middle of the model.
if not self.pre_process and not self.post_process:
return language_embeddings, loss_mask, labels
# If using the inference KV cache, the image tokens are already computed.
if use_inference_kv_cache:
return language_embeddings, loss_mask, labels
img_seq_len = self._img_seq_len
batch_size, text_seq_len = input_ids.shape
has_labels = labels is not None
if has_labels:
assert (
labels.shape == loss_mask.shape
), f"mismatching labels shape {labels.shape} and loss mask shape {loss_mask.shape}"
# Create indices for new text and label positions.
with torch.no_grad():
image_token_mask = input_ids == image_token_index
num_images_per_sample = torch.sum(image_token_mask, dim=-1)
# Number of tiles per sample.
num_image_tiles_batch = num_image_tiles.split(num_images_per_sample.tolist(), dim=0)
num_image_tiles_batch = torch.tensor(
[x.sum() for x in num_image_tiles_batch], device=input_ids.device
)
# Sequence length for each sample is the image sequence length multiplied by
# the number of tiles for that image, minus image token indices,
# plus text sequence length.
seq_lens = num_image_tiles_batch * img_seq_len - num_images_per_sample + text_seq_len
max_seq_len = seq_lens.max()
batch_indices, non_image_indices = torch.where(input_ids != image_token_index)
# New position ids for the text tokens, shifted by the image sequence length.
# E.g. for input_ids = [-200, 1, 2, 3] and img_seq_len = 576, we get
# new_position_ids = [576, 577, 578, 579]. text_position_ids are then [577, 578, 579].
image_token_mask_lens = image_token_mask.int().clone()
# -1 is for the removed image token index.
image_token_mask_lens[image_token_mask] = num_image_tiles * img_seq_len - 1
# +1 is needed here for the cumulative sum. -1 is adjusting for zero-based indexing.
new_position_ids = torch.cumsum((image_token_mask_lens + 1), dim=-1) - 1
text_position_ids = new_position_ids[batch_indices, non_image_indices]
# Labels are shifted to left by one.
# So, shift text position ids and non-image indices to left by one.
if has_labels:
label_text_position_ids = text_position_ids - 1
valid_label_text_position_ids = label_text_position_ids >= 0
label_text_position_ids = label_text_position_ids[valid_label_text_position_ids]
label_batch_indices = batch_indices[valid_label_text_position_ids]
label_non_image_indices = non_image_indices - 1
valid_label_non_image_indices = label_non_image_indices >= 0
label_non_image_indices = label_non_image_indices[valid_label_non_image_indices]
# Create a mask for the image embedding positions.
images_mask = torch.full(
(batch_size, max_seq_len), True, dtype=torch.bool, device=input_ids.device
)
# No images in the text positions.
images_mask[batch_indices, text_position_ids] = False
# Samples can have different amount of images tokens.
# new_position_ids[:, -1] gives the last text position id for each sample.
# Padding is needed when the number of image tokens differs.
first_padding_idx = new_position_ids[:, -1] + 1
images_mask[
torch.arange(max_seq_len, device=first_padding_idx.device).repeat(batch_size, 1)
>= first_padding_idx.unsqueeze(1)
] = False
# Create the final input embedding (if this is the first language model stage).
final_embedding = None
if self.pre_process:
embed_dim = language_embeddings.shape[-1]
final_embedding = torch.zeros(
batch_size,
max_seq_len,
embed_dim,
dtype=language_embeddings.dtype,
device=language_embeddings.device,
)
# Put text embeddings to the text positions in the result tensor.
final_embedding[batch_indices, text_position_ids] = language_embeddings[
batch_indices, non_image_indices
]
# Put image embeddings to image positions.
final_embedding[images_mask] = image_embeddings.reshape(-1, embed_dim).contiguous()
# Create the final labels and loss mask (if this is the last language model stage).
final_labels, final_loss_mask = None, None
if has_labels:
final_labels = torch.full(
(batch_size, max_seq_len), IGNORE_INDEX, dtype=labels.dtype, device=labels.device
)
final_loss_mask = torch.full(
(batch_size, max_seq_len), 0, dtype=loss_mask.dtype, device=loss_mask.device
)
# Put text labels and loss mask to the text positions.
final_labels[label_batch_indices, label_text_position_ids] = labels[
label_batch_indices, label_non_image_indices
]
final_loss_mask[batch_indices, text_position_ids] = loss_mask[
batch_indices, non_image_indices
]
# For labels, pick the last label index that got dropped by the shift to left.
label_extra_text_position_ids = seq_lens - 1
batch_range = torch.arange(len(label_extra_text_position_ids))
final_labels[batch_range, label_extra_text_position_ids] = labels[batch_range, -1]
# Loss mask the image positions.
final_loss_mask[images_mask] = 0
# Loss mask last text position just before an image
# so that text token does not need to predict the first image token.
batch_image_indices, image_indices = torch.where(image_token_mask)
# Indices just before image tokens. If it's -1, skip it.
before_image_indices = image_indices - 1
valid = before_image_indices >= 0
valid_batch_image_indices = batch_image_indices[valid]
valid_before_image_indices = before_image_indices[valid]
# Map those indices those position ids.
valid_before_image_indices = new_position_ids[
valid_batch_image_indices, valid_before_image_indices
]
final_loss_mask[valid_batch_image_indices, valid_before_image_indices] = 0
if final_embedding is not None and has_labels:
assert (
final_embedding.shape[:2] == final_labels.shape == final_loss_mask.shape
), "unexpected shapes after data preprocessing"
if final_embedding is not None:
final_embedding = final_embedding.transpose(1, 0).contiguous()
# Truncate if exceeding the language model's max sequence length.
if (
final_embedding is not None
and final_embedding.shape[0] > self._language_max_sequence_length
):
final_embedding = final_embedding[: self._language_max_sequence_length]
if has_labels and final_labels.shape[1] > self._language_max_sequence_length:
final_labels = final_labels[:, : self._language_max_sequence_length]
final_loss_mask = final_loss_mask[:, : self._language_max_sequence_length]
return final_embedding, final_labels, final_loss_mask
def forward(
self,
images: torch.Tensor,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
inference_params: Optional[InferenceParams] = None,
num_image_tiles: Optional[List[int]] = None,
image_token_index: Optional[int] = IMAGE_TOKEN_INDEX,
) -> torch.Tensor:
"""Forward function of the LLaVA model.
Args:
images (torch.Tensor): input images of shape [num_tiles, img_h, img_w].
num_tiles means the number of image tiles in this batch.
num_tiles = 0 if the batch doesn't contain images.
input_ids (torch.Tensor): input text ids [batch, text_seq_len].
position_ids (torch.Tensor): input text position ids [batch, text_seq_len].
attention_mask (torch.Tensor): Language model attention mask
[batch, 1, combined_seq_len, combined_seq_len].
labels (torch.Tensor): Optional target text labels [batch, combined_seq_len].
loss_mask (torch.Tensor): Text loss mask [batch, text_seq_len].
inference_params (InferenceParams): Inference-time parameters including KV cache.
num_image_tiles (list of int): Number of tiles per image. Default 1 tile per image.
image_token_index (int): ID for input images.
Returns:
output (torch.Tensor): Loss of shape [b, s] if labels are provided,
otherwise logits of shape [b, s, vocab_size].
loss_mask (torch.Tensor): Loss mask expanded to combined sequence length. Shape [b, s].
"""
use_inference_kv_cache = (
inference_params is not None
and "image_tokens_count" in inference_params.key_value_memory_dict
)
has_images = images.shape[0] > 0
# If running inference, we can skip image token computation
# if they were computed already earlier for this sample.
if use_inference_kv_cache:
image_embeddings = None
elif self.add_encoder and not has_images:
# If no images provided, use an empty image embeddings tensor.
image_embeddings = torch.tensor([], dtype=images.dtype, device=images.device)
elif self.add_encoder and has_images:
image_embeddings = self.vision_model(images) # [num_tiles, img_seq_len, h_vision]
if self._drop_vision_class_token:
image_embeddings = image_embeddings[:, self.vision_model.class_token_len :, :]
# contiguous() required as `permute` can sparsify the tensor and this breaks pipelining
image_embeddings = image_embeddings.permute(
1, 0, 2
).contiguous() # [img_seq_len, num_tiles, h_vision]
# map vision model output size to language model input size.
image_embeddings = self.vision_projection(
image_embeddings
) # [img_seq_len, num_tiles, h_language]
# TODO: Support batched inference.
# In inference, the language model KV cache will be updated for image token positions.
# Store the image tokens sequence length to be used as an offset to the KV cache later.
if inference_params is not None:
inference_params.key_value_memory_dict["image_tokens_count"] = (
image_embeddings.shape[0] * image_embeddings.shape[1]
)
else:
image_embeddings = self.encoder_hidden_state
if not self.add_decoder:
return image_embeddings, loss_mask
language_embeddings = None
if self.pre_process:
input_ids_text = input_ids.clone()
input_ids_text[input_ids_text == image_token_index] = 0
# Note: This adds absolute position embedding but not RoPE.
# Each image is counted as one position.
# RoPE is added in language_model forward. Each image embedding is one position.
language_embeddings = self.language_model.embedding(
input_ids=input_ids_text, position_ids=position_ids
) # [text_seq_len, b, h_language]
language_embeddings = language_embeddings.transpose(
1, 0
).contiguous() # [b, text_seq_len, h_language]
# Assume 1 tile per image if the number of tiles is not provided.
if num_image_tiles is None:
num_image_tiles = torch.ones(images.shape[0], dtype=torch.int, device=input_ids.device)
# Preprocess input, labels and loss mask.
combined_embeddings, new_labels, new_loss_mask = self._preprocess_data(
image_embeddings,
language_embeddings,
input_ids,
loss_mask,
labels,
use_inference_kv_cache,
image_token_index,
num_image_tiles,
) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len]
output = self.language_model(
input_ids=None,
position_ids=None,
attention_mask=attention_mask,
decoder_input=combined_embeddings,
labels=new_labels,
inference_params=inference_params,
)
if labels is None or loss_mask is None:
return output
return output, new_loss_mask
def _load_state_dict_hook_ignore_param_names(
param_names: List[str], module: torch.nn.Module, incompatible_keys: namedtuple
):
"""Hook to ignore missing keys during checkpoint loading.
By default, this should not be used to avoid accidentally missing weights in checkpoint loading.
Example use case: Use this if you want to load a checkpoint that contains vision and language
model weights but not the vision projection weights.
Args:
param_names (list str): Parameter names allowed to be missing when calling load_state_dict.
module (torch.nn.Module): The torch module this hook applies to. Required by the torch API.
incompatible_keys (namedtuple): Namedtuple with fields missing_keys and unexpected_keys,
which collect the missing and unexpected keys, respectively.
"""
for param_name in param_names:
if param_name in incompatible_keys.missing_keys:
logging.getLogger(__name__).warning(
f"{param_name} being removed from incompatible_keys.missing_keys in LlavaModel"
)
incompatible_keys.missing_keys.remove(param_name)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.extensions.transformer_engine import (
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelLinear,
)
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
try:
import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings
from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm
warnings.warn(f'Apex is not installed. Falling back to Torch LayerNorm')
LNImpl = WrappedTorchLayerNorm
def decoder_model_with_transformer_engine_default_spec(
num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
) -> ModuleSpec:
"""LLava decoder TE spec (uses Transformer Engine components)."""
mlp = _get_mlp_module_spec(
use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=TENorm if qk_layernorm else IdentityOp,
k_layernorm=TENorm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
def decoder_model_with_local_default_spec(
num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
) -> ModuleSpec:
"""LLava decoder local spec."""
mlp = _get_mlp_module_spec(
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
Exports:
- RetroConfig: configuration dataclass for RetroModel.
- RetroModel: The Retro model.
- get_retro_decoder_block_spec: Get spec for Retro decoder transformer block.
"""
from .config import RetroConfig
from .decoder_spec import get_retro_decoder_block_spec
from .model import RetroModel
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Base class for decoder and encoder attention modules."""
from megatron.core.models.retro.config import RetroConfig
from megatron.core.transformer.attention import CrossAttention, CrossAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
class BaseRetroCrossAttention(MegatronModule):
"""Base class for Retro cross attention, for both encoder & decoder layers.
This class collects the retro arguments below (i.e., num neighbors, chunk
length, and retrieve length) for use in Retro's custom cross attention
operators.
Args:
config (RetroConfig): Retro config.
submodules (CrossAttentionSubmodules): Cross attention submodules.
layer_number (int): Layer number within transformer block.
attn_mask_type (AttnMaskType): Mask type ('causal' or 'padding').
"""
def __init__(
self,
config: RetroConfig,
submodules: CrossAttentionSubmodules,
layer_number: int = 1,
attn_mask_type: AttnMaskType = AttnMaskType.padding,
):
super().__init__(config=config)
self.attn = CrossAttention(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
)
self.retro_num_neighbors = config.retro_num_neighbors
self.retro_chunk_length = config.retro_chunk_length
self.retro_retrieved_length = config.retro_retrieved_length
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Configuration dataclass for a RetroModel."""
import os
from dataclasses import dataclass
from megatron.core.transformer import TransformerConfig
from megatron.core.utils import is_te_min_version
@dataclass
class RetroConfig(TransformerConfig):
"""Configuration object for Retro models."""
# Retro.
retro_project_dir: str = None
"""Retro project directory, which contains the preprocessed data for for pretraining. This
directory is built during preprocessing (see tools/retro/README.md), and contains
subdirectories for the chunk database and pretraining neighbors.
"""
retro_block_size: int = None
"""Number of records to load per data file, as saved during preprocessing. Block processing is
used for efficient data preprocessing.
"""
retro_chunk_length: int = None
"""Chunk length used for performing chunked- cross-attention (CCA)."""
retro_encoder_num_layers: int = 2
"""Number of layers to use for the retrieval encoder."""
retro_encoder_hidden_dropout: float = 0.1
"""Hidden dropout for retrieval encoder."""
retro_encoder_attention_dropout: float = 0.1
"""Attention dropout for retrieval encoder."""
retro_neighbor_dirs: dict = None
"""Directory names of saved neighbor id files for train, valid, and test datasets."""
retro_num_neighbors: int = 2
"""Number of neighbors to retrieve during pretraining."""
retro_num_retrieved_chunks: int = 2
"""Number of chunks to retrieve from the retrieval database."""
retro_retrieved_length: int = None
"""Cached value of retro_num_retrieved_chunks * retro_chunk_length (i.e., the total number of
retrieved tokens; neighbor + continuation).
"""
retro_split_preprocessing: str = None
"""Data split used during data preprocessing."""
retro_verify_neighbor_count: bool = True
"""Verify that len(GPT dataset) == len(saved neighbors)."""
# pylint: disable=line-too-long
def __post_init__(self) -> None:
"""Validate Retro config."""
super().__post_init__()
# Validate Transformer Engine version.
if is_te_min_version("1.3"):
try:
assert os.getenv("NVTE_FLASH_ATTN") == "0"
assert os.getenv("NVTE_FUSED_ATTN") == "0"
except Exception as e:
raise Exception(
"When using Transformer Engine >= 1.3, environment vars NVTE_FLASH_ATTN and NVTE_FUSED_ATTN most both be defined and set to '0'. Currently, NVTE_FLASH_ATTN == %s, NVTE_FUSED_ATTN == %s."
% (
os.getenv("NVTE_FLASH_ATTN", "[unset]"),
os.getenv("NVTE_FUSED_ATTN", "[unset]"),
)
)
# Preprocessing split should be defined.
assert self.retro_split_preprocessing is not None
# Pre-compute retrieved length.
self.retro_retrieved_length = self.retro_num_retrieved_chunks * self.retro_chunk_length
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Retro's cross attention modules for the decoder block."""
from functools import partial
from typing import Callable
import numpy as np
import torch
from torch import Tensor
from megatron.core import InferenceParams
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.retro.base_attention import BaseRetroCrossAttention
from megatron.core.models.retro.config import RetroConfig
from megatron.core.models.retro.utils import get_all_true_mask
from megatron.core.transformer import ModuleSpec
from megatron.core.transformer.attention import CrossAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_block import TransformerBlock
class RetroDecoderCrossAttention(BaseRetroCrossAttention):
"""Retro decoder's chunked cross attention operator.
See this paper for more details: https://arxiv.org/abs/2112.04426.
Neighboring chunks retrieved from the chunk database are used here for
chunked-cross attention.
** Note about 'encoder_block_spec' **
Retro is an encoder-decoder model that uses its encoder for encoding
neighboring chunks that are retrieved from a chunk database. These
encoded neighbors are then used in the decoder stack for performing
chunked-cross attention (see paper link above).
In contrast to the T5 model, the encoder and decoder are computationally
intertwined, since the input to the encoder is the output of the self-
attention of the first decoder layer. As such, the encoder block itself
is instantiated within the first Retro decoder layer, in order to receive
the self-attention's output. (Note, that only the first decoder layer
instantiates an encoder block, and the remaining decoder layers use the
encoder output from the first decoder layer.)
Args:
config (RetroConfig): Retro config.
submodules (CrossAttentionSubmodules): Cross attention submodules.
layer_number (int): Layer number within transformer block.
attn_mask_type (AttnMaskType): Mask type ('causal' or 'padding').
encoder_block_spec (ModuleSpec): The first Retro decoder layer is provided with a transformer block spec to construct the neighbor encoder.
"""
def __init__(
self,
config: RetroConfig,
submodules: CrossAttentionSubmodules,
layer_number: int = 1,
attn_mask_type: AttnMaskType = AttnMaskType.padding,
encoder_block_spec: ModuleSpec = None,
):
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
)
if encoder_block_spec:
self.encoder = TransformerBlock(
config=config, spec=encoder_block_spec, pre_process=True, post_process=False
)
# self._encoder_key = 'encoder' # ... necessary?
else:
self.encoder = None
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
key_value_states: Tensor = None,
inference_params: InferenceParams = None,
# rotary_pos_emb: Tensor = None, # ... unsupported for retro.
) -> dict:
"""Cross attention for Retro decoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
m : Number of tokens per chunk.
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
Args:
hidden_states (Tensor): Transformer layer hidden states.
attention_mask (Tensor): Attention mask.
key_value_states (Tensor): Neighbor embeddings if first decoder layer, else encoder output.
inference_params (InferenceParams): Inference params.
Returns:
A dict consisting of the attention output and context, along with other scalars necessary for performing the downstream bias-dropout-add.
"""
# hidden_states: [ ns, bs, d ]
# key_value_states: [ r, k*bs*l, d ]
ns, bs, d = hidden_states.shape
l = int(np.ceil(ns / self.retro_chunk_length))
# Retrieve neighbors.
if self.encoder:
# Sequence length remainder.
first_ns = ns % self.retro_chunk_length
# Case 1: Sequence length not divisible by chunk length.
if first_ns > 0:
# Split sequence into first partial chunk & remaining chunks.
first_chunk, rest_chunk = hidden_states[:first_ns], hidden_states[first_ns:]
# Pad partial chunk with zeros.
first_chunk = torch.nn.functional.pad(
first_chunk, (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns), 'constant', 0
)
# Concatenate padded chunk with remaining chunks.
chunked_output = torch.cat((first_chunk, rest_chunk), dim=0) # [ l*m, bs, d ]
# Case 2: Sequence length is divisible by chunk length.
else:
chunked_output = hidden_states # [ l*m, bs, d ]
# Chunk & permute hidden states.
# - hidden_states: [ l*m, bs, d ]
# - chunked_output: [ m, bs*l, d ]
chunked_output = (
chunked_output.reshape(l, self.retro_chunk_length, bs, d)
.permute(1, 2, 0, 3)
.reshape(self.retro_chunk_length, bs * l, d)
.contiguous()
)
# flash attn: [ b, h, sq, sk ]
# fused attn: [ b, 1, 1, sq ]
chunked_output_mask = get_all_true_mask(
size=(1, 1, chunked_output.shape[0], key_value_states.shape[0]),
device=chunked_output.device,
)
# Encode neighbors. (Note: 'key_value_states' re-assigned here.)
key_value_states = self.encoder(
hidden_states=key_value_states,
attention_mask=attention_mask,
context=chunked_output,
context_mask=chunked_output_mask,
inference_params=inference_params,
) # [ r, k*bs*l, d ]
key_value_states = key_value_states.reshape(
self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d
) # [ r*k, bs*l, d ]
# Attend starting at last token of first chunk.
pad = (ns - 1) % self.retro_chunk_length
attending_chunks = hidden_states[pad:]
# Pad attending tokens to sequence length.
padded_chunks = torch.nn.functional.pad(
attending_chunks, (0, 0, 0, 0, 0, self.retro_chunk_length - 1), 'constant', 0
)
# Permute attending chunks.
# - padded_chunks: [ l*m, bs, d ]
# - padded_chunked_output: [ m, bs*l, d ] (matches 'chunked_output' above)
padded_chunked_output = padded_chunks.reshape(l, self.retro_chunk_length, bs, d).permute(
1, 2, 0, 3
)
padded_chunked_output = padded_chunked_output.reshape(
self.retro_chunk_length, bs * l, d
).contiguous()
# flash attn: [ b, h, sq, sk ]
# fused attn: [ b, 1, 1, sq ]
padded_chunked_output_mask = get_all_true_mask(
size=(1, 1, padded_chunked_output.shape[0], key_value_states.shape[0]),
device=padded_chunked_output.device,
)
# Attend to encoded neighbors.
attention_output, attention_bias = self.attn(
hidden_states=padded_chunked_output,
attention_mask=padded_chunked_output_mask,
key_value_states=key_value_states,
)
# Return dimensions for bias-dropout step.
return {
"ns": ns,
"bs": bs,
"d": d,
"l": l,
"pad": pad,
"attention_output": attention_output, # [ m, bs*l, d ]
"attention_bias": attention_bias, # [ d ]
"context": key_value_states, # [ r*k, bs*l, d ]
}
class RetroDecoderBiasDropoutAdd(MegatronModule):
"""Retro decoder's bias-dropout-add operator.
This operator takes care of reshaping and permuting the output from the
chunk dimension to the sequence dimension.
Args:
config (RetroConfig): Retro config.
"""
def __init__(self, config: RetroConfig):
super().__init__(config=config)
self.retro_chunk_length = config.retro_chunk_length
@classmethod
def _forward(
cls,
x_with_bias: dict,
residual: Tensor,
prob: float,
retro_chunk_length: int,
bias_dropout_add: Callable,
) -> Tensor:
"""Per-chunk bias-dropout-add.
Args:
x_with_bias (dict): Attention output and bias, along with other Retro relevant parameters.
residual (Tensor): Transformer layer residual.
prob (float): Dropout probability.
retro_chunk_length (int): Retro chunk length (e.g., 64).
bias_dropout_add (Callable): Bias-dropout-add function.
Returns:
Output of bias-dropout-add.
"""
# Extract input dict.
ns = x_with_bias["ns"]
bs = x_with_bias["bs"]
d = x_with_bias["d"]
l = x_with_bias["l"]
pad = x_with_bias["pad"]
attention_output = x_with_bias["attention_output"] # [ m, bs*l, d ]
attention_bias = x_with_bias["attention_bias"] # [ d ]
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
# Bias-dropout-add.
x = bias_dropout_add(
(
attention_output,
None if attention_bias is None else attention_bias.expand_as(attention_output),
),
torch.zeros_like(attention_output),
prob,
)
# Permute chunks back to sequence dimension.
# 1. [ m, bs*l, d ]
# 2. [ m, bs, l, d ]
# 3. [ l, m, bs, d ]
# 4. [ m*l, bs, d ] == [ ns, bs, d ]
x = (
x.reshape(retro_chunk_length, bs, l, d)
.permute(2, 0, 1, 3)
.reshape(retro_chunk_length * l, bs, d)
)
# Prepend zeros for non-attending tokens.
x = torch.nn.functional.pad(x, (0, 0, 0, 0, pad, 0), 'constant', 0)[
:ns
] # [ ns, bs, d ]
# Add residual. [ ns, bs, d ]
x = x + residual
# Output. [ ns, bs, d ]
return x
def forward(self, training: bool, fused: bool) -> partial:
"""Retro decoder bias-dropout-add.
Args:
training (bool): If training, then apply dropout.
fused (bool): Fuse bias-dropout-add.
Returns:
The partial function for performing bias-dropout-add.
"""
return partial(
self._forward,
retro_chunk_length=self.retro_chunk_length,
bias_dropout_add=get_bias_dropout_add(training, fused),
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Specs for Retro decoder."""
import typing
from megatron.core import parallel_state
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
from megatron.core.models.retro.config import RetroConfig
from megatron.core.models.retro.decoder_attention import (
RetroDecoderBiasDropoutAdd,
RetroDecoderCrossAttention,
)
from megatron.core.models.retro.encoder_spec import get_retro_encoder_block_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer import ModuleSpec
from megatron.core.transformer.attention import CrossAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.transformer_block import (
TransformerBlockSubmodules,
get_num_layers_to_build,
)
try:
import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings
from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm
warnings.warn(f'Apex is not installed. Falling back to Torch LayerNorm')
LNImpl = WrappedTorchLayerNorm
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TENorm,
TERowParallelLinear,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
def get_retro_decoder_layer_te_spec(
encoder_block_spec: typing.Union[ModuleSpec, TransformerBlockSubmodules, None] = None
) -> ModuleSpec:
"""Retro decoder TE spec (uses Transformer Engine components).
A Retro decoder layer uses custom attention and bias-dropout-add operators
to perform chunked-cross attention. Additionally, the first Retro decoder
layer instantiates an entire encoder transformer block. As such, the decoder
cross attention module takes an optional encoder block spec, which is only
provided for the first Retro decoder layer.
Args:
encoder_block_spec (ModuleSpec): Retro encoder block spec, to be provided for
the first Retro decoder layer.
Returns:
A module spec with Transformer Engine modules.
"""
spec = get_gpt_layer_with_transformer_engine_spec()
spec.submodules.pre_cross_attn_layernorm = TENorm
spec.submodules.cross_attention = ModuleSpec(
module=RetroDecoderCrossAttention,
params={"encoder_block_spec": encoder_block_spec},
submodules=CrossAttentionSubmodules(
linear_q=TEColumnParallelLinear,
linear_kv=TEColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
)
spec.submodules.cross_attn_bda = ModuleSpec(module=RetroDecoderBiasDropoutAdd)
return spec
def get_retro_decoder_layer_local_spec(
encoder_block_spec: typing.Optional[ModuleSpec] = None,
) -> ModuleSpec:
"""Retro decoder local spec (uses Megatron-Core components).
A Retro decoder layer uses custom attention and bias-dropout-add operators
to perform chunked-cross attention. Additionally, the first Retro decoder
layer instantiates an entire encoder transformer block. As such, the decoder
cross attention module takes an optional encoder block spec, which is only
provided for the first Retro decoder layer.
Args:
encoder_block_spec (ModuleSpec): Retro encoder block spec, to be provided
for the first Retro decoder layer.
Returns:
A module spec with local modules.
"""
spec = get_gpt_layer_local_spec()
spec.submodules.pre_cross_attn_layernorm = LNImpl
spec.submodules.cross_attention = ModuleSpec(
module=RetroDecoderCrossAttention,
params={"encoder_block_spec": encoder_block_spec},
submodules=CrossAttentionSubmodules(
linear_q=ColumnParallelLinear,
linear_kv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
),
)
spec.submodules.cross_attn_bda = ModuleSpec(module=RetroDecoderBiasDropoutAdd)
return spec
def get_retro_decoder_block_spec(
config: RetroConfig, use_transformer_engine: bool
) -> TransformerBlockSubmodules:
"""Retro decoder block spec.
Retro decoder block implementation details:
- The retro decoder block consists of interleaved GPT layers
and customized Retro decoder layers.
- The Retro decoder layers are spaced three layers apart,
and start on layer 6 or 9 (depending on the total number of layers).
- The first decoder layer instantiates an encoder block,
and it therefore passes in an encoder_block_spec.
Args:
config (RetroConfig): Retro config.
use_transformer_engine (bool): If True, use Transformer Engine (instead of local modules.
Returns:
Transformer block submodules for the given spec.
"""
# Num layers.
assert (
parallel_state.get_pipeline_model_parallel_world_size() == 1
), "retro does not currently support pipeline parallelism."
assert (
parallel_state.get_virtual_pipeline_model_parallel_world_size() is None
), "retro does not currently support virtual pipeline parallelism."
num_layers = get_num_layers_to_build(config)
# Retro layer numbers.
retro_layer_start = 6 if num_layers <= 15 else 9
retro_layer_numbers = list(range(retro_layer_start, num_layers + 1, 3))
# Layer specs.
gpt_layer_spec = (
get_gpt_layer_with_transformer_engine_spec()
if use_transformer_engine
else get_gpt_layer_local_spec()
)
get_retro_decoder_layer_spec = (
get_retro_decoder_layer_te_spec
if use_transformer_engine
else get_retro_decoder_layer_local_spec
)
retro_layer_spec = get_retro_decoder_layer_spec()
retro_layer_spec_with_retriever = get_retro_decoder_layer_spec(
get_retro_encoder_block_spec(config, use_transformer_engine)
)
layer_specs = []
for layer_number in range(1, num_layers + 1):
if layer_number == retro_layer_numbers[0]:
layer_specs.append(retro_layer_spec_with_retriever)
elif layer_number in retro_layer_numbers:
layer_specs.append(retro_layer_spec)
else:
layer_specs.append(gpt_layer_spec)
# Block spec.
block_spec = TransformerBlockSubmodules(layer_specs=layer_specs)
return block_spec
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Retro's cross attention modules for the encoder block."""
from functools import partial
from typing import Callable, List, Optional, Tuple, Type
import torch
from torch import Tensor
from megatron.core import InferenceParams
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.retro.base_attention import BaseRetroCrossAttention
from megatron.core.models.retro.config import RetroConfig
from megatron.core.models.retro.utils import get_all_true_mask
from megatron.core.transformer.module import MegatronModule
class RetroEncoderCrossAttention(BaseRetroCrossAttention):
"""Retro encoder's cross attention operator.
See this paper for more details: https://arxiv.org/abs/2112.04426.
Neighboring chunks are retrieved from the chunk database, encoded, and
used by the decoder layers for chunked cross attention.
Args:
config (RetroConfig): Retro config.
submodules (CrossAttentionSubmodules): Cross attention submodules.
layer_number (int): Layer number within transformer block.
attn_mask_type (AttnMaskType): Mask type ('causal' or 'padding').
"""
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
key_value_states: Tensor = None,
inference_params: InferenceParams = None,
# rotary_pos_emb: Tensor = None, # unsupported for retro.
) -> List[Tuple[Tensor, Optional[Tensor], Tensor]]:
"""Cross attention for Retro encoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
Args:
hidden_states (Tensor): Transformer layer hidden states.
attention_mask (Tensor): Attention mask.
key_value_states (Tensor): Neighbor embeddings.
inference_params (InferenceParams): Inference params.
Returns:
List of tuples, where each tuple is (attention_output, attention_bias, residual).
"""
# Input shape. [ r, bs*l*k, d ]
ns, bs, d = hidden_states.shape
# Reshape sequence into neighboring chunks.
# - hidden_states: [ r, bs*l*k, d ]
# - chunked_outputs: [ r, bs*l, k, d ]
chunked_outputs = hidden_states.reshape(
self.retro_retrieved_length, -1, self.retro_num_neighbors, d
)
# flash attn: [ b, h, sq, sk ]
# fused attn: [ b, 1, 1, sq ]
chunked_output_mask = get_all_true_mask(
size=(1, 1, chunked_outputs.shape[0], key_value_states.shape[0]),
device=chunked_outputs.device,
)
# Per-chunk attention.
attention_output_tuples = []
for k in range(self.retro_num_neighbors):
# Attend to current neighboring chunks.
# - chunked_output: [ r, bs*l, d ]
# - key_value_states: [ m, bs*l, d ]
# - attention_output: [ r, bs*l, d ]
# - attention_bias: [ d ]
chunked_output = chunked_outputs[:, :, k].contiguous()
attention_output, attention_bias = self.attn(
hidden_states=chunked_output, # Q (neighbor embedding)
attention_mask=chunked_output_mask,
key_value_states=key_value_states, # K, V (hidden act)
)
# Residual connection. [ r, bs*l, d ]
residual = chunked_output
# Collect tensors.
attention_output_tuples.append((attention_output, attention_bias, residual))
# Output. (List[Tuple[( [ r, bs*l, d ], [ d ] )]])
return attention_output_tuples
class RetroEncoderBiasDropoutAdd(MegatronModule):
"""Retro encoder's bias-dropout-add operator.
This operator applies bias-dropout-add individually on each neighboring
chunk that is retrieved from the chunk database.
Args:
config (RetroConfig): Retro config.
"""
def __init__(self, config: RetroConfig):
super().__init__(config=config)
self.retro_num_neighbors = config.retro_num_neighbors
@classmethod
def _forward(
cls,
x_with_bias: List[Tuple[Tensor, Optional[Tensor], Tensor]],
residual: Tensor,
prob: float,
retro_num_neighbors: int,
bias_dropout_add: Callable,
) -> Tensor:
"""Per-chunk bias-dropout-add.
Args:
x_with_bias (dict): Attention output and bias tuple.
residual (Tensor): Transformer layer residual.
prob (float): Dropout probability.
retro_num_neighbors (int): Number of retrieved neighbor chunks (e.g., 2).
bias_dropout_add (Callable): Bias-dropout-add function.
Returns:
Output of bias-dropout-add.
"""
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
# Per-neighbor bias-dropout-add.
# - attention_output: [ r, bs*l, d ]
# - attention_bias: [ d ]
# - residual: [ r, bs*l, d ]
# - output: [ r, bs*l, d ]
outputs = [
bias_dropout_add(
(
attention_output,
None if attention_bias is None else attention_bias.expand_as(residual),
),
residual,
prob,
)
for attention_output, attention_bias, residual in x_with_bias
]
# Concatenate outputs (to shape [r, k*bs*l, d]; see notation above).
r, _, d = outputs[0].shape
output = torch.stack(outputs, dim=1).reshape(r, -1, d)
# Output. [ r, k*bs*l, d ]
return output
def forward(self, training: bool, fused: bool) -> partial:
"""Retro decoder bias-dropout-add.
Args:
training (bool): If training, then apply dropout.
fused (bool): Fuse bias-dropout-add.
Returns:
A partial function for performing bias-dropout-add.
"""
return partial(
self._forward,
retro_num_neighbors=self.retro_num_neighbors,
bias_dropout_add=get_bias_dropout_add(training, fused),
)
class RetroEncoderLayerNorm(MegatronModule):
"""Retro encoder's layernorm operator.
This operator applies layernorm individually on each neighboring chunk that
is retrieved from the chunk database, and then concatenates the chunks into
a single tensor.
Args:
config (RetroConfig): Retro config.
submodules (Type): Layer norm class. (Named 'submodules' to fit external interface.)
"""
def __init__(self, config: RetroConfig, submodules: Type, **kwargs: dict):
super().__init__(config=config)
norm_class = submodules
self.norm = norm_class(config=config, **kwargs)
self.retro_num_neighbors = config.retro_num_neighbors
def forward(self, input: Tensor) -> Tensor:
"""Per-chunk layer norm.
Args:
input (Tensor): Input chunks, concatenated into a single tensor.
Returns:
Output of the layer norm.
"""
# Input shape: [ r, k*bs*l, d ]. (see notation above in attention module)
# Split input into 'num_neighbors' tensors.
chunk_size = input.shape[1] // self.retro_num_neighbors
inputs = torch.split(input, chunk_size, dim=1)
# Norm.
outputs = [self.norm(inp.contiguous()) for inp in inputs]
# Concatenate layer norms (to shape [r, k*bs*l, d]; see notation above).
r, _, d = inputs[0].shape
output = torch.stack(outputs, dim=1).reshape(r, -1, d)
# Output. [ r, k*bs*l, d ]
return output
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Specs for Retro encoder."""
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
from megatron.core.models.retro.config import RetroConfig
from megatron.core.models.retro.encoder_attention import (
RetroEncoderBiasDropoutAdd,
RetroEncoderCrossAttention,
RetroEncoderLayerNorm,
)
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer import ModuleSpec
from megatron.core.transformer.attention import CrossAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TENorm,
TERowParallelLinear,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings
from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm
warnings.warn(f'Apex is not installed. Falling back to Torch LayerNorm')
LNImpl = WrappedTorchLayerNorm
def get_retro_encoder_layer_te_spec() -> ModuleSpec:
"""Retro encoder TE spec (uses Transformer Engine components).
A Retro encoder layer uses custom attention, bias-dropout-add, and layernorm
operators to encode neighboring chunks that are retrieved from the chunk
database. Each operator is responsible for iterating the retrieved chunks
and processing them individually.
Returns:
A module spec if Transformer Engine modules.
"""
spec = get_gpt_layer_with_transformer_engine_spec()
spec.submodules.pre_cross_attn_layernorm = TENorm
spec.submodules.cross_attention = ModuleSpec(
module=RetroEncoderCrossAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=CrossAttentionSubmodules(
linear_q=TEColumnParallelLinear,
linear_kv=TEColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
)
spec.submodules.cross_attn_bda = ModuleSpec(module=RetroEncoderBiasDropoutAdd)
spec.submodules.pre_mlp_layernorm = ModuleSpec(module=RetroEncoderLayerNorm, submodules=TENorm)
spec.submodules.mlp = ModuleSpec(
module=MLP,
submodules=MLPSubmodules(linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear),
)
return spec
def get_retro_encoder_layer_local_spec() -> ModuleSpec:
"""Retro encoder local spec (uses Megatron-Core components).
A Retro encoder layer uses custom attention, bias-dropout-add, and layernorm
operators to encode neighboring chunks that are retrieved from the chunk
database. Each operator is responsible for iterating the retrieved chunks
and processing them individually.
Returns:
A module spec if local modules.
"""
spec = get_gpt_layer_local_spec()
spec.submodules.pre_cross_attn_layernorm = LNImpl
spec.submodules.cross_attention = ModuleSpec(
module=RetroEncoderCrossAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=CrossAttentionSubmodules(
linear_q=ColumnParallelLinear,
linear_kv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
),
)
spec.submodules.cross_attn_bda = ModuleSpec(module=RetroEncoderBiasDropoutAdd)
spec.submodules.pre_mlp_layernorm = ModuleSpec(module=RetroEncoderLayerNorm, submodules=LNImpl)
spec.submodules.mlp = ModuleSpec(
module=MLP,
submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear),
)
spec.submodules.sharded_state_dict_keys_map = {
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_'
} # pre_mlp_layernorm doesn't need remapping
return spec
def get_retro_encoder_block_spec(
config: RetroConfig, use_transformer_engine: bool
) -> TransformerBlockSubmodules:
"""Retro encoder block spec.
The retro encoder block consists of one customized Retro encoder layer
(layer 1), and all of the following layers are standard GPT layers.
Args:
config (RetroConfig): Retro config.
use_transformer_engine (bool): If True, use Transformer Engine (instead of local modules).
Returns:
Transformer block submodules for the given spec.
"""
# Num layers.
num_layers = config.retro_encoder_num_layers
retro_layer_numbers = [1]
# Layer specs.
gpt_layer_spec = (
get_gpt_layer_with_transformer_engine_spec()
if use_transformer_engine
else get_gpt_layer_local_spec()
)
get_retro_encoder_layer_spec = (
get_retro_encoder_layer_te_spec
if use_transformer_engine
else get_retro_encoder_layer_local_spec
)
retro_layer_spec = get_retro_encoder_layer_spec()
for spec in (gpt_layer_spec, retro_layer_spec):
spec.params["hidden_dropout"] = config.retro_encoder_hidden_dropout
spec.submodules.self_attention.params["attn_mask_type"] = AttnMaskType.padding
spec.submodules.self_attention.submodules.core_attention = ModuleSpec(
module=TEDotProductAttention if use_transformer_engine else DotProductAttention,
params={"attention_dropout": config.retro_encoder_attention_dropout},
)
layer_specs = []
for layer_number in range(1, num_layers + 1):
if layer_number in retro_layer_numbers:
layer_specs.append(retro_layer_spec)
else:
layer_specs.append(gpt_layer_spec)
# Block spec.
block_spec = TransformerBlockSubmodules(layer_specs=layer_specs)
return block_spec
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Retro Model."""
from typing import Dict, Optional
from torch import Tensor
from megatron.core import InferenceParams
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.gpt import GPTModel
class RetroModel(GPTModel):
"""Retro Model.
A Retro model mostly re-uses the GPTModel interface, with the only difference
being the embedding of the 'context' this is used by Retro for processing
neighbor tokens. This embedded context is then forwarded to the Transformer
Block.
"""
def forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
context_input_ids: Tensor = None,
context_position_ids: Tensor = None,
context_mask: Tensor = None,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
) -> Tensor:
"""RetroModel forward method.
Foward input tokens & mask, along with neighbor tokens & mask, through
the Retro model..
Args:
input_ids (Tensor): Input token IDs.
position_ids (Tensor): Input position IDs.
attention_mask (Tensor): Input attention mask.
context_input_ids (Tensor): Context (i.e., neighbor) token IDs.
context_position_ids (Tensor): Context (i.e., neighbor) position IDs.
context_mask (Tensor): Context (i.e., neighbor) attention mask.
decoder_input (Tensor): When using pipeline parallelism, input_ids and position_ids will only be used on the first stage, and for all other stages decoder_input will be provided via communication from the previous stage.
labels (Tensor): The labels of dimension [batch size, seq length].
inference_params (InferenceParams): Parameters for inference.
Returns:
Output tensor of forward pass.
"""
# Argument shapes:
# Notation:
# ns : Sequence length.
# bs : Batch size.
# d : Hidden size.
# l : Number of chunks per sample (i.e., seq_length/chunk_length).
# k : Number of neighbors.
# r : Number of retrieved tokens (neighbors + continuation).
# - input_ids: [ bs, ns ]
# - context_ids: [ k*bs*l, r ]
# - context: [ r, k*bs*l, d ]
# - output: [ ns, bs, d ]
# Context embedding (e.g., for Retro neighbor tokens).
if context_input_ids is not None:
context = self.embedding(context_input_ids, context_position_ids)
else:
context = None
# Call GPTModel.forward, and pass in embedded context.
return super().forward(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
decoder_input=decoder_input,
labels=labels,
inference_params=inference_params,
extra_block_kwargs={"context": context, "context_mask": context_mask},
)
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None
) -> ShardedStateDict:
"""Get sharded state dict.
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): Offsets of local shard within global tensor.
metadata (Optional[Dict]): Shard metadata.
Returns:
A <ShardedStateDict> ?
"""
metadata = metadata or {}
metadata['non_homogeneous_layers'] = True
return super().sharded_state_dict(prefix, sharded_offsets, metadata)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import os
import torch
def get_config_path(project_dir: str) -> str:
"""Config copy stored within retro project dir."""
return os.path.join(project_dir, "config.json")
def get_gpt_data_dir(project_dir: str) -> str:
"""Get project-relative directory of GPT bin/idx datasets."""
return os.path.join(project_dir, "data")
# ** Note ** : Retro's compatibility between cross attention and Flash/Fused
# Attention is currently a work in progress. We default to returning None for
# now.
# def get_all_true_mask(size, device):
# return torch.full(size=size, fill_value=True, dtype=torch.bool, device=device)
def get_all_true_mask(size, device):
return None
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Optional, Union
import torch
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.extensions.transformer_engine import TENorm
from megatron.core.models.common.vision_module.vision_module import VisionModule
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
# Note: This is under development and is missing features like position embedding interpolation.
class CLIPViTModel(VisionModule):
"""CLIP ViT vision model.
Args:
transformer_config (TransformerConfig): Transformer config.
transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers.
ln_pre_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_pre.
add_class_token (bool, optional): Include a class token. Defaults to True.
class_token_len (int): Class token length. Defaults to 1 but 8 may be faster.
patch_dim (int): Image patch size.
img_h (int): Input image height.
img_w (int): Input image width.
"""
def __init__(
self,
transformer_config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
ln_pre_impl: Union[ModuleSpec, type] = TENorm,
add_class_token: bool = True,
class_token_len: int = 1,
patch_dim: int = 14,
img_h: int = 336,
img_w: int = 336,
) -> None:
super().__init__(config=transformer_config)
if has_config_logger_enabled(transformer_config):
log_config_to_disk(transformer_config, locals(), prefix=type(self).__name__)
self.class_token_len = class_token_len
self.visual_hidden_size = transformer_config.hidden_size
self.patch_dim = patch_dim
self.img_h = img_h
self.img_w = img_w
assert self.img_h % self.patch_dim == 0
assert self.img_w % self.patch_dim == 0
self.num_patches_per_dim_h = self.img_h // self.patch_dim
self.num_patches_per_dim_w = self.img_w // self.patch_dim
self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w
self.add_class_token = add_class_token
self.class_token_len = class_token_len
self.seq_length = self.num_patches + (self.class_token_len if self.add_class_token else 0)
self.conv1 = torch.nn.Conv2d(
in_channels=3,
out_channels=self.visual_hidden_size,
kernel_size=self.patch_dim,
stride=self.patch_dim,
bias=False,
)
self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
self.position_embeddings = torch.nn.Embedding(self.seq_length, self.visual_hidden_size)
self.add_class_token = add_class_token
if self.add_class_token:
self.class_token = torch.nn.Parameter(
torch.randn(1, self.class_token_len, self.visual_hidden_size)
)
self.ln_pre = build_module(
ln_pre_impl,
config=transformer_config,
hidden_size=self.visual_hidden_size,
eps=transformer_config.layernorm_epsilon,
)
self.model_type = ModelType.encoder_or_decoder
# Transformer layers.
# TODO: Follow-up changes will make pre and post_process configurable.
# They are needed for supporting pipeline parallelism.
# Note: a final layer norm and/or linear layer present in some implementations
# are omitted here. They can be added separately where needed.
self.decoder = TransformerBlock(
config=transformer_config,
spec=transformer_layer_spec,
pre_process=True,
post_process=False,
)
def set_input_tensor(self, input_tensor: torch.Tensor) -> None:
"""Sets input tensor to the model.
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
self.decoder.set_input_tensor(input_tensor)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Forward function of the CLIP ViT Model. This function passes the input tensors
through the embedding layer and then the transformer.
Args:
x (torch.Tensor): input data of shape [batch, img_h, img_w]
attention_mask (torch.Tensor with dtype=bool): Attention mask to use.
Returns:
x (torch.Tensor): output after final transformer block of shape [b, s, h].
"""
x = self.conv1(x) # shape = [batch, hidden_size, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # [batch, hidden_size, grid ** 2]
x = x.permute(0, 2, 1) # [batch, grid ** 2, hidden_size]
if self.add_class_token:
class_token = self.class_token.expand(
x.shape[0], -1, -1
) # [batch, class_token_len, hidden_size]
x = torch.cat(
[class_token, x], dim=1
) # [batch, grid ** 2 + class_token_len, hidden_size]
assert x.shape[1] == self.seq_length, f"{x.shape[1]} != {self.seq_length}"
x = x + self.position_embeddings(self.position_ids)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # [b, s, h] -> [s, b, h]
x = x.contiguous()
# contiguous() call required as `permute` can sparsify the tensor and this breaks pipelining
x = self.decoder(x, attention_mask)
x = x.permute(1, 0, 2) # [s, b, h] -> [b, s, h]
x = x.contiguous()
return x
def get_image_sequence_length(img_h, img_w, patch_dim, add_class_token, class_token_len):
"""Get image sequence length given image size, patch size, and class token."""
num_patches_per_dim_h = img_h // patch_dim
num_patches_per_dim_w = img_w // patch_dim
num_patches = num_patches_per_dim_h * num_patches_per_dim_w
return num_patches + (class_token_len if add_class_token else 0)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core import tensor_parallel
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_viewless_tensor
class MultimodalProjector(MegatronModule):
"""
MultimodalProjector will take the encoded input with input_size hidden state and project
it into the hidden size of the language model for multimodal training. When projector is
type affine linear_fc1 from submodules is used.
Args:
transformer_config (TransformerConfig): Transformer config
submodules (MLPSubmodules): Specifies MLP submodules for mlp type projector
projector_type (str): Projector type
input_size (int): Input size from feature encoder
"""
def __init__(
self,
config: TransformerConfig,
submodules: MLPSubmodules,
projector_type: str,
input_size: int,
):
super().__init__(config=config)
self.projector_type = projector_type
assert submodules is not None, "MLPSubmodules must be provided"
if self.projector_type == "mlp":
self.encoder = MLP(config=config, submodules=submodules, input_size=input_size)
elif self.projector_type == "affine":
self.encoder = build_module(
submodules.linear_fc1,
input_size,
config.hidden_size,
config=config,
init_method=config.init_method,
gather_output=True,
bias=config.add_bias_linear,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name=None,
)
else:
raise Exception(f"Unsupported multimodal projection type {self.projector_type}")
def forward(self, hidden_states):
# Run encoder.
encoder_output, encoder_output_bias = self.encoder(hidden_states)
if encoder_output_bias is not None:
encoder_output = encoder_output + encoder_output_bias
# the encoder produces "viewed" tensor. This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
encoder_output = make_viewless_tensor(
inp=encoder_output, requires_grad=True, keep_graph=True
)
return encoder_output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment