Commit 0816dd4a authored by libo11's avatar libo11
Browse files

Initial commit

parents
Pipeline #1728 canceled with stages
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
try:
from megatron.core.transformer.custom_layers.transformer_engine import (
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelLinear,
)
except ImportError:
TEDotProductAttention = None
TELayerNormColumnParallelLinear = None
TENorm = None
TERowParallelLinear = None
#print("Do not support transformer_engine")
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
# Use this spec to use lower level Transformer Engine modules (required for fp8 training)
def get_gpt_layer_with_transformer_engine_spec(
num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
) -> ModuleSpec:
mlp = _get_mlp_module_spec(
use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=TENorm if qk_layernorm else IdentityOp,
k_layernorm=TENorm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
# Use this spec for an implementation using only modules in megatron core
def get_gpt_layer_local_spec(
num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
) -> ModuleSpec:
mlp = _get_mlp_module_spec(
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=FusedLayerNorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=FusedLayerNorm if qk_layernorm else IdentityOp,
k_layernorm=FusedLayerNorm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=FusedLayerNorm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
# Helper function to get module spec for MLP/MoE
def _get_mlp_module_spec(
use_te: bool = True, num_experts: int = None, moe_grouped_gemm: bool = False
) -> ModuleSpec:
if num_experts is None:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
)
else:
# Mixture of experts with modules in megatron core.
return ModuleSpec(
module=MoELayer,
submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear,)
if not moe_grouped_gemm
else None,
)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import logging
from typing import Dict, Literal, Optional, Tuple, Union
import torch
from torch import Tensor
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import AttnMaskType, ModelType
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint
class GPTModel(LanguageModule):
"""GPT Transformer language model.
Args:
config (TransformerConfig): Transformer config
transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers
vocab_size (int): Vocabulary size
max_sequence_length (int): maximum size of sequence. This is used for positional embedding
pre_process (bool, optional): Include embedding layer (used with pipeline parallelism). Defaults to True.
post_process (bool, optional): Include an output layer (used with pipeline parallelism). Defaults to True.
fp16_lm_cross_entropy (bool, optional): Defaults to False.
parallel_output (bool, optional): Do not gather the outputs, keep them split across tensor parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional): When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope], optional): Position embedding type.. Defaults to 'learned_absolute'.
rotary_percent (float, optional): Percent of rotary dimension to use for rotary position embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
rotary_base (int, optional): Base period for rotary position embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 10000.
seq_len_interpolation_factor (Optional[float], optional): scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None.
"""
def __init__(
self,
config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
seq_len_interpolation_factor: Optional[float] = None,
) -> None:
super().__init__(config=config)
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
# These 2 attributes are needed for TensorRT-LLM export.
self.max_position_embeddings = max_sequence_length
self.rotary_percent = rotary_percent
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
)
if self.position_embedding_type == 'rope':
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
)
# Transformer.
self.decoder = TransformerBlock(
config=self.config,
spec=transformer_layer_spec,
pre_process=self.pre_process,
post_process=self.post_process,
)
# Output
if post_process:
if self.config.defer_embedding_wgrad_compute:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs stored
# in gradient buffer to calculate the weight gradients for the embedding final linear layer.
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
def set_input_tensor(self, input_tensor: Tensor) -> None:
"""Sets input tensor to the model.
See megatron.model.transformer.set_input_tensor()
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
self.decoder.set_input_tensor(input_tensor[0])
def forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if decoder_input is not None:
pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
**(extra_block_kwargs or {}),
)
if not self.post_process:
return hidden_states
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
logits, _ = self.output_layer(hidden_states, weight=output_weight)
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
return loss
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None
) -> ShardedStateDict:
""" Sharded state dict implementation for GPTModel backward-compatibility (removing extra state).
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the GPTModel
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
output_layer_extra_state_key = f'{prefix}output_layer._extra_state'
# Old GPT checkpoints only stored the output layer weight key. So we remove the _extra_state key
# but check that it doesn't contain any data anyway
output_extra_state = sharded_state_dict.pop(output_layer_extra_state_key, None)
assert not (
output_extra_state and output_extra_state.data
), f'Expected output layer extra state to be empty, got: {output_extra_state}'
return sharded_state_dict
from .mamba_model import MambaModel
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules
from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules
from megatron.core.ssm.mamba_mixer import Mamba
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.custom_layers.transformer_engine import (
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelLinear,
)
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
mamba_stack_spec = ModuleSpec(
module=MambaStack,
submodules=MambaStackSubmodules(
mamba_layer=ModuleSpec(
module=MambaLayer, submodules=MambaLayerSubmodules(norm=TENorm, mixer=Mamba,),
),
# Started with spec from gpt_layer_specs.py (with MLP removed)
# Using the TE spec because we had problems getting the non-TE spec
# working
attention_layer=ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
),
),
# Started with spec from gpt_layer_specs.py
# Using the TE spec because we had problems getting the non-TE spec
# working
mlp_layer=ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
),
),
),
)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from typing import Literal, Optional
from torch import Tensor
from megatron.core import InferenceParams, tensor_parallel
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
class MambaModel(LanguageModule):
"""Mamba language model.
Args:
config (TransformerConfig): Transformer config
mamba_stack_spec (ModuleSpec): Specifies the modules to use for the various layer types
vocab_size (int): Vocabulary size
max_sequence_length (int): maximum size of sequence. This is used for positional embedding
pre_process (bool, optional): Include embedding layer (used with pipeline parallelism). Defaults to True.
hybrid_attention_ratio (float, optional): The target ratio of attention layers to total layers
hybrid_mlp_ratio (float, optional): The target ratio of mlp layers to total layers
hybrid_override_pattern (str, optional): The hybrid layer pattern to override with
post_process (bool, optional): Include an output layer (used with pipeline parallelism). Defaults to True.
fp16_lm_cross_entropy (bool, optional): Defaults to False.
parallel_output (bool, optional): Do not gather the outputs, keep them split across tensor parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional): When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope,none], optional): Position embedding type. Defaults to 'none'.
rotary_percent (float, optional): Percent of rotary dimension to use for rotary position embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
rotary_base (int, optional): Base period for rotary position embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 10000.
seq_len_interpolation_factor (Optional[float], optional): scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None.
"""
def __init__(
self,
config: TransformerConfig,
mamba_stack_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
hybrid_attention_ratio: float = 0.0,
hybrid_mlp_ratio: float = 0.0,
hybrid_override_pattern: str = None,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
# Mamba with no attention has no need for position embeddings, so none is default
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'none',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
seq_len_interpolation_factor: Optional[float] = None,
) -> None:
super().__init__(config=config)
self.mamba_stack_spec: ModuleSpec = mamba_stack_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.hybrid_attention_ratio = hybrid_attention_ratio
self.hybrid_mlp_ratio = hybrid_mlp_ratio
self.hybrid_override_pattern = hybrid_override_pattern
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
)
if self.position_embedding_type == 'rope':
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
)
self.decoder = build_module(
mamba_stack_spec,
self.config,
pre_process=self.pre_process,
hybrid_attention_ratio=self.hybrid_attention_ratio,
hybrid_mlp_ratio=self.hybrid_mlp_ratio,
hybrid_override_pattern=self.hybrid_override_pattern,
post_process=self.post_process,
dtype=config.params_dtype,
)
# Output
if post_process:
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
def set_input_tensor(self, input_tensor: Tensor) -> None:
"""Sets input tensor to the model.
See megatron.model.transformer.set_input_tensor()
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
self.decoder.set_input_tensor(input_tensor[0])
def forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
) -> Tensor:
"""Forward function of the Mamba model. This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if decoder_input is not None:
pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# The following assert will currently fail when running inference.
# Commented out for now.
# TODO (duncan/rwaleffe): (1) confirm that the externally-generated
# attention mask is not needed and is ignored by the model in
# inference mode, (2) reduce the size of the externally-generated
# attention mask to prevent CPU OOM (as we did for training), (3)
# force the attention mask passed to the model in inference mode to
# be None, so this assert will succeed.
# assert attention_mask is None, "The attention mask is ignored and should be set to None"
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
)
if not self.post_process:
return hidden_states
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
logits, _ = self.output_layer(hidden_states, weight=output_weight)
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
return loss
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
from collections import namedtuple
from functools import partial
from typing import List
import torch
from megatron.core import InferenceParams, parallel_state
from megatron.core.models.gpt import GPTModel
from megatron.core.models.vision.clip_vit_model import CLIPViTModel
from megatron.core.models.vision.multimodal_projector import MultimodalProjector
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
# Note: This is under development and may be missing features.
class LLaVAModel(MegatronModule):
"""LLaVA multi-modal model.
Args:
language_transformer_config (TransformerConfig): Transformer config for the language model.
language_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the language model.
language_vocab_size (int): Language model vocabulary size.
language_max_sequence_length (int): Language model maximum sequence length. This is used for positional embedding.
vision_transformer_config (TransformerConfig): Transformer config for the vision model.
vision_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the vision model.
drop_vision_class_token (bool): Drop vision class token(s) before input to the language model.
vision_projection_config (TransformerConfig): Config for the projection from vision model outputs to language model inputs.
vision_projection_layer_spec (ModuleSpec): Specifies the module to use for the vision projection.
vision_projection_type (str): Type of the vision projection to use. Default is a 2-layer MLP.
allow_missing_vision_projection_checkpoint (bool): Allow vision projection weights to be missing when loading a checkpoint. Default False.
parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks. This is typically True for training and False for inference.
language_position_embedding_type (str): Position embedding type to use in the language model. Default learned absolute.
language_rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings in the language model. Defaults to 1.0.
"""
def __init__(
self,
language_transformer_config: TransformerConfig,
language_transformer_layer_spec: ModuleSpec,
language_vocab_size: int,
language_max_sequence_length: int,
vision_transformer_config: TransformerConfig,
vision_transformer_layer_spec: ModuleSpec,
drop_vision_class_token: bool,
vision_projection_config: TransformerConfig,
vision_projection_layer_spec: ModuleSpec,
vision_projection_type: str = "mlp",
allow_missing_vision_projection_checkpoint: bool = False,
parallel_output: bool = True,
language_position_embedding_type: str = 'learned_absolute',
language_rotary_percent: float = 1.0,
) -> None:
super().__init__(config=language_transformer_config)
logging.getLogger(__name__).warning(
"LLaVA model is under development and may be missing features."
)
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
raise NotImplementedError("pipeline parallelism is not supported in this model yet.")
self.language_model = GPTModel(
config=language_transformer_config,
transformer_layer_spec=language_transformer_layer_spec,
vocab_size=language_vocab_size,
max_sequence_length=language_max_sequence_length,
parallel_output=parallel_output,
position_embedding_type=language_position_embedding_type,
rotary_percent=language_rotary_percent,
)
self.vision_model = CLIPViTModel(vision_transformer_config, vision_transformer_layer_spec)
self._drop_vision_class_token = drop_vision_class_token
# Map (intermediate) vision model outputs to the language model input dimension.
self.vision_projection = MultimodalProjector(
vision_projection_config,
vision_projection_layer_spec,
vision_projection_type,
vision_transformer_config.hidden_size, # input size to the projection.
)
# This allows ignoring missing weights for the vision projection during checkpoint loading.
# This should be disabled by default but can be enabled if your checkpoint contains pretrained
# vision and language models but not the projection from vision model outputs to language model inputs.
if allow_missing_vision_projection_checkpoint:
vision_projection_param_names = [
f"vision_projection.{name}" for name in self.vision_projection.state_dict().keys()
]
self.vision_projection.register_load_state_dict_post_hook(
partial(_load_state_dict_hook_ignore_param_names, vision_projection_param_names)
)
def set_input_tensor(self, input_tensor: torch.Tensor) -> None:
"""Sets input tensor to the model.
NOTE: Pipeline parallelism is not supported in this model yet. This is just a placeholder implementation.
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
self.vision_model.set_input_tensor(input_tensor)
def freeze(
self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool
):
"""Freeze model modules.
Make specific modules non-trainable by setting requires_grad to False for the module's parameters.
Args:
freeze_language_model (bool): Freeze the language model module.
freeze_vision_model (bool): Freeze the vision model module.
freeze_vision_projection (bool): Freeze the vision projection module.
"""
modules = []
if freeze_language_model:
modules.append(self.language_model)
if freeze_vision_model:
modules.append(self.vision_model)
if freeze_vision_projection:
modules.append(self.vision_projection)
for module in modules:
for param in module.parameters():
param.requires_grad = False
def forward(
self,
images: torch.Tensor,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: torch.Tensor = None,
inference_params: InferenceParams = None,
) -> torch.Tensor:
"""Forward function of the LLaVA model.
Args:
images (torch.Tensor): input image of shape [batch, img_h, img_w].
input_ids (torch.Tensor): input text ids [batch, text_seq_len].
position_ids (torch.Tensor): input text position ids [batch, text_seq_len].
attention_mask (torch.Tensor): attention mask for the language model [batch, 1, combined_seq_len, combined_seq_len].
labels (torch.Tensor): Optional target text labels [batch, combined_seq_len].
inference_params (InferenceParams): Inference-time parameters including KV cache.
Returns:
output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size].
"""
language_embeddings = self.language_model.embedding(
input_ids=input_ids, position_ids=position_ids
) # [text_seq_len, b, h_language]
# If running inference, we can skip image token computation if they were computed already earlier for this sample.
if (
inference_params is not None
and "image_tokens_count" in inference_params.key_value_memory_dict
):
combined_embeddings = language_embeddings
else:
image_embeddings = self.vision_model(images) # [b, img_seq_len, h_vision]
if self._drop_vision_class_token:
image_embeddings = image_embeddings[:, self.vision_model.class_token_len :, :]
image_embeddings = image_embeddings.permute(1, 0, 2) # [img_seq_len, b, h_vision]
# map vision model output size to language model input size.
image_embeddings = self.vision_projection(
image_embeddings
) # [img_seq_len, b, h_vision]
# If running inference, the language model KV cache will be updated for image token positions.
# Here we store the image tokens sequence length, which can be used as an offset to the KV cache later.
if inference_params is not None:
inference_params.key_value_memory_dict[
"image_tokens_count"
] = image_embeddings.shape[0]
combined_embeddings = torch.cat(
[image_embeddings, language_embeddings], dim=0
) # [combined_seq_len, b, h_language]
# Embedding is computed above so we can discard input and position ids.
input_ids = None
position_ids = None
# Note: This returns loss if labels are provided, otherwise logits.
output = self.language_model(
input_ids,
position_ids,
attention_mask,
decoder_input=combined_embeddings,
labels=labels,
inference_params=inference_params,
)
return output
def _load_state_dict_hook_ignore_param_names(
param_names: List[str], module: torch.nn.Module, incompatible_keys: namedtuple
):
"""Hook to ignore missing keys during checkpoint loading.
By default, this should not be used to avoid accidentally missing weights in checkpoint loading.
Example use case: Use this for the vision projection if you want to load a checkpoint that contains vision and language model weights
but not the vision projection weights.
Args:
param_names (list of str): Parameter names allowed to be missing when calling load_state_dict.
module (torch.nn.Module): The torch module this hook applies to. Unused here but required by the torch API.
incompatible_keys (namedtuple): Namedtuple with fields missing_keys and unexpected_keys, which collect the missing and unexpected
keys when calling load_state_dict on this torch module, respectively.
"""
for param_name in param_names:
incompatible_keys.missing_keys.remove(param_name)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
Exports:
- RetroConfig: configuration dataclass for RetroModel.
- RetroModel: The Retro model.
- get_retro_decoder_block_spec: Get spec for Retro decoder transformer block.
"""
from .config import RetroConfig
from .decoder_spec import get_retro_decoder_block_spec
from .model import RetroModel
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Base class for decoder and encoder attention modules."""
from megatron.core.models.retro.config import RetroConfig
from megatron.core.transformer.attention import CrossAttention, CrossAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
class BaseRetroCrossAttention(MegatronModule):
"""Base class for Retro cross attention, for both encoder & decoder layers.
This class collects the retro arguments below (i.e., num neighbors, chunk
length, and retrieve length) for use in Retro's custom cross attention
operators.
Args:
config (RetroConfig): Retro config.
submodules (CrossAttentionSubmodules): Cross attention submodules.
layer_number (int): Layer number within transformer block.
attn_mask_type (AttnMaskType): Mask type ('causal' or 'padding').
"""
def __init__(
self,
config: RetroConfig,
submodules: CrossAttentionSubmodules,
layer_number: int = 1,
attn_mask_type: AttnMaskType = AttnMaskType.padding,
):
super().__init__(config=config)
self.attn = CrossAttention(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
)
self.retro_num_neighbors = config.retro_num_neighbors
self.retro_chunk_length = config.retro_chunk_length
self.retro_retrieved_length = config.retro_retrieved_length
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Configuration dataclass for a RetroModel."""
import os
import types
from dataclasses import dataclass
from importlib.metadata import version
from pkg_resources import packaging
from megatron.core.transformer import TransformerConfig
@dataclass
class RetroConfig(TransformerConfig):
"""Configuration object for Retro models. """
# Retro.
retro_project_dir: str = None
"""Retro project directory, which contains the preprocessed data for for pretraining. This
directory is built during preprocessing (see tools/retro/README.md), and contains
subdirectories for the chunk database and pretraining neighbors.
"""
retro_block_size: int = None
"""Number of records to load per data file, as saved during preprocessing. Block processing is
used for efficient data preprocessing.
"""
retro_chunk_length: int = None
"""Chunk length used for performing chunked- cross-attention (CCA)."""
retro_encoder_num_layers: int = 2
"""Number of layers to use for the retrieval encoder."""
retro_encoder_hidden_dropout: float = 0.1
"""Hidden dropout for retrieval encoder."""
retro_encoder_attention_dropout: float = 0.1
"""Attention dropout for retrieval encoder."""
retro_neighbor_dirs: dict = None
"""Directory names of saved neighbor id files for train, valid, and test datasets."""
retro_num_neighbors: int = 2
"""Number of neighbors to retrieve during pretraining."""
retro_num_retrieved_chunks: int = 2
"""Number of chunks to retrieve from the retrieval database."""
retro_retrieved_length: int = None
"""Cached value of retro_num_retrieved_chunks * retro_chunk_length (i.e., the total number of
retrieved tokens; neighbor + continuation).
"""
retro_split_preprocessing: str = None
"""Data split used during data preprocessing."""
retro_verify_neighbor_count: bool = True
"""Verify that len(GPT dataset) == len(saved neighbors)."""
def __post_init__(self) -> None:
"""Validate Retro config."""
super().__post_init__()
# Validate Transformer Engine version.
te_version = packaging.version.Version(version("transformer-engine"))
if te_version >= packaging.version.Version("1.3"):
try:
assert os.getenv("NVTE_FLASH_ATTN") == "0"
assert os.getenv("NVTE_FUSED_ATTN") == "0"
except Exception as e:
raise Exception(
"When using Transformer Engine >= 1.3, environment vars NVTE_FLASH_ATTN and NVTE_FUSED_ATTN most both be defined and set to '0'. Currently, NVTE_FLASH_ATTN == %s, NVTE_FUSED_ATTN == %s."
% (
os.getenv("NVTE_FLASH_ATTN", "[unset]"),
os.getenv("NVTE_FUSED_ATTN", "[unset]"),
)
)
# Preprocessing split should be defined.
assert self.retro_split_preprocessing is not None
# Pre-compute retrieved length.
self.retro_retrieved_length = self.retro_num_retrieved_chunks * self.retro_chunk_length
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Retro's cross attention modules for the decoder block."""
from functools import partial
from typing import Callable
import numpy as np
import torch
from torch import Tensor
from megatron.core import InferenceParams
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.retro.base_attention import BaseRetroCrossAttention
from megatron.core.models.retro.config import RetroConfig
from megatron.core.models.retro.utils import get_all_true_mask
from megatron.core.transformer import ModuleSpec
from megatron.core.transformer.attention import CrossAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_block import TransformerBlock
class RetroDecoderCrossAttention(BaseRetroCrossAttention):
"""Retro decoder's chunked cross attention operator.
See this paper for more details: https://arxiv.org/abs/2112.04426.
Neighboring chunks retrieved from the chunk database are used here for
chunked-cross attention.
** Note about 'encoder_block_spec' **
Retro is an encoder-decoder model that uses its encoder for encoding
neighboring chunks that are retrieved from a chunk database. These
encoded neighbors are then used in the decoder stack for performing
chunked-cross attention (see paper link above).
In contrast to the T5 model, the encoder and decoder are computationally
intertwined, since the input to the encoder is the output of the self-
attention of the first decoder layer. As such, the encoder block itself
is instantiated within the first Retro decoder layer, in order to receive
the self-attention's output. (Note, that only the first decoder layer
instantiates an encoder block, and the remaining decoder layers use the
encoder output from the first decoder layer.)
Args:
config (RetroConfig): Retro config.
submodules (CrossAttentionSubmodules): Cross attention submodules.
layer_number (int): Layer number within transformer block.
attn_mask_type (AttnMaskType): Mask type ('causal' or 'padding').
encoder_block_spec (ModuleSpec): The first Retro decoder layer is provided with a transformer block spec to construct the neighbor encoder.
"""
def __init__(
self,
config: RetroConfig,
submodules: CrossAttentionSubmodules,
layer_number: int = 1,
attn_mask_type: AttnMaskType = AttnMaskType.padding,
encoder_block_spec: ModuleSpec = None,
):
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
)
if encoder_block_spec:
self.encoder = TransformerBlock(
config=config, spec=encoder_block_spec, pre_process=True, post_process=False,
)
# self._encoder_key = 'encoder' # ... necessary?
else:
self.encoder = None
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
key_value_states: Tensor = None,
inference_params: InferenceParams = None,
# rotary_pos_emb: Tensor = None, # ... unsupported for retro.
) -> dict:
"""Cross attention for Retro decoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
m : Number of tokens per chunk.
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
Args:
hidden_states (Tensor): Transformer layer hidden states.
attention_mask (Tensor): Attention mask.
key_value_states (Tensor): Neighbor embeddings if first decoder layer, else encoder output.
inference_params (InferenceParams): Inference params.
Returns:
A dict consisting of the attention output and context, along with other scalars necessary for performing the downstream bias-dropout-add.
"""
# hidden_states: [ ns, bs, d ]
# key_value_states: [ r, k*bs*l, d ]
ns, bs, d = hidden_states.shape
l = int(np.ceil(ns / self.retro_chunk_length))
# Retrieve neighbors.
if self.encoder:
# Sequence length remainder.
first_ns = ns % self.retro_chunk_length
# Case 1: Sequence length not divisible by chunk length.
if first_ns > 0:
# Split sequence into first partial chunk & remaining chunks.
first_chunk, rest_chunk = hidden_states[:first_ns], hidden_states[first_ns:]
# Pad partial chunk with zeros.
first_chunk = torch.nn.functional.pad(
first_chunk, (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns), 'constant', 0,
)
# Concatenate padded chunk with remaining chunks.
chunked_output = torch.cat((first_chunk, rest_chunk), dim=0) # [ l*m, bs, d ]
# Case 2: Sequence length is divisible by chunk length.
else:
chunked_output = hidden_states # [ l*m, bs, d ]
# Chunk & permute hidden states.
# - hidden_states: [ l*m, bs, d ]
# - chunked_output: [ m, bs*l, d ]
chunked_output = (
chunked_output.reshape(l, self.retro_chunk_length, bs, d)
.permute(1, 2, 0, 3)
.reshape(self.retro_chunk_length, bs * l, d)
.contiguous()
)
# flash attn: [ b, h, sq, sk ]
# fused attn: [ b, 1, 1, sq ]
chunked_output_mask = get_all_true_mask(
size=(1, 1, chunked_output.shape[0], key_value_states.shape[0]),
device=chunked_output.device,
)
# Encode neighbors. (Note: 'key_value_states' re-assigned here.)
key_value_states = self.encoder(
hidden_states=key_value_states,
attention_mask=attention_mask,
context=chunked_output,
context_mask=chunked_output_mask,
inference_params=inference_params,
) # [ r, k*bs*l, d ]
key_value_states = key_value_states.reshape(
self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d
) # [ r*k, bs*l, d ]
# Attend starting at last token of first chunk.
pad = (ns - 1) % self.retro_chunk_length
attending_chunks = hidden_states[pad:]
# Pad attending tokens to sequence length.
padded_chunks = torch.nn.functional.pad(
attending_chunks, (0, 0, 0, 0, 0, self.retro_chunk_length - 1), 'constant', 0,
)
# Permute attending chunks.
# - padded_chunks: [ l*m, bs, d ]
# - padded_chunked_output: [ m, bs*l, d ] (matches 'chunked_output' above)
padded_chunked_output = padded_chunks.reshape(l, self.retro_chunk_length, bs, d).permute(
1, 2, 0, 3
)
padded_chunked_output = padded_chunked_output.reshape(
self.retro_chunk_length, bs * l, d
).contiguous()
# flash attn: [ b, h, sq, sk ]
# fused attn: [ b, 1, 1, sq ]
padded_chunked_output_mask = get_all_true_mask(
size=(1, 1, padded_chunked_output.shape[0], key_value_states.shape[0]),
device=padded_chunked_output.device,
)
# Attend to encoded neighbors.
attention_output, attention_bias = self.attn(
hidden_states=padded_chunked_output,
attention_mask=padded_chunked_output_mask,
key_value_states=key_value_states,
)
# Return dimensions for bias-dropout step.
return {
"ns": ns,
"bs": bs,
"d": d,
"l": l,
"pad": pad,
"attention_output": attention_output, # [ m, bs*l, d ]
"attention_bias": attention_bias, # [ d ]
"context": key_value_states, # [ r*k, bs*l, d ]
}
class RetroDecoderBiasDropoutAdd(MegatronModule):
"""Retro decoder's bias-dropout-add operator.
This operator takes care of reshaping and permuting the output from the
chunk dimension to the sequence dimension.
Args:
config (RetroConfig): Retro config.
"""
def __init__(
self, config: RetroConfig,
):
super().__init__(config=config)
self.retro_chunk_length = config.retro_chunk_length
@classmethod
def _forward(
cls,
x_with_bias: dict,
residual: Tensor,
prob: float,
retro_chunk_length: int,
bias_dropout_add: Callable,
) -> Tensor:
"""Per-chunk bias-dropout-add.
Args:
x_with_bias (dict): Attention output and bias, along with other Retro relevant parameters.
residual (Tensor): Transformer layer residual.
prob (float): Dropout probability.
retro_chunk_length (int): Retro chunk length (e.g., 64).
bias_dropout_add (Callable): Bias-dropout-add function.
Returns:
Output of bias-dropout-add.
"""
# Extract input dict.
ns = x_with_bias["ns"]
bs = x_with_bias["bs"]
d = x_with_bias["d"]
l = x_with_bias["l"]
pad = x_with_bias["pad"]
attention_output = x_with_bias["attention_output"] # [ m, bs*l, d ]
attention_bias = x_with_bias["attention_bias"] # [ d ]
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
# Bias-dropout-add.
x = bias_dropout_add(
(
attention_output,
None if attention_bias is None else attention_bias.expand_as(attention_output),
),
torch.zeros_like(attention_output),
prob,
)
# Permute chunks back to sequence dimension.
# 1. [ m, bs*l, d ]
# 2. [ m, bs, l, d ]
# 3. [ l, m, bs, d ]
# 4. [ m*l, bs, d ] == [ ns, bs, d ]
x = (
x.reshape(retro_chunk_length, bs, l, d)
.permute(2, 0, 1, 3)
.reshape(retro_chunk_length * l, bs, d)
)
# Prepend zeros for non-attending tokens.
x = torch.nn.functional.pad(x, (0, 0, 0, 0, pad, 0), 'constant', 0,)[
:ns
] # [ ns, bs, d ]
# Add residual. [ ns, bs, d ]
x = x + residual
# Output. [ ns, bs, d ]
return x
def forward(self, training: bool, fused: bool) -> partial:
"""Retro decoder bias-dropout-add.
Args:
training (bool): If training, then apply dropout.
fused (bool): Fuse bias-dropout-add.
Returns:
The partial function for performing bias-dropout-add.
"""
return partial(
self._forward,
retro_chunk_length=self.retro_chunk_length,
bias_dropout_add=get_bias_dropout_add(training, fused),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment