Commit d444a97a authored by yangzhong's avatar yangzhong
Browse files

首次上传

parents
Pipeline #3020 canceled with stages
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def __init__(self, max_batch_size, max_sequence_length):
self.max_sequence_length = max_sequence_length
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.key_value_memory_dict = {}
def swap_key_value_dict(self, batch_idx):
"swap between batches"
if len(self.key_value_memory_dict) == 0:
raise ValueError("should not swap when dict in empty")
for layer_number in self.key_value_memory_dict.keys():
inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]
assert (
len(batch_idx) == inference_key_memory.shape[1]
) # make sure batch size is the same
new_inference_key_memory = inference_key_memory[:, batch_idx]
new_inference_value_memory = inference_value_memory[:, batch_idx]
self.key_value_memory_dict[layer_number] = (
new_inference_key_memory,
new_inference_value_memory,
)
def __str__(self):
return f"InferenceParams(max_seq_len = {self.max_sequence_length}, max_batch_size = {self.max_batch_size}, sequence_len_offset = {self.sequence_len_offset}, batch_size_offset = {self.batch_size_offset}, key_value_memory_dict = {self.key_value_memory_dict.keys()})"
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch
from megatron.core.utils import is_torch_min_version
jit_fuser = torch.jit.script
# nvFuser is deprecated in PyTorch JIT starting from 2.2
if is_torch_min_version("2.2.0a0"):
jit_fuser = torch.compile(mode='max-autotune-no-cudagraphs')
# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2":
import torch._dynamo
if torch.__version__ >= "2.1":
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(
f, recursive=recursive
)
else:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Callable, ContextManager, Optional
import torch
@dataclass
class ModelParallelConfig:
"""Base configuration for Megatron Core
The initialization function has an argument for each parameter.
"""
###################
# Model parallelism
###################
tensor_model_parallel_size: int = 1
"""Intra-layer model parallelism. Splits tensors across GPU ranks."""
pipeline_model_parallel_size: int = 1
"""Inter-layer model parallelism. Splits transformer layers across GPU ranks."""
virtual_pipeline_model_parallel_size: Optional[int] = None
"""Interleaved pipeline parallelism is used to improve performance by reducing the pipeline
bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks.
The number of virtual blocks per pipeline model parallel rank is the virtual model parallel
size. See Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM:
arxiv.org/pdf/2104.04473.pdf for more details.
"""
sequence_parallel: bool = False
"""Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms
and dropout sequentially. See Reducing Activation Recomputation in Large Transformer Models
(https://arxiv.org/abs/2205.05198) for more details.
"""
context_parallel_size: int = 1
"""Splits network input along sequence dimension across GPU ranks."""
hierarchical_context_parallel_sizes: Optional[list[int]] = None
"""Degrees of the hierarchical context parallelism. Users should provide a list to specify
the sizes for different levels. Taking the a2a+p2p cp comm type as example, it contains
groups of two levels, so the first value of the list indicates the group size of the a2a
communication type, and the second value indicates the group size of the p2p communication
type.
"""
expert_model_parallel_size: int = 1
"""Distributes Moe Experts across sub data parallel dimension."""
expert_tensor_parallel_size: Optional[int] = None
"""Intra-layer tensor model parallelsm for expert layer. Splits tensors across GPU ranks."""
moe_extended_tp: bool = False
"""NOTE: Deprecated from MCore v0.10. This flag is ignored.
Its functionality is replaced by expert_tensor_parallel_size.
"""
###################
# Initialization
###################
perform_initialization: bool = True
"""If true, weights are initialized. This option can be useful when you know you are going to
load values from a checkpoint.
"""
use_cpu_initialization: bool = False
"""When set to False, we initialize the weights directly on the GPU. CPU initialization is the
same regardless of tensor model parallelism, but GPU initialization is not. Transferring
weights from CPU to GPU can take a significant amount of time for large models.
"""
###################
# Training
###################
fp16: bool = False
"""If true, train with fp16 mixed precision training."""
bf16: bool = False
"""If true, train with bf16 mixed precision training."""
params_dtype: torch.dtype = torch.float32
"""dtype used when intializing the weights."""
timers: Optional[Callable] = None
"""Timers object to call for various timing functions. See megatron.core.timers.Timers"""
finalize_model_grads_func: Optional[Callable] = None
"""Function that finalizes gradients on all workers. Could include ensuring that grads are
all-reduced across data parallelism, pipeline parallelism, and sequence parallelism
dimensions.
"""
grad_scale_func: Optional[Callable] = None
"""If using loss scaling, this function should take the loss and return the scaled loss. If
None, no function is called on the loss.
"""
no_sync_func: Optional[Callable] = None
"""Function that creates a context that suppresses asynchronous data-parallel communication. If
the model is an instance of core.distributed.DistributedDataParallel, the default is to use
core.distributed.DistributedDataParallel.no_sync.
"""
grad_sync_func: Optional[Callable] = None
"""Function that launches asynchronous gradient reductions (e.g. distributed optimizer gradient
reduce-scatters). The function should take one argument: an iterable of parameters whose
gradients are to be synchronized.
"""
param_sync_func: Optional[Callable] = None
"""Function that launches asynchronous parameter synchronizations (e.g. distributed optimizer
parameter all-gathers). The function should take one argument: an iterable of parameters to
be synchronized.
"""
deterministic_mode: bool = False
"""If true, code that has deterministic execution will be chosen. This usually
means slower execution, but is good for debugging and testing. Defaults to False."""
enable_autocast: bool = False
"""If true runs the forward step function inside torch.autocast context."""
autocast_dtype: Optional[torch.dtype] = None
"""dtype to pass to torch.amp.autocast when enabled. If None, is set to pipeline_dtype."""
num_microbatches_with_partial_activation_checkpoints: Optional[int] = None
"""If int, set the number of microbatches where not all of the layers will be checkpointed and
recomputed. The rest of the microbatches within the window of maximum outstanding
microbatches will recompute all layers (either full recompute or selective recompute). If
None, the checkpoint and recompute will be left up to the forward_step function.
"""
###################
# Optimizations
###################
gradient_accumulation_fusion: bool = False
"""If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install
APEX with --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\"
--global-option=\"--cuda_ext\" ". Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion.
"""
async_tensor_model_parallel_allreduce: bool = False
"""NOTE: Deprecated. This flag is ignored."""
use_te_rng_tracker: bool = False
"""If true, uses RNG state tracker in TransformerEngine if exists.
"""
tp_comm_overlap: bool = False
"""If true, allows overlapping of Linear layer execution with tensor parallel communication
collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever
possible during the forward and the backward pass.
"""
tp_comm_bulk_wgrad: bool = True
"""If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't care if
tp_comm_overlap is False.
"""
tp_comm_bulk_dgrad: bool = True
"""If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't care if
tp_comm_overlap is False.
"""
tp_comm_overlap_ag: bool = True
"""If true, allows All-Gather overlap with GEMM by pipelining the GEMM and All-Gather.
Don't care if tp_comm_overlap is False.
"""
tp_comm_overlap_rs: bool = True
"""If true, allows Reduce-Scatter overlap with GEMM by pipelining the GEMM and Reduce-Scatter.
Don't care if tp_comm_overlap is False.
"""
tp_comm_overlap_rs_dgrad: bool = False
"""If true, allows Reduce-Scatter overlap with DGRAD GEMM by pipelining the
GEMM and Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
"""
tp_comm_split_ag: bool = True
"""Deprecated from TransformerEngine v1.6.0.
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
splits. Don't care if tp_comm_overlap is False.
"""
tp_comm_atomic_ag: bool = False
"""Deprecated from TransformerEngine v1.6.0.
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
both done atomically. Don't care if tp_comm_overlap is False.
"""
tp_comm_split_rs: bool = True
"""Deprecated from TransformerEngine v1.6.0.
If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
"""
tp_comm_atomic_rs: bool = False
"""Deprecated from TransformerEngine v1.6.0.
If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
Reduce-Scatter both done atomically. Don't care if tp_comm_overlap is False.
"""
cross_entropy_loss_fusion: bool = False
"""If this is enabled, the fused cross entropy implementation would be used.
Defaults to False.
"""
tp_comm_overlap_disable_qkv: bool = False
"""
If true, the AllGather -> Gemm overlap for QKV gets disabled
"""
tp_comm_overlap_disable_fc1: bool = False
"""
If true, the AllGather -> Gemm overlap for FC1 layer of MLP gets disabled
"""
tp_comm_bootstrap_backend: str = 'nccl'
"""
Set the bootstrapping backend out of 'nccl', 'mpi', and 'gloo'
"""
###################
# Pipeline Parallel
###################
pipeline_dtype: torch.dtype = None
"""dtype used in p2p communication, usually params_dtype"""
variable_seq_lengths: bool = False
"""Support for variable sequence lengths across microbatches. Setting this communicates the size
of tensors during pipeline parallelism communication, because of this extra overhead it
should only be set if the sequence length varies by microbatch within a global batch.
"""
overlap_p2p_comm: bool = False
"""When True some of the peer to peer communication for pipeline parallelism will overlap with
computation. Must be False if batch_p2p_comm is true.
"""
batch_p2p_comm: bool = True
"""Use batch_isend_irecv instead of individual isend/irecv calls. Must be False if
overlap_p2p_comm is True.
"""
batch_p2p_sync: bool = True
"""When using batch_isend_irecv, do a cuda.device.synchronize afterward to work around a bug in
older version of PyTorch.
"""
use_ring_exchange_p2p: bool = False
"""Use custom ring_exchange kernel instead of torch.distributed.batch_isend_irecv(). Requires
custom built torch with torch.distributed.ring_exchange.
"""
deallocate_pipeline_outputs: bool = False
"""If True, output data is deallocated after the tensor is sent to the next pipeline stage.
Helps with saving memory, does nothing when pipeline parallel is not used.
"""
defer_embedding_wgrad_compute: bool = False
"""If true, defers the embedding WGRAD GEMMs while pipeline flush is
taking place enabling us to hide pipeline flush latency. Defaults to False.
"""
wgrad_deferral_limit: int = 0
"""This value tunes the number of micro-batches for which the embedding weight gradient compute
needs to be deferred to pipeline flush, this argument is invalid if
`defer_embedding_wgrad_compute` is False.
Defaults to 0, which means all micro-batches are deferred.
"""
pipeline_model_parallel_split_rank: Optional[int] = None
"""If int, rank where encoder and decoder should be split in cases where the model has both an
encoder and decoder (e.g., T5). Ignored if None.
"""
overlap_p2p_comm_warmup_flush: bool = False
"""If true, overlap communication and computation in warm up and flush phase.
Only valid when overlap_p2p_comm is True and batch_p2p_comm is False.
Defaults to False.
"""
microbatch_group_size_per_vp_stage: Optional[int] = None
"""This value specifies the number of micro-batches that are executed
at a time for a given virtual stage (both forward and backward).
Default (in __post_init__() method below) to pipeline_parallel_size
which specifies a depth-first schedule.
Example: for PP=2 VP=2, when microbatch_group_size_per_vp_stage=2,
num_microbatches = 4, we have
rank 0 | 0 1 0 1 2 3 2 3
rank 1 | 0 1 0 1 2 3 2 3
When microbatch_group_size_per_vp_stage=3, num_microbatches = 5,
we have
rank 0 | 0 1 2 0 1 2 3 4 3 4
rank 1 | 0 1 2 0 1 2 3 4 3 4
"""
###################
# CPU Offloading
###################
cpu_offloading: bool = False
"""When set to True, all the activations are offloaded to the CPU asynchronously."""
cpu_offloading_num_layers: int = 0
"""Tells the number of transformer layers for which activations has to be offloaded."""
_cpu_offloading_context: Optional[ContextManager] = (
None
# Used for internal use only, not to be set by a user.
# TODO: Need to move to the 'right' place when possible.
)
"""For internal use only, do not set."""
cpu_offloading_activations: bool = True
"""If True, offloads the activations to CPU."""
cpu_offloading_weights: bool = True
"""If True, offloads the weights to CPU."""
###################
# Timing
###################
barrier_with_L1_time: bool = True
"""If true, use barrier with level 1 time measurements. It is up to the user to make sure
calling barrier with their timers will not result in hangs. This can happen if for example
the user adds a level 1 timer that is not called by all ranks.
"""
def __post_init__(self):
"""Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more
details.
"""
if self.sequence_parallel:
if self.tensor_model_parallel_size <= 1:
raise ValueError("Can not use sequence paralllelism without tensor parallelism")
if self.expert_tensor_parallel_size is None:
self.expert_tensor_parallel_size = self.tensor_model_parallel_size
if self.pipeline_model_parallel_size > 1:
if self.pipeline_dtype is None:
raise ValueError(
"When using pipeline parallelism, pipeline_dtype must be specified"
)
if self.autocast_dtype is None:
self.autocast_dtype = self.params_dtype
if self.defer_embedding_wgrad_compute and self.pipeline_model_parallel_size == 1:
raise ValueError(
"Cannot defer embedding wgrad compute when pipeline model parallel is not used"
)
if self.defer_embedding_wgrad_compute and not self.gradient_accumulation_fusion:
raise ValueError(
"Cannot defer embedding wgrad compute when gradient accumulation fusion is not used"
)
if self.defer_embedding_wgrad_compute and self.wgrad_deferral_limit < 0:
raise ValueError(
"Wgrad deferral limit should be greater than or equal to 0 when it is enabled!"
)
if self.expert_model_parallel_size > 1 and self.tensor_model_parallel_size > 1:
if self.sequence_parallel is False:
raise ValueError(
"When using expert parallelism and tensor parallelism, "
"sequence parallelism must be used"
)
if self.microbatch_group_size_per_vp_stage is None:
self.microbatch_group_size_per_vp_stage = self.pipeline_model_parallel_size
if self.overlap_p2p_comm_warmup_flush:
if not self.overlap_p2p_comm or self.batch_p2p_comm:
raise ValueError(
"Pipeline parallel communication overlapping in warmup and flush is only "
"compatible with overlap_p2p_comm but not batch_p2p_comm."
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from .t5_model import T5Model
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from typing import List, Literal, Optional, Tuple
import torch
from torch import Tensor
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.enums import ModelType
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
class T5LMHead(MegatronModule):
"""Masked LM head for T5
Args:
config (TransformerConfig): transformer config
parallel_output (bool): wether output logits being distributed or not.
vocab_size (int): vocabulary size
pre_process (bool): Include embedding layer
share_embeddings_and_output_weights (bool): When True, input
embeddings and output logit weights are shared.
"""
def __init__(
self,
config: TransformerConfig,
parallel_output: bool,
vocab_size: int,
pre_process: bool = True,
share_embeddings_and_output_weights: bool = False,
):
super(T5LMHead, self).__init__(config=config)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.parallel_output = parallel_output
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
vocab_size,
config=config,
init_method=config.init_method,
bias=share_embeddings_and_output_weights,
skip_bias_add=not share_embeddings_and_output_weights,
gather_output=not self.parallel_output,
skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
)
def forward(self, hidden_states: Tensor, word_embeddings_weight: Tensor) -> Tensor:
"""Forward pass.
Args:
hidden_states (Tensor): output hidden states from decoder
word_embeddings_weight (Tensor): word embedding weight
Returns:
Tensor: logits tensor
"""
logits, _ = self.output_layer(hidden_states, weight=word_embeddings_weight)
return logits
class T5Model(LanguageModule):
"""T5 Language model.
Args:
config (TransformerConfig): transformer config
encoder_config (TransformerConfig): encoder transformer config
transformer_encoder_layer_spec (ModuleSpec): transformer layer
customization specs for encoder
transformer_decoder_layer_spec (ModuleSpec): transformer layer
customization specs for decoder
vocab_size (int): vocabulary size
max_sequence_length (int): maximum size of sequence. This is used for positional embedding
pre_process (bool): Include embedding layer (used with pipeline parallelism)
post_process (bool): Include an output layer (used with pipeline parallelism)
fp16_lm_cross_entropy (bool, optional): Defaults to False
parallel_output (bool): Do not gather the outputs,
keep them split across tensor parallel ranks
share_embeddings_and_output_weights (bool): When True,
input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (string): Position embedding type.
Options ['learned_absolute', 'rope'].
Defaults is 'learned_absolute'.
rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.
seq_len_interpolation_factor (float): scale of linearly interpolating
RoPE for longer sequences. The value must be a float larger than 1.0.
Defaults to None.
add_encoder (bool): Create the encoder (used with pipeline parallelism).
When using pipelining, the encoder will only be created on a subset
of the pipeline ranks.
add_decoder (bool): Include an output layer (used with pipeline parallelism).
As with `add_encoder`, when using this model and pipelining,
the decoder will only be created on a subset of the pipeline ranks.
"""
def __init__(
self,
config: TransformerConfig,
encoder_config: TransformerConfig,
transformer_encoder_layer_spec: ModuleSpec,
transformer_decoder_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute',
rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[float] = None,
add_encoder: bool = True,
add_decoder: bool = True,
):
super(T5Model, self).__init__(config=config)
self.config: TransformerConfig = config
self.encoder_config: TransformerConfig = encoder_config
self.transformer_encoder_layer_spec: ModuleSpec = transformer_encoder_layer_spec
self.transformer_decoder_layer_spec: ModuleSpec = transformer_decoder_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.add_encoder = add_encoder
self.add_decoder = add_decoder
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
self.encoder_hidden_state = None
self.model_type = ModelType.encoder_and_decoder
# Tells schedules.py that this model has a skip connection
# between the encoder's output and the decoder
# (and hence both the encoder and decoder's tensors are required for correct backprop).
self.xattn_needed = True
# specify the position embeddings as a member
# variable in the T5 class so that they are easy to
# find for `finalize_model_grads._allreduce_position_embedding_grads`
self.position_embeddings = None
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=self.position_embedding_type,
)
if position_embedding_type == "learned_absolute":
self.position_embeddings = self.embedding.position_embeddings
else:
self.position_embeddings = None
# Rotary Position Embeddings
if self.position_embedding_type == 'rope':
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
use_cpu_initialization=self.config.use_cpu_initialization,
)
# Transformer encoder
encoder_spec, decoder_spec = (
self.transformer_encoder_layer_spec,
self.transformer_decoder_layer_spec,
)
if self.add_encoder:
self.encoder = TransformerBlock(
config=self.encoder_config,
spec=encoder_spec,
pre_process=self.pre_process,
post_process=self.post_process,
)
else:
self.encoder = None
if self.add_decoder:
# Transformer decoder
self.decoder = TransformerBlock(
config=self.config,
spec=decoder_spec,
pre_process=self.pre_process,
post_process=self.post_process,
)
else:
self.decoder = None
# Output
if post_process:
self.lm_head = T5LMHead(
config,
parallel_output,
self.vocab_size,
self.pre_process,
self.share_embeddings_and_output_weights,
)
self.output_layer = self.lm_head.output_layer
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
def forward(
self,
encoder_input_ids: Tensor,
decoder_input_ids: Tensor,
encoder_attn_mask: Tensor,
decoder_attn_mask: Tensor,
encoder_decoder_attn_mask: Tensor,
lm_labels: Tensor = None,
encoder_hidden_states: Tensor = None,
output_encoder_hidden_only: bool = False,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
) -> Tensor:
"""Forward pass.
Args:
encoder_input_ids (Tensor): input ids for encoder
decoder_input_ids (Tensor): input ids for decoder
encoder_attn_mask (Tensor): self-attention mask for encoder
decoder_attn_mask (Tensor): self-attention mask for decoder
encoder_decoder_attn_mask (Tensor): cross-attention mask between encoder and decoder
lm_labels (Tensor): labels for decoder output
inference_params (InferenceParams): relevant arguments for inferencing
Returns:
Tensor: loss tensor
"""
## Encoder forward
if encoder_hidden_states is None:
# Encoder position ids
encoder_position_ids = t5_position_ids(encoder_input_ids)
# Encoder embedding.
if self.pre_process:
encoder_input = self.embedding(
input_ids=encoder_input_ids, position_ids=encoder_position_ids
)
else:
# intermediate stage of pipeline
encoder_input = None
# Rotary positional embeddings
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.encoder, encoder_input, self.config, packed_seq_params
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# Run encoder.
if self.add_encoder:
encoder_hidden_states = self.encoder(
hidden_states=encoder_input,
attention_mask=encoder_attn_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
)
else:
encoder_hidden_states = self.encoder_hidden_state
if not self.add_decoder or output_encoder_hidden_only:
return encoder_hidden_states
## Decoder forward
# Decoder position ids
decoder_position_ids = t5_position_ids(decoder_input_ids)
# Decoder embedding.
if self.pre_process:
decoder_input = self.embedding(
input_ids=decoder_input_ids, position_ids=decoder_position_ids
)
else:
# intermediate stage of pipeline
decoder_input = None ### should it take encoder_hidden_states
# Rotary positional embeddings
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.encoder, encoder_input, self.config, packed_seq_params
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# Run decoder.
decoder_hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=decoder_attn_mask,
context=encoder_hidden_states,
context_mask=encoder_decoder_attn_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
)
if self.post_process:
lm_logits = self.lm_head(
decoder_hidden_states, self.shared_embedding_or_output_weight()
)
if lm_labels is None:
# [s b h] => [b s h]
return lm_logits.transpose(0, 1).contiguous()
else:
# [b s] => [s b]
lm_loss = self.compute_language_model_loss(lm_labels, lm_logits)
return lm_loss
else:
return decoder_hidden_states
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
if self.add_encoder and self.add_decoder:
assert (
len(input_tensor) == 1
), 'input_tensor should only be length 1 for stage with both encoder and decoder'
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_encoder:
assert (
len(input_tensor) == 1
), 'input_tensor should only be length 1 for stage with only encoder'
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_decoder:
if len(input_tensor) == 2:
self.decoder.set_input_tensor(input_tensor[0])
self.encoder_hidden_state = input_tensor[1]
elif len(input_tensor) == 1:
self.decoder.set_input_tensor(None)
self.encoder_hidden_state = input_tensor[0]
else:
raise Exception('input_tensor must have either length 1 or 2')
else:
raise Exception('Stage must have at least either encoder or decoder')
def shared_embedding_or_output_weight(self) -> Tensor:
"""Function to share the input embeddings and output logit weights."""
if self.pre_process:
return self.embedding.word_embeddings.weight
elif self.post_process:
return self.lm_head.output_layer.weight
return None
def sharded_state_dict(
self,
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
) -> ShardedStateDict:
"""Sharded state dict implementation handling duplication of encoder and decoder layers.
Some layers (output, embedding) are shared between the encoder and decoder.
This method sets the replica_id for them to ensure there is only one
layer instance with replica_id (0, 0, 0).
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the T5Model
"""
sharded_sd = super().sharded_state_dict(prefix, sharded_offsets, metadata)
if not parallel_state.is_inside_encoder():
for k, sh_ten in sharded_sd.items():
if not k.startswith(f'{prefix}decoder'):
# Bump replica_id of all the layers shared with the encoder (output, embedding)
sh_ten.replica_id = (sh_ten.replica_id[0] + 1, *sh_ten.replica_id[1:])
return sharded_sd
def t5_extended_attention_mask(attention_mask_list: List[Tensor]) -> List[Tensor]:
"""Creates the extended attention mask
Converts the attention mask of dimension [batch size, seq_len, seq_len]
to [batch size, 1, seq_len, seq_len]
Args:
attention_mask (Tensor): The input attention mask
Returns:
Tensor: The extended binary attention mask
"""
def attn_mask_postprocess(attn_mask):
# [b, 1, s, s]
extended_attention_mask = attn_mask.unsqueeze(1)
return extended_attention_mask
return [
(attn_mask_postprocess(attn_mask) if attn_mask is not None else None)
for attn_mask in attention_mask_list
]
def t5_position_ids(token_ids: Tensor) -> Tensor:
"""Calculate position ids from token ids
Args:
token_ids (Tensor): input tokens
Returns:
Tensor: position ids
"""
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import (
CrossAttention,
CrossAttentionSubmodules,
SelfAttention,
SelfAttentionSubmodules,
)
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelLinear,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn(f'Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
def encoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
"""T5 encoder TE spec (uses Transformer Engine components)."""
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear
),
),
mlp_bda=get_bias_dropout_add,
),
)
def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
"""T5 decoder TE spec (uses Transformer Engine components)."""
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_cross_attn_layernorm=TENorm,
cross_attention=ModuleSpec(
module=CrossAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=CrossAttentionSubmodules(
linear_q=TEColumnParallelLinear,
linear_kv=TEColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
),
cross_attn_bda=get_bias_dropout_add,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear
),
),
mlp_bda=get_bias_dropout_add,
),
)
def encoder_model_with_local_spec() -> ModuleSpec:
"""T5 encoder local spec (uses Megatron-Core components)."""
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.arbitrary},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear
),
),
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
def decoder_model_with_local_spec() -> ModuleSpec:
"""T5 decoder local spec (uses Megatron-Core components)."""
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_cross_attn_layernorm=LNImpl,
cross_attention=ModuleSpec(
module=CrossAttention,
params={"attn_mask_type": AttnMaskType.arbitrary},
submodules=CrossAttentionSubmodules(
linear_q=ColumnParallelLinear,
linear_kv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
),
),
cross_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear
),
),
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
def get_t5_encoder_with_transformer_engine_block_spec(
num_layers: int,
) -> TransformerBlockSubmodules:
"""T5 encoder block spec for Transformer Engine
Args:
config (TransformerConfig): config, containing number of layers for encoder
"""
layer_spec = encoder_model_with_transformer_engine_default_spec()
block_spec = TransformerBlockSubmodules([layer_spec] * num_layers, layer_norm=TENorm)
return block_spec
def get_t5_decoder_with_transformer_engine_block_spec(
num_layers: int,
) -> TransformerBlockSubmodules:
"""T5 decoder block spec for Transformer Engine
Args:
config (TransformerConfig): config, containing number of layers for decoder
"""
layer_spec = decoder_model_with_transformer_engine_default_spec()
block_spec = TransformerBlockSubmodules([layer_spec] * num_layers, layer_norm=TENorm)
return block_spec
def get_t5_encoder_with_local_block_spec(num_layers: int) -> TransformerBlockSubmodules:
"""T5 encoder block spec for local (uses Megatron-Core components)
Args:
num_layers (int): number of encoder layers
"""
layer_spec = encoder_model_with_local_spec()
block_spec = TransformerBlockSubmodules([layer_spec] * num_layers, layer_norm=TENorm)
return block_spec
def get_t5_decoder_with_local_block_spec(num_layers: int) -> TransformerBlockSubmodules:
"""T5 decoder block spec for local (uses Megatron-Core components)
Args:
num_layers (int): number of decoder layers
"""
layer_spec = decoder_model_with_local_spec()
block_spec = TransformerBlockSubmodules([layer_spec] * num_layers, layer_norm=TENorm)
return block_spec
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import warnings
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
try:
from megatron.core.extensions.transformer_engine import (
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TERowParallelLinear,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
def get_bert_layer_with_transformer_engine_spec():
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
Returns:
ModuleSpec: Module specification with TE modules
"""
if not HAVE_TE:
raise ImportError(
"Transformer Engine is not installed. Please use local Bert layer spec instead."
)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear
),
),
mlp_bda=get_bias_dropout_add,
),
)
def __getattr__(name):
if name == 'bert_layer_with_transformer_engine_spec':
warnings.warn(
"""Attribute bert_layer_specs.bert_layer_with_transformer_engine_spec is on a
deprecation track and will be removed in future releases. Please migrate to
bert_layer_specs.get_bert_layer_with_transformer_engine_spec()."""
)
return get_bert_layer_with_transformer_engine_spec()
# Use this spec for an implementation using only modules in megatron core
bert_layer_local_spec = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear),
),
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch
from torch import Tensor
from megatron.core.fusions.fused_layer_norm import HAVE_FUSED_LAYER_NORM, FusedLayerNorm
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import get_linear_layer
if HAVE_FUSED_LAYER_NORM:
LNImpl = FusedLayerNorm
else:
import warnings
warnings.warn(f'Apex is not installed. Falling back to Torch Norm')
from megatron.core.transformer.torch_norm import WrappedTorchNorm as LNImpl
class BertLMHead(MegatronModule):
"""Masked LM head for Bert.
Args:
hidden_size: hidden size
config (TransformerConfig): TransformerConfig object
"""
def __init__(self, hidden_size: int, config: TransformerConfig):
super().__init__(config=config)
# TODO: Should switch this to TE ?
self.dense = get_linear_layer(
hidden_size, hidden_size, config.init_method, config.perform_initialization
)
setattr(self.dense.weight, 'sequence_parallel', config.sequence_parallel)
setattr(self.dense.bias, 'sequence_parallel', config.sequence_parallel)
self.layer_norm = LNImpl(
config=config, hidden_size=hidden_size, eps=config.layernorm_epsilon
)
self.gelu = torch.nn.functional.gelu
def forward(self, hidden_states: Tensor) -> Tensor:
"""forward pass"""
hidden_states = self.dense(hidden_states)
hidden_states = self.gelu(hidden_states)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import warnings
from typing import Literal, Optional
import torch
from torch import Tensor
from megatron.core import parallel_state, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.models.bert.bert_lm_head import BertLMHead
from megatron.core.models.bert.pooler import Pooler
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.transformer.dot_product_attention import (
DotProductAttention as MCoreDotProductAttention,
)
from megatron.core.transformer.enums import AttnBackend, AttnMaskType, ModelType
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import get_linear_layer
from megatron.core.utils import get_te_version as _get_te_version
from megatron.core.utils import is_te_min_version
def get_te_version():
"""Included for backwards compatibility."""
warnings.warn("`get_te_version` will be deprecated in a future release")
return _get_te_version()
class BertModel(LanguageModule):
"""Transformer language model.
Args:
config (TransformerConfig): transformer config
num_tokentypes (int) : Set to 2 when args.bert_binary_head is True, and 0 otherwise.
Defaults to 0.
transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers
vocab_size (int): vocabulary size
max_sequence_length (int): maximum size of sequence. This is used for positional embedding
pre_process (bool): Include embedding layer (used with pipeline parallelism)
post_process (bool): Include an output layer (used with pipeline parallelism)
parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel
ranks
share_embeddings_and_output_weights (bool): When True, input embeddings and output logit
weights are shared. Defaults to False.
position_embedding_type (string): Position embedding type.
Options ['learned_absolute', 'rope']. Defaults is 'learned_absolute'.
rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.
"""
def __init__(
self,
config: TransformerConfig,
num_tokentypes: int,
transformer_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute',
rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[float] = None,
add_binary_head=True,
return_embeddings=False,
):
super(BertModel, self).__init__(config=config)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
if return_embeddings:
assert self.post_process and self.add_binary_head
self.config: TransformerConfig = config
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
self.add_binary_head = add_binary_head
self.return_embeddings = return_embeddings
# megatron core pipelining currently depends on model type
self.model_type = ModelType.encoder_or_decoder
self.attn_mask_dimensions = self._sanity_check_attention_and_get_attn_mask_dimension()
# Embeddings.
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
num_tokentypes=num_tokentypes,
)
if self.position_embedding_type == 'rope':
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
use_cpu_initialization=self.config.use_cpu_initialization,
)
# Transformer.
self.encoder = TransformerBlock(
config=self.config,
spec=self.transformer_layer_spec,
pre_process=self.pre_process,
post_process=self.post_process,
)
# Output
if post_process:
# TODO: Make sure you are passing in the mpu_vocab_size properly
self.lm_head = BertLMHead(config.hidden_size, config)
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=True,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
)
self.binary_head = None
if self.add_binary_head:
# TODO: Shoudl switch this to TE ?
self.binary_head = get_linear_layer(
config.hidden_size, 2, config.init_method, config.perform_initialization
)
self.pooler = Pooler(
config.hidden_size, config.init_method, config, config.sequence_parallel
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
# pylint: disable=line-too-long
def _sanity_check_attention_and_get_attn_mask_dimension(self) -> str:
"""We do some checks and return attention mask dimensions for self attention
Transformer engine library underwent a lot of change. So we need to change dimensions of
the attention mask depending on the TE version. We also santiy check some arguments.
1. If we use local version of attention dimension of the mask is [b,1,s,s]
2. If we use transformer engine > 1.10 we support all 3 backends with padding mask and [b,1,s,s]
3. If we use transformer engine >= 1.7 but less than 1.10
a ) Flash and Fused attention uses padding mask with [b,1,1,s]
b ) Unfused attention works with arbitrary mask with [b,1,s,s]
4. If we use transformer engine < 1.7
Flash and fused attention is not supported. Unfused attention will work with padding mask [b,1,s,s]
Default if you dont set any NVTE_ATTN flag will it will just use the fused path for transformer engine version >= 1.7 and unfused path for other
Args:
transformer_layer_spec (ModuleSpec): The transformer layer spec
Returns:
str: A string showing the format of the attn mask dimensions
"""
attention_backend = self.config.attention_backend
attn_mask_dimensions = None
# For local layer spec we just use b1ss
if (
self.transformer_layer_spec.submodules.self_attention.submodules.core_attention
== MCoreDotProductAttention
):
assert attention_backend in [
AttnBackend.local,
AttnBackend.auto,
], f'Expected AttnBackend to be local or auto while using mcore self attention, but found {attention_backend}. Set --attn-backend to local or dont use MCore SelfAttention submodule in layer specs'
attn_mask_dimensions = "b1ss"
else:
attn_mask_type = self.transformer_layer_spec.submodules.self_attention.params[
'attn_mask_type'
]
# For TE >= 1.10 (We always use padding mask and use b11s)
if is_te_min_version("1.10.0"):
attn_mask_dimensions = "b11s"
if attn_mask_type != AttnMaskType.padding:
warnings.warn(
f'For TE versions >= 1.10 , flash/fused/unfused support padding mask. Setting attention mask from {attn_mask_type} to padding'
)
self.transformer_layer_spec.submodules.self_attention.params[
'attn_mask_type'
] = AttnMaskType.padding
# For 1.7 >= TE < 1.10 flash and fused path use padding mask with b11s and unfused path uses arbitrary mask with b1ss
elif is_te_min_version("1.7.0"):
if attention_backend in [AttnBackend.flash, AttnBackend.fused, AttnBackend.auto]:
attn_mask_dimensions = "b11s"
else:
if attn_mask_type != AttnMaskType.arbitrary:
warnings.warn(
f'For TE versions >= 1.7 but < 1.10 , unfused path supports only arbitrary mask. Setting attention mask from {attn_mask_type} to arbitray'
)
self.transformer_layer_spec.submodules.self_attention.params[
'attn_mask_type'
] = AttnMaskType.arbitrary
attn_mask_dimensions = "b1ss"
# For TE < 1.7 we only support unfused attention with b1ss and padding mask
else:
attn_mask_dimensions = "b1ss"
assert not (attention_backend in [AttnBackend.flash, AttnBackend.fused]), (
"Flash and fused attention is not supported with transformer engine version "
"< 1.7. Set --attention-backend to unfused or leave it to be default (auto) or upgrade transformer engine >= 1.7"
)
return attn_mask_dimensions
def bert_extended_attention_mask(self, attention_mask: Tensor) -> Tensor:
"""Creates the extended attention mask
Converts the attention mask of dimension
[batch size, 1, seq len] to [batch size, 1, seq len, seq len]
or [batch size, 1, 1, seq_len] and makes it binary
Args:
attention_mask (Tensor): The input attention mask
Returns:
Tensor: The extended binary attention mask
"""
# We create a 3D attention mask from a 2D tensor mask.
if self.attn_mask_dimensions == "b1ss":
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
else:
# [b, 1, 1, s]
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
# Convert attention mask to binary:
extended_attention_mask = extended_attention_mask < 0.5
return extended_attention_mask
def bert_position_ids(self, token_ids):
"""Position ids for bert model"""
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
def set_input_tensor(self, input_tensor: Tensor) -> None:
"""Sets input tensor to the model.
See megatron.model.transformer.set_input_tensor()
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
self.encoder.set_input_tensor(input_tensor[0])
def forward(
self,
input_ids: Tensor,
attention_mask: Tensor,
tokentype_ids: Tensor = None,
lm_labels: Tensor = None,
inference_params=None,
):
"""Forward function of BERT model
Forward function of the BERT Model This function passes the input tensors
through the embedding layer, and then the encoder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
"""
extended_attention_mask = self.bert_extended_attention_mask(attention_mask)
if parallel_state.is_pipeline_first_stage():
input_ids = input_ids
position_ids = self.bert_position_ids(input_ids)
else:
position_ids = None
input_ids = None
# Encoder embedding.
if self.pre_process:
encoder_input = self.embedding(
input_ids=input_ids, position_ids=position_ids, tokentype_ids=tokentype_ids
)
else:
# intermediate stage of pipeline
# encoder will get hidden_states from encoder.input_tensor
encoder_input = None
# Rotary positional embeddings (Why not move this into BERT/GPTEmberdding ?)
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.encoder, encoder_input, self.config
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# Run encoder.
hidden_states = self.encoder(
hidden_states=encoder_input,
attention_mask=extended_attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
)
if not self.post_process:
return hidden_states
if self.add_binary_head:
pooled_output = self.pooler(hidden_states, 0)
if self.return_embeddings:
embeddings = torch.transpose(hidden_states, 0, 1)
masks = torch.sum(attention_mask, dim=1)
# Collect masked embeddings.
output = torch.zeros(
size=(embeddings.shape[0], embeddings.shape[2]),
dtype=torch.float32,
device=torch.cuda.current_device(),
)
for i, (embedding, mask) in enumerate(zip(embeddings, masks)):
output[i, :] = torch.mean(embedding[1 : mask - 1], dim=0)
return output
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
hidden_states_after_lm_head = self.lm_head(hidden_states=hidden_states)
logits, _ = self.output_layer(hidden_states_after_lm_head, weight=output_weight)
binary_logits = None
if self.binary_head is not None:
binary_logits = self.binary_head(pooled_output)
if lm_labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous(), binary_logits
loss = self.compute_language_model_loss(lm_labels, logits)
return loss, binary_logits
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch
from torch import Tensor
from megatron.core import tensor_parallel
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import get_linear_layer
class Pooler(MegatronModule):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Args:
hidden_size (int): The hidden size_
init_method (callable): weight initialization method for the linear layer. bias is set to zero.
config (TransformerConfig): The transformer configuration
sequence_parallel (bool): Using squence parallel ? Defaults to False
"""
def __init__(
self,
hidden_size: int,
init_method: callable,
config: TransformerConfig,
sequence_parallel: bool = False,
):
super(Pooler, self).__init__(config)
# TODO: Shoudl switch this to TE ?
self.dense = get_linear_layer(
hidden_size, hidden_size, init_method, config.perform_initialization
)
self.sequence_parallel = sequence_parallel
def forward(self, hidden_states: Tensor, sequence_index=0):
# hidden_states: [s, b, h]
# sequence_index: index of the token to pool.
# gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes
if self.sequence_parallel:
hidden_states = tensor_parallel.gather_from_sequence_parallel_region(
hidden_states, tensor_parallel_output_grad=False
)
pooled = hidden_states[sequence_index, :, :]
pooled = self.dense(pooled)
pooled = torch.tanh(pooled)
return pooled
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from .rope_utils import apply_rotary_pos_emb
from .rotary_pos_embedding import RotaryEmbedding
from .yarn_rotary_pos_embedding import YarnRotaryEmbedding, _yarn_get_mscale
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment