Commit 688448db authored by silencealiang's avatar silencealiang
Browse files

更新代码

parent a02a5490
Pipeline #2503 passed with stage
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import warnings import warnings
from typing import Optional from typing import Optional
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.mlp import MLPSubmodules from megatron.core.transformer.mlp import MLPSubmodules
from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP
from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.utils import get_te_version, is_te_min_version from megatron.core.utils import get_te_version, is_te_min_version
try: try:
from megatron.core.extensions.transformer_engine import ( from megatron.core.extensions.transformer_engine import (
TEColumnParallelGroupedLinear, TEColumnParallelGroupedLinear,
TEColumnParallelLinear, TEColumnParallelLinear,
TERowParallelGroupedLinear, TERowParallelGroupedLinear,
TERowParallelLinear, TERowParallelLinear,
) )
HAVE_TE = True HAVE_TE = True
except ImportError: except ImportError:
HAVE_TE = False HAVE_TE = False
def get_moe_module_spec( def get_moe_module_spec(
use_te: Optional[bool] = True, use_te: Optional[bool] = True,
num_experts: Optional[int] = None, num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False, moe_grouped_gemm: Optional[bool] = False,
moe_use_legacy_grouped_gemm: Optional[bool] = False, moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec: ) -> ModuleSpec:
"""Helper function to get module spec for MoE""" """Helper function to get module spec for MoE"""
assert num_experts is not None assert num_experts is not None
mlp = MLPSubmodules( mlp = MLPSubmodules(
linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
) )
# experts spec # experts spec
if moe_grouped_gemm: if moe_grouped_gemm:
## use GroupedMLP ## use GroupedMLP
if use_te and TEColumnParallelGroupedLinear is not None and not moe_use_legacy_grouped_gemm: if use_te and TEColumnParallelGroupedLinear is not None and not moe_use_legacy_grouped_gemm:
## use TEGroupedLinear ## use TEGroupedLinear
expert_module = TEGroupedMLP expert_module = TEGroupedMLP
expert_submodule = MLPSubmodules( expert_submodule = MLPSubmodules(
linear_fc1=TEColumnParallelGroupedLinear, linear_fc2=TERowParallelGroupedLinear linear_fc1=TEColumnParallelGroupedLinear, linear_fc2=TERowParallelGroupedLinear
) )
else: else:
## use legacy GroupedMLP ## use legacy GroupedMLP
expert_module = GroupedMLP expert_module = GroupedMLP
expert_submodule = None expert_submodule = None
warnings.warn( warnings.warn(
'The legacy GroupedMLP will be deprecated in Megatron-Core v0.12.0. ' 'The legacy GroupedMLP will be deprecated in Megatron-Core v0.12.0. '
'Please update the TransformerEngine to version>=1.7.0 and use TEGroupedMLP.' 'Please update the TransformerEngine to version>=1.7.0 and use TEGroupedMLP.'
) )
else: else:
## use SequentialMLP ## use SequentialMLP
expert_module = SequentialMLP expert_module = SequentialMLP
if use_te and not is_te_min_version("1.7.0.dev0"): if use_te and not is_te_min_version("1.7.0.dev0"):
warnings.warn( warnings.warn(
"Only transformer-engine>=1.7.0 supports MoE experts, " "Only transformer-engine>=1.7.0 supports MoE experts, "
f"but your version is {get_te_version()}. Use local linear implementation instead." f"but your version is {get_te_version()}. Use local linear implementation instead."
) )
expert_submodule = MLPSubmodules( expert_submodule = MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear
) )
else: else:
expert_submodule = mlp expert_submodule = mlp
experts = ModuleSpec(module=expert_module, submodules=expert_submodule) experts = ModuleSpec(module=expert_module, submodules=expert_submodule)
# shared experts spec # shared experts spec
shared_experts = ModuleSpec(module=SharedExpertMLP, params={"gate": False}, submodules=mlp) shared_experts = ModuleSpec(module=SharedExpertMLP, params={"gate": False}, submodules=mlp)
# MoE module spec # MoE module spec
moe_module_spec = ModuleSpec( moe_module_spec = ModuleSpec(
module=MoELayer, submodules=MoESubmodules(experts=experts, shared_experts=shared_experts) module=MoELayer, submodules=MoESubmodules(experts=experts, shared_experts=shared_experts)
) )
return moe_module_spec return moe_module_spec
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from .module import HuggingFaceModule, build_hf_model
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from transformers import AutoModel
from megatron.core.models.huggingface import HuggingFaceModule
class ClipHuggingFaceModel(HuggingFaceModule):
"""
Wrapper for CLIP HuggingFace models
"""
def __init__(self, config):
super().__init__(config)
self.model = AutoModel.from_pretrained(config.huggingface_model_name_or_path)
def forward(self, *args, **kwargs):
"""Forward function"""
x = self.model(*args, **kwargs)
x = x['last_hidden_state']
return x
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from transformers import AutoConfig, AutoModel
from megatron.core.transformer.module import MegatronModule
class HuggingFaceModule(MegatronModule):
"""
Basic module for huggingface
"""
def __init__(self, config):
super().__init__(config=config)
def set_input_tensor(self, input_tensor):
"""Dummy function for set_input_tensor"""
self.input_tensor = input_tensor
class AutoHuggingFaceModel(HuggingFaceModule):
"""
Wrapper for HuggingFace AutoModel
"""
def __init__(self, config):
super().__init__(config)
self.model = AutoModel.from_pretrained(config.huggingface_model_name_or_path)
def forward(self, *args, **kwargs):
"""Forward function"""
return self.model(*args, **kwargs)
def build_hf_model(config):
"""Builds huggingface wrapper model given config"""
hf_config = AutoConfig.from_pretrained(config.huggingface_model_name_or_path)
if "qwen" in hf_config.model_type:
from megatron.core.models.huggingface.qwen_model import QwenHuggingFaceModel
model = QwenHuggingFaceModel(config)
elif "vit" in hf_config.model_type:
from megatron.core.models.huggingface.clip_model import ClipHuggingFaceModel
model = ClipHuggingFaceModel(config)
else:
raise NotImplementedError(f"Huggingface model type {hf_config.model_type} is not supported")
return model
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from transformers.models.qwen2 import Qwen2ForCausalLM
from megatron.core.models.huggingface import HuggingFaceModule
class QwenHuggingFaceModel(HuggingFaceModule):
"""
Wrapper for Qwen LM HuggingFace models
"""
def __init__(self, config):
super().__init__(config)
self.model = Qwen2ForCausalLM.from_pretrained(config.huggingface_model_name_or_path)
def forward(self, *args, **kwargs):
"""Forward function"""
combined_embeddings = kwargs['decoder_input'].permute(1, 0, 2)
x = self.model(
position_ids=None, # TODO: I guess we're just assuming no custom pos ids
attention_mask=kwargs['attention_mask'],
inputs_embeds=combined_embeddings,
labels=kwargs['labels'],
)
if kwargs['labels'] is not None:
x = x["loss"]
else:
x = x["logits"]
return x
def embedding(self, input_ids, position_ids=None):
"""Function to run process tokens with input embeddings"""
return self.model.get_input_embeddings()(input_ids).transpose(1, 0).contiguous()
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from megatron.core.extensions.transformer_engine import ( from megatron.core.extensions.transformer_engine import (
TEDotProductAttention, TEDotProductAttention,
TELayerNormColumnParallelLinear, TELayerNormColumnParallelLinear,
TERowParallelLinear, TERowParallelLinear,
) )
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules
from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules
from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
mamba_stack_spec = ModuleSpec( mamba_stack_spec = ModuleSpec(
module=MambaStack, module=MambaStack,
submodules=MambaStackSubmodules( submodules=MambaStackSubmodules(
mamba_layer=ModuleSpec( mamba_layer=ModuleSpec(
module=MambaLayer, module=MambaLayer,
submodules=MambaLayerSubmodules( submodules=MambaLayerSubmodules(
mixer=ModuleSpec( mixer=ModuleSpec(
module=MambaMixer, module=MambaMixer,
submodules=MambaMixerSubmodules( submodules=MambaMixerSubmodules(
in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear
), ),
), ),
mamba_bda=get_bias_dropout_add, mamba_bda=get_bias_dropout_add,
), ),
), ),
# Started with spec from gpt_layer_specs.py (with MLP removed) # Started with spec from gpt_layer_specs.py (with MLP removed)
# Using the TE spec because we had problems getting the non-TE spec # Using the TE spec because we had problems getting the non-TE spec
# working # working
attention_layer=ModuleSpec( attention_layer=ModuleSpec(
module=TransformerLayer, module=TransformerLayer,
submodules=TransformerLayerSubmodules( submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec( self_attention=ModuleSpec(
module=SelfAttention, module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal}, params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules( submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear, linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention, core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear, linear_proj=TERowParallelLinear,
), ),
), ),
self_attn_bda=get_bias_dropout_add, self_attn_bda=get_bias_dropout_add,
), ),
), ),
# Started with spec from gpt_layer_specs.py # Started with spec from gpt_layer_specs.py
# Using the TE spec because we had problems getting the non-TE spec # Using the TE spec because we had problems getting the non-TE spec
# working # working
mlp_layer=ModuleSpec( mlp_layer=ModuleSpec(
module=TransformerLayer, module=TransformerLayer,
submodules=TransformerLayerSubmodules( submodules=TransformerLayerSubmodules(
mlp=ModuleSpec( mlp=ModuleSpec(
module=MLP, module=MLP,
submodules=MLPSubmodules( submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear
), ),
), ),
mlp_bda=get_bias_dropout_add, mlp_bda=get_bias_dropout_add,
), ),
), ),
), ),
) )
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
"""Multimodal Sequence Parallel (SP) and Context Parallel (CP) functionality."""
import torch
from megatron.core.packed_seq_params import PackedSeqParams
def get_padding(
seq_len, cp_size, tp_size, has_sp, decoder_tp_comm_overlap=False, decoder_seq_len=None
):
"""Calculate padding needed for SP and/or CP.
Args:
seq_len (int): Model sequence length.
cp_size (int): Context parallel size.
tp_size (int): Tensor parallel size.
has_sp (bool): Model uses sequence parallelism.
decoder_tp_comm_overlap (bool): Decoder (LLM) uses tensor parallel communication overlap.
decoder_seq_len (int): Decoder (LLM) maximum sequence length.
Returns:
padding (int): Padding needed given model configuration.
"""
padding = 0
# TP Comm overlap is performed with combined text+image embeddings.
if has_sp and decoder_tp_comm_overlap:
# If TP Comm Overlap is enabled for combined text+image embedding in LM backbone,
# user needs to provide decoder_seq_len with any potential padding needed for SP+CP
assert (
decoder_seq_len is not None
), "Please provide decoder seq length when using TP comm overlap for LM backbone"
padding = decoder_seq_len - seq_len
elif has_sp or cp_size > 1:
padding_factor = 1
if has_sp and cp_size > 1:
# Padding to multiple of tp_size * cp_size * 2 when using CP + SP.
padding_factor = tp_size * cp_size * 2
elif cp_size > 1:
padding_factor = cp_size * 2
elif has_sp:
padding_factor = tp_size
padding = int((seq_len + padding_factor - 1) // padding_factor * padding_factor) - seq_len
return padding
def get_packed_seq_params(tokens, img_seq_len, padding_needed, cp_size, use_packed_sequence=False):
"""Get PackedSeqParams for CP.
Args:
tokens (torch.Tensor): [batch, seq_len] input tokens.
img_seq_len (int): Image sequence length.
padding_needed (int): Padding to add.
cp_size (int): Context parallel size.
use_packed_sequence (bool): Uses sequence packing.
Returns:
packed_seq_params (PackedSeqParams): Parameters to be sent to Transformer Engine.
"""
batch_size = tokens.shape[0]
# Calculate the valid token seq len that LM backbone should compute on
combined_valid_seqlen = tokens.shape[1] + img_seq_len - padding_needed
cu_seqlens = torch.arange(
0,
(batch_size + 1) * (combined_valid_seqlen),
step=(combined_valid_seqlen),
dtype=torch.int32,
device=tokens.device,
)
# Calculate the total padded token seq len
combined_padded_seqlen = tokens.shape[1] + img_seq_len
cu_seqlens_padded = None
qkv_format = 'sbhd'
if cp_size > 1 and (padding_needed > 0 or use_packed_sequence):
# Provide cu_seqlens_<q/kv>_padded for CP support
cu_seqlens_padded = torch.arange(
0,
(batch_size + 1) * (combined_padded_seqlen),
step=(combined_padded_seqlen),
dtype=torch.int32,
device=tokens.device,
)
# CP with padding mask type requires THD format
qkv_format = 'thd'
packed_seq_params = PackedSeqParams(
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
cu_seqlens_q_padded=cu_seqlens_padded,
cu_seqlens_kv_padded=cu_seqlens_padded,
max_seqlen_q=combined_padded_seqlen,
max_seqlen_kv=combined_padded_seqlen,
qkv_format=qkv_format,
)
return packed_seq_params
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging import logging
from collections import namedtuple from collections import namedtuple
from functools import partial from functools import partial
from typing import List, Optional from typing import List, Optional
import torch import torch
from megatron.core import InferenceParams, tensor_parallel from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.models.gpt import GPTModel from megatron.core.models.gpt import GPTModel
from megatron.core.models.vision.clip_vit_model import CLIPViTModel, get_num_image_embeddings from megatron.core.models.vision.clip_vit_model import CLIPViTModel, get_num_image_embeddings
from megatron.core.models.vision.multimodal_projector import MultimodalProjector from megatron.core.models.vision.multimodal_projector import MultimodalProjector
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.models.vision.radio import RADIOViTModel
from megatron.core.parallel_state import get_context_parallel_group, get_context_parallel_world_size from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer import MegatronModule from megatron.core.parallel_state import get_context_parallel_rank, get_context_parallel_world_size
from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.utils import log_single_rank from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import log_single_rank
try:
import transformer_engine # pylint: disable=unused-import try:
from transformer_engine.pytorch.distributed import gather_along_first_dim import transformer_engine # pylint: disable=unused-import
from megatron.core.extensions.transformer_engine import TEDotProductAttention from megatron.core.extensions.transformer_engine import TEDotProductAttention
from megatron.core.utils import is_te_min_version from megatron.core.utils import is_te_min_version
HAVE_TE = True HAVE_TE = True
except: try:
HAVE_TE = False import transformer_engine_torch as tex
if get_context_parallel_world_size() > 1:
raise RuntimeError("ContextParallelism requires TransformerEngine support, but not found.") HAVE_TEX = True
except:
HAVE_TEX = False
IGNORE_INDEX = -100 # ID for labels that should be ignored. except:
# Image token index can be tokenizer dependent so the default value does not work in all cases. HAVE_TE = False
DEFAULT_IMAGE_TOKEN_INDEX = -200 if get_context_parallel_world_size() > 1:
IMAGE_TOKEN = "<image>" raise RuntimeError("ContextParallelism requires TransformerEngine support, but not found.")
VIDEO_TOKEN = "<video>"
IGNORE_INDEX = -100 # ID for labels that should be ignored.
# Note: This is under development and may be missing features. # Image token index can be tokenizer dependent so the default value does not work in all cases.
class LLaVAModel(MegatronModule): DEFAULT_IMAGE_TOKEN_INDEX = -200
"""LLaVA multi-modal model. IMAGE_TOKEN = "<image>"
VIDEO_TOKEN = "<video>"
Args:
language_transformer_config (TransformerConfig): Transformer config for the language model.
language_transformer_layer_spec (ModuleSpec): Language model spec. class _get_data_on_this_cp_rank(torch.autograd.Function):
language_vocab_size (int): Language model vocabulary size. """Performs sharding for Context Parallelism in THD format
language_max_sequence_length (int): Language model maximum sequence length.
vision_transformer_config (TransformerConfig): Transformer config for the vision model. In the forward pass, indices are selected for each CP rank and remaining tokens are dropped.
vision_transformer_layer_spec (ModuleSpec): Vision model spec. In the backward pass, this class takes care of managing gradients for dropped tokens on each
drop_vision_class_token (bool): Drop vision class token(s) before the language model. CP rank.
vision_projection_config (TransformerConfig): Vision projection config. """
vision_projection_layer_spec (ModuleSpec): Vision projection spec.
vision_projection_type (str): Type of the vision projection. Default: 2-layer MLP. @staticmethod
allow_missing_vision_projection_checkpoint (bool): Allow vision projection weights to be def forward(ctx, batch, packed_seq_params):
missing when loading a checkpoint. Default False. """Context Parallelism forward support for THD format"""
parallel_output (bool): Keep outputs split across tensor parallel ranks. cp_size = get_context_parallel_world_size()
This is typically True for training and False for inference. cp_rank = get_context_parallel_rank()
language_position_embedding_type (str): Language model position embedding type. for key, data in batch.items():
language_rotary_percent (float): RoPE percent. Defaults to 1.0. index = tex.thd_get_partitioned_indices(
pre_process (bool): Include embedding layer in the decoder (used with pipeline parallel). packed_seq_params.cu_seqlens_q_padded, data.size(1), cp_size, cp_rank
post_process (bool): Include output layer in the decoder (used with pipeline parallel). )
add_encoder (bool): Construct the encoder (used with pipeline parallel). if key == "combined_embeddings":
When we use pipelining, the encoder will live on only the first stage ctx.decoder_emb_index = index
add_decoder (bool): Construct the decoder (used with pipeline parallel). ctx.decoder_emb_seqlen = data.size(1)
When we use pipelining, the decoder will live on every stage after the first one. batch[key] = data.index_select(1, index)
img_h (int): Input image height. batch[key].requires_grad = data.requires_grad
img_w (int): Input image width.
patch_dim (int): The size of each image patch side. return batch
language_rotary_base (int): RoPE base.
language_rope_scaling (bool): Toggle RoPE scaling. @staticmethod
image_token_index (int): Token ID for image token such as <image>. def backward(ctx, grad_out, grad_label, grad_loss):
pixel_shuffle (bool): Enable pixel shuffle. """Context Parallelism backward support for THD format"""
tile_tags (list): Optional tile tags. seqlen = ctx.decoder_emb_seqlen
""" index = ctx.decoder_emb_index
assert grad_out.size(1) == index.size(
def __init__( 0
self, ), f"Shape mismatch in incoming gradient {grad_out.shape} and \
language_transformer_config: TransformerConfig, index from THD CP sharding {index.shape}"
language_transformer_layer_spec: ModuleSpec, grad_in = torch.zeros(
language_vocab_size: int, grad_out.size(0),
language_max_sequence_length: int, seqlen,
vision_transformer_config: TransformerConfig, *grad_out.size()[2:],
vision_transformer_layer_spec: ModuleSpec, dtype=grad_out.dtype,
drop_vision_class_token: bool, device=grad_out.device,
vision_projection_config: TransformerConfig, )
vision_projection_layer_spec: ModuleSpec, grad_in[:, ctx.decoder_emb_index, :] = grad_out
vision_projection_type: str = "mlp",
allow_missing_vision_projection_checkpoint: bool = False, return (grad_in, None, None, None)
parallel_output: bool = True,
language_position_embedding_type: str = 'learned_absolute',
language_rotary_percent: float = 1.0, # Note: This is under development and may be missing features.
pre_process: bool = True, class LLaVAModel(MegatronModule):
post_process: bool = True, """LLaVA multi-modal model.
add_encoder: bool = True,
add_decoder: bool = True, Args:
img_h: int = 336, language_transformer_config (TransformerConfig): Transformer config for the language model.
img_w: int = 336, language_transformer_layer_spec (ModuleSpec): Language model spec.
patch_dim: int = 14, language_vocab_size (int): Language model vocabulary size.
language_rotary_base: int = 10000, language_max_sequence_length (int): Language model maximum sequence length.
language_rope_scaling: bool = False, vision_transformer_config (TransformerConfig): Transformer config for the vision model.
image_token_index: int = DEFAULT_IMAGE_TOKEN_INDEX, vision_transformer_layer_spec (ModuleSpec): Vision model spec.
pixel_shuffle: bool = False, drop_vision_class_token (bool): Drop vision class token(s) before the language model.
tile_tags: Optional[list] = None, vision_projection_config (TransformerConfig): Vision projection config.
) -> None: vision_projection_layer_spec (ModuleSpec): Vision projection spec.
super().__init__(config=language_transformer_config) vision_projection_type (str): Type of the vision projection. Default: 2-layer MLP.
allow_missing_vision_projection_checkpoint (bool): Allow vision projection weights to be
if has_config_logger_enabled(language_transformer_config): missing when loading a checkpoint. Default False.
log_config_to_disk(language_transformer_config, locals(), prefix=type(self).__name__) parallel_output (bool): Keep outputs split across tensor parallel ranks.
This is typically True for training and False for inference.
log_single_rank( share_embeddings_and_output_weights (bool): Input embedding and output layer share weights.
logging.getLogger(__name__), language_position_embedding_type (str): Language model position embedding type.
logging.WARNING, language_rotary_percent (float): RoPE percent. Defaults to 1.0.
"LLaVA is work in progress. Features are missing and methods can change.", pre_process (bool): Include embedding layer in the decoder (used with pipeline parallel).
) post_process (bool): Include output layer in the decoder (used with pipeline parallel).
add_encoder (bool): Construct the encoder (used with pipeline parallel).
self.pre_process = pre_process When we use pipelining, the encoder will live on only the first stage
self.post_process = post_process add_decoder (bool): Construct the decoder (used with pipeline parallel).
self.add_encoder = add_encoder When we use pipelining, the decoder will live on every stage after the first one.
self.add_decoder = add_decoder img_h (int): Input image height.
img_w (int): Input image width.
self.encoder_hidden_state = None patch_dim (int): The size of each image patch side.
self.vision_model = None language_rotary_base (int): RoPE base.
self.vision_projection = None language_rope_scaling (bool): Toggle RoPE scaling.
self.language_model = None language_rope_scaling_factor (float): RoPE scaling factor. Defaults to 8.
image_token_index (int): Token ID for image token such as <image>.
self.sequence_parallel_lm = language_transformer_config.sequence_parallel pixel_shuffle (bool): Enable pixel shuffle.
self.tp_comm_overlap_lm = language_transformer_config.tp_comm_overlap tile_tags (list): Optional tile tags.
self.context_parallel_lm = language_transformer_config.context_parallel_size """
if self.sequence_parallel_lm or self.context_parallel_lm > 1:
assert ( def __init__(
language_transformer_layer_spec.submodules.self_attention.submodules.core_attention self,
== TEDotProductAttention language_transformer_config: TransformerConfig,
and HAVE_TE language_transformer_layer_spec: ModuleSpec,
), "Sequence/Context Parallelism is supported only with TE DotProductAttention." language_vocab_size: int,
if self.context_parallel_lm > 1: language_max_sequence_length: int,
assert is_te_min_version( vision_transformer_config: TransformerConfig,
"1.10.0" vision_transformer_layer_spec: ModuleSpec,
), "Context Parallelism in LLaVA requires TE v1.10 or higher" drop_vision_class_token: bool,
self.tensor_model_parallel_size_lm = language_transformer_config.tensor_model_parallel_size vision_projection_config: TransformerConfig,
vision_projection_layer_spec: ModuleSpec,
# This attribute is needed to check if an all-reduce is required vision_projection_type: str = "mlp",
# on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`. allow_missing_vision_projection_checkpoint: bool = False,
self.share_embeddings_and_output_weights = False parallel_output: bool = True,
if self.add_decoder: share_embeddings_and_output_weights: bool = False,
self.language_model = GPTModel( language_position_embedding_type: str = 'learned_absolute',
config=language_transformer_config, language_rotary_percent: float = 1.0,
transformer_layer_spec=language_transformer_layer_spec, pre_process: bool = True,
vocab_size=language_vocab_size, post_process: bool = True,
max_sequence_length=language_max_sequence_length, add_encoder: bool = True,
parallel_output=parallel_output, add_decoder: bool = True,
position_embedding_type=language_position_embedding_type, img_h: int = 336,
rotary_percent=language_rotary_percent, img_w: int = 336,
pre_process=self.pre_process, patch_dim: int = 14,
post_process=self.post_process, language_rotary_base: int = 10000,
rotary_base=language_rotary_base, language_rope_scaling: bool = False,
rope_scaling=language_rope_scaling, language_rope_scaling_factor: float = 8.0,
scatter_embedding_sequence_parallel=False, image_token_index: int = DEFAULT_IMAGE_TOKEN_INDEX,
) pixel_shuffle: bool = False,
self.share_embeddings_and_output_weights = ( tile_tags: Optional[list] = None,
self.language_model.share_embeddings_and_output_weights ) -> None:
) super().__init__(config=language_transformer_config)
self._language_max_sequence_length = language_max_sequence_length
self._language_is_pipeline_parallel = ( if has_config_logger_enabled(language_transformer_config):
language_transformer_config.pipeline_model_parallel_size > 1 log_config_to_disk(language_transformer_config, locals(), prefix=type(self).__name__)
)
log_single_rank(
class_token_len = 1 logging.getLogger(__name__),
if self.add_encoder: logging.WARNING,
self._drop_vision_class_token = drop_vision_class_token "LLaVA is work in progress. Features are missing and methods can change.",
add_class_token = True )
if vision_transformer_config.vision_model_type == "siglip":
class_token_len = 0 self.pre_process = pre_process
add_class_token = False self.post_process = post_process
error_msg = ( self.add_encoder = add_encoder
"Siglip does not support vision class token, " self.add_decoder = add_decoder
"set disable-vision-class-token to False."
) self.encoder_hidden_state = None
assert not self._drop_vision_class_token, error_msg self.vision_model = None
self.vision_model = CLIPViTModel( self.vision_projection = None
vision_transformer_config, self.language_model = None
vision_transformer_layer_spec,
img_h=img_h, self.sequence_parallel_lm = language_transformer_config.sequence_parallel
img_w=img_w, self.tp_comm_overlap_lm = language_transformer_config.tp_comm_overlap
class_token_len=class_token_len, self.context_parallel_lm = language_transformer_config.context_parallel_size
patch_dim=patch_dim, if self.sequence_parallel_lm or self.context_parallel_lm > 1:
model_subtype=vision_transformer_config.vision_model_type, assert (
add_class_token=add_class_token, language_transformer_layer_spec.submodules.self_attention.submodules.core_attention
) == TEDotProductAttention
and HAVE_TE
vision_projection_input_size = vision_transformer_config.hidden_size ), "Sequence/Context Parallelism is supported only with TE DotProductAttention."
vision_projection_input_size *= 4 if pixel_shuffle else 1 if self.context_parallel_lm > 1:
assert is_te_min_version(
# Map (intermediate) vision model outputs to the language model input dimension. "1.10.0"
self.vision_projection = MultimodalProjector( ), "Context Parallelism in LLaVA requires TE v1.10 or higher"
vision_projection_config, self.tensor_model_parallel_size_lm = language_transformer_config.tensor_model_parallel_size
vision_projection_layer_spec,
vision_projection_type, # This attribute is needed to check if an all-reduce is required
vision_projection_input_size, # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`.
) self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
# Ignore missing weights for the vision projection during checkpoint loading. if self.add_decoder:
# This should be disabled by default but can be enabled if your checkpoint contains if hasattr(
# pretrained vision and language models but not the projection from vision model language_transformer_config, "language_model_type"
# outputs to language model inputs. ) and language_transformer_config.language_model_type.startswith("huggingface"):
if allow_missing_vision_projection_checkpoint: from megatron.core.models.huggingface.module import build_hf_model
vision_projection_param_names = [
f"vision_projection.{name}" self.language_model = build_hf_model(language_transformer_config)
for name in self.vision_projection.state_dict().keys() else:
] self.language_model = GPTModel(
self.vision_projection.register_load_state_dict_post_hook( config=language_transformer_config,
partial(_load_state_dict_hook_ignore_param_names, vision_projection_param_names) transformer_layer_spec=language_transformer_layer_spec,
) vocab_size=language_vocab_size,
max_sequence_length=language_max_sequence_length,
self._img_seq_len = get_num_image_embeddings( parallel_output=parallel_output,
img_h, position_embedding_type=language_position_embedding_type,
img_w, rotary_percent=language_rotary_percent,
patch_dim, pre_process=self.pre_process,
vision_transformer_config.vision_model_type, post_process=self.post_process,
drop_vision_class_token, rotary_base=language_rotary_base,
class_token_len, rope_scaling=language_rope_scaling,
pixel_shuffle, scatter_embedding_sequence_parallel=False,
tile_tags is not None, # Tile tags enabled/disabled. )
)
self.share_embeddings_and_output_weights = (
self.image_token_index = image_token_index self.language_model.share_embeddings_and_output_weights
self._pixel_shuffle = pixel_shuffle )
self._tile_tags = tile_tags self._language_max_sequence_length = language_max_sequence_length
self._language_is_pipeline_parallel = (
def shared_embedding_or_output_weight(self): language_transformer_config.pipeline_model_parallel_size > 1
"""This is a convenience method to surface the language model's word embeddings, which is )
necessary for `finalize_model_grads._allreduce_word_embedding_grads`."""
if self.add_decoder: # Newer Transformer Engine versions add _extra_state keys in state_dict when using FP8.
return self.language_model.shared_embedding_or_output_weight() # Older models may not have _extra_state and can be ignored.
return None self.language_model.register_load_state_dict_post_hook(
_load_state_dict_hook_ignore_extra_state
def set_input_tensor(self, input_tensor) -> None: )
"""Set model chunk input tensor."""
# This is usually handled in schedules.py but some inference code still class_token_len = 1
# gives us non-lists or None if self.add_encoder:
if not isinstance(input_tensor, list): self._drop_vision_class_token = drop_vision_class_token
input_tensor = [input_tensor] add_class_token = True
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for llava' if vision_transformer_config.vision_model_type.startswith(
("clip", "siglip", "internvit")
if self.add_encoder and self.add_decoder: ):
self.vision_model.set_input_tensor(input_tensor[0]) if vision_transformer_config.vision_model_type == "siglip":
elif self.add_encoder: class_token_len = 0
self.vision_model.set_input_tensor(input_tensor[0]) add_class_token = False
elif self.pre_process: error_msg = (
self.encoder_hidden_state = input_tensor[0] "Siglip does not support vision class token, "
else: "set disable-vision-class-token to False."
self.language_model.set_input_tensor(input_tensor[0]) )
assert not self._drop_vision_class_token, error_msg
def freeze( self.vision_model = CLIPViTModel(
self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool vision_transformer_config,
): vision_transformer_layer_spec,
"""Freeze model modules. img_h=img_h,
img_w=img_w,
Make specific modules non-trainable by setting requires_grad to False. class_token_len=class_token_len,
patch_dim=patch_dim,
Args: model_subtype=vision_transformer_config.vision_model_type,
freeze_language_model (bool): Freeze the language model module. add_class_token=add_class_token,
freeze_vision_model (bool): Freeze the vision model module. )
freeze_vision_projection (bool): Freeze the vision projection module. elif vision_transformer_config.vision_model_type in ("radio"):
""" # TODO: should refactor into model code itself?
modules = [] class_token_len = 8
if freeze_language_model and self.language_model is not None: max_img_h = 2048
modules.append(self.language_model) max_img_w = 2048
if freeze_vision_model and self.vision_model is not None: embedder_bias = False
modules.append(self.vision_model) use_mask_token = False
if freeze_vision_projection and self.vision_projection is not None: self.vision_model = RADIOViTModel(
modules.append(self.vision_projection) vision_transformer_config,
vision_transformer_layer_spec,
for module in modules: img_h=img_h,
for param in module.parameters(): img_w=img_w,
param.requires_grad = False max_img_h=max_img_h,
max_img_w=max_img_w,
def _preprocess_data( class_token_len=class_token_len,
self, patch_dim=patch_dim,
image_embeddings, add_class_token=add_class_token,
language_embeddings, embedder_bias=embedder_bias,
input_ids, use_mask_token=use_mask_token,
loss_mask, )
labels, elif vision_transformer_config.vision_model_type.startswith("huggingface"):
use_inference_kv_cache, from megatron.core.models.huggingface.module import build_hf_model
inference_params,
image_token_index, self.vision_model = build_hf_model(vision_transformer_config)
num_image_tiles, else:
image_token_mask=None, raise ValueError(
): "Vision model "
"""Preprocess input data before input to language model. f"{vision_transformer_config.vision_model_type} is not "
"supported."
This function is adopted from )
https://github.com/huggingface/transformers/blob/85817d98fb60977c97e3014196a462b732d2ed1a/src/transformers/models/llava_next/modeling_llava_next.py#L409
for our input data conventions. self.vision_model.register_load_state_dict_post_hook(
_load_state_dict_hook_ignore_extra_state
image_token_index = -200 indicates the image position in the input_ids = [0, 1, -200, 2, 3] )
and labels = [1, -200, 2, 3, 4], for example.
We want to replace the image position (-200) with image_embeddings and return the following: vision_projection_input_size = vision_transformer_config.hidden_size
- final_embeddings = [0, 1, image_embeddings, 2, 3], vision_projection_input_size *= 4 if pixel_shuffle else 1
- final_labels = [1, -100, 2, 3, 4]
- final_loss_mask = [1, 0, 0, 1, 1] # Map (intermediate) vision model outputs to the language model input dimension.
self.vision_projection = MultimodalProjector(
This function handles samples without images (text-only sample). It also handles samples vision_projection_config,
with images that are split into multiples tiles. vision_projection_layer_spec,
vision_projection_type,
If pipeline parallelism is not used, then self.pre_process and self.post_process vision_projection_input_size,
are both True and we update both input embeddings, labels and loss masks (if available). )
# Ignore missing weights for the vision projection during checkpoint loading.
If pipeline parallelism is used, then we do the following # This should be disabled by default but can be enabled if your checkpoint contains
- the first language model chunk has self.pre_process = True and # pretrained vision and language models but not the projection from vision model
self.post_process = False. We update input embeddings. # outputs to language model inputs.
- the middle language model chunk(s) has self.pre_process = False and if allow_missing_vision_projection_checkpoint:
self.post_process = False. We don't need to update anything. vision_projection_param_names = [
- the last language model chunk has self.pre_process = False and f"vision_projection.{name}"
self.post_process = True. We update labels and loss mask. for name in self.vision_projection.state_dict().keys()
]
TODO: This function should adjust the attention mask too. self.vision_projection.register_load_state_dict_post_hook(
Currently, we assume the language model uses a causal mask. partial(_load_state_dict_hook_ignore_param_names, vision_projection_param_names)
)
Returns:
final_embedding (torch.Tensor): image and text embeddings [combined_seq_len, b, h]. self.img_seq_len = get_num_image_embeddings(
final_labels (torch.Tensor): labels for image and text positions [b, combined_seq_len]. img_h,
final_loss_mask (torch.Tensor): loss mask [b, combined_seq_len]. img_w,
""" patch_dim,
assert self.add_decoder, "input text preprocessing is only needed for the language model" vision_transformer_config.vision_model_type,
drop_vision_class_token,
# No pre- or postprocessing needed. class_token_len,
# With pipeline parallel > 2, this means a chunk in the middle of the model. pixel_shuffle,
if not self.pre_process and not self.post_process: tile_tags is not None, # Tile tags enabled/disabled.
return None, None, None )
# If using the inference KV cache, the image tokens are already computed. self.image_token_index = image_token_index
if use_inference_kv_cache: self._pixel_shuffle = pixel_shuffle
return language_embeddings, loss_mask, labels self._tile_tags = tile_tags
img_seq_len = self._img_seq_len def shared_embedding_or_output_weight(self):
batch_size, text_seq_len = input_ids.shape """This is a convenience method to surface the language model's word embeddings, which is
# input_ids seq len is expected to be sharded by CP size necessary for `finalize_model_grads._allreduce_word_embedding_grads`."""
if self.context_parallel_lm: if self.add_decoder:
text_seq_len *= self.context_parallel_lm return self.language_model.shared_embedding_or_output_weight()
return None
has_labels = labels is not None
if has_labels: def set_input_tensor(self, input_tensor) -> None:
assert ( """Set model chunk input tensor."""
labels.shape == loss_mask.shape # This is usually handled in schedules.py but some inference code still
), f"mismatching labels shape {labels.shape} and loss mask shape {loss_mask.shape}" # gives us non-lists or None
if not isinstance(input_tensor, list):
# Create indices for new text and label positions. input_tensor = [input_tensor]
with torch.no_grad(): assert len(input_tensor) == 1, 'input_tensor should only be length 1 for llava'
if image_token_mask is None:
assert ( if self.add_encoder and self.add_decoder:
self.context_parallel_lm <= 1 self.vision_model.set_input_tensor(input_tensor[0])
), "image_token_mask cannot be inferred from input_ids if using \ elif self.add_encoder:
Context Parallelism. Please provide in forward_step" self.vision_model.set_input_tensor(input_tensor[0])
image_token_mask = input_ids == image_token_index elif self.pre_process:
num_images_per_sample = torch.sum(image_token_mask, dim=-1) self.encoder_hidden_state = input_tensor[0]
else:
# Number of tiles per sample. self.language_model.set_input_tensor(input_tensor[0])
num_image_tiles_batch = num_image_tiles.split(num_images_per_sample.tolist(), dim=0)
num_image_tiles_batch = torch.tensor( def freeze(
[x.sum() for x in num_image_tiles_batch], device=input_ids.device self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool
) ):
"""Freeze model modules.
# Sequence length for each sample is the image sequence length multiplied by
# the number of tiles for that image, minus image token indices, Make specific modules non-trainable by setting requires_grad to False.
# plus text sequence length.
seq_lens = num_image_tiles_batch * img_seq_len - num_images_per_sample + text_seq_len Args:
max_seq_len = seq_lens.max() freeze_language_model (bool): Freeze the language model module.
# Pipeline parallel expects fixed input size. Check if we need to pad. freeze_vision_model (bool): Freeze the vision model module.
if ( freeze_vision_projection (bool): Freeze the vision projection module.
self._language_is_pipeline_parallel """
and max_seq_len < self._language_max_sequence_length modules = []
and inference_params is None if freeze_language_model and self.language_model is not None:
): modules.append(self.language_model)
max_seq_len = self._language_max_sequence_length if freeze_vision_model and self.vision_model is not None:
modules.append(self.vision_model)
batch_indices, non_image_indices = torch.where(image_token_mask != True) if freeze_vision_projection and self.vision_projection is not None:
modules.append(self.vision_projection)
# New position ids for the text tokens, shifted by the image sequence length.
# E.g. for input_ids = [-200, 1, 2, 3] and img_seq_len = 576, we get for module in modules:
# new_position_ids = [576, 577, 578, 579]. text_position_ids are then [577, 578, 579]. for param in module.parameters():
image_token_mask_lens = image_token_mask.int().clone() param.requires_grad = False
# -1 is for the removed image token index.
image_token_mask_lens[image_token_mask] = num_image_tiles * img_seq_len - 1 def _preprocess_data(
# +1 is needed here for the cumulative sum. -1 is adjusting for zero-based indexing. self,
new_position_ids = torch.cumsum((image_token_mask_lens + 1), dim=-1) - 1 image_embeddings,
text_position_ids = new_position_ids[batch_indices, non_image_indices] language_embeddings,
input_ids,
# Labels are shifted to left by one. loss_mask,
# So, shift text position ids and non-image indices to left by one. labels,
if has_labels: use_inference_kv_cache,
label_text_position_ids = text_position_ids - 1 inference_params,
valid_label_text_position_ids = label_text_position_ids >= 0 image_token_index,
label_text_position_ids = label_text_position_ids[valid_label_text_position_ids] num_image_tiles,
):
label_batch_indices = batch_indices[valid_label_text_position_ids] """Preprocess input data before input to language model.
label_non_image_indices = non_image_indices - 1 This function is adopted from
valid_label_non_image_indices = label_non_image_indices >= 0 https://github.com/huggingface/transformers/blob/85817d98fb60977c97e3014196a462b732d2ed1a/src/transformers/models/llava_next/modeling_llava_next.py#L409
label_non_image_indices = label_non_image_indices[valid_label_non_image_indices] for our input data conventions.
# Create a mask for the image embedding positions. image_token_index = -200 indicates the image position in the input_ids = [0, 1, -200, 2, 3]
images_mask = torch.full( and labels = [1, -200, 2, 3, 4], for example.
(batch_size, max_seq_len), True, dtype=torch.bool, device=input_ids.device We want to replace the image position (-200) with image_embeddings and return the following:
) - final_embeddings = [0, 1, image_embeddings, 2, 3],
# No images in the text positions. - final_labels = [1, -100, 2, 3, 4]
images_mask[batch_indices, text_position_ids] = False - final_loss_mask = [1, 0, 0, 1, 1]
# Samples can have different amount of images tokens.
# new_position_ids[:, -1] gives the last text position id for each sample. This function handles samples without images (text-only sample). It also handles samples
# Padding is needed when the number of image tokens differs. with images that are split into multiples tiles.
first_padding_idx = new_position_ids[:, -1] + 1
images_mask[ If pipeline parallelism is not used, then self.pre_process and self.post_process
torch.arange(max_seq_len, device=first_padding_idx.device).repeat(batch_size, 1) are both True and we update both input embeddings, labels and loss masks (if available).
>= first_padding_idx.unsqueeze(1)
] = False If pipeline parallelism is used, then we do the following
- the first language model chunk has self.pre_process = True and
# Create the final input embedding (if this is the first language model stage). self.post_process = False. We update input embeddings.
final_embedding = None - the middle language model chunk(s) has self.pre_process = False and
if self.pre_process: self.post_process = False. We don't need to update anything.
embed_dim = language_embeddings.shape[-1] - the last language model chunk has self.pre_process = False and
final_embedding = torch.zeros( self.post_process = True. We update labels and loss mask.
batch_size,
max_seq_len, TODO: This function should adjust the attention mask too.
embed_dim, Currently, we assume the language model uses a causal mask.
dtype=language_embeddings.dtype,
device=language_embeddings.device, Returns:
) final_embedding (torch.Tensor): image and text embeddings [combined_seq_len, b, h].
final_labels (torch.Tensor): labels for image and text positions [b, combined_seq_len].
# Put text embeddings to the text positions in the result tensor. final_loss_mask (torch.Tensor): loss mask [b, combined_seq_len].
final_embedding[batch_indices, text_position_ids] = language_embeddings[ """
batch_indices, non_image_indices assert self.add_decoder, "input text preprocessing is only needed for the language model"
]
# No pre- or postprocessing needed.
# Put image embeddings to image positions. # With pipeline parallel > 2, this means a chunk in the middle of the model.
final_embedding[images_mask] = ( if not self.pre_process and not self.post_process:
image_embeddings.permute(1, 0, 2).reshape(-1, embed_dim).contiguous() return None, None, None
)
# If using the inference KV cache, the image tokens are already computed.
# Create the final labels and loss mask (if this is the last language model stage). if use_inference_kv_cache:
final_labels, final_loss_mask = None, None return language_embeddings, loss_mask, labels
if self.post_process and has_labels:
final_labels = torch.full( img_seq_len = self.img_seq_len
(batch_size, max_seq_len), IGNORE_INDEX, dtype=labels.dtype, device=labels.device batch_size, text_seq_len = input_ids.shape
)
final_loss_mask = torch.full( has_labels = labels is not None
(batch_size, max_seq_len), 0, dtype=loss_mask.dtype, device=loss_mask.device if has_labels:
) assert (
labels.shape == loss_mask.shape
# Put text labels and loss mask to the text positions. ), f"mismatching labels shape {labels.shape} and loss mask shape {loss_mask.shape}"
final_labels[label_batch_indices, label_text_position_ids] = labels[
label_batch_indices, label_non_image_indices # Create indices for new text and label positions.
] with torch.no_grad():
image_token_mask = input_ids == image_token_index
final_loss_mask[batch_indices, text_position_ids] = loss_mask[ num_images_per_sample = torch.sum(image_token_mask, dim=-1)
batch_indices, non_image_indices
] # Number of tiles per sample.
num_image_tiles_batch = num_image_tiles.split(num_images_per_sample.tolist(), dim=0)
# For labels, pick the last label index that got dropped by the shift to left. num_image_tiles_batch = torch.tensor(
label_extra_text_position_ids = seq_lens - 1 [x.sum() for x in num_image_tiles_batch], device=input_ids.device
batch_range = torch.arange(len(label_extra_text_position_ids)) )
final_labels[batch_range, label_extra_text_position_ids] = labels[batch_range, -1]
# Sequence length for each sample is the image sequence length multiplied by
# Loss mask the image positions. # the number of tiles for that image, minus image token indices,
final_loss_mask[images_mask] = 0 # plus text sequence length.
seq_lens = num_image_tiles_batch * img_seq_len - num_images_per_sample + text_seq_len
# Loss mask last text position just before an image max_seq_len = seq_lens.max()
# so that text token does not need to predict the first image token. # Pipeline parallel expects fixed input size. Check if we need to pad.
batch_image_indices, image_indices = torch.where(image_token_mask) if (
# Indices just before image tokens. If it's -1, skip it. self._language_is_pipeline_parallel
before_image_indices = image_indices - 1 and max_seq_len < self._language_max_sequence_length
valid = before_image_indices >= 0 and inference_params is None
valid_batch_image_indices = batch_image_indices[valid] ):
valid_before_image_indices = before_image_indices[valid] max_seq_len = self._language_max_sequence_length
# Map those indices those position ids.
valid_before_image_indices = new_position_ids[ batch_indices, non_image_indices = torch.where(image_token_mask != True)
valid_batch_image_indices, valid_before_image_indices
] # New position ids for the text tokens, shifted by the image sequence length.
# E.g. for input_ids = [-200, 1, 2, 3] and img_seq_len = 576, we get
final_loss_mask[valid_batch_image_indices, valid_before_image_indices] = 0 # new_position_ids = [576, 577, 578, 579]. text_position_ids are then [577, 578, 579].
image_token_mask_lens = image_token_mask.int().clone()
if final_embedding is not None and final_labels is not None: # -1 is for the removed image token index.
assert ( image_token_mask_lens[image_token_mask] = num_image_tiles * img_seq_len - 1
final_embedding.shape[:2] == final_labels.shape == final_loss_mask.shape # +1 is needed here for the cumulative sum. -1 is adjusting for zero-based indexing.
), "unexpected shapes after data preprocessing" new_position_ids = torch.cumsum((image_token_mask_lens + 1), dim=-1) - 1
text_position_ids = new_position_ids[batch_indices, non_image_indices]
if final_embedding is not None:
# Truncate if exceeding the language model's max sequence length. label_batch_indices = None # dummy value to pass formatting
if final_embedding.shape[1] > self._language_max_sequence_length: # Labels are shifted to left by one.
final_embedding = final_embedding[:, : self._language_max_sequence_length] # So, shift text position ids and non-image indices to left by one.
# Transpose to [s,b,h] if not using CP because CP Sharding expects seq in dim=1 label_batch_indices = None
if self.context_parallel_lm == 1: if has_labels:
final_embedding = final_embedding.transpose(1, 0).contiguous() label_text_position_ids = text_position_ids - 1
valid_label_text_position_ids = label_text_position_ids >= 0
truncate_labels = ( label_text_position_ids = label_text_position_ids[valid_label_text_position_ids]
final_labels is not None and final_labels.shape[1] > self._language_max_sequence_length
) label_batch_indices = batch_indices[valid_label_text_position_ids]
if truncate_labels:
final_labels = final_labels[:, : self._language_max_sequence_length] label_non_image_indices = non_image_indices - 1
final_loss_mask = final_loss_mask[:, : self._language_max_sequence_length] valid_label_non_image_indices = label_non_image_indices >= 0
label_non_image_indices = label_non_image_indices[valid_label_non_image_indices]
return final_embedding, final_labels, final_loss_mask
# Create a mask for the image embedding positions.
def _process_embedding_token_parallel( images_mask = torch.full(
self, combined_embeddings, new_labels, new_loss_mask, packed_seq_params (batch_size, max_seq_len), True, dtype=torch.bool, device=input_ids.device
): )
"""Processes the input data for model parallelism support. # No images in the text positions.
images_mask[batch_indices, text_position_ids] = False
When using sequence parallelism (SP) or context parallelism (CP), the sequence is sharded # Samples can have different amount of images tokens.
across different GPUs. This function helps ensure that the sharding is done correctly by # new_position_ids[:, -1] gives the last text position id for each sample.
1. Calculates `padding_factor` which determines based on how many chunks we expect to shard # Padding is needed when the number of image tokens differs.
the sequence first_padding_idx = new_position_ids[:, -1] + 1
2. Calculates and pads the inputs to necessary length to ensure equal sized chunks images_mask[
3. Creates/Modifies PackedSeqParams which helps mask padded tokens during calculations torch.arange(max_seq_len, device=first_padding_idx.device).repeat(batch_size, 1)
4. Performs any layout changes if necessary >= first_padding_idx.unsqueeze(1)
5. Distributes the sequence across GPUs for SP and CP ] = False
Context Parallelism is a feature that helps improve memory efficiency for # Create the final input embedding (if this is the first language model stage).
long sequence training by distributing sequence across CP ranks. final_embedding = None
It requires token length to be divisible by (CP size *2) to ensure proper load balance. if self.pre_process:
Please refer to `get_batch_on_this_cp_rank` function for more details. embed_dim = language_embeddings.shape[-1]
final_embedding = torch.zeros(
Sequence Parallelism is a feature that helps improve memory efficiency for batch_size,
long sequence training by distributing sequence across TP ranks. max_seq_len,
It requires token length to be divisible by TP size. embed_dim,
dtype=language_embeddings.dtype,
Returns: device=language_embeddings.device,
combined_embeddings (torch.Tensor): image and text embeddings combined and distributed. )
new_labels (torch.Tensor): Distributed labels for image and text positions.
new_loss_mask (torch.Tensor): Distributed loss mask. # Put text embeddings to the text positions in the result tensor.
packed_seq_params (PackedSeqParams): Dict with padded token information. final_embedding[batch_indices, text_position_ids] = language_embeddings[
batch_indices, non_image_indices
""" ]
# combined_embeddings - `s,b,h` if not using CP, `b,s,h` if using CP
batch_size = ( # Put image embeddings to image positions.
combined_embeddings.shape[0] # NOTE: FSDP can hang with text-only samples so we use a workaround to run a dummy image
if self.context_parallel_lm > 1 # through the vision model and then zero-out the impact of the output here.
else combined_embeddings.shape[1] if num_image_tiles.shape[0] == 0 and image_embeddings.shape[0] > 0:
) assert images_mask.sum() == 0 and getattr(
seq_dim = 1 if self.context_parallel_lm > 1 else 0 self.vision_model, "_is_fsdp_managed_module", False
), "expected FSDP and dummy image"
padding_mask_type = 'padding' in str( final_embedding[:1, :1, :1] += 0 * image_embeddings[:1, :1, :1]
self.language_model.transformer_layer_spec.submodules.self_attention.params.get( else:
'attn_mask_type', '' final_embedding[images_mask] = (
) image_embeddings.permute(1, 0, 2).reshape(-1, embed_dim).contiguous()
) )
if self.sequence_parallel_lm and self.tp_comm_overlap_lm:
assert ( # Create the final labels and loss mask (if this is the last language model stage).
combined_embeddings.shape[seq_dim] == self._language_max_sequence_length final_labels, final_loss_mask = None, None
) or padding_mask_type, f"TP Comm overlap either requires Vision+Text token length \ if self.post_process and has_labels:
== language_max_sequence_length or mask type to be set to padding/padding_causal" final_labels = torch.full(
(batch_size, max_seq_len), IGNORE_INDEX, dtype=labels.dtype, device=labels.device
if padding_mask_type: )
# Calculate the padded sequence length needed to support SP and CP final_loss_mask = torch.full(
# SP and CP are used to distributed the sequence across GPUs to improve (batch_size, max_seq_len), 0, dtype=loss_mask.dtype, device=loss_mask.device
# memory efficiency and enable very long context training. )
# To distribute workload equally, we need to ensure that the sequence is
# divisible by the appropriate padding factor calculated below. # Put text labels and loss mask to the text positions.
padding_factor = None final_labels[label_batch_indices, label_text_position_ids] = labels[
padded_seq_len = None label_batch_indices, label_non_image_indices
mp_padding_needed = 0 ]
if self.context_parallel_lm > 1 and self.sequence_parallel_lm:
padding_factor = self.tensor_model_parallel_size_lm * self.context_parallel_lm * 2 final_loss_mask[batch_indices, text_position_ids] = loss_mask[
elif self.context_parallel_lm > 1: batch_indices, non_image_indices
padding_factor = self.context_parallel_lm * 2 ]
elif self.sequence_parallel_lm:
padding_factor = self.tensor_model_parallel_size_lm # For labels, pick the last label index that got dropped by the shift to left.
label_extra_text_position_ids = seq_lens - 1
padded_seq_len = int( batch_range = torch.arange(len(label_extra_text_position_ids))
(combined_embeddings.shape[seq_dim] + (padding_factor - 1)) final_labels[batch_range, label_extra_text_position_ids] = labels[batch_range, -1]
// padding_factor
* padding_factor # Loss mask the image positions.
) final_loss_mask[images_mask] = 0
assert ( # Loss mask last text position just before an image
padded_seq_len <= self._language_max_sequence_length # so that text token does not need to predict the first image token.
), f"Sequence length after padding {padded_seq_len} for SP/CP has exceeded \ batch_image_indices, image_indices = torch.where(image_token_mask)
language_max_sequence_length. Ensure language_max_sequence_length is \ # Indices just before image tokens. If it's -1, skip it.
divisible by SP/CP factor: {padding_factor}" before_image_indices = image_indices - 1
valid = before_image_indices >= 0
if self.sequence_parallel_lm and self.tp_comm_overlap_lm: valid_batch_image_indices = batch_image_indices[valid]
# TP Comm overlap initializes the user buffer shape used for communication valid_before_image_indices = before_image_indices[valid]
# at the beginning of training run and the same shape is expected to be # Map those indices those position ids.
# used throughout the training. valid_before_image_indices = new_position_ids[
# Pad to language_max_sequence_length to use TP Comm overlap. valid_batch_image_indices, valid_before_image_indices
assert ( ]
self._language_max_sequence_length % padding_factor == 0
), f"TP Comm overlap uses language_max_sequence_length \ final_loss_mask[valid_batch_image_indices, valid_before_image_indices] = 0
which needs to be divisible by SP/CP factor {padding_factor}"
padded_seq_len = self._language_max_sequence_length if final_embedding is not None and final_labels is not None:
assert (
assert ( final_embedding.shape[:2] == final_labels.shape == final_loss_mask.shape
packed_seq_params is not None ), "unexpected shapes after data preprocessing"
), "Please provide PackedSeqParams dict when using SP or CP with padding"
valid_seqlens = packed_seq_params.cu_seqlens_q[1:] - packed_seq_params.cu_seqlens_q[:-1] if final_embedding is not None:
valid_seq_len = max(valid_seqlens) # Truncate if exceeding the language model's max sequence length.
assert ( if final_embedding.shape[1] > self._language_max_sequence_length:
padded_seq_len >= valid_seq_len final_embedding = final_embedding[:, : self._language_max_sequence_length]
), f"Padded Seq Len calculated for model parallelism: {padded_seq_len} \ # Transpose to [s,b,h] only if not using CP because CP Sharding expects seq in dim=1
is shorter than expected valid token len {valid_seq_len} provided." if self.context_parallel_lm == 1:
final_embedding = final_embedding.transpose(1, 0).contiguous()
mp_padding_needed = padded_seq_len - combined_embeddings.shape[seq_dim]
if mp_padding_needed > 0: truncate_labels = (
new_labels = torch.nn.functional.pad( final_labels is not None and final_labels.shape[1] > self._language_max_sequence_length
new_labels, (0, mp_padding_needed), value=IGNORE_INDEX )
) if truncate_labels:
new_loss_mask = torch.nn.functional.pad(new_loss_mask, (0, mp_padding_needed)) final_labels = final_labels[:, : self._language_max_sequence_length]
if self.context_parallel_lm > 1: final_loss_mask = final_loss_mask[:, : self._language_max_sequence_length]
combined_embeddings = torch.nn.functional.pad(
combined_embeddings, (0, 0, 0, mp_padding_needed) return final_embedding, final_labels, final_loss_mask
)
else: def _process_embedding_token_parallel(
combined_embeddings = torch.nn.functional.pad( self, combined_embeddings, new_labels, new_loss_mask, packed_seq_params
combined_embeddings, (0, 0, 0, 0, 0, mp_padding_needed) ):
) """Processes the input data for model parallelism support.
# Update PackedSeqParams if padding needed beyond user provided PackedSeqParams When using sequence parallelism (SP) or context parallelism (CP), the sequence is sharded
packed_seq_params.max_seqlen_q = padded_seq_len across different GPUs. This function performs the sharding and distributes the sequence
packed_seq_params.max_seqlen_kv = padded_seq_len across GPUs for SP and CP
cu_seqlens_padded = None
# We need cu_seqlens_q_padded/cu_seqlens_kv_padded when doing Context Parallelism is a feature that helps improve memory efficiency for
# CP+Padding to support accurate Attention with THD format. long sequence training by distributing sequence across CP ranks.
if self.context_parallel_lm > 1: It requires token length to be divisible by (CP size *2) to ensure proper load balance.
cu_seqlens_padded = torch.arange(
0, Sequence Parallelism is a feature that helps improve memory efficiency for
(batch_size + 1) * (padded_seq_len), long sequence training by distributing sequence across TP ranks.
step=(padded_seq_len), It requires token length to be divisible by TP size.
dtype=torch.int32,
device=combined_embeddings.device, Returns:
) combined_embeddings (torch.Tensor): image and text embeddings combined and distributed.
packed_seq_params.cu_seqlens_q_padded = cu_seqlens_padded new_labels (torch.Tensor): Distributed labels for image and text positions.
packed_seq_params.cu_seqlens_kv_padded = cu_seqlens_padded new_loss_mask (torch.Tensor): Distributed loss mask.
packed_seq_params.qkv_format = 'thd' packed_seq_params (PackedSeqParams): Dict with padded token information.
else:
packed_seq_params.qkv_format = 'sbhd' """
if self.context_parallel_lm > 1: # No pre or post processing needed with PP middle chunks.
# Distribute sequence across CP ranks if not self.pre_process and not self.post_process:
from megatron.training.utils import get_batch_on_this_cp_rank return combined_embeddings, new_labels, new_loss_mask, packed_seq_params
batch = get_batch_on_this_cp_rank( shard_factor = seq_dim = None
{ if self.pre_process:
"combined_embeddings": combined_embeddings, if self.context_parallel_lm > 1 and self.sequence_parallel_lm:
"new_labels": new_labels, shard_factor = self.tensor_model_parallel_size_lm * self.context_parallel_lm * 2
"new_loss_mask": new_loss_mask, seq_dim = 1
} elif self.context_parallel_lm > 1:
) shard_factor = self.context_parallel_lm * 2
seq_dim = 1
combined_embeddings = batch["combined_embeddings"] # [B, S/CP, H] elif self.sequence_parallel_lm:
new_labels = batch["new_labels"] shard_factor = self.tensor_model_parallel_size_lm
new_loss_mask = batch["new_loss_mask"] seq_dim = 0
if getattr(packed_seq_params, 'qkv_format', None) == 'thd': assert (
# If PackedSeqParams requires THD format, combined_embeddings.shape[seq_dim] % shard_factor == 0
# reshape embedding from [B,S,H] to [T,1,H] where T=B*S ), f"Sequence length should be divisible by {shard_factor} for \
combined_embeddings = ( Sequence/Context parallelism"
combined_embeddings.contiguous() if self.sequence_parallel_lm and self.tp_comm_overlap_lm:
.view(combined_embeddings.shape[0] * combined_embeddings.shape[1], -1) assert (
.unsqueeze(1) combined_embeddings.shape[seq_dim] == self._language_max_sequence_length
) ), f"TP Comm overlap either requires Vision+Text token length \
new_labels = new_labels.view(new_labels.shape[0] * new_labels.shape[1]).unsqueeze(0) == language_max_sequence_length"
new_loss_mask = new_loss_mask.view(
new_loss_mask.shape[0] * new_loss_mask.shape[1] if self.context_parallel_lm > 1:
).unsqueeze(0) batch = dict()
else: if self.pre_process:
combined_embeddings = combined_embeddings.transpose( batch["combined_embeddings"] = combined_embeddings
1, 0 if self.post_process:
).contiguous() # [B,S/CP,H] -> [S/CP,B,H] batch["new_labels"] = new_labels
batch["new_loss_mask"] = new_loss_mask
if self.sequence_parallel_lm: # Distribute sequence across CP ranks
combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region( if packed_seq_params is None or packed_seq_params.qkv_format == 'sbhd':
combined_embeddings from megatron.training.utils import get_batch_on_this_cp_rank
) # [S/(CP*TP),B,H]
batch = get_batch_on_this_cp_rank(batch)
return combined_embeddings, new_labels, new_loss_mask, packed_seq_params else:
assert HAVE_TEX and is_te_min_version(
def _apply_tile_tagging(self, image_embeddings, num_image_tiles): "1.10.0"
"""Apply tile tagging. ), "Please update Transformer Engine to >= 1.10 to use \
Context Parallel with THD format data"
The image embeddings of multiple tiles are prepended with tile tags such as <tile_1>. batch = _get_data_on_this_cp_rank.apply(batch, packed_seq_params)
This implements the method used in NVLM https://arxiv.org/pdf/2409.11402.
if self.pre_process:
Args: combined_embeddings = batch["combined_embeddings"] # [B, S/CP, H]
image_embeddings (torch.Tensor): [img_seq_len, num_tiles, h_language]. combined_embeddings = combined_embeddings.transpose(
num_image_tiles (torch.Tensor): Number of tiles for each input image [num_images]. 1, 0
).contiguous() # [B,S/CP,H] -> [S/CP,B,H]
Returns: if self.post_process:
torch.Tensor: Tile tags prepended to image embeddings. new_labels = batch["new_labels"]
[tile_seq_len (=5) + img_seq_len, num_tiles, h_language] new_loss_mask = batch["new_loss_mask"]
"""
assert ( if self.sequence_parallel_lm and self.pre_process:
num_image_tiles.shape[0] == 1 and len(num_image_tiles) == 1 combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(
), "multiple input images are not supported yet." combined_embeddings
) # [S/(CP*TP),B,H]
num_tiles = num_image_tiles[0].item()
tile_tags = self._tile_tags[: num_tiles - 1] + [self._tile_tags[-1]] return combined_embeddings, new_labels, new_loss_mask, packed_seq_params
# [num_tiles, tile_seq_len (=5)] def _apply_tile_tagging(self, image_embeddings, num_image_tiles):
tile_tag_input_ids = torch.tensor( """Apply tile tagging.
tile_tags, dtype=torch.int64, device=num_image_tiles.device
) The image embeddings of multiple tiles are prepended with tile tags such as <tile_1>.
This implements the method used in NVLM https://arxiv.org/pdf/2409.11402.
# [tile_seq_len, num_tiles, h_language]
tile_tag_embeds = self.language_model.embedding(tile_tag_input_ids, position_ids=None) Args:
image_embeddings (torch.Tensor): [img_seq_len, num_tiles, h_language].
# [num_tiles, dim] should be the same same num_image_tiles (torch.Tensor): Number of tiles for each input image [num_images].
assert tile_tag_embeds.shape[1:] == image_embeddings.shape[1:]
Returns:
image_embeddings = torch.cat([tile_tag_embeds, image_embeddings]) torch.Tensor: Tile tags prepended to image embeddings.
[tile_seq_len (=5) + img_seq_len, num_tiles, h_language]
return image_embeddings # [tile_seq_len + img_seq_len, num_tiles, h_language] """
assert (
def forward( num_image_tiles.shape[0] == 1 and len(num_image_tiles) == 1
self, ), "multiple input images are not supported yet."
images: torch.Tensor,
input_ids: torch.Tensor, num_tiles = num_image_tiles[0].item()
position_ids: torch.Tensor, tile_tags = self._tile_tags[: num_tiles - 1] + [self._tile_tags[-1]]
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None, # [num_tiles, tile_seq_len (=5)]
loss_mask: Optional[torch.Tensor] = None, tile_tag_input_ids = torch.tensor(
inference_params: Optional[InferenceParams] = None, tile_tags, dtype=torch.int64, device=num_image_tiles.device
num_image_tiles: Optional[List[int]] = None, )
image_token_index: Optional[int] = None,
runtime_gather_output: Optional[bool] = None, # [tile_seq_len, num_tiles, h_language]
image_token_mask: Optional[torch.Tensor] = None, tile_tag_embeds = self.language_model.embedding(tile_tag_input_ids, position_ids=None)
packed_seq_params: Optional[PackedSeqParams] = None,
) -> torch.Tensor: # [num_tiles, dim] should be the same same
"""Forward function of the LLaVA model. assert tile_tag_embeds.shape[1:] == image_embeddings.shape[1:]
Args: image_embeddings = torch.cat([tile_tag_embeds, image_embeddings])
images (torch.Tensor): input images of shape [num_tiles, img_h, img_w].
num_tiles means the number of image tiles in this batch. return image_embeddings # [tile_seq_len + img_seq_len, num_tiles, h_language]
num_tiles = 0 if the batch doesn't contain images.
input_ids (torch.Tensor): input text ids [batch, text_seq_len]. def forward(
position_ids (torch.Tensor): input text position ids [batch, text_seq_len]. self,
attention_mask (torch.Tensor): Language model attention mask images: torch.Tensor,
[batch, 1, 1, combined_seq_len]. NOTE: attention_mask is typically None and input_ids: torch.Tensor,
attn_mask_type in layer specs determines the attention mask used. position_ids: torch.Tensor,
labels (torch.Tensor): Optional target text labels [batch, combined_seq_len]. attention_mask: torch.Tensor,
loss_mask (torch.Tensor): Text loss mask [batch, text_seq_len]. labels: Optional[torch.Tensor] = None,
inference_params (InferenceParams): Inference-time parameters including KV cache. loss_mask: Optional[torch.Tensor] = None,
num_image_tiles (list of int): Number of tiles per image. Default 1 tile per image. inference_params: Optional[InferenceParams] = None,
image_token_index (int): ID for input images. Default None means `image_token_index` num_image_tiles: Optional[List[int]] = None,
arg in the constructor will be used. image_token_index: Optional[int] = None,
runtime_gather_output (bool): Gather output at runtime. Default None means runtime_gather_output: Optional[bool] = None,
`parallel_output` arg in the constructor will be used. packed_seq_params: Optional[PackedSeqParams] = None,
image_token_mask (torch.Tensor): Tensor indicating the location of ) -> torch.Tensor:
image token index in input_ids. """Forward function of the LLaVA model.
packed_seq_params (PackedSeqParams): 1) If using sequence packing, must contain
subsample length information. 2) If using SP/CP with padding mask type, Args:
must contain padded token information. images (torch.Tensor): input images of shape [num_tiles, img_h, img_w].
num_tiles means the number of image tiles in this batch.
Returns: num_tiles = 0 if the batch doesn't contain images.
output (torch.Tensor): Loss of shape [b, s] if labels are provided, input_ids (torch.Tensor): input text ids [batch, text_seq_len].
otherwise logits of shape [b, s, vocab_size]. position_ids (torch.Tensor): input text position ids [batch, text_seq_len].
loss_mask (torch.Tensor): Loss mask expanded to combined sequence length. Shape [b, s]. attention_mask (torch.Tensor): Language model attention mask
""" [batch, 1, 1, combined_seq_len]. NOTE: attention_mask is typically None and
use_inference_kv_cache = ( attn_mask_type in layer specs determines the attention mask used.
inference_params is not None labels (torch.Tensor): Optional target text labels [batch, combined_seq_len].
and "image_tokens_count" in inference_params.key_value_memory_dict loss_mask (torch.Tensor): Text loss mask [batch, text_seq_len].
) inference_params (InferenceParams): Inference-time parameters including KV cache.
has_images = images is not None and images.shape[0] > 0 num_image_tiles (list of int): Number of tiles per image. Default 1 tile per image.
image_token_index (int): ID for input images. Default None means `image_token_index`
# If running inference, we can skip image token computation arg in the constructor will be used.
# if they were computed already earlier for this sample. runtime_gather_output (bool): Gather output at runtime. Default None means
if use_inference_kv_cache: `parallel_output` arg in the constructor will be used.
image_embeddings = None packed_seq_params (PackedSeqParams): 1) If using sequence packing, must contain
elif self.add_encoder and not has_images: subsample length information. 2) If using SP/CP with padding mask type,
# If no images provided, use an empty image embeddings tensor. must contain padded token information.
image_embeddings = torch.tensor([], dtype=images.dtype, device=images.device).reshape(
0, 0, 0 Returns:
) output (torch.Tensor): Loss of shape [b, s] if labels are provided,
elif self.add_encoder and has_images: otherwise logits of shape [b, s, vocab_size].
image_embeddings = self.vision_model(images) # [num_tiles, img_seq_len, h_vision] loss_mask (torch.Tensor): Loss mask expanded to combined sequence length. Shape [b, s].
if self._drop_vision_class_token: """
image_embeddings = image_embeddings[:, self.vision_model.class_token_len :, :] use_inference_kv_cache = (
inference_params is not None
if self._pixel_shuffle: and "image_tokens_count" in inference_params.key_value_memory_dict
image_embeddings = pixel_shuffle( )
image_embeddings has_images = images is not None and images.shape[0] > 0
) # [num_tiles, img_seq_len_shuffled, h_vision_shuffled]
# If running inference, we can skip image token computation
# contiguous() required as `permute` can sparsify the tensor and this breaks pipelining # if they were computed already earlier for this sample.
image_embeddings = image_embeddings.permute( if use_inference_kv_cache:
1, 0, 2 image_embeddings = None
).contiguous() # [img_seq_len, num_tiles, h_vision] elif self.add_encoder and not has_images:
# If no images provided, use an empty image embeddings tensor.
# map vision model output size to language model input size. image_embeddings = torch.tensor([], dtype=images.dtype, device=images.device).reshape(
image_embeddings = self.vision_projection( 0, 0, 0
image_embeddings )
) # [img_seq_len, num_tiles, h_language] elif self.add_encoder and has_images:
image_embeddings = self.vision_model(images) # [num_tiles, img_seq_len, h_vision]
# Apply tile tagging if enabled and an image token is present. if self._drop_vision_class_token:
if self._tile_tags is not None and torch.any(input_ids == self.image_token_index): image_embeddings = image_embeddings[:, self.vision_model.class_token_len :, :]
image_embeddings = self._apply_tile_tagging(image_embeddings, num_image_tiles)
if self._pixel_shuffle:
# TODO: Support batched inference. image_embeddings = pixel_shuffle(
# In inference, the language model KV cache will be updated for image token positions. image_embeddings
# Store the image tokens sequence length to be used as an offset to the KV cache later. ) # [num_tiles, img_seq_len_shuffled, h_vision_shuffled]
if inference_params is not None:
inference_params.key_value_memory_dict["image_tokens_count"] = ( # contiguous() required as `permute` can sparsify the tensor and this breaks pipelining
image_embeddings.shape[0] * image_embeddings.shape[1] image_embeddings = image_embeddings.permute(
) 1, 0, 2
else: ).contiguous() # [img_seq_len, num_tiles, h_vision]
image_embeddings = self.encoder_hidden_state
# map vision model output size to language model input size.
if not self.add_decoder: image_embeddings = self.vision_projection(
return image_embeddings, loss_mask image_embeddings
) # [img_seq_len, num_tiles, h_language]
language_embeddings = None
if self.pre_process: # Apply tile tagging if enabled and an image token is present.
input_ids_text = input_ids.clone() if self._tile_tags is not None and torch.any(input_ids == self.image_token_index):
input_ids_text[input_ids_text == self.image_token_index] = 0 image_embeddings = self._apply_tile_tagging(image_embeddings, num_image_tiles)
# Note: This adds absolute position embedding but not RoPE.
# Each image is counted as one position. # TODO: Support batched inference.
# RoPE is added in language_model forward. Each image embedding is one position. # In inference, the language model KV cache will be updated for image token positions.
language_embeddings = self.language_model.embedding( # Store the image tokens sequence length to be used as an offset to the KV cache later.
input_ids=input_ids_text, position_ids=position_ids if inference_params is not None:
) # [text_seq_len, b, h_language] inference_params.key_value_memory_dict["image_tokens_count"] = (
# Gather the language embeddings back. We need the full embedding to insert image_embeddings.shape[0] * image_embeddings.shape[1]
# image embeddings and then scatter again to avoid load imbalance. )
if self.context_parallel_lm > 1: else:
cp_group = get_context_parallel_group() image_embeddings = self.encoder_hidden_state
language_embeddings, _ = gather_along_first_dim(language_embeddings, cp_group)
if not self.add_decoder:
language_embeddings = language_embeddings.transpose( return image_embeddings, loss_mask
1, 0
).contiguous() # [b, text_seq_len, h_language] language_embeddings = None
if self.pre_process:
# Assume 1 tile per image if the number of tiles is not provided. input_ids_text = input_ids.clone()
if num_image_tiles is None and images is not None: input_ids_text[input_ids_text == self.image_token_index] = 0
num_image_tiles = torch.ones(images.shape[0], dtype=torch.int, device=input_ids.device) # Note: This adds absolute position embedding but not RoPE.
# Each image is counted as one position.
combined_embeddings, new_labels, new_loss_mask = self._preprocess_data( # RoPE is added in language_model forward. Each image embedding is one position.
image_embeddings, language_embeddings = self.language_model.embedding(
language_embeddings, input_ids=input_ids_text, position_ids=position_ids
input_ids, ) # [text_seq_len, b, h_language]
loss_mask,
labels, language_embeddings = language_embeddings.transpose(
use_inference_kv_cache, 1, 0
inference_params, ).contiguous() # [b, text_seq_len, h_language]
image_token_index if image_token_index is not None else self.image_token_index,
num_image_tiles, # Assume 1 tile per image if the number of tiles is not provided.
image_token_mask, if num_image_tiles is None and images is not None:
) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len] num_image_tiles = torch.ones(images.shape[0], dtype=torch.int, device=input_ids.device)
if self.context_parallel_lm > 1 or self.sequence_parallel_lm: combined_embeddings, new_labels, new_loss_mask = self._preprocess_data(
combined_embeddings, new_labels, new_loss_mask, packed_seq_params = ( image_embeddings,
self._process_embedding_token_parallel( language_embeddings,
combined_embeddings, new_labels, new_loss_mask, packed_seq_params input_ids,
) loss_mask,
) labels,
use_inference_kv_cache,
output = self.language_model( inference_params,
input_ids=None, image_token_index if image_token_index is not None else self.image_token_index,
position_ids=None, num_image_tiles,
attention_mask=attention_mask, ) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len]
decoder_input=combined_embeddings,
labels=new_labels, if self.context_parallel_lm > 1 or self.sequence_parallel_lm:
inference_params=inference_params, combined_embeddings, new_labels, new_loss_mask, packed_seq_params = (
runtime_gather_output=runtime_gather_output, self._process_embedding_token_parallel(
packed_seq_params=packed_seq_params, combined_embeddings, new_labels, new_loss_mask, packed_seq_params
) )
)
return output, new_loss_mask
output = self.language_model(
input_ids=None,
def _load_state_dict_hook_ignore_param_names( position_ids=None,
param_names: List[str], module: torch.nn.Module, incompatible_keys: namedtuple attention_mask=attention_mask,
): decoder_input=combined_embeddings,
"""Hook to ignore missing keys during checkpoint loading. labels=new_labels,
inference_params=inference_params,
By default, this should not be used to avoid accidentally missing weights in checkpoint loading. runtime_gather_output=runtime_gather_output,
packed_seq_params=packed_seq_params,
Example use case: Use this if you want to load a checkpoint that contains vision and language )
model weights but not the vision projection weights.
return output, new_loss_mask
Args:
param_names (list str): Parameter names allowed to be missing when calling load_state_dict.
module (torch.nn.Module): The torch module this hook applies to. Required by the torch API. def _load_state_dict_hook_ignore_param_names(
incompatible_keys (namedtuple): Namedtuple with fields missing_keys and unexpected_keys, param_names: List[str], module: torch.nn.Module, incompatible_keys: namedtuple
which collect the missing and unexpected keys, respectively. ):
""" """Hook to ignore missing keys during checkpoint loading.
for param_name in param_names:
if param_name in incompatible_keys.missing_keys: By default, this should not be used to avoid accidentally missing weights in checkpoint loading.
logging.getLogger(__name__).warning(
f"{param_name} being removed from incompatible_keys.missing_keys in LlavaModel" Example use case: Use this if you want to load a checkpoint that contains vision and language
) model weights but not the vision projection weights.
incompatible_keys.missing_keys.remove(param_name)
Args:
param_names (list str): Parameter names allowed to be missing when calling load_state_dict.
# pylint: disable-next=line-too-long module (torch.nn.Module): The torch module this hook applies to. Required by the torch API.
# Based on https://github.com/OpenGVLab/InternVL/blob/c7c5af1a8930b4862afe8ed14672307082ef61fa/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py#L218 incompatible_keys (namedtuple): Namedtuple with fields missing_keys and unexpected_keys,
# Copyright (c) 2023 OpenGVLab. which collect the missing and unexpected keys, respectively.
def pixel_shuffle(x, scale_factor=0.5, version=2): """
"""Pixel shuffle based on InternVL but adapted for our use case. for param_name in param_names:
if param_name in incompatible_keys.missing_keys:
Args: logging.getLogger(__name__).warning(
x (torch.Tensor): Vision model outputs [num_tiles, img_seq_len, h_vision] f"{param_name} being removed from incompatible_keys.missing_keys in LlavaModel"
version (int): Implementation version. )
incompatible_keys.missing_keys.remove(param_name)
Returns:
Shuffled vision model outputs [num_tiles, (sq ** 2) * (scale ** 2), h_vision / (scale ** 2)]
""" def _load_state_dict_hook_ignore_extra_state(
h = w = int(x.shape[1] ** 0.5) # sq module: torch.nn.Module, incompatible_keys: namedtuple
x = x.reshape(x.shape[0], h, w, -1) # [num_tiles, sq, sq, h_vision] ):
"""Hook to ignore Transformer Engine _extra_state used for FP8.
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale This is for backwards-compatibility. Newer TE versions add _extra_state keys to the state dict,
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) while older models might not have those keys. Those keys can be ignored when not using FP8.
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous() Args:
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) module (torch.nn.Module): The torch module this hook applies to. Required by the torch API.
x = x.view( incompatible_keys (namedtuple): Namedtuple with fields missing_keys and unexpected_keys,
n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)) which collect the missing and unexpected keys, respectively.
) """
for name, keys in incompatible_keys._asdict().items():
if version == 2: for key in keys[::-1]:
x = x.permute(0, 2, 1, 3).contiguous() if "extra_state" in key:
logging.getLogger(__name__).warning(
x = x.reshape(x.shape[0], -1, x.shape[-1]) f"_extra_state key {key} being removed from {name}"
)
return x keys.remove(key)
# pylint: disable-next=line-too-long
# Based on https://github.com/OpenGVLab/InternVL/blob/c7c5af1a8930b4862afe8ed14672307082ef61fa/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py#L218
# Copyright (c) 2023 OpenGVLab.
def pixel_shuffle(x, scale_factor=0.5, version=2):
"""Pixel shuffle based on InternVL but adapted for our use case.
Args:
x (torch.Tensor): Vision model outputs [num_tiles, img_seq_len, h_vision]
version (int): Implementation version.
Returns:
Shuffled vision model outputs [num_tiles, (sq ** 2) * (scale ** 2), h_vision / (scale ** 2)]
"""
h = w = int(x.shape[1] ** 0.5) # sq
x = x.reshape(x.shape[0], h, w, -1) # [num_tiles, sq, sq, h_vision]
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
x = x.view(
n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor))
)
if version == 2:
x = x.permute(0, 2, 1, 3).contiguous()
x = x.reshape(x.shape[0], -1, x.shape[-1])
return x
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.extensions.transformer_engine import ( from typing import Optional
TEDotProductAttention,
TELayerNormColumnParallelLinear, from megatron.core.extensions.transformer_engine import (
TENorm, TEDotProductAttention,
TERowParallelLinear, TELayerNormColumnParallelLinear,
) TENorm,
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add TERowParallelLinear,
from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec )
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec
from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.spec_utils import ModuleSpec
try: from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
import apex # pylint: disable=unused-import
try:
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm import apex # pylint: disable=unused-import
HAVE_APEX = True from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
LNImpl = FusedLayerNorm
except ImportError: HAVE_APEX = True
import warnings LNImpl = FusedLayerNorm
except ImportError:
from megatron.core.transformer.torch_norm import WrappedTorchNorm import warnings
warnings.warn(f'Apex is not installed. Falling back to Torch Norm') from megatron.core.transformer.torch_norm import WrappedTorchNorm
LNImpl = WrappedTorchNorm
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
def decoder_model_with_transformer_engine_default_spec(
num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
) -> ModuleSpec: def decoder_model_with_transformer_engine_default_spec(
"""LLava decoder TE spec (uses Transformer Engine components).""" num_experts: Optional[int] = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
mlp = _get_mlp_module_spec( ) -> ModuleSpec:
use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm """LLava decoder TE spec (uses Transformer Engine components)."""
) mlp = get_mlp_module_spec(
return ModuleSpec( use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
module=TransformerLayer, )
submodules=TransformerLayerSubmodules( return ModuleSpec(
self_attention=ModuleSpec( module=TransformerLayer,
module=SelfAttention, submodules=TransformerLayerSubmodules(
params={"attn_mask_type": AttnMaskType.causal}, self_attention=ModuleSpec(
submodules=SelfAttentionSubmodules( module=SelfAttention,
linear_qkv=TELayerNormColumnParallelLinear, params={"attn_mask_type": AttnMaskType.causal},
core_attention=TEDotProductAttention, submodules=SelfAttentionSubmodules(
linear_proj=TERowParallelLinear, linear_qkv=TELayerNormColumnParallelLinear,
q_layernorm=TENorm if qk_layernorm else IdentityOp, core_attention=TEDotProductAttention,
k_layernorm=TENorm if qk_layernorm else IdentityOp, linear_proj=TERowParallelLinear,
), q_layernorm=TENorm if qk_layernorm else IdentityOp,
), k_layernorm=TENorm if qk_layernorm else IdentityOp,
self_attn_bda=get_bias_dropout_add, ),
mlp=mlp, ),
mlp_bda=get_bias_dropout_add, self_attn_bda=get_bias_dropout_add,
), mlp=mlp,
) mlp_bda=get_bias_dropout_add,
),
)
def decoder_model_with_local_default_spec(
num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
) -> ModuleSpec: def decoder_model_with_local_default_spec(
"""LLava decoder local spec.""" num_experts: Optional[int] = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
mlp = _get_mlp_module_spec( ) -> ModuleSpec:
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm """LLava decoder local spec."""
) mlp = get_mlp_module_spec(
return ModuleSpec( use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
module=TransformerLayer, )
submodules=TransformerLayerSubmodules( return ModuleSpec(
input_layernorm=LNImpl, module=TransformerLayer,
self_attention=ModuleSpec( submodules=TransformerLayerSubmodules(
module=SelfAttention, input_layernorm=LNImpl,
params={"attn_mask_type": AttnMaskType.causal}, self_attention=ModuleSpec(
submodules=SelfAttentionSubmodules( module=SelfAttention,
linear_qkv=ColumnParallelLinear, params={"attn_mask_type": AttnMaskType.causal},
core_attention=DotProductAttention, submodules=SelfAttentionSubmodules(
linear_proj=RowParallelLinear, linear_qkv=ColumnParallelLinear,
), core_attention=DotProductAttention,
), linear_proj=RowParallelLinear,
self_attn_bda=get_bias_dropout_add, ),
pre_mlp_layernorm=LNImpl, ),
mlp=mlp, self_attn_bda=get_bias_dropout_add,
mlp_bda=get_bias_dropout_add, pre_mlp_layernorm=LNImpl,
), mlp=mlp,
) mlp_bda=get_bias_dropout_add,
),
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Optional, Union from typing import Optional, Union
import torch import torch
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.models.common.vision_module.vision_module import VisionModule from megatron.core.models.common.vision_module.vision_module import VisionModule
from megatron.core.transformer.enums import ModelType from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_config import TransformerConfig
try: try:
import transformer_engine # pylint: disable=unused-import import transformer_engine # pylint: disable=unused-import
from megatron.core.extensions.transformer_engine import TENorm from megatron.core.extensions.transformer_engine import TENorm
NORM_IMPL = TENorm NORM_IMPL = TENorm
except: except:
NORM_IMPL = torch.nn.LayerNorm NORM_IMPL = torch.nn.LayerNorm
# Note: This is under development and is missing features like position embedding interpolation. # Note: This is under development and is missing features like position embedding interpolation.
class CLIPViTModel(VisionModule): class CLIPViTModel(VisionModule):
"""CLIP ViT vision model. """CLIP ViT vision model.
Args: Args:
transformer_config (TransformerConfig): Transformer config. transformer_config (TransformerConfig): Transformer config.
transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers. transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers.
ln_pre_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_pre. ln_pre_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_pre.
add_class_token (bool, optional): Include a class token. Defaults to True. add_class_token (bool, optional): Include a class token. Defaults to True.
class_token_len (int): Class token length. Defaults to 1 but 8 may be faster. class_token_len (int): Class token length. Defaults to 1 but 8 may be faster.
patch_dim (int): Image patch size. patch_dim (int): Image patch size.
img_h (int): Input image height. img_h (int): Input image height.
img_w (int): Input image width. img_w (int): Input image width.
""" """
def __init__( def __init__(
self, self,
transformer_config: TransformerConfig, transformer_config: TransformerConfig,
transformer_layer_spec: ModuleSpec, transformer_layer_spec: ModuleSpec,
ln_pre_impl: Union[ModuleSpec, type] = NORM_IMPL, ln_pre_impl: Union[ModuleSpec, type] = NORM_IMPL,
ln_post_impl: Union[ModuleSpec, type] = NORM_IMPL, ln_post_impl: Union[ModuleSpec, type] = NORM_IMPL,
add_class_token: bool = True, add_class_token: bool = True,
class_token_len: int = 1, class_token_len: int = 1,
patch_dim: int = 14, patch_dim: int = 14,
img_h: int = 336, img_h: int = 336,
img_w: int = 336, img_w: int = 336,
model_subtype: str = "clip", model_subtype: str = "clip",
) -> None: ) -> None:
error_msg = f"CLIPViTModel model subtype {model_subtype} is not supported." error_msg = f"CLIPViTModel model subtype {model_subtype} is not supported."
assert model_subtype in ["clip", "siglip", "internvit"], error_msg assert model_subtype in ["clip", "siglip", "internvit"], error_msg
if model_subtype == "siglip": if model_subtype == "siglip":
assert class_token_len == 0, "SigLIP does not support class tokens." assert class_token_len == 0, "SigLIP does not support class tokens."
assert not add_class_token, "SigLIP does not support class tokens." assert not add_class_token, "SigLIP does not support class tokens."
super().__init__(config=transformer_config) super().__init__(config=transformer_config)
if has_config_logger_enabled(transformer_config): if has_config_logger_enabled(transformer_config):
log_config_to_disk(transformer_config, locals(), prefix=type(self).__name__) log_config_to_disk(transformer_config, locals(), prefix=type(self).__name__)
self.class_token_len = class_token_len self.class_token_len = class_token_len
self.visual_hidden_size = transformer_config.hidden_size self.visual_hidden_size = transformer_config.hidden_size
self.patch_dim = patch_dim self.patch_dim = patch_dim
self.img_h = img_h self.img_h = img_h
self.img_w = img_w self.img_w = img_w
assert self.img_h % self.patch_dim == 0 assert self.img_h % self.patch_dim == 0
assert self.img_w % self.patch_dim == 0 assert self.img_w % self.patch_dim == 0
self.num_patches_per_dim_h = self.img_h // self.patch_dim self.num_patches_per_dim_h = self.img_h // self.patch_dim
self.num_patches_per_dim_w = self.img_w // self.patch_dim self.num_patches_per_dim_w = self.img_w // self.patch_dim
self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w
self.add_class_token = add_class_token self.add_class_token = add_class_token
self.class_token_len = class_token_len self.class_token_len = class_token_len
self.seq_length = self.num_patches + (self.class_token_len if self.add_class_token else 0) self.seq_length = self.num_patches + (self.class_token_len if self.add_class_token else 0)
self.ln_pre = None self.ln_pre = None
self.ln_post = None self.ln_post = None
if model_subtype == "clip": if model_subtype == "clip":
self.ln_pre = build_module( self.ln_pre = build_module(
ln_pre_impl, ln_pre_impl,
config=transformer_config, config=transformer_config,
hidden_size=self.visual_hidden_size, hidden_size=self.visual_hidden_size,
eps=transformer_config.layernorm_epsilon, eps=transformer_config.layernorm_epsilon,
) )
conv_bias = False conv_bias = False
padding = 0 padding = 0
elif model_subtype == "siglip": elif model_subtype == "siglip":
self.ln_post = build_module( self.ln_post = build_module(
ln_post_impl, ln_post_impl,
config=transformer_config, config=transformer_config,
hidden_size=self.visual_hidden_size, hidden_size=self.visual_hidden_size,
eps=transformer_config.layernorm_epsilon, eps=transformer_config.layernorm_epsilon,
) )
conv_bias = True conv_bias = True
padding = "valid" padding = "valid"
elif model_subtype == "internvit": elif model_subtype == "internvit":
conv_bias = True conv_bias = True
padding = 0 padding = 0
else: else:
raise ValueError(f"unsupported vision model type {model_subtype}") raise ValueError(f"unsupported vision model type {model_subtype}")
self.conv1 = torch.nn.Conv2d( self.conv1 = torch.nn.Conv2d(
in_channels=3, in_channels=3,
out_channels=self.visual_hidden_size, out_channels=self.visual_hidden_size,
kernel_size=self.patch_dim, kernel_size=self.patch_dim,
stride=self.patch_dim, stride=self.patch_dim,
bias=conv_bias, bias=conv_bias,
padding=padding, padding=padding,
) )
self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda() self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
self.position_embeddings = torch.nn.Embedding(self.seq_length, self.visual_hidden_size) self.position_embeddings = torch.nn.Embedding(self.seq_length, self.visual_hidden_size)
self.add_class_token = add_class_token self.add_class_token = add_class_token
if self.add_class_token: if self.add_class_token:
self.class_token = torch.nn.Parameter( self.class_token = torch.nn.Parameter(
torch.randn(1, self.class_token_len, self.visual_hidden_size) torch.randn(1, self.class_token_len, self.visual_hidden_size)
) )
self.model_type = ModelType.encoder_or_decoder self.model_type = ModelType.encoder_or_decoder
# Transformer layers. # Transformer layers.
# TODO: Make pre_process and post_process configurable. # TODO: Make pre_process and post_process configurable.
# NOTE: a final layer norm and/or linear layer in some implementations are omitted here. # NOTE: a final layer norm and/or linear layer in some implementations are omitted here.
# They can be added separately where needed. # They can be added separately where needed.
self.decoder = TransformerBlock( self.decoder = TransformerBlock(
config=transformer_config, config=transformer_config,
spec=transformer_layer_spec, spec=transformer_layer_spec,
pre_process=True, pre_process=True,
post_process=False, post_process=False,
) )
def set_input_tensor(self, input_tensor: torch.Tensor) -> None: def set_input_tensor(self, input_tensor: torch.Tensor) -> None:
"""Sets input tensor to the model. """Sets input tensor to the model.
Args: Args:
input_tensor (Tensor): Sets the input tensor for the model. input_tensor (Tensor): Sets the input tensor for the model.
""" """
self.decoder.set_input_tensor(input_tensor) self.decoder.set_input_tensor(input_tensor)
def forward( def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward function of the CLIP ViT Model. This function passes the input tensors """Forward function of the CLIP ViT Model. This function passes the input tensors
through the embedding layer and then the transformer. through the embedding layer and then the transformer.
Args: Args:
x (torch.Tensor): input data of shape [batch, img_h, img_w] x (torch.Tensor): input data of shape [batch, img_h, img_w]
attention_mask (torch.Tensor with dtype=bool): Attention mask to use. attention_mask (torch.Tensor with dtype=bool): Attention mask to use.
Returns: Returns:
x (torch.Tensor): output after final transformer block of shape [b, s, h]. x (torch.Tensor): output after final transformer block of shape [b, s, h].
""" """
x = self.conv1(x) # shape = [batch, hidden_size, grid, grid] x = self.conv1(x) # shape = [batch, hidden_size, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # [batch, hidden_size, grid ** 2] x = x.reshape(x.shape[0], x.shape[1], -1) # [batch, hidden_size, grid ** 2]
x = x.permute(0, 2, 1) # [batch, grid ** 2, hidden_size] x = x.permute(0, 2, 1) # [batch, grid ** 2, hidden_size]
if self.add_class_token: if self.add_class_token:
class_token = self.class_token.expand( class_token = self.class_token.expand(
x.shape[0], -1, -1 x.shape[0], -1, -1
) # [batch, class_token_len, hidden_size] ) # [batch, class_token_len, hidden_size]
x = torch.cat( x = torch.cat(
[class_token, x], dim=1 [class_token, x], dim=1
) # [batch, grid ** 2 + class_token_len, hidden_size] ) # [batch, grid ** 2 + class_token_len, hidden_size]
assert x.shape[1] == self.seq_length, f"{x.shape[1]} != {self.seq_length}" assert x.shape[1] == self.seq_length, f"{x.shape[1]} != {self.seq_length}"
x = x + self.position_embeddings(self.position_ids) x = x + self.position_embeddings(self.position_ids)
if self.ln_pre: if self.ln_pre:
x = self.ln_pre(x) x = self.ln_pre(x)
x = x.permute(1, 0, 2) # [b, s, h] -> [s, b, h] x = x.permute(1, 0, 2) # [b, s, h] -> [s, b, h]
# `permute` can make the tensor non-contiguous, breaking pipelining. # `permute` can make the tensor non-contiguous, breaking pipelining.
x = x.contiguous() x = x.contiguous()
x = self.decoder(x, attention_mask) x = self.decoder(x, attention_mask)
x = x.permute(1, 0, 2) # [s, b, h] -> [b, s, h] x = x.permute(1, 0, 2) # [s, b, h] -> [b, s, h]
x = x.contiguous() x = x.contiguous()
if self.ln_post: if self.ln_post:
x = self.ln_post(x) x = self.ln_post(x)
return x return x
def get_num_image_embeddings( def get_num_image_embeddings(
img_h, img_h,
img_w, img_w,
patch_dim, patch_dim,
vision_model_type, vision_model_type,
disable_vision_class_token, disable_vision_class_token,
class_token_len, class_token_len,
pixel_shuffle=False, pixel_shuffle=False,
use_tile_tags=False, use_tile_tags=False,
): ):
"""Get the number of image embeddings per image tile.""" """Get the number of image embeddings per image tile."""
if vision_model_type == "siglip": if vision_model_type == "siglip":
keep_class_token = False keep_class_token = False
elif vision_model_type in ("clip", "internvit"): elif vision_model_type in ("clip", "internvit"):
keep_class_token = not disable_vision_class_token keep_class_token = not disable_vision_class_token
else: elif vision_model_type.startswith("radio"):
raise ValueError(f"unsupported vision model: {vision_model_type}") keep_class_token = not disable_vision_class_token
elif vision_model_type.startswith("huggingface"):
num_patches_per_dim_h = img_h // patch_dim # TODO: Temp, what do we do in this sitaution?
num_patches_per_dim_w = img_w // patch_dim keep_class_token = True
num_patches = num_patches_per_dim_h * num_patches_per_dim_w else:
num_image_embeddings_per_tile = num_patches + (class_token_len if keep_class_token else 0) raise ValueError(f"unsupported vision model: {vision_model_type}")
if pixel_shuffle: num_patches_per_dim_h = img_h // patch_dim
num_image_embeddings_per_tile = int(num_image_embeddings_per_tile * (0.5**2)) num_patches_per_dim_w = img_w // patch_dim
num_patches = num_patches_per_dim_h * num_patches_per_dim_w
if use_tile_tags: num_image_embeddings_per_tile = num_patches + (class_token_len if keep_class_token else 0)
# The length of tile tags tokenized. Currently, the same across tokenizers used.
num_image_embeddings_per_tile += 5 if pixel_shuffle:
num_image_embeddings_per_tile = int(num_image_embeddings_per_tile * (0.5**2))
return num_image_embeddings_per_tile
if use_tile_tags:
# The length of tile tags tokenized. Currently, the same across tokenizers used.
num_image_embeddings_per_tile += 5
return num_image_embeddings_per_tile
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.models.common.vision_module.vision_module import VisionModule
from megatron.core.tensor_parallel.layers import ColumnParallelLinear
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
# RADIO reference code: https://github.com/NVlabs/RADIO
class RADIOViTModel(VisionModule):
"""RADIO ViT vision model.
Args:
transformer_config (TransformerConfig): Transformer config.
transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers.
ln_pre_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_pre.
ln_post_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_post.
use_mask_token (bool, optional): Whether to use RADIO mask token. Default to False.
add_class_token (bool, optional): Include a class token. Defaults to True.
class_token_len (int): Class token length. Defaults to 1 but 8 may be faster.
patch_dim (int): Image patch size.
img_h (int): Input image height.
img_w (int): Input image width.
max_img_h (int): Max input image height.
max_img_w (int): Max input image width.
pos_dropout (int): Positional encoding dropout value. Defaults to 0.
has_cpe: (bool): Whether to use conditional positional encoding. Defaults to True.
embedder_bias: (bool): Bias in embedder linear. Defaults to False.
"""
def __init__(
self,
transformer_config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
ln_pre_impl: Union[ModuleSpec, type] = None,
ln_post_impl: Union[ModuleSpec, type] = None,
use_mask_token: bool = False,
add_class_token: bool = True,
class_token_len: int = 8,
patch_dim: int = 16,
img_h: int = 224,
img_w: int = 224,
max_img_h: int = 2048,
max_img_w: int = 2048,
pos_dropout: int = 0,
has_cpe: bool = True,
embedder_bias: bool = False,
) -> None:
super().__init__(config=transformer_config)
if has_config_logger_enabled(transformer_config):
log_config_to_disk(transformer_config, locals(), prefix=type(self).__name__)
self.class_token_len = class_token_len
self.visual_hidden_size = transformer_config.hidden_size
self.patch_dim = patch_dim
self.img_h = img_h
self.img_w = img_w
assert self.img_h % self.patch_dim == 0
assert self.img_w % self.patch_dim == 0
self.input_dims = (img_h // patch_dim, img_w // patch_dim)
# used for positional embedding
self.max_img_h = max_img_h
self.max_img_w = max_img_w
self.max_num_rows = max_img_h // patch_dim
self.max_num_cols = max_img_w // patch_dim
self.max_num_patches = self.max_num_rows * self.max_num_cols
# TODO: are we actually going to use this anywhere?
self.use_mask_token = use_mask_token
if self.use_mask_token:
self.mask_token = nn.Parameter(torch.zeros(1, self.visual_hidden_size))
self.add_class_token = add_class_token
self.class_token_len = class_token_len
if self.add_class_token:
self.class_token = nn.Parameter(
torch.randn(self.class_token_len, self.visual_hidden_size)
)
self.seq_length = (img_h // self.patch_dim) * (img_w // self.patch_dim) + (
self.class_token_len if self.add_class_token else 0
)
pos_scale = self.visual_hidden_size**-0.5
self.position_embeddings = nn.Parameter(
torch.randn(1, self.max_num_patches, self.visual_hidden_size) * pos_scale
)
self.pos_dropout = pos_dropout
self.has_cpe = has_cpe
# Using non-TE version so we can force gather_output
self.embedder = ColumnParallelLinear(
input_size=3 * self.patch_dim * self.patch_dim,
output_size=self.visual_hidden_size,
bias=embedder_bias,
config=transformer_config,
gather_output=True,
init_method=lambda tensor: torch.nn.init.normal_(tensor, mean=0.0, std=1.0),
)
self.model_type = ModelType.encoder_or_decoder
self.ln_pre = None
self.ln_post = None
if ln_pre_impl is not None:
self.ln_pre = build_module(
ln_pre_impl,
config=transformer_config,
hidden_size=self.visual_hidden_size,
eps=transformer_config.layernorm_epsilon,
)
if ln_post_impl is not None:
self.ln_post = build_module(
ln_post_impl,
config=transformer_config,
hidden_size=self.visual_hidden_size,
eps=transformer_config.layernorm_epsilon,
)
self.decoder = TransformerBlock(
config=transformer_config,
spec=transformer_layer_spec,
pre_process=True,
post_process=False,
)
def set_input_tensor(self, input_tensor: torch.Tensor) -> None:
"""Sets input tensor to the model.
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
self.decoder.set_input_tensor(input_tensor)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Forward function of the RADIO ViT Model. This function passes the input tensors
through the embedding layer and then the transformer.
Args:
x (torch.Tensor): input data of shape [batch, img_h, img_w]
attention_mask (torch.Tensor with dtype=bool): Attention mask to use.
Returns:
x (torch.Tensor): output after final transformer block of shape [b, s, h].
"""
input_size = x.shape[2:]
py = x.shape[-2] // self.patch_dim
px = x.shape[-1] // self.patch_dim
x = rearrange(
x,
'b c (py yy) (px xx) -> b (py px) (c yy xx)',
py=py,
yy=self.patch_dim,
px=px,
xx=self.patch_dim,
)
x, _ = self.embedder(x) # [batch, seq_length, hidden_size]
x, _ = self.apply_pos_enc(x, input_size=input_size)
if self.add_class_token:
class_token = self.class_token.expand(
x.shape[0], -1, -1
) # [batch, class_token_len, hidden_size]
x = torch.cat(
[class_token, x], dim=1
) # [batch, seq_length + class_token_len, hidden_size]
assert x.shape[1] == self.seq_length, f"{x.shape[1]} != {self.seq_length}"
if self.ln_pre:
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # [b, s, h] -> [s, b, h]
x = x.contiguous()
x = self.decoder(x, attention_mask=attention_mask)
x = x.permute(1, 0, 2) # [s, b, h] -> [b, s, h]
x = x.contiguous()
if self.ln_post:
x = self.ln_post(x)
return x
def apply_pos_enc(
self,
patches: torch.Tensor,
patch_idxs: Optional[torch.Tensor] = None,
input_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
"""Apply positional encoding to patches"""
pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size)
if self.training and self.pos_dropout > 0:
keeps = (
torch.rand(patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device)
> self.pos_dropout
)
pos_enc_drop = torch.where(keeps, pos_enc, 0)
else:
pos_enc_drop = pos_enc
return patches + pos_enc_drop, pos_enc
def get_pos_enc(
self,
batch_size: int,
patch_idxs: Optional[torch.Tensor] = None,
input_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
"""Get positional encoding for certain input size"""
if input_size is None:
input_dims = self.input_dims
else:
input_dims = tuple(d // self.patch_dim for d in input_size)
pos_embed = self._get_pos_embeddings(batch_size, input_dims)
if patch_idxs is None:
return pos_embed
exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])
pos_embed = torch.gather(
pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs
)
return pos_embed
def _get_pos_embeddings(self, batch_size: int, input_dims: Tuple[int, int]):
"""Get RADIO absolute positional embeddings"""
if (self.max_num_rows, self.max_num_cols) == input_dims:
return self.position_embeddings
pos_embed = self.position_embeddings.reshape(
1, self.max_num_rows, self.max_num_cols, -1
).permute(0, 3, 1, 2)
def window_select(pos_embed):
if input_dims[0] < pos_embed.shape[-2]:
pos_embed = pos_embed[..., : input_dims[0], :]
if input_dims[1] < pos_embed.shape[-1]:
pos_embed = pos_embed[..., :, : input_dims[1]]
return pos_embed
if self.has_cpe:
if self.training:
min_scale = math.sqrt(0.1)
scale = (
torch.rand(batch_size, 1, 1, device=pos_embed.device) * (1 - min_scale)
+ min_scale
)
aspect_min = math.log(3 / 4)
aspect_max = -aspect_min
aspect = torch.exp(
torch.rand(batch_size, 1, 1, device=pos_embed.device)
* (aspect_max - aspect_min)
+ aspect_min
)
scale_x = scale * aspect
scale_y = scale * (1 / aspect)
scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)
pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (1 - scale_xy)
lin_x = torch.linspace(0, 1, steps=input_dims[1], device=pos_embed.device)[
None, None
].expand(batch_size, input_dims[0], -1)
lin_y = torch.linspace(0, 1, steps=input_dims[0], device=pos_embed.device)[
None, :, None
].expand(batch_size, -1, input_dims[1])
lin_xy = torch.stack([lin_x, lin_y], dim=-1)
grid_xy = lin_xy * scale_xy + pos_xy
# Convert to [-1, 1] range
grid_xy.mul_(2).sub_(1)
pos_embed = F.grid_sample(
pos_embed.float().expand(batch_size, -1, -1, -1),
grid=grid_xy,
mode='bilinear',
padding_mode='zeros',
align_corners=True,
).to(pos_embed.dtype)
else:
max_dim = max(input_dims)
pos_embed = F.interpolate(
pos_embed.float(), size=(max_dim, max_dim), align_corners=True, mode='bilinear'
).to(pos_embed.dtype)
pos_embed = window_select(pos_embed)
else:
pos_embed = window_select(pos_embed)
if pos_embed.shape[-2:] != input_dims:
pos_embed = F.interpolate(
pos_embed.float(), size=input_dims, align_corners=True, mode='bilinear'
).to(pos_embed.dtype)
pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
return pos_embed
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.extensions.transformer_engine import ( from megatron.core.extensions.transformer_engine import (
TEDotProductAttention, TEDotProductAttention,
TELayerNormColumnParallelLinear, TELayerNormColumnParallelLinear,
TERowParallelLinear, TERowParallelLinear,
) )
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
try: try:
import apex # pylint: disable=unused-import import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True HAVE_APEX = True
LNImpl = FusedLayerNorm LNImpl = FusedLayerNorm
except ImportError: except ImportError:
import warnings import warnings
from megatron.core.transformer.torch_norm import WrappedTorchNorm from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn(f'Apex is not installed. Falling back to Torch Norm') warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm LNImpl = WrappedTorchNorm
# Use this spec to use lower level Transformer Engine modules (required for fp8 training) # Use this spec to use lower level Transformer Engine modules (required for fp8 training)
def get_vit_layer_with_transformer_engine_spec() -> ModuleSpec: def get_vit_layer_with_transformer_engine_spec() -> ModuleSpec:
''' '''
Returns ViT layer spec with Transformer Engine layers Returns ViT layer spec with Transformer Engine layers
''' '''
mlp = _get_mlp_module_spec(use_te=True) mlp = _get_mlp_module_spec(use_te=True)
return ModuleSpec( return ModuleSpec(
module=TransformerLayer, module=TransformerLayer,
submodules=TransformerLayerSubmodules( submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec( self_attention=ModuleSpec(
module=SelfAttention, module=SelfAttention,
params={"attn_mask_type": AttnMaskType.no_mask}, params={"attn_mask_type": AttnMaskType.no_mask},
submodules=SelfAttentionSubmodules( submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear, linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention, core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear, linear_proj=TERowParallelLinear,
), ),
), ),
self_attn_bda=get_bias_dropout_add, self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=IdentityOp, pre_mlp_layernorm=IdentityOp,
mlp=mlp, mlp=mlp,
mlp_bda=get_bias_dropout_add, mlp_bda=get_bias_dropout_add,
), ),
) )
def get_vit_layer_with_local_spec() -> ModuleSpec: def get_vit_layer_with_local_spec() -> ModuleSpec:
''' '''
Returns ViT layer spec with Mcore local layers Returns ViT layer spec with Mcore local layers
''' '''
mlp = _get_mlp_module_spec(use_te=False) mlp = _get_mlp_module_spec(use_te=False)
return ModuleSpec( return ModuleSpec(
module=TransformerLayer, module=TransformerLayer,
submodules=TransformerLayerSubmodules( submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl, input_layernorm=LNImpl,
self_attention=ModuleSpec( self_attention=ModuleSpec(
module=SelfAttention, module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal}, params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules( submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear, linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention, core_attention=DotProductAttention,
linear_proj=RowParallelLinear, linear_proj=RowParallelLinear,
), ),
), ),
self_attn_bda=get_bias_dropout_add, self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl, pre_mlp_layernorm=LNImpl,
mlp=mlp, mlp=mlp,
mlp_bda=get_bias_dropout_add, mlp_bda=get_bias_dropout_add,
), ),
) )
# Helper function to get module spec for MLP/MoE # Helper function to get module spec for MLP/MoE
def _get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: def _get_mlp_module_spec(use_te: bool = True) -> ModuleSpec:
# Dense MLP w/ or w/o TE modules. # Dense MLP w/ or w/o TE modules.
return ModuleSpec( return ModuleSpec(
module=MLP, module=MLP,
submodules=MLPSubmodules( submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear, linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
), ),
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment