Commit d444a97a authored by yangzhong's avatar yangzhong
Browse files

首次上传

parents
Pipeline #3020 canceled with stages
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Configuration dataclass for a RetroModel."""
import os
from dataclasses import dataclass
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.enums import AttnBackend
from megatron.core.utils import is_te_min_version
@dataclass
class RetroConfig(TransformerConfig):
"""Configuration object for Retro models."""
# Retro.
retro_project_dir: str = None
"""Retro project directory, which contains the preprocessed data for for pretraining. This
directory is built during preprocessing (see tools/retro/README.md), and contains
subdirectories for the chunk database and pretraining neighbors.
"""
retro_block_size: int = None
"""Number of records to load per data file, as saved during preprocessing. Block processing is
used for efficient data preprocessing.
"""
retro_chunk_length: int = None
"""Chunk length used for performing chunked- cross-attention (CCA)."""
retro_encoder_num_layers: int = 2
"""Number of layers to use for the retrieval encoder."""
retro_encoder_hidden_dropout: float = 0.1
"""Hidden dropout for retrieval encoder."""
retro_encoder_attention_dropout: float = 0.1
"""Attention dropout for retrieval encoder."""
retro_neighbor_dirs: dict = None
"""Directory names of saved neighbor id files for train, valid, and test datasets."""
retro_num_neighbors: int = 2
"""Number of neighbors to retrieve during pretraining."""
retro_num_retrieved_chunks: int = 2
"""Number of chunks to retrieve from the retrieval database."""
retro_retrieved_length: int = None
"""Cached value of retro_num_retrieved_chunks * retro_chunk_length (i.e., the total number of
retrieved tokens; neighbor + continuation).
"""
retro_split_preprocessing: str = None
"""Data split used during data preprocessing."""
retro_verify_neighbor_count: bool = True
"""Verify that len(GPT dataset) == len(saved neighbors)."""
def __post_init__(self) -> None:
"""Validate Retro config."""
super().__post_init__()
self.attention_backend = AttnBackend.unfused
# Validate Transformer Engine version.
if is_te_min_version("1.3"):
try:
assert os.getenv("NVTE_FLASH_ATTN") == "0"
assert os.getenv("NVTE_FUSED_ATTN") == "0"
except Exception as e:
raise Exception(
"When using Transformer Engine >= 1.3, environment vars NVTE_FLASH_ATTN "
"and NVTE_FUSED_ATTN most both be defined and set to '0'. "
"Currently, NVTE_FLASH_ATTN == %s, NVTE_FUSED_ATTN == %s."
% (
os.getenv("NVTE_FLASH_ATTN", "[unset]"),
os.getenv("NVTE_FUSED_ATTN", "[unset]"),
)
)
# Preprocessing split should be defined.
assert self.retro_split_preprocessing is not None
# Pre-compute retrieved length.
self.retro_retrieved_length = self.retro_num_retrieved_chunks * self.retro_chunk_length
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Retro's cross attention modules for the decoder block."""
from functools import partial
from typing import Callable
import numpy as np
import torch
from torch import Tensor
from megatron.core import InferenceParams
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.retro.base_attention import BaseRetroCrossAttention
from megatron.core.models.retro.config import RetroConfig
from megatron.core.models.retro.utils import get_all_true_mask
from megatron.core.transformer import ModuleSpec
from megatron.core.transformer.attention import CrossAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_block import TransformerBlock
class RetroDecoderCrossAttention(BaseRetroCrossAttention):
"""Retro decoder's chunked cross attention operator.
See this paper for more details: https://arxiv.org/abs/2112.04426.
Neighboring chunks retrieved from the chunk database are used here for
chunked-cross attention.
** Note about 'encoder_block_spec' **
Retro is an encoder-decoder model that uses its encoder for encoding
neighboring chunks that are retrieved from a chunk database. These
encoded neighbors are then used in the decoder stack for performing
chunked-cross attention (see paper link above).
In contrast to the T5 model, the encoder and decoder are computationally
intertwined, since the input to the encoder is the output of the self-
attention of the first decoder layer. As such, the encoder block itself
is instantiated within the first Retro decoder layer, in order to receive
the self-attention's output. (Note, that only the first decoder layer
instantiates an encoder block, and the remaining decoder layers use the
encoder output from the first decoder layer.)
Args:
config (RetroConfig): Retro config.
submodules (CrossAttentionSubmodules): Cross attention submodules.
layer_number (int): Layer number within transformer block.
attn_mask_type (AttnMaskType): Mask type ('causal' or 'padding').
encoder_block_spec (ModuleSpec): The first Retro decoder layer is provided with a transformer block spec to construct the neighbor encoder.
"""
def __init__(
self,
config: RetroConfig,
submodules: CrossAttentionSubmodules,
layer_number: int = 1,
attn_mask_type: AttnMaskType = AttnMaskType.padding,
encoder_block_spec: ModuleSpec = None,
):
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
)
if encoder_block_spec:
self.encoder = TransformerBlock(
config=config, spec=encoder_block_spec, pre_process=True, post_process=False
)
# self._encoder_key = 'encoder' # ... necessary?
else:
self.encoder = None
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
key_value_states: Tensor = None,
inference_params: InferenceParams = None,
# rotary_pos_emb: Tensor = None, # ... unsupported for retro.
) -> dict:
"""Cross attention for Retro decoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
m : Number of tokens per chunk.
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
Args:
hidden_states (Tensor): Transformer layer hidden states.
attention_mask (Tensor): Attention mask.
key_value_states (Tensor): Neighbor embeddings if first decoder layer, else encoder output.
inference_params (InferenceParams): Inference params.
Returns:
A dict consisting of the attention output and context, along with other scalars necessary for performing the downstream bias-dropout-add.
"""
# hidden_states: [ ns, bs, d ]
# key_value_states: [ r, k*bs*l, d ]
ns, bs, d = hidden_states.shape
l = int(np.ceil(ns / self.retro_chunk_length))
# Retrieve neighbors.
if self.encoder:
# Sequence length remainder.
first_ns = ns % self.retro_chunk_length
# Case 1: Sequence length not divisible by chunk length.
if first_ns > 0:
# Split sequence into first partial chunk & remaining chunks.
first_chunk, rest_chunk = hidden_states[:first_ns], hidden_states[first_ns:]
# Pad partial chunk with zeros.
first_chunk = torch.nn.functional.pad(
first_chunk, (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns), 'constant', 0
)
# Concatenate padded chunk with remaining chunks.
chunked_output = torch.cat((first_chunk, rest_chunk), dim=0) # [ l*m, bs, d ]
# Case 2: Sequence length is divisible by chunk length.
else:
chunked_output = hidden_states # [ l*m, bs, d ]
# Chunk & permute hidden states.
# - hidden_states: [ l*m, bs, d ]
# - chunked_output: [ m, bs*l, d ]
chunked_output = (
chunked_output.reshape(l, self.retro_chunk_length, bs, d)
.permute(1, 2, 0, 3)
.reshape(self.retro_chunk_length, bs * l, d)
.contiguous()
)
# flash attn: [ b, h, sq, sk ]
# fused attn: [ b, 1, 1, sq ]
chunked_output_mask = get_all_true_mask(
size=(1, 1, chunked_output.shape[0], key_value_states.shape[0]),
device=chunked_output.device,
)
# Encode neighbors. (Note: 'key_value_states' re-assigned here.)
key_value_states = self.encoder(
hidden_states=key_value_states,
attention_mask=attention_mask,
context=chunked_output,
context_mask=chunked_output_mask,
inference_params=inference_params,
) # [ r, k*bs*l, d ]
key_value_states = key_value_states.reshape(
self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d
) # [ r*k, bs*l, d ]
# Attend starting at last token of first chunk.
pad = (ns - 1) % self.retro_chunk_length
attending_chunks = hidden_states[pad:]
# Pad attending tokens to sequence length.
padded_chunks = torch.nn.functional.pad(
attending_chunks, (0, 0, 0, 0, 0, self.retro_chunk_length - 1), 'constant', 0
)
# Permute attending chunks.
# - padded_chunks: [ l*m, bs, d ]
# - padded_chunked_output: [ m, bs*l, d ] (matches 'chunked_output' above)
padded_chunked_output = padded_chunks.reshape(l, self.retro_chunk_length, bs, d).permute(
1, 2, 0, 3
)
padded_chunked_output = padded_chunked_output.reshape(
self.retro_chunk_length, bs * l, d
).contiguous()
# flash attn: [ b, h, sq, sk ]
# fused attn: [ b, 1, 1, sq ]
padded_chunked_output_mask = get_all_true_mask(
size=(1, 1, padded_chunked_output.shape[0], key_value_states.shape[0]),
device=padded_chunked_output.device,
)
# Attend to encoded neighbors.
attention_output, attention_bias = self.attn(
hidden_states=padded_chunked_output,
attention_mask=padded_chunked_output_mask,
key_value_states=key_value_states,
)
# Return dimensions for bias-dropout step.
return {
"ns": ns,
"bs": bs,
"d": d,
"l": l,
"pad": pad,
"attention_output": attention_output, # [ m, bs*l, d ]
"attention_bias": attention_bias, # [ d ]
"context": key_value_states, # [ r*k, bs*l, d ]
}
class RetroDecoderBiasDropoutAdd(MegatronModule):
"""Retro decoder's bias-dropout-add operator.
This operator takes care of reshaping and permuting the output from the
chunk dimension to the sequence dimension.
Args:
config (RetroConfig): Retro config.
"""
def __init__(self, config: RetroConfig):
super().__init__(config=config)
self.retro_chunk_length = config.retro_chunk_length
@classmethod
def _forward(
cls,
x_with_bias: dict,
residual: Tensor,
prob: float,
retro_chunk_length: int,
bias_dropout_add: Callable,
) -> Tensor:
"""Per-chunk bias-dropout-add.
Args:
x_with_bias (dict): Attention output and bias, along with other Retro relevant parameters.
residual (Tensor): Transformer layer residual.
prob (float): Dropout probability.
retro_chunk_length (int): Retro chunk length (e.g., 64).
bias_dropout_add (Callable): Bias-dropout-add function.
Returns:
Output of bias-dropout-add.
"""
# Extract input dict.
ns = x_with_bias["ns"]
bs = x_with_bias["bs"]
d = x_with_bias["d"]
l = x_with_bias["l"]
pad = x_with_bias["pad"]
attention_output = x_with_bias["attention_output"] # [ m, bs*l, d ]
attention_bias = x_with_bias["attention_bias"] # [ d ]
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
# Bias-dropout-add.
x = bias_dropout_add(
(
attention_output,
None if attention_bias is None else attention_bias.expand_as(attention_output),
),
torch.zeros_like(attention_output),
prob,
)
# Permute chunks back to sequence dimension.
# 1. [ m, bs*l, d ]
# 2. [ m, bs, l, d ]
# 3. [ l, m, bs, d ]
# 4. [ m*l, bs, d ] == [ ns, bs, d ]
x = (
x.reshape(retro_chunk_length, bs, l, d)
.permute(2, 0, 1, 3)
.reshape(retro_chunk_length * l, bs, d)
)
# Prepend zeros for non-attending tokens.
x = torch.nn.functional.pad(x, (0, 0, 0, 0, pad, 0), 'constant', 0)[
:ns
] # [ ns, bs, d ]
# Add residual. [ ns, bs, d ]
x = x + residual
# Output. [ ns, bs, d ]
return x
def forward(self, training: bool, fused: bool) -> partial:
"""Retro decoder bias-dropout-add.
Args:
training (bool): If training, then apply dropout.
fused (bool): Fuse bias-dropout-add.
Returns:
The partial function for performing bias-dropout-add.
"""
return partial(
self._forward,
retro_chunk_length=self.retro_chunk_length,
bias_dropout_add=get_bias_dropout_add(training, fused),
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Specs for Retro decoder."""
import typing
from megatron.core import parallel_state
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
from megatron.core.models.retro.config import RetroConfig
from megatron.core.models.retro.decoder_attention import (
RetroDecoderBiasDropoutAdd,
RetroDecoderCrossAttention,
)
from megatron.core.models.retro.encoder_spec import get_retro_encoder_block_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer import ModuleSpec
from megatron.core.transformer.attention import CrossAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.transformer_block import (
TransformerBlockSubmodules,
get_num_layers_to_build,
)
try:
import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn(f'Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TENorm,
TERowParallelLinear,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
def get_retro_decoder_layer_te_spec(
encoder_block_spec: typing.Union[ModuleSpec, TransformerBlockSubmodules, None] = None
) -> ModuleSpec:
"""Retro decoder TE spec (uses Transformer Engine components).
A Retro decoder layer uses custom attention and bias-dropout-add operators
to perform chunked-cross attention. Additionally, the first Retro decoder
layer instantiates an entire encoder transformer block. As such, the decoder
cross attention module takes an optional encoder block spec, which is only
provided for the first Retro decoder layer.
Args:
encoder_block_spec (ModuleSpec): Retro encoder block spec, to be provided for
the first Retro decoder layer.
Returns:
A module spec with Transformer Engine modules.
"""
spec = get_gpt_layer_with_transformer_engine_spec()
spec.submodules.pre_cross_attn_layernorm = TENorm
spec.submodules.cross_attention = ModuleSpec(
module=RetroDecoderCrossAttention,
params={"encoder_block_spec": encoder_block_spec},
submodules=CrossAttentionSubmodules(
linear_q=TEColumnParallelLinear,
linear_kv=TEColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
)
spec.submodules.cross_attn_bda = ModuleSpec(module=RetroDecoderBiasDropoutAdd)
return spec
def get_retro_decoder_layer_local_spec(
encoder_block_spec: typing.Optional[ModuleSpec] = None,
) -> ModuleSpec:
"""Retro decoder local spec (uses Megatron-Core components).
A Retro decoder layer uses custom attention and bias-dropout-add operators
to perform chunked-cross attention. Additionally, the first Retro decoder
layer instantiates an entire encoder transformer block. As such, the decoder
cross attention module takes an optional encoder block spec, which is only
provided for the first Retro decoder layer.
Args:
encoder_block_spec (ModuleSpec): Retro encoder block spec, to be provided
for the first Retro decoder layer.
Returns:
A module spec with local modules.
"""
spec = get_gpt_layer_local_spec()
spec.submodules.pre_cross_attn_layernorm = LNImpl
spec.submodules.cross_attention = ModuleSpec(
module=RetroDecoderCrossAttention,
params={"encoder_block_spec": encoder_block_spec},
submodules=CrossAttentionSubmodules(
linear_q=ColumnParallelLinear,
linear_kv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
),
)
spec.submodules.cross_attn_bda = ModuleSpec(module=RetroDecoderBiasDropoutAdd)
return spec
def get_retro_decoder_block_spec(
config: RetroConfig, use_transformer_engine: bool
) -> TransformerBlockSubmodules:
"""Retro decoder block spec.
Retro decoder block implementation details:
- The retro decoder block consists of interleaved GPT layers
and customized Retro decoder layers.
- The Retro decoder layers are spaced three layers apart,
and start on layer 6 or 9 (depending on the total number of layers).
- The first decoder layer instantiates an encoder block,
and it therefore passes in an encoder_block_spec.
Args:
config (RetroConfig): Retro config.
use_transformer_engine (bool): If True, use Transformer Engine (instead of local modules.
Returns:
Transformer block submodules for the given spec.
"""
# Num layers.
assert (
parallel_state.get_pipeline_model_parallel_world_size() == 1
), "retro does not currently support pipeline parallelism."
assert (
parallel_state.get_virtual_pipeline_model_parallel_world_size() is None
), "retro does not currently support virtual pipeline parallelism."
num_layers = get_num_layers_to_build(config)
# Retro layer numbers.
retro_layer_start = 6 if num_layers <= 15 else 9
retro_layer_numbers = list(range(retro_layer_start, num_layers + 1, 3))
# Layer specs.
gpt_layer_spec = (
get_gpt_layer_with_transformer_engine_spec()
if use_transformer_engine
else get_gpt_layer_local_spec()
)
get_retro_decoder_layer_spec = (
get_retro_decoder_layer_te_spec
if use_transformer_engine
else get_retro_decoder_layer_local_spec
)
retro_layer_spec = get_retro_decoder_layer_spec()
retro_layer_spec_with_retriever = get_retro_decoder_layer_spec(
get_retro_encoder_block_spec(config, use_transformer_engine)
)
layer_specs = []
for layer_number in range(1, num_layers + 1):
if layer_number == retro_layer_numbers[0]:
layer_specs.append(retro_layer_spec_with_retriever)
elif layer_number in retro_layer_numbers:
layer_specs.append(retro_layer_spec)
else:
layer_specs.append(gpt_layer_spec)
# Block spec.
block_spec = TransformerBlockSubmodules(layer_specs=layer_specs)
return block_spec
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Retro's cross attention modules for the encoder block."""
from functools import partial
from typing import Callable, List, Optional, Tuple, Type
import torch
from torch import Tensor
from megatron.core import InferenceParams
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.retro.base_attention import BaseRetroCrossAttention
from megatron.core.models.retro.config import RetroConfig
from megatron.core.models.retro.utils import get_all_true_mask
from megatron.core.transformer.module import MegatronModule
class RetroEncoderCrossAttention(BaseRetroCrossAttention):
"""Retro encoder's cross attention operator.
See this paper for more details: https://arxiv.org/abs/2112.04426.
Neighboring chunks are retrieved from the chunk database, encoded, and
used by the decoder layers for chunked cross attention.
Args:
config (RetroConfig): Retro config.
submodules (CrossAttentionSubmodules): Cross attention submodules.
layer_number (int): Layer number within transformer block.
attn_mask_type (AttnMaskType): Mask type ('causal' or 'padding').
"""
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
key_value_states: Tensor = None,
inference_params: InferenceParams = None,
# rotary_pos_emb: Tensor = None, # unsupported for retro.
) -> List[Tuple[Tensor, Optional[Tensor], Tensor]]:
"""Cross attention for Retro encoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
Args:
hidden_states (Tensor): Transformer layer hidden states.
attention_mask (Tensor): Attention mask.
key_value_states (Tensor): Neighbor embeddings.
inference_params (InferenceParams): Inference params.
Returns:
List of tuples, where each tuple is (attention_output, attention_bias, residual).
"""
# Input shape. [ r, bs*l*k, d ]
ns, bs, d = hidden_states.shape
# Reshape sequence into neighboring chunks.
# - hidden_states: [ r, bs*l*k, d ]
# - chunked_outputs: [ r, bs*l, k, d ]
chunked_outputs = hidden_states.reshape(
self.retro_retrieved_length, -1, self.retro_num_neighbors, d
)
# flash attn: [ b, h, sq, sk ]
# fused attn: [ b, 1, 1, sq ]
chunked_output_mask = get_all_true_mask(
size=(1, 1, chunked_outputs.shape[0], key_value_states.shape[0]),
device=chunked_outputs.device,
)
# Per-chunk attention.
attention_output_tuples = []
for k in range(self.retro_num_neighbors):
# Attend to current neighboring chunks.
# - chunked_output: [ r, bs*l, d ]
# - key_value_states: [ m, bs*l, d ]
# - attention_output: [ r, bs*l, d ]
# - attention_bias: [ d ]
chunked_output = chunked_outputs[:, :, k].contiguous()
attention_output, attention_bias = self.attn(
hidden_states=chunked_output, # Q (neighbor embedding)
attention_mask=chunked_output_mask,
key_value_states=key_value_states, # K, V (hidden act)
)
# Residual connection. [ r, bs*l, d ]
residual = chunked_output
# Collect tensors.
attention_output_tuples.append((attention_output, attention_bias, residual))
# Output. (List[Tuple[( [ r, bs*l, d ], [ d ] )]])
return attention_output_tuples
class RetroEncoderBiasDropoutAdd(MegatronModule):
"""Retro encoder's bias-dropout-add operator.
This operator applies bias-dropout-add individually on each neighboring
chunk that is retrieved from the chunk database.
Args:
config (RetroConfig): Retro config.
"""
def __init__(self, config: RetroConfig):
super().__init__(config=config)
self.retro_num_neighbors = config.retro_num_neighbors
@classmethod
def _forward(
cls,
x_with_bias: List[Tuple[Tensor, Optional[Tensor], Tensor]],
residual: Tensor,
prob: float,
retro_num_neighbors: int,
bias_dropout_add: Callable,
) -> Tensor:
"""Per-chunk bias-dropout-add.
Args:
x_with_bias (dict): Attention output and bias tuple.
residual (Tensor): Transformer layer residual.
prob (float): Dropout probability.
retro_num_neighbors (int): Number of retrieved neighbor chunks (e.g., 2).
bias_dropout_add (Callable): Bias-dropout-add function.
Returns:
Output of bias-dropout-add.
"""
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
# Per-neighbor bias-dropout-add.
# - attention_output: [ r, bs*l, d ]
# - attention_bias: [ d ]
# - residual: [ r, bs*l, d ]
# - output: [ r, bs*l, d ]
outputs = [
bias_dropout_add(
(
attention_output,
None if attention_bias is None else attention_bias.expand_as(residual),
),
residual,
prob,
)
for attention_output, attention_bias, residual in x_with_bias
]
# Concatenate outputs (to shape [r, k*bs*l, d]; see notation above).
r, _, d = outputs[0].shape
output = torch.stack(outputs, dim=1).reshape(r, -1, d)
# Output. [ r, k*bs*l, d ]
return output
def forward(self, training: bool, fused: bool) -> partial:
"""Retro decoder bias-dropout-add.
Args:
training (bool): If training, then apply dropout.
fused (bool): Fuse bias-dropout-add.
Returns:
A partial function for performing bias-dropout-add.
"""
return partial(
self._forward,
retro_num_neighbors=self.retro_num_neighbors,
bias_dropout_add=get_bias_dropout_add(training, fused),
)
class RetroEncoderLayerNorm(MegatronModule):
"""Retro encoder's layernorm operator.
This operator applies layernorm individually on each neighboring chunk that
is retrieved from the chunk database, and then concatenates the chunks into
a single tensor.
Args:
config (RetroConfig): Retro config.
submodules (Type): Layer norm class. (Named 'submodules' to fit external interface.)
"""
def __init__(self, config: RetroConfig, submodules: Type, **kwargs: dict):
super().__init__(config=config)
norm_class = submodules
self.norm = norm_class(config=config, **kwargs)
self.retro_num_neighbors = config.retro_num_neighbors
def forward(self, input: Tensor) -> Tensor:
"""Per-chunk layer norm.
Args:
input (Tensor): Input chunks, concatenated into a single tensor.
Returns:
Output of the layer norm.
"""
# Input shape: [ r, k*bs*l, d ]. (see notation above in attention module)
# Split input into 'num_neighbors' tensors.
chunk_size = input.shape[1] // self.retro_num_neighbors
inputs = torch.split(input, chunk_size, dim=1)
# Norm.
outputs = [self.norm(inp.contiguous()) for inp in inputs]
# Concatenate layer norms (to shape [r, k*bs*l, d]; see notation above).
r, _, d = inputs[0].shape
output = torch.stack(outputs, dim=1).reshape(r, -1, d)
# Output. [ r, k*bs*l, d ]
return output
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Specs for Retro encoder."""
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
from megatron.core.models.retro.config import RetroConfig
from megatron.core.models.retro.encoder_attention import (
RetroEncoderBiasDropoutAdd,
RetroEncoderCrossAttention,
RetroEncoderLayerNorm,
)
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer import ModuleSpec
from megatron.core.transformer.attention import CrossAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TENorm,
TERowParallelLinear,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn(f'Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
def get_retro_encoder_layer_te_spec() -> ModuleSpec:
"""Retro encoder TE spec (uses Transformer Engine components).
A Retro encoder layer uses custom attention, bias-dropout-add, and layernorm
operators to encode neighboring chunks that are retrieved from the chunk
database. Each operator is responsible for iterating the retrieved chunks
and processing them individually.
Returns:
A module spec if Transformer Engine modules.
"""
spec = get_gpt_layer_with_transformer_engine_spec()
spec.submodules.pre_cross_attn_layernorm = TENorm
spec.submodules.cross_attention = ModuleSpec(
module=RetroEncoderCrossAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=CrossAttentionSubmodules(
linear_q=TEColumnParallelLinear,
linear_kv=TEColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
)
spec.submodules.cross_attn_bda = ModuleSpec(module=RetroEncoderBiasDropoutAdd)
spec.submodules.pre_mlp_layernorm = ModuleSpec(module=RetroEncoderLayerNorm, submodules=TENorm)
spec.submodules.mlp = ModuleSpec(
module=MLP,
submodules=MLPSubmodules(linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear),
)
return spec
def get_retro_encoder_layer_local_spec() -> ModuleSpec:
"""Retro encoder local spec (uses Megatron-Core components).
A Retro encoder layer uses custom attention, bias-dropout-add, and layernorm
operators to encode neighboring chunks that are retrieved from the chunk
database. Each operator is responsible for iterating the retrieved chunks
and processing them individually.
Returns:
A module spec if local modules.
"""
spec = get_gpt_layer_local_spec()
spec.submodules.pre_cross_attn_layernorm = LNImpl
spec.submodules.cross_attention = ModuleSpec(
module=RetroEncoderCrossAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=CrossAttentionSubmodules(
linear_q=ColumnParallelLinear,
linear_kv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
),
)
spec.submodules.cross_attn_bda = ModuleSpec(module=RetroEncoderBiasDropoutAdd)
spec.submodules.pre_mlp_layernorm = ModuleSpec(module=RetroEncoderLayerNorm, submodules=LNImpl)
spec.submodules.mlp = ModuleSpec(
module=MLP,
submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear),
)
spec.submodules.sharded_state_dict_keys_map = {
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_'
} # pre_mlp_layernorm doesn't need remapping
return spec
def get_retro_encoder_block_spec(
config: RetroConfig, use_transformer_engine: bool
) -> TransformerBlockSubmodules:
"""Retro encoder block spec.
The retro encoder block consists of one customized Retro encoder layer
(layer 1), and all of the following layers are standard GPT layers.
Args:
config (RetroConfig): Retro config.
use_transformer_engine (bool): If True, use Transformer Engine (instead of local modules).
Returns:
Transformer block submodules for the given spec.
"""
# Num layers.
num_layers = config.retro_encoder_num_layers
retro_layer_numbers = [1]
# Layer specs.
gpt_layer_spec = (
get_gpt_layer_with_transformer_engine_spec()
if use_transformer_engine
else get_gpt_layer_local_spec()
)
get_retro_encoder_layer_spec = (
get_retro_encoder_layer_te_spec
if use_transformer_engine
else get_retro_encoder_layer_local_spec
)
retro_layer_spec = get_retro_encoder_layer_spec()
for spec in (gpt_layer_spec, retro_layer_spec):
spec.params["hidden_dropout"] = config.retro_encoder_hidden_dropout
spec.submodules.self_attention.params["attn_mask_type"] = AttnMaskType.padding
spec.submodules.self_attention.submodules.core_attention = ModuleSpec(
module=TEDotProductAttention if use_transformer_engine else DotProductAttention,
params={"attention_dropout": config.retro_encoder_attention_dropout},
)
layer_specs = []
for layer_number in range(1, num_layers + 1):
if layer_number in retro_layer_numbers:
layer_specs.append(retro_layer_spec)
else:
layer_specs.append(gpt_layer_spec)
# Block spec.
block_spec = TransformerBlockSubmodules(layer_specs=layer_specs)
return block_spec
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Retro Model."""
from typing import Dict, Optional
from torch import Tensor
from megatron.core import InferenceParams
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.gpt import GPTModel
class RetroModel(GPTModel):
"""Retro Model.
A Retro model mostly re-uses the GPTModel interface, with the only difference
being the embedding of the 'context' this is used by Retro for processing
neighbor tokens. This embedded context is then forwarded to the Transformer
Block.
"""
def forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
context_input_ids: Tensor = None,
context_position_ids: Tensor = None,
context_mask: Tensor = None,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
) -> Tensor:
"""RetroModel forward method.
Foward input tokens & mask, along with neighbor tokens & mask, through
the Retro model..
Args:
input_ids (Tensor): Input token IDs.
position_ids (Tensor): Input position IDs.
attention_mask (Tensor): Input attention mask.
context_input_ids (Tensor): Context (i.e., neighbor) token IDs.
context_position_ids (Tensor): Context (i.e., neighbor) position IDs.
context_mask (Tensor): Context (i.e., neighbor) attention mask.
decoder_input (Tensor): When using pipeline parallelism, input_ids and position_ids will only be used on the first stage, and for all other stages decoder_input will be provided via communication from the previous stage.
labels (Tensor): The labels of dimension [batch size, seq length].
inference_params (InferenceParams): Parameters for inference.
Returns:
Output tensor of forward pass.
"""
# Argument shapes:
# Notation:
# ns : Sequence length.
# bs : Batch size.
# d : Hidden size.
# l : Number of chunks per sample (i.e., seq_length/chunk_length).
# k : Number of neighbors.
# r : Number of retrieved tokens (neighbors + continuation).
# - input_ids: [ bs, ns ]
# - context_ids: [ k*bs*l, r ]
# - context: [ r, k*bs*l, d ]
# - output: [ ns, bs, d ]
# Context embedding (e.g., for Retro neighbor tokens).
if context_input_ids is not None:
context = self.embedding(context_input_ids, context_position_ids)
else:
context = None
# Call GPTModel.forward, and pass in embedded context.
return super().forward(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
decoder_input=decoder_input,
labels=labels,
inference_params=inference_params,
extra_block_kwargs={"context": context, "context_mask": context_mask},
)
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None
) -> ShardedStateDict:
"""Get sharded state dict.
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): Offsets of local shard within global tensor.
metadata (Optional[Dict]): Shard metadata.
Returns:
A <ShardedStateDict> ?
"""
metadata = metadata or {}
metadata['non_homogeneous_layers'] = True
return super().sharded_state_dict(prefix, sharded_offsets, metadata)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import os
import torch
def get_config_path(project_dir: str) -> str:
"""Config copy stored within retro project dir."""
return os.path.join(project_dir, "config.json")
def get_gpt_data_dir(project_dir: str) -> str:
"""Get project-relative directory of GPT bin/idx datasets."""
return os.path.join(project_dir, "data")
# ** Note ** : Retro's compatibility between cross attention and Flash/Fused
# Attention is currently a work in progress. We default to returning None for
# now.
# def get_all_true_mask(size, device):
# return torch.full(size=size, fill_value=True, dtype=torch.bool, device=device)
def get_all_true_mask(size, device):
return None
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Optional, Union
import torch
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.models.common.vision_module.vision_module import VisionModule
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
try:
import transformer_engine # pylint: disable=unused-import
from megatron.core.extensions.transformer_engine import TENorm
NORM_IMPL = TENorm
except:
NORM_IMPL = torch.nn.LayerNorm
# Note: This is under development and is missing features like position embedding interpolation.
class CLIPViTModel(VisionModule):
"""CLIP ViT vision model.
Args:
transformer_config (TransformerConfig): Transformer config.
transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers.
ln_pre_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_pre.
add_class_token (bool, optional): Include a class token. Defaults to True.
class_token_len (int): Class token length. Defaults to 1 but 8 may be faster.
patch_dim (int): Image patch size.
img_h (int): Input image height.
img_w (int): Input image width.
"""
def __init__(
self,
transformer_config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
ln_pre_impl: Union[ModuleSpec, type] = NORM_IMPL,
ln_post_impl: Union[ModuleSpec, type] = NORM_IMPL,
add_class_token: bool = True,
class_token_len: int = 1,
patch_dim: int = 14,
img_h: int = 336,
img_w: int = 336,
model_subtype: str = "clip",
) -> None:
error_msg = f"CLIPViTModel model subtype {model_subtype} is not supported."
assert model_subtype in ["clip", "siglip", "internvit"], error_msg
if model_subtype == "siglip":
assert class_token_len == 0, "SigLIP does not support class tokens."
assert not add_class_token, "SigLIP does not support class tokens."
super().__init__(config=transformer_config)
if has_config_logger_enabled(transformer_config):
log_config_to_disk(transformer_config, locals(), prefix=type(self).__name__)
self.class_token_len = class_token_len
self.visual_hidden_size = transformer_config.hidden_size
self.patch_dim = patch_dim
self.img_h = img_h
self.img_w = img_w
assert self.img_h % self.patch_dim == 0
assert self.img_w % self.patch_dim == 0
self.num_patches_per_dim_h = self.img_h // self.patch_dim
self.num_patches_per_dim_w = self.img_w // self.patch_dim
self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w
self.add_class_token = add_class_token
self.class_token_len = class_token_len
self.seq_length = self.num_patches + (self.class_token_len if self.add_class_token else 0)
self.ln_pre = None
self.ln_post = None
if model_subtype == "clip":
self.ln_pre = build_module(
ln_pre_impl,
config=transformer_config,
hidden_size=self.visual_hidden_size,
eps=transformer_config.layernorm_epsilon,
)
conv_bias = False
padding = 0
elif model_subtype == "siglip":
self.ln_post = build_module(
ln_post_impl,
config=transformer_config,
hidden_size=self.visual_hidden_size,
eps=transformer_config.layernorm_epsilon,
)
conv_bias = True
padding = "valid"
elif model_subtype == "internvit":
conv_bias = True
padding = 0
else:
raise ValueError(f"unsupported vision model type {model_subtype}")
self.conv1 = torch.nn.Conv2d(
in_channels=3,
out_channels=self.visual_hidden_size,
kernel_size=self.patch_dim,
stride=self.patch_dim,
bias=conv_bias,
padding=padding,
)
self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
self.position_embeddings = torch.nn.Embedding(self.seq_length, self.visual_hidden_size)
self.add_class_token = add_class_token
if self.add_class_token:
self.class_token = torch.nn.Parameter(
torch.randn(1, self.class_token_len, self.visual_hidden_size)
)
self.model_type = ModelType.encoder_or_decoder
# Transformer layers.
# TODO: Make pre_process and post_process configurable.
# NOTE: a final layer norm and/or linear layer in some implementations are omitted here.
# They can be added separately where needed.
self.decoder = TransformerBlock(
config=transformer_config,
spec=transformer_layer_spec,
pre_process=True,
post_process=False,
)
def set_input_tensor(self, input_tensor: torch.Tensor) -> None:
"""Sets input tensor to the model.
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
self.decoder.set_input_tensor(input_tensor)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Forward function of the CLIP ViT Model. This function passes the input tensors
through the embedding layer and then the transformer.
Args:
x (torch.Tensor): input data of shape [batch, img_h, img_w]
attention_mask (torch.Tensor with dtype=bool): Attention mask to use.
Returns:
x (torch.Tensor): output after final transformer block of shape [b, s, h].
"""
x = self.conv1(x) # shape = [batch, hidden_size, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # [batch, hidden_size, grid ** 2]
x = x.permute(0, 2, 1) # [batch, grid ** 2, hidden_size]
if self.add_class_token:
class_token = self.class_token.expand(
x.shape[0], -1, -1
) # [batch, class_token_len, hidden_size]
x = torch.cat(
[class_token, x], dim=1
) # [batch, grid ** 2 + class_token_len, hidden_size]
assert x.shape[1] == self.seq_length, f"{x.shape[1]} != {self.seq_length}"
x = x + self.position_embeddings(self.position_ids)
if self.ln_pre:
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # [b, s, h] -> [s, b, h]
# `permute` can make the tensor non-contiguous, breaking pipelining.
x = x.contiguous()
x = self.decoder(x, attention_mask)
x = x.permute(1, 0, 2) # [s, b, h] -> [b, s, h]
x = x.contiguous()
if self.ln_post:
x = self.ln_post(x)
return x
def get_num_image_embeddings(
img_h,
img_w,
patch_dim,
vision_model_type,
disable_vision_class_token,
class_token_len,
pixel_shuffle=False,
use_tile_tags=False,
):
"""Get the number of image embeddings per image tile."""
if vision_model_type == "siglip":
keep_class_token = False
elif vision_model_type in ("clip", "internvit"):
keep_class_token = not disable_vision_class_token
else:
raise ValueError(f"unsupported vision model: {vision_model_type}")
num_patches_per_dim_h = img_h // patch_dim
num_patches_per_dim_w = img_w // patch_dim
num_patches = num_patches_per_dim_h * num_patches_per_dim_w
num_image_embeddings_per_tile = num_patches + (class_token_len if keep_class_token else 0)
if pixel_shuffle:
num_image_embeddings_per_tile = int(num_image_embeddings_per_tile * (0.5**2))
if use_tile_tags:
# The length of tile tags tokenized. Currently, the same across tokenizers used.
num_image_embeddings_per_tile += 5
return num_image_embeddings_per_tile
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_viewless_tensor
class MultimodalProjector(MegatronModule):
"""
MultimodalProjector will take the encoded input with input_size hidden state and project
it into the hidden size of the language model for multimodal training. When projector is
type affine linear_fc1 from submodules is used.
Args:
transformer_config (TransformerConfig): Transformer config
submodules (MLPSubmodules): Specifies MLP submodules for mlp type projector
projector_type (str): Projector type
input_size (int): Input size from feature encoder
"""
def __init__(
self,
config: TransformerConfig,
submodules: MLPSubmodules,
projector_type: str,
input_size: int,
):
super().__init__(config=config)
self.projector_type = projector_type
assert submodules is not None, "MLPSubmodules must be provided"
if self.projector_type == "mlp":
self.encoder = MLP(config=config, submodules=submodules, input_size=input_size)
elif self.projector_type == "affine":
self.encoder = build_module(
submodules.linear_fc1,
input_size,
config.hidden_size,
config=config,
init_method=config.init_method,
gather_output=True,
bias=config.add_bias_linear,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name=None,
)
else:
raise Exception(f"Unsupported multimodal projection type {self.projector_type}")
def forward(self, hidden_states):
"""Run multimodal projector.
Args:
hidden_states (torch.Tensor): Input.
Returns:
torch.Tensor: The projected output.
"""
# Run encoder.
encoder_output, encoder_output_bias = self.encoder(hidden_states)
if encoder_output_bias is not None:
encoder_output = encoder_output + encoder_output_bias
# the encoder produces "viewed" tensor. This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
encoder_output = make_viewless_tensor(
inp=encoder_output, requires_grad=True, keep_graph=True
)
return encoder_output
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.extensions.transformer_engine import (
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TERowParallelLinear,
)
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
try:
import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn(f'Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
# Use this spec to use lower level Transformer Engine modules (required for fp8 training)
def get_vit_layer_with_transformer_engine_spec() -> ModuleSpec:
'''
Returns ViT layer spec with Transformer Engine layers
'''
mlp = _get_mlp_module_spec(use_te=True)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.no_mask},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
def get_vit_layer_with_local_spec() -> ModuleSpec:
'''
Returns ViT layer spec with Mcore local layers
'''
mlp = _get_mlp_module_spec(use_te=False)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
# Helper function to get module spec for MLP/MoE
def _get_mlp_module_spec(use_te: bool = True) -> ModuleSpec:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Megatron Core number of microbatches calculators."""
import logging
from abc import ABC, abstractmethod
from typing import List, Optional, Union
logger = logging.getLogger(__name__)
# TODO: global_var merge into mcore?
_GLOBAL_NUM_MICROBATCHES_CALCULATOR: Union[
'ConstantNumMicroBatchesCalculator', 'RampupBatchsizeNumMicroBatchesCalculator'
] = None
def get_num_microbatches() -> int:
"""Get number of microbatches."""
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
def get_current_global_batch_size() -> int:
"""Get current global batch size."""
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
def get_micro_batch_size() -> int:
"""Get micro batch size."""
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_micro_batch_size()
def get_current_running_global_batch_size() -> int:
"""Get current running global batch size, taking into account number of DP replicas might be
incompatible with true global batch size if `decrease_batch_size_if_needed` is True."""
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_running_global_batch_size()
def update_num_microbatches(
consumed_samples: int, consistency_check: bool = True, verbose: bool = False
) -> None:
"""Update number of microbatches.
Args:
consumed_samples (int):
Number of samples consumed.
consistency_check (bool, optional):
Option to check current schedule's consistency. Defaults to True.
verbose (bool, optional):
Option to control logging. Defaults to False.
"""
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, consistency_check, verbose)
def unset_num_microbatches_calculator():
"""Unset microbatches calculator.
Useful for multiple runs. See `tests/unit_tests/ckpt_converter/test_ckpt_converter.py`
for an example.
"""
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
def init_num_microbatches_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
decrease_batch_size_if_needed: bool = False,
) -> None:
"""Initialize number of microbatches calculator. Supporting backward compatibility.
Args:
rank (int):
Rank of the GPU, only rank 0 will log the information.
rampup_batch_size (Optional[List[int]]):
Rampup batch size, should be in format of [start_global_batch_size,
batch_size_increment, ramup_samples].
global_batch_size (int):
Global batch size for the model.
micro_batch_size (int):
Micro batch size at initialization.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool, optional):
If true, scale down batch size to ensure divisibility by DP size * microbatch size.
Defaults to False.
"""
_configure_global_num_microbatches_calculator(
rank,
rampup_batch_size,
global_batch_size,
micro_batch_size,
data_parallel_size,
decrease_batch_size_if_needed,
init=True,
)
def destroy_num_microbatches_calculator():
"""Destroy number of microbatches calculator."""
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
def reconfigure_num_microbatches_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
decrease_batch_size_if_needed: bool = False,
) -> None:
"""Reconfigure number of microbatches calculator. Supporting backward compatibility.
Args:
rank (int):
Rank of the GPU, only rank 0 will log the information.
rampup_batch_size (Optional[List[int]]):
Rampup batch size, should be in format of
[start_global_batch_size, batch_size_increment, ramup_samples].
global_batch_size (int):
Global batch size for the model.
micro_batch_size (int):
Micro batch size at initialization.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool, optional):
If true, scale down batch size to ensure divisibility by DP size * microbatch size.
Defaults to False.
"""
_configure_global_num_microbatches_calculator(
rank,
rampup_batch_size,
global_batch_size,
micro_batch_size,
data_parallel_size,
decrease_batch_size_if_needed,
init=False,
)
def _configure_global_num_microbatches_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
decrease_batch_size_if_needed: bool = False,
init: bool = False,
) -> None:
"""Configure number of microbatches calculator. Can be used for initialization and
reconfiguration.
Args:
rank (int):
Rank of the GPU, only rank 0 will log the information.
rampup_batch_size (Optional[List[int]]):
Rampup batch size, should be in format of
[start_global_batch_size, batch_size_increment, ramup_samples].
global_batch_size (int):
Global batch size for the model.
micro_batch_size (int):
Micro batch size at initialization.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool, optional):
If true, scale down batch size to ensure divisibility by DP size * microbatch size.
Defaults to False.
init (bool, optional):
If true, initialize the calculator. Defaults to False.
"""
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
if init:
assert (
_GLOBAL_NUM_MICROBATCHES_CALCULATOR is None
), 'num microbatches calculator is already initialized.'
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = _build_num_microbatches_calculator(
rank,
rampup_batch_size,
global_batch_size,
micro_batch_size,
data_parallel_size,
decrease_batch_size_if_needed,
)
def _build_num_microbatches_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
decrease_batch_size_if_needed: bool,
) -> Union['ConstantNumMicroBatchesCalculator', 'RampupBatchsizeNumMicroBatchesCalculator']:
"""Build number of microbatches calculator. Internal helper method.
Args:
rank (int):
Rank of the GPU, only rank 0 will log the information.
rampup_batch_size (Optional[List[int]]):
Rampup batch size, should be in format of
[start_global_batch_size, batch_size_increment, ramup_samples].
global_batch_size (int):
Global batch size for the model.
micro_batch_size (int):
Micro batch size at initialization.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool):
If true, scale down batch size to ensure divisibility by DP size * microbatch size.
"""
# Constant batch size.
if rampup_batch_size is None:
num_microbatches_calculator = ConstantNumMicroBatchesCalculator(
global_batch_size,
micro_batch_size,
data_parallel_size,
decrease_batch_size_if_needed,
rank,
)
if rank == 0:
logger.info(
f'setting number of microbatches to constant {num_microbatches_calculator.get()}'
)
# Batch size ramp up.
else:
assert len(rampup_batch_size) == 3, (
'expected the following '
'format: --rampup-batch-size <start batch size> '
'<batch size incerement> <ramp-up samples>'
)
start_global_batch_size = int(rampup_batch_size[0])
batch_size_increment = int(rampup_batch_size[1])
ramup_samples = int(rampup_batch_size[2])
if rank == 0:
logger.info(
f'will use batch size rampup starting from global batch size '
f'{start_global_batch_size} to global batch size {global_batch_size} with batch'
f'size increments {batch_size_increment} over {ramup_samples} samples.'
)
num_microbatches_calculator = RampupBatchsizeNumMicroBatchesCalculator(
global_batch_size,
micro_batch_size,
data_parallel_size,
decrease_batch_size_if_needed,
rank,
start_global_batch_size,
batch_size_increment,
ramup_samples,
)
return num_microbatches_calculator
def _round(batch_size: int, divisor: int) -> int:
"""Round `batch_size` down to nearest batch size divisible by `divisor`."""
return (batch_size // divisor) * divisor
class NumMicroBatchesCalculator(ABC):
"""Base class for number of microbatches calculator."""
def __init__(self) -> None:
self.num_micro_batches = None
self.current_global_batch_size = None
self.micro_batch_size = None
self.current_running_global_batch_size = None
def get(self) -> int:
"""Get number of microbatches."""
return self.num_micro_batches
def get_current_global_batch_size(self) -> int:
"""Get current global batch size."""
return self.current_global_batch_size
def get_micro_batch_size(self) -> int:
"""Get current global batch size."""
return self.micro_batch_size
def get_current_running_global_batch_size(self) -> int:
"""Get current running global batch size. If decrease_batch_size_if_needed is False,
this just equals global batch size."""
return self.current_running_global_batch_size
@abstractmethod
def update(self, consumed_samples, consistency_check, verbose=False) -> None:
"""Update number of microbatches depending on batch size rampup."""
pass
class ConstantNumMicroBatchesCalculator(NumMicroBatchesCalculator):
"""Calculator of number of microbatches with constant global batch size.
Args:
global_batch_size (int):
Global batch size.
micro_batch_size (int):
Micro batch size.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool):
If true, decrease batch size to ensure divisibility by DP size * microbatch size
(if needed).
rank (int):
Rank (to determine whether logging should be performed).
"""
def __init__(
self,
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
decrease_batch_size_if_needed: bool,
rank: int,
) -> None:
micro_batch_times_data_parallel_size = micro_batch_size * data_parallel_size
if decrease_batch_size_if_needed:
running_global_batch_size = _round(
global_batch_size, micro_batch_times_data_parallel_size
)
assert running_global_batch_size % micro_batch_times_data_parallel_size == 0
if rank == 0:
logger.info(
f'decreasing batch size from {global_batch_size} to {running_global_batch_size}'
f'to keep divisiblity by micro_batch_size={micro_batch_size} * '
f'data_parallel_size={data_parallel_size}'
)
self.num_micro_batches = (
running_global_batch_size // micro_batch_times_data_parallel_size
)
else:
assert global_batch_size % micro_batch_times_data_parallel_size == 0, (
'global batch size ({}) is not divisible by micro batch size ({})'
' times data parallel size ({})'.format(
global_batch_size, micro_batch_size, data_parallel_size
)
)
running_global_batch_size = global_batch_size
self.num_micro_batches = global_batch_size // micro_batch_times_data_parallel_size
assert (
self.num_micro_batches >= 1
), 'number of microbatches should be at least 1, got {}.'.format(self.num_micro_batches)
self.current_global_batch_size = global_batch_size
self.current_running_global_batch_size = running_global_batch_size
self.micro_batch_size = micro_batch_size
def update(self, consumed_samples, consistency_check, verbose=False) -> None:
pass
class RampupBatchsizeNumMicroBatchesCalculator(NumMicroBatchesCalculator):
"""Calculator of number of microbatches with batch size rampup.
Over `steps = (global-batch-size - start-batch-size) / batch_size_increment` increment batch
size from start-batch-size to global-batch-size using rampup-samples / steps
samples.
Args:
global_batch_size (int):
Global batch size post rampup.
micro_batch_size (int):
Micro batch size.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool):
If true, decrease batch size to ensure divisibility by DP size * microbatch size
(if needed).
rank (int):
Rank (to determine whether logging should be performed).
start_global_batch_size (int):
Global batch size to start with.
batch_size_increment (int):
Global batch size increments.
ramup_samples (int):
Number of samples to use ramp up global
batch size from `start_global_batch_size` to `global_batch_size`.
"""
def __init__(
self,
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
decrease_batch_size_if_needed: bool,
rank: int,
start_global_batch_size: int,
batch_size_increment: int,
ramup_samples: int,
) -> None:
assert global_batch_size > 0, 'global batch size should be positive, got {}.'.format(
global_batch_size
)
assert start_global_batch_size > 0, 'start batch size should be positive, got {}.'.format(
start_global_batch_size
)
assert batch_size_increment > 0, 'batch size increment should be positive, got {}.'.format(
batch_size_increment
)
assert ramup_samples >= 0, 'ramp-up samples should be non-negative, got {}.'.format(
ramup_samples
)
self.global_batch_size = global_batch_size
self.micro_batch_size = micro_batch_size
self.data_parallel_size = data_parallel_size
self.decrease_batch_size_if_needed = decrease_batch_size_if_needed
self.rank = rank
self.start_global_batch_size = start_global_batch_size
self.batch_size_increment = batch_size_increment
self.ramup_samples = ramup_samples
self.micro_batch_times_data_parallel_size = self.micro_batch_size * self.data_parallel_size
assert self.micro_batch_times_data_parallel_size > 0
self.current_global_batch_size = None
diff_batch_size = self.global_batch_size - self.start_global_batch_size
assert diff_batch_size >= 0, (
'expected global batch size to be greater than or equal to start batch size, '
f'got {self.global_batch_size} and {self.start_global_batch_size}'
)
assert diff_batch_size % batch_size_increment == 0, (
'expected '
f'global batch size interval ({diff_batch_size}) to be divisible by global batch '
f'size increment ({batch_size_increment})'
)
num_increments = diff_batch_size // self.batch_size_increment
self.rampup_samples_per_increment = self.ramup_samples / num_increments
# Initialize number of microbatches.
self.update(0, consistency_check=False, verbose=True)
def update(self, consumed_samples: int, consistency_check: bool, verbose: bool = False) -> None:
"""Update number of microbatches.
Args:
consumed_samples (int): Number of samples consumed.
consistency_check (bool): Option to check current schedule's consistency.
verbose (bool, optional): Option to control logging. Defaults to False.
"""
# Update current global batch size.
global_batch_size_changed = False
old_current_global_batch_size = self.current_global_batch_size
if consumed_samples > self.ramup_samples:
self.current_global_batch_size = self.global_batch_size
else:
steps = int(consumed_samples / self.rampup_samples_per_increment)
self.current_global_batch_size = (
self.start_global_batch_size + steps * self.batch_size_increment
)
assert self.current_global_batch_size <= self.global_batch_size
if old_current_global_batch_size != self.current_global_batch_size:
global_batch_size_changed = True
if self.rank == 0 and global_batch_size_changed and verbose:
if old_current_global_batch_size is None:
logger.info(f'setting initial batch size to {self.current_global_batch_size}')
else:
logger.info(
f'ramping up batch size from {old_current_global_batch_size} to '
f'{self.current_global_batch_size}'
)
# Check consistency of the current global batch size.
if consistency_check and not self.decrease_batch_size_if_needed:
assert (
self.current_global_batch_size % self.micro_batch_times_data_parallel_size == 0
), (
'current global '
'batch size ({}) is not divisible by micro-batch-size ({}) times'
'data parallel size ({})'.format(
self.current_global_batch_size, self.micro_batch_size, self.data_parallel_size
)
)
if (
self.decrease_batch_size_if_needed
and self.current_global_batch_size % self.micro_batch_times_data_parallel_size != 0
):
self.current_running_global_batch_size = _round(
self.current_global_batch_size, self.micro_batch_times_data_parallel_size
)
if self.rank == 0 and global_batch_size_changed and verbose:
logger.info(
f'decreasing batch size from {self.current_global_batch_size} to '
f'{self.current_running_global_batch_size} to keep divisiblity by '
f'micro_batch_size={self.micro_batch_size} * '
f'data_parallel_size={self.data_parallel_size}'
)
assert (
self.current_running_global_batch_size % self.micro_batch_times_data_parallel_size
== 0
)
else:
self.current_running_global_batch_size = self.current_global_batch_size
self.num_micro_batches = (
self.current_running_global_batch_size // self.micro_batch_times_data_parallel_size
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
from typing import Callable, Dict, List, Optional, Tuple
import torch
try:
from transformer_engine.pytorch.optimizers import FusedAdam as Adam
from transformer_engine.pytorch.optimizers import FusedSGD as SGD
except ImportError:
try:
from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD
except ImportError:
import warnings
warnings.warn(
f'Transformer Engine and Apex are not installed. Falling back to Torch optimizers.'
)
# Apex's FusedAdam is a drop-in replacement for torch's AdamW.
# pylint: disable-next=line-too-long.
# See https://github.com/NVIDIA/apex/blob/7b73b12361068a10b0f44844534613f252a5ea75/apex/optimizers/fused_adam.py#L16.
from torch.optim import AdamW as Adam, SGD
from megatron.core import mpu
from ..distributed.param_and_grad_buffer import _ParamAndGradBuffer
from ..transformer.module import MegatronModule
from ..utils import log_single_rank
from .distrib_optimizer import DistributedOptimizer
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import (
ChainedOptimizer,
Float16OptimizerWithFloat16Params,
FP32Optimizer,
MegatronOptimizer,
)
from .optimizer_config import OptimizerConfig
logger = logging.getLogger(__name__)
def _get_param_groups(
model_chunks: List[MegatronModule],
no_weight_decay_cond: Optional[Callable],
scale_lr_cond: Optional[Callable],
lr_mult: float,
lr: float,
min_lr: float,
decoupled_lr: Optional[float],
decoupled_min_lr: Optional[float],
) -> List[Dict]:
"""Create parameter groups for optimizer.
Creates parameter groups based on weight decay condition (regularized vs
non regularized), learning rate scale condition (lr vs lr_mult * lr),
and whether it is expert parameters. scale_lr_cond is used during finetuning
where head of the network requires a scaled version of the base learning rate.
Args:
model_chunks (List[MegatronModule]): model chunks to create parameter
groups for.
no_weight_decay_cond (func, optional): function to determine whether a
parameter should not perform weight decay.
scale_lr_cond (func, optional): function to determine whether a parameter
should have a scaled learning rate.
lr_mult (float): learning rate multiplier for parameters that
satisfy scale_lr_cond.
lr (float): learning rate.
min_lr (float): minimum learning rate.
decoupled_lr (Optional[float]): optional decoupled learning rate.
decoupled_min_lr (Optional[float]): optional decoupled minimum learning rate.
Returns:
List of parameter groups.
"""
use_decoupled_learning_rate = decoupled_lr is not None
# Map (wd_mult, lr_mult, is_expert_parallel, is_decoupled_lr) to params.
params_map = {}
for model_chunk in model_chunks:
for name, param in model_chunk.named_parameters():
if not param.requires_grad:
continue
is_expert_parallel = not getattr(param, 'allreduce', True)
if no_weight_decay_cond is not None:
no_wd = no_weight_decay_cond(name, param)
else:
# Do not regularize biases and norm parameters.
no_wd = name.endswith(".bias") or len(param.shape) == 1
if scale_lr_cond is not None:
scale_lr = scale_lr_cond(name, param)
else:
scale_lr = False
if not no_wd and not scale_lr:
wd_mult, _lr_mult = 1.0, 1.0
elif not no_wd and scale_lr:
wd_mult, _lr_mult = 1.0, lr_mult
elif no_wd and not scale_lr:
wd_mult, _lr_mult = 0.0, 1.0
else:
wd_mult, _lr_mult = 0.0, lr_mult
is_decoupled_lr = False
# For input/embedding and output layer: embedding.word_embeddings.weight /
# output_layer.weight.
if use_decoupled_learning_rate and getattr(
param, 'is_embedding_or_output_parameter', False
):
is_decoupled_lr = True
key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr)
if key not in params_map:
params_map[key] = []
params_map[key].append(param)
param_groups = []
for (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr), params in params_map.items():
assert len(params) > 0
param_group = {
'params': params,
'wd_mult': wd_mult,
'lr_mult': _lr_mult,
'is_expert_parallel': is_expert_parallel,
'is_decoupled_lr': is_decoupled_lr,
}
param_groups.append(param_group)
param_groups = _update_min_and_max_lr_in_param_groups(
param_groups,
lr=lr,
min_lr=min_lr,
decoupled_lr=decoupled_lr,
decoupled_min_lr=decoupled_min_lr,
)
return param_groups
def _update_min_and_max_lr_in_param_groups(
param_groups: List[Dict],
lr: float,
min_lr: float,
decoupled_lr: Optional[float],
decoupled_min_lr: Optional[float],
) -> List[Dict]:
"""
Updates `max_lr` and `min_lr` values in each parameter group, and returns new list.
By default, each group will use `lr` / `min_lr` as `max_lr` / `min_lr`.
If `decoupled_lr` is provided, then `decoupled_lr` / `decoupled_min_lr` will be used
as `max_lr` / `min_lr` for the input and output layer.
Args:
param_groups (List): parameter groups whose 'max_lr' and `min_lr` fields need to
be adjusted.
lr (float): learning rate.
min_lr (float): minimum learning rate.
decoupled_lr (Optional[float]): optional decoupled learning rate.
decoupled_min_lr (Optional[float]): optional decoupled minimum learning rate.
Returns:
List of adjusted parameter groups.
"""
if decoupled_min_lr is None:
decoupled_min_lr = min_lr
for param_group in param_groups:
if param_group['is_decoupled_lr']:
assert decoupled_lr is not None
param_group['max_lr'] = decoupled_lr
param_group['min_lr'] = decoupled_min_lr
else:
param_group['max_lr'] = lr
param_group['min_lr'] = min_lr
return param_groups
def _get_param_groups_and_buffers(
model_chunks: List[MegatronModule],
model_chunk_offset: int,
config: OptimizerConfig,
no_weight_decay_cond: Optional[Callable],
scale_lr_cond: Optional[Callable],
lr_mult: float,
filter_fn: Callable,
buffer_name: str,
) -> Tuple[List[Dict], Dict[int, List[_ParamAndGradBuffer]]]:
"""Returns parameter groups and buffer for optimizer.
Args:
model_chunks (List[MegatronModule]): model chunks to create parameter
groups for.
model_chunk_offset (int): offset of model_chunks in global model_chunks list.
config (OptimizerConfig): optimizer configuration object.
no_weight_decay_cond (func, optional): function to determine whether a
parameter should not perform weight decay.
scale_lr_cond (func, optional): function to determine whether a parameter
should have a scaled learning rate.
lr_mult (float): learning rate multiplier for parameters that
satisfy scale_lr_cond.
lr (float): learning rate.
min_lr (float): minimum learning rate.
filter_fn (callable): filtering function for param_groups.
buffer_name (str): name of buffer.
Returns:
List of parameter groups and dictionary of model chunk IDs to buffers.
"""
param_groups = _get_param_groups(
model_chunks,
no_weight_decay_cond,
scale_lr_cond,
lr_mult,
lr=config.lr,
min_lr=config.min_lr,
decoupled_lr=config.decoupled_lr,
decoupled_min_lr=config.decoupled_min_lr,
)
param_groups = list(filter(filter_fn, param_groups))
buffers = {}
for model_chunk_idx, model_chunk in enumerate(model_chunks):
if hasattr(model_chunk, buffer_name):
buffers[model_chunk_idx + model_chunk_offset] = getattr(model_chunk, buffer_name)
return param_groups, buffers
def _get_megatron_optimizer_based_on_param_groups(
config: OptimizerConfig,
model_chunks: List[MegatronModule],
param_groups: List,
per_model_buffers: Optional[Dict[int, List[_ParamAndGradBuffer]]] = None,
model_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
data_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
data_parallel_group_gloo: Optional[torch.distributed.ProcessGroup] = None,
data_parallel_group_idx: Optional[int] = None,
distributed_optimizer_instance_id: Optional[int] = 0,
) -> MegatronOptimizer:
"""Get Megatron optimizer based on parameter groups.
Args:
config (OptimizerConfig): optimizer configuration object.
model_chunks (list): list of model chunks.
param_groups (list): list of parameter groups.
per_model_buffers (dict, optional): buffers for distributed optimizer. Defaults to None.
data_parallel_group (torch.distributed.ProcessGroup, optional): data-parallel group for
distributed optimizer. Defaults to None.
data_parallel_group_gloo (torch.distributed.ProcessGroup, optional): gloo data-parallel
group for distributed optimizer. Defaults to None.
data_parallel_group_idx (int, optional): data-parallel group index for distributed
optimizer. Defaults to None.
distributed_optimizer_instance_id (int, optional): Distributed optimizer instance. Defaults
0.
Returns:
Instance of MegatronOptimizer.
"""
# when freezing sub-models we may have no trainable parameters on a rank and
# hence an empty param_groups. However, we still need to create an optimizer
# for the purposes of grad stats reductions
if param_groups:
if config.optimizer == 'adam':
kwargs = {
"params": param_groups,
"lr": config.lr,
"weight_decay": config.weight_decay,
"betas": (config.adam_beta1, config.adam_beta2),
"eps": config.adam_eps,
}
if config.use_precision_aware_optimizer:
kwargs.update(
{
"master_weights": True,
"use_decoupled_grad": True,
"master_weight_dtype": config.main_params_dtype,
"exp_avg_dtype": config.exp_avg_dtype,
"exp_avg_sq_dtype": config.exp_avg_sq_dtype,
}
)
optimizer = Adam(**kwargs)
def init_state_fn(opt, config=None):
for group in opt.param_groups:
for p in group['params']:
if len(opt.state[p]) == 0:
if config is None or not config.use_precision_aware_optimizer:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)
else:
opt.initialize_state(p)
elif config.optimizer == 'sgd':
optimizer = SGD(
param_groups,
lr=config.lr,
weight_decay=config.weight_decay,
momentum=config.sgd_momentum,
)
init_state_fn = None
else:
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
else:
optimizer = None
init_state_fn = None
# Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
# from the MixedPrecisionOptimizer, which manages any optimizer where
# the model params and main params are distinct.
if config.fp16 or config.bf16 or config.use_distributed_optimizer:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler = None
# Constant loss scale.
if config.loss_scale:
grad_scaler = ConstantGradScaler(config.loss_scale)
# Dynamic loss scale.
else:
if config.fp16:
grad_scaler = DynamicGradScaler(
initial_scale=config.initial_loss_scale,
min_scale=config.min_loss_scale,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=config.loss_scale_window,
hysteresis=config.hysteresis,
)
optimizer_args = [optimizer, config, grad_scaler, init_state_fn]
if config.use_distributed_optimizer:
optimizer = DistributedOptimizer(
*optimizer_args,
model_chunks=model_chunks,
per_model_buffers=per_model_buffers,
data_parallel_group=data_parallel_group,
data_parallel_group_gloo=data_parallel_group_gloo,
data_parallel_group_idx=data_parallel_group_idx,
distributed_optimizer_instance_id=distributed_optimizer_instance_id,
)
else:
optimizer = Float16OptimizerWithFloat16Params(*optimizer_args)
setattr(optimizer, 'grad_stats_parallel_group', model_parallel_group)
else:
# FP32 optimizer.
optimizer = FP32Optimizer(optimizer, config, init_state_fn)
setattr(optimizer, 'grad_stats_parallel_group', model_parallel_group)
return optimizer
def get_megatron_optimizer(
config: OptimizerConfig,
model_chunks: List[MegatronModule],
no_weight_decay_cond: Optional[Callable] = None,
scale_lr_cond: Optional[Callable] = None,
lr_mult: float = 1.0,
) -> MegatronOptimizer:
"""Retrieve the Megatron optimizer for model chunks.
We use separate optimizers for expert parameters and non-expert parameters.
Args:
config (OptimizerConfig): optimizer configuration object.
model_chunks (List[MegatronModule]): model chunks to get optimizer for.
no_weight_decay_cond (func, optional): function to determine whether a parameter
should not perform weight decay. Defaults to None.
scale_lr_cond (func, optional): function to determine whether a parameter
should have a scaled learning rate. Defaults to None.
lr_mult (float, optional): learning rate multiplier for parameters that
satisfy scale_lr_cond. Defaults to 1.0.
Returns:
Instance of MegatronOptimizer.
"""
log_single_rank(logger, logging.INFO, f'Setting up optimizer with config {config}')
# Separate out first model chunk if overlapping param AG with optimizer step.
if config.overlap_param_gather_with_optimizer_step:
all_dense_model_chunks = [[model_chunks[0]], model_chunks[1:]]
overlap_param_gather_with_optimizer_step_flags = [True, False]
else:
all_dense_model_chunks = [model_chunks]
overlap_param_gather_with_optimizer_step_flags = [False]
model_parallel_rank = torch.distributed.get_rank(mpu.get_model_parallel_group())
if torch.distributed.get_world_size(
mpu.get_data_parallel_group(with_context_parallel=True, partial_data_parallel=False)
) > torch.distributed.get_world_size(
mpu.get_data_parallel_group(with_context_parallel=True, partial_data_parallel=True)
):
distributed_optimizer_instance_id = torch.distributed.get_rank(
mpu.get_inter_partial_data_parallel_group()
)
else:
distributed_optimizer_instance_id = 0
optimizers = []
model_chunk_offset = 0
for dense_model_chunks, overlap_param_gather_with_optimizer_step in zip(
all_dense_model_chunks, overlap_param_gather_with_optimizer_step_flags
):
param_groups, buffers = _get_param_groups_and_buffers(
dense_model_chunks,
model_chunk_offset=model_chunk_offset,
config=config,
no_weight_decay_cond=no_weight_decay_cond,
scale_lr_cond=scale_lr_cond,
lr_mult=lr_mult,
filter_fn=lambda g: not g['is_expert_parallel'],
buffer_name='buffers',
)
for model_chunk in dense_model_chunks:
model_chunk.overlap_param_gather_with_optimizer_step = (
overlap_param_gather_with_optimizer_step
)
optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config,
model_chunks=dense_model_chunks,
param_groups=param_groups,
per_model_buffers=buffers,
model_parallel_group=mpu.get_model_parallel_group(),
data_parallel_group=mpu.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=True
),
data_parallel_group_gloo=mpu.get_data_parallel_group_gloo(
with_context_parallel=True, partial_data_parallel=True
),
data_parallel_group_idx=model_parallel_rank,
distributed_optimizer_instance_id=distributed_optimizer_instance_id,
)
)
model_chunk_offset += 1
moe_param_groups, moe_buffers = _get_param_groups_and_buffers(
model_chunks,
model_chunk_offset=0,
config=config,
no_weight_decay_cond=no_weight_decay_cond,
scale_lr_cond=scale_lr_cond,
lr_mult=lr_mult,
filter_fn=lambda g: g['is_expert_parallel'],
buffer_name='expert_parallel_buffers',
)
if len(moe_param_groups) > 0:
model_parallel_rank = torch.distributed.get_rank(
mpu.get_expert_tensor_model_pipeline_parallel_group()
)
optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config,
model_chunks=model_chunks,
param_groups=moe_param_groups,
per_model_buffers=moe_buffers,
model_parallel_group=mpu.get_expert_tensor_model_pipeline_parallel_group(),
data_parallel_group=mpu.get_expert_data_parallel_group(),
data_parallel_group_gloo=mpu.get_expert_data_parallel_group_gloo(),
data_parallel_group_idx=model_parallel_rank,
)
)
if len(optimizers) == 1:
return optimizers[0]
return ChainedOptimizer(optimizers)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment