Commit bc5c7fa7 authored by wxj's avatar wxj
Browse files

第一次测试提交

parent 70fddd0f
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Retro's cross attention modules for the decoder block."""
from functools import partial
from typing import Callable
import numpy as np
import torch
from torch import Tensor
from megatron.core import InferenceParams
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.retro.base_attention import BaseRetroCrossAttention
from megatron.core.models.retro.config import RetroConfig
from megatron.core.models.retro.utils import get_all_true_mask
from megatron.core.transformer import ModuleSpec
from megatron.core.transformer.attention import CrossAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_block import TransformerBlock
class RetroDecoderCrossAttention(BaseRetroCrossAttention):
"""Retro decoder's chunked cross attention operator.
See this paper for more details: https://arxiv.org/abs/2112.04426.
Neighboring chunks retrieved from the chunk database are used here for
chunked-cross attention.
** Note about 'encoder_block_spec' **
Retro is an encoder-decoder model that uses its encoder for encoding
neighboring chunks that are retrieved from a chunk database. These
encoded neighbors are then used in the decoder stack for performing
chunked-cross attention (see paper link above).
In contrast to the T5 model, the encoder and decoder are computationally
intertwined, since the input to the encoder is the output of the self-
attention of the first decoder layer. As such, the encoder block itself
is instantiated within the first Retro decoder layer, in order to receive
the self-attention's output. (Note, that only the first decoder layer
instantiates an encoder block, and the remaining decoder layers use the
encoder output from the first decoder layer.)
Args:
config (RetroConfig): Retro config.
submodules (CrossAttentionSubmodules): Cross attention submodules.
layer_number (int): Layer number within transformer block.
attn_mask_type (AttnMaskType): Mask type ('causal' or 'padding').
encoder_block_spec (ModuleSpec): The first Retro decoder layer is provided with a transformer block spec to construct the neighbor encoder.
"""
def __init__(
self,
config: RetroConfig,
submodules: CrossAttentionSubmodules,
layer_number: int = 1,
attn_mask_type: AttnMaskType = AttnMaskType.padding,
encoder_block_spec: ModuleSpec = None,
):
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
)
if encoder_block_spec:
self.encoder = TransformerBlock(
config=config, spec=encoder_block_spec, pre_process=True, post_process=False,
)
# self._encoder_key = 'encoder' # ... necessary?
else:
self.encoder = None
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
key_value_states: Tensor = None,
inference_params: InferenceParams = None,
# rotary_pos_emb: Tensor = None, # ... unsupported for retro.
) -> dict:
"""Cross attention for Retro decoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
m : Number of tokens per chunk.
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
Args:
hidden_states (Tensor): Transformer layer hidden states.
attention_mask (Tensor): Attention mask.
key_value_states (Tensor): Neighbor embeddings if first decoder layer, else encoder output.
inference_params (InferenceParams): Inference params.
Returns:
A dict consisting of the attention output and context, along with other scalars necessary for performing the downstream bias-dropout-add.
"""
# hidden_states: [ ns, bs, d ]
# key_value_states: [ r, k*bs*l, d ]
ns, bs, d = hidden_states.shape
l = int(np.ceil(ns / self.retro_chunk_length))
# Retrieve neighbors.
if self.encoder:
# Sequence length remainder.
first_ns = ns % self.retro_chunk_length
# Case 1: Sequence length not divisible by chunk length.
if first_ns > 0:
# Split sequence into first partial chunk & remaining chunks.
first_chunk, rest_chunk = hidden_states[:first_ns], hidden_states[first_ns:]
# Pad partial chunk with zeros.
first_chunk = torch.nn.functional.pad(
first_chunk, (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns), 'constant', 0,
)
# Concatenate padded chunk with remaining chunks.
chunked_output = torch.cat((first_chunk, rest_chunk), dim=0) # [ l*m, bs, d ]
# Case 2: Sequence length is divisible by chunk length.
else:
chunked_output = hidden_states # [ l*m, bs, d ]
# Chunk & permute hidden states.
# - hidden_states: [ l*m, bs, d ]
# - chunked_output: [ m, bs*l, d ]
chunked_output = (
chunked_output.reshape(l, self.retro_chunk_length, bs, d)
.permute(1, 2, 0, 3)
.reshape(self.retro_chunk_length, bs * l, d)
.contiguous()
)
# flash attn: [ b, h, sq, sk ]
# fused attn: [ b, 1, 1, sq ]
chunked_output_mask = get_all_true_mask(
size=(1, 1, chunked_output.shape[0], key_value_states.shape[0]),
device=chunked_output.device,
)
# Encode neighbors. (Note: 'key_value_states' re-assigned here.)
key_value_states = self.encoder(
hidden_states=key_value_states,
attention_mask=attention_mask,
context=chunked_output,
context_mask=chunked_output_mask,
inference_params=inference_params,
) # [ r, k*bs*l, d ]
key_value_states = key_value_states.reshape(
self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d
) # [ r*k, bs*l, d ]
# Attend starting at last token of first chunk.
pad = (ns - 1) % self.retro_chunk_length
attending_chunks = hidden_states[pad:]
# Pad attending tokens to sequence length.
padded_chunks = torch.nn.functional.pad(
attending_chunks, (0, 0, 0, 0, 0, self.retro_chunk_length - 1), 'constant', 0,
)
# Permute attending chunks.
# - padded_chunks: [ l*m, bs, d ]
# - padded_chunked_output: [ m, bs*l, d ] (matches 'chunked_output' above)
padded_chunked_output = padded_chunks.reshape(l, self.retro_chunk_length, bs, d).permute(
1, 2, 0, 3
)
padded_chunked_output = padded_chunked_output.reshape(
self.retro_chunk_length, bs * l, d
).contiguous()
# flash attn: [ b, h, sq, sk ]
# fused attn: [ b, 1, 1, sq ]
padded_chunked_output_mask = get_all_true_mask(
size=(1, 1, padded_chunked_output.shape[0], key_value_states.shape[0]),
device=padded_chunked_output.device,
)
# Attend to encoded neighbors.
attention_output, attention_bias = self.attn(
hidden_states=padded_chunked_output,
attention_mask=padded_chunked_output_mask,
key_value_states=key_value_states,
)
# Return dimensions for bias-dropout step.
return {
"ns": ns,
"bs": bs,
"d": d,
"l": l,
"pad": pad,
"attention_output": attention_output, # [ m, bs*l, d ]
"attention_bias": attention_bias, # [ d ]
"context": key_value_states, # [ r*k, bs*l, d ]
}
class RetroDecoderBiasDropoutAdd(MegatronModule):
"""Retro decoder's bias-dropout-add operator.
This operator takes care of reshaping and permuting the output from the
chunk dimension to the sequence dimension.
Args:
config (RetroConfig): Retro config.
"""
def __init__(
self, config: RetroConfig,
):
super().__init__(config=config)
self.retro_chunk_length = config.retro_chunk_length
@classmethod
def _forward(
cls,
x_with_bias: dict,
residual: Tensor,
prob: float,
retro_chunk_length: int,
bias_dropout_add: Callable,
) -> Tensor:
"""Per-chunk bias-dropout-add.
Args:
x_with_bias (dict): Attention output and bias, along with other Retro relevant parameters.
residual (Tensor): Transformer layer residual.
prob (float): Dropout probability.
retro_chunk_length (int): Retro chunk length (e.g., 64).
bias_dropout_add (Callable): Bias-dropout-add function.
Returns:
Output of bias-dropout-add.
"""
# Extract input dict.
ns = x_with_bias["ns"]
bs = x_with_bias["bs"]
d = x_with_bias["d"]
l = x_with_bias["l"]
pad = x_with_bias["pad"]
attention_output = x_with_bias["attention_output"] # [ m, bs*l, d ]
attention_bias = x_with_bias["attention_bias"] # [ d ]
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
# Bias-dropout-add.
x = bias_dropout_add(
(
attention_output,
None if attention_bias is None else attention_bias.expand_as(attention_output),
),
torch.zeros_like(attention_output),
prob,
)
# Permute chunks back to sequence dimension.
# 1. [ m, bs*l, d ]
# 2. [ m, bs, l, d ]
# 3. [ l, m, bs, d ]
# 4. [ m*l, bs, d ] == [ ns, bs, d ]
x = (
x.reshape(retro_chunk_length, bs, l, d)
.permute(2, 0, 1, 3)
.reshape(retro_chunk_length * l, bs, d)
)
# Prepend zeros for non-attending tokens.
x = torch.nn.functional.pad(x, (0, 0, 0, 0, pad, 0), 'constant', 0,)[
:ns
] # [ ns, bs, d ]
# Add residual. [ ns, bs, d ]
x = x + residual
# Output. [ ns, bs, d ]
return x
def forward(self, training: bool, fused: bool) -> partial:
"""Retro decoder bias-dropout-add.
Args:
training (bool): If training, then apply dropout.
fused (bool): Fuse bias-dropout-add.
Returns:
The partial function for performing bias-dropout-add.
"""
return partial(
self._forward,
retro_chunk_length=self.retro_chunk_length,
bias_dropout_add=get_bias_dropout_add(training, fused),
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Specs for Retro decoder."""
import typing
from megatron.core import parallel_state
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
from megatron.core.models.retro.config import RetroConfig
from megatron.core.models.retro.decoder_attention import (
RetroDecoderBiasDropoutAdd,
RetroDecoderCrossAttention,
)
from megatron.core.models.retro.encoder_spec import get_retro_encoder_block_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer import ModuleSpec
from megatron.core.transformer.attention import CrossAttentionSubmodules
from megatron.core.transformer.custom_layers.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TENorm,
TERowParallelLinear,
)
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.transformer_block import (
TransformerBlockSubmodules,
get_num_layers_to_build,
)
def get_retro_decoder_layer_te_spec(
encoder_block_spec: typing.Union[ModuleSpec, TransformerBlockSubmodules, None] = None
) -> ModuleSpec:
"""Retro decoder TE spec (uses Transformer Engine components).
A Retro decoder layer uses custom attention and bias-dropout-add operators
to perform chunked-cross attention. Additionally, the first Retro decoder
layer instantiates an entire encoder transformer block. As such, the decoder
cross attention module takes an optional encoder block spec, which is only
provided for the first Retro decoder layer.
Args:
encoder_block_spec (ModuleSpec): Retro encoder block spec, to be provided for the first Retro decoder layer.
Returns:
A module spec with Transformer Engine modules.
"""
spec = get_gpt_layer_with_transformer_engine_spec()
spec.submodules.pre_cross_attn_layernorm = TENorm
spec.submodules.cross_attention = ModuleSpec(
module=RetroDecoderCrossAttention,
params={"encoder_block_spec": encoder_block_spec,},
submodules=CrossAttentionSubmodules(
linear_q=TEColumnParallelLinear,
linear_kv=TEColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
)
spec.submodules.cross_attn_bda = ModuleSpec(module=RetroDecoderBiasDropoutAdd)
return spec
def get_retro_decoder_layer_local_spec(
encoder_block_spec: typing.Optional[ModuleSpec] = None,
) -> ModuleSpec:
"""Retro decoder local spec (uses Megatron-Core components).
A Retro decoder layer uses custom attention and bias-dropout-add operators
to perform chunked-cross attention. Additionally, the first Retro decoder
layer instantiates an entire encoder transformer block. As such, the decoder
cross attention module takes an optional encoder block spec, which is only
provided for the first Retro decoder layer.
Args:
encoder_block_spec (ModuleSpec): Retro encoder block spec, to be provided for the first Retro decoder layer.
Returns:
A module spec with local modules.
"""
spec = get_gpt_layer_local_spec()
spec.submodules.pre_cross_attn_layernorm = FusedLayerNorm
spec.submodules.cross_attention = ModuleSpec(
module=RetroDecoderCrossAttention,
params={"encoder_block_spec": encoder_block_spec,},
submodules=CrossAttentionSubmodules(
linear_q=ColumnParallelLinear,
linear_kv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
),
)
spec.submodules.cross_attn_bda = ModuleSpec(module=RetroDecoderBiasDropoutAdd)
return spec
def get_retro_decoder_block_spec(
config: RetroConfig, use_transformer_engine: bool
) -> TransformerBlockSubmodules:
"""Retro decoder block spec.
Retro decoder block implementation details:
- The retro decoder block consists of interleaved GPT layers and customized Retro decoder layers.
- The Retro decoder layers are spaced three layers apart, and start on layer 6 or 9 (depending on the total number of layers).
- The first decoder layer instantiates an encoder block, and it therefore passes in an encoder_block_spec.
Args:
config (RetroConfig): Retro config.
use_transformer_engine (bool): If True, use Transformer Engine (instead of local modules.
Returns:
Transformer block submodules for the given spec.
"""
# Num layers.
assert (
parallel_state.get_pipeline_model_parallel_world_size() == 1
), "retro does not currently support pipeline parallelism."
assert (
parallel_state.get_virtual_pipeline_model_parallel_world_size() is None
), "retro does not currently support virtual pipeline parallelism."
num_layers = get_num_layers_to_build(config)
# Retro layer numbers.
retro_layer_start = 6 if num_layers <= 15 else 9
retro_layer_numbers = list(range(retro_layer_start, num_layers + 1, 3))
# Layer specs.
gpt_layer_spec = (
get_gpt_layer_with_transformer_engine_spec()
if use_transformer_engine
else get_gpt_layer_local_spec()
)
get_retro_decoder_layer_spec = (
get_retro_decoder_layer_te_spec
if use_transformer_engine
else get_retro_decoder_layer_local_spec
)
retro_layer_spec = get_retro_decoder_layer_spec()
retro_layer_spec_with_retriever = get_retro_decoder_layer_spec(
get_retro_encoder_block_spec(config, use_transformer_engine)
)
layer_specs = []
for layer_number in range(1, num_layers + 1):
if layer_number == retro_layer_numbers[0]:
layer_specs.append(retro_layer_spec_with_retriever)
elif layer_number in retro_layer_numbers:
layer_specs.append(retro_layer_spec)
else:
layer_specs.append(gpt_layer_spec)
# Block spec.
block_spec = TransformerBlockSubmodules(layer_specs=layer_specs)
return block_spec
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Retro's cross attention modules for the encoder block."""
from functools import partial
from typing import Callable, List, Optional, Tuple, Type
import torch
from torch import Tensor
from megatron.core import InferenceParams
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.retro.base_attention import BaseRetroCrossAttention
from megatron.core.models.retro.config import RetroConfig
from megatron.core.models.retro.utils import get_all_true_mask
from megatron.core.transformer.module import MegatronModule
class RetroEncoderCrossAttention(BaseRetroCrossAttention):
"""Retro encoder's cross attention operator.
See this paper for more details: https://arxiv.org/abs/2112.04426.
Neighboring chunks are retrieved from the chunk database, encoded, and
used by the decoder layers for chunked cross attention.
Args:
config (RetroConfig): Retro config.
submodules (CrossAttentionSubmodules): Cross attention submodules.
layer_number (int): Layer number within transformer block.
attn_mask_type (AttnMaskType): Mask type ('causal' or 'padding').
"""
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
key_value_states: Tensor = None,
inference_params: InferenceParams = None,
# rotary_pos_emb: Tensor = None, # unsupported for retro.
) -> List[Tuple[Tensor, Optional[Tensor], Tensor]]:
"""Cross attention for Retro encoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
Args:
hidden_states (Tensor): Transformer layer hidden states.
attention_mask (Tensor): Attention mask.
key_value_states (Tensor): Neighbor embeddings.
inference_params (InferenceParams): Inference params.
Returns:
List of tuples, where each tuple is (attention_output, attention_bias, residual).
"""
# Input shape. [ r, bs*l*k, d ]
ns, bs, d = hidden_states.shape
# Reshape sequence into neighboring chunks.
# - hidden_states: [ r, bs*l*k, d ]
# - chunked_outputs: [ r, bs*l, k, d ]
chunked_outputs = hidden_states.reshape(
self.retro_retrieved_length, -1, self.retro_num_neighbors, d
)
# flash attn: [ b, h, sq, sk ]
# fused attn: [ b, 1, 1, sq ]
chunked_output_mask = get_all_true_mask(
size=(1, 1, chunked_outputs.shape[0], key_value_states.shape[0]),
device=chunked_outputs.device,
)
# Per-chunk attention.
attention_output_tuples = []
for k in range(self.retro_num_neighbors):
# Attend to current neighboring chunks.
# - chunked_output: [ r, bs*l, d ]
# - key_value_states: [ m, bs*l, d ]
# - attention_output: [ r, bs*l, d ]
# - attention_bias: [ d ]
chunked_output = chunked_outputs[:, :, k].contiguous()
attention_output, attention_bias = self.attn(
hidden_states=chunked_output, # Q (neighbor embedding)
attention_mask=chunked_output_mask,
key_value_states=key_value_states, # K, V (hidden act)
)
# Residual connection. [ r, bs*l, d ]
residual = chunked_output
# Collect tensors.
attention_output_tuples.append((attention_output, attention_bias, residual,))
# Output. (List[Tuple[( [ r, bs*l, d ], [ d ] )]])
return attention_output_tuples
class RetroEncoderBiasDropoutAdd(MegatronModule):
"""Retro encoder's bias-dropout-add operator.
This operator applies bias-dropout-add individually on each neighboring
chunk that is retrieved from the chunk database.
Args:
config (RetroConfig): Retro config.
"""
def __init__(
self, config: RetroConfig,
):
super().__init__(config=config)
self.retro_num_neighbors = config.retro_num_neighbors
@classmethod
def _forward(
cls,
x_with_bias: List[Tuple[Tensor, Optional[Tensor], Tensor]],
residual: Tensor,
prob: float,
retro_num_neighbors: int,
bias_dropout_add: Callable,
) -> Tensor:
"""Per-chunk bias-dropout-add.
Args:
x_with_bias (dict): Attention output and bias tuple.
residual (Tensor): Transformer layer residual.
prob (float): Dropout probability.
retro_num_neighbors (int): Number of retrieved neighbor chunks (e.g., 2).
bias_dropout_add (Callable): Bias-dropout-add function.
Returns:
Output of bias-dropout-add.
"""
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
# Per-neighbor bias-dropout-add.
# - attention_output: [ r, bs*l, d ]
# - attention_bias: [ d ]
# - residual: [ r, bs*l, d ]
# - output: [ r, bs*l, d ]
outputs = [
bias_dropout_add(
(
attention_output,
None if attention_bias is None else attention_bias.expand_as(residual),
),
residual,
prob,
)
for attention_output, attention_bias, residual in x_with_bias
]
# Concatenate outputs (to shape [r, k*bs*l, d]; see notation above).
r, _, d = outputs[0].shape
output = torch.stack(outputs, dim=1).reshape(r, -1, d)
# Output. [ r, k*bs*l, d ]
return output
def forward(self, training: bool, fused: bool) -> partial:
"""Retro decoder bias-dropout-add.
Args:
training (bool): If training, then apply dropout.
fused (bool): Fuse bias-dropout-add.
Returns:
A partial function for performing bias-dropout-add.
"""
return partial(
self._forward,
retro_num_neighbors=self.retro_num_neighbors,
bias_dropout_add=get_bias_dropout_add(training, fused),
)
class RetroEncoderLayerNorm(MegatronModule):
"""Retro encoder's layernorm operator.
This operator applies layernorm individually on each neighboring chunk that
is retrieved from the chunk database, and then concatenates the chunks into
a single tensor.
Args:
config (RetroConfig): Retro config.
submodules (Type): Layer norm class. (Named 'submodules' to fit external interface.)
"""
def __init__(
self, config: RetroConfig, submodules: Type, **kwargs: dict,
):
super().__init__(config=config)
norm_class = submodules
self.norm = norm_class(config=config, **kwargs)
self.retro_num_neighbors = config.retro_num_neighbors
def forward(self, input: Tensor) -> Tensor:
"""Per-chunk layer norm.
Args:
input (Tensor): Input chunks, concatenated into a single tensor.
Returns:
Output of the layer norm.
"""
# Input shape: [ r, k*bs*l, d ]. (see notation above in attention module)
# Split input into 'num_neighbors' tensors.
chunk_size = input.shape[1] // self.retro_num_neighbors
inputs = torch.split(input, chunk_size, dim=1)
# Norm.
outputs = [self.norm(inp.contiguous()) for inp in inputs]
# Concatenate layer norms (to shape [r, k*bs*l, d]; see notation above).
r, _, d = inputs[0].shape
output = torch.stack(outputs, dim=1).reshape(r, -1, d)
# Output. [ r, k*bs*l, d ]
return output
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Specs for Retro encoder."""
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
from megatron.core.models.retro.config import RetroConfig
from megatron.core.models.retro.encoder_attention import (
RetroEncoderBiasDropoutAdd,
RetroEncoderCrossAttention,
RetroEncoderLayerNorm,
)
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer import ModuleSpec
from megatron.core.transformer.attention import CrossAttentionSubmodules
from megatron.core.transformer.custom_layers.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TENorm,
TERowParallelLinear,
)
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
def get_retro_encoder_layer_te_spec() -> ModuleSpec:
"""Retro encoder TE spec (uses Transformer Engine components).
A Retro encoder layer uses custom attention, bias-dropout-add, and layernorm
operators to encode neighboring chunks that are retrieved from the chunk
database. Each operator is responsible for iterating the retrieved chunks
and processing them individually.
Returns:
A module spec if Transformer Engine modules.
"""
spec = get_gpt_layer_with_transformer_engine_spec()
spec.submodules.pre_cross_attn_layernorm = TENorm
spec.submodules.cross_attention = ModuleSpec(
module=RetroEncoderCrossAttention,
params={"attn_mask_type": AttnMaskType.padding,},
submodules=CrossAttentionSubmodules(
linear_q=TEColumnParallelLinear,
linear_kv=TEColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
)
spec.submodules.cross_attn_bda = ModuleSpec(module=RetroEncoderBiasDropoutAdd)
spec.submodules.pre_mlp_layernorm = ModuleSpec(module=RetroEncoderLayerNorm, submodules=TENorm,)
spec.submodules.mlp = ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear,
),
)
return spec
def get_retro_encoder_layer_local_spec() -> ModuleSpec:
"""Retro encoder local spec (uses Megatron-Core components).
A Retro encoder layer uses custom attention, bias-dropout-add, and layernorm
operators to encode neighboring chunks that are retrieved from the chunk
database. Each operator is responsible for iterating the retrieved chunks
and processing them individually.
Returns:
A module spec if local modules.
"""
spec = get_gpt_layer_local_spec()
spec.submodules.pre_cross_attn_layernorm = FusedLayerNorm
spec.submodules.cross_attention = ModuleSpec(
module=RetroEncoderCrossAttention,
params={"attn_mask_type": AttnMaskType.padding,},
submodules=CrossAttentionSubmodules(
linear_q=ColumnParallelLinear,
linear_kv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
),
)
spec.submodules.cross_attn_bda = ModuleSpec(module=RetroEncoderBiasDropoutAdd)
spec.submodules.pre_mlp_layernorm = ModuleSpec(
module=RetroEncoderLayerNorm, submodules=FusedLayerNorm,
)
spec.submodules.mlp = ModuleSpec(
module=MLP,
submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear,),
)
spec.submodules.sharded_state_dict_keys_map = {
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
} # pre_mlp_layernorm doesn't need remapping
return spec
def get_retro_encoder_block_spec(
config: RetroConfig, use_transformer_engine: bool
) -> TransformerBlockSubmodules:
"""Retro encoder block spec.
The retro encoder block consists of one customized Retro encoder layer
(layer 1), and all of the following layers are standard GPT layers.
Args:
config (RetroConfig): Retro config.
use_transformer_engine (bool): If True, use Transformer Engine (instead of local modules).
Returns:
Transformer block submodules for the given spec.
"""
# Num layers.
num_layers = config.retro_encoder_num_layers
retro_layer_numbers = [1]
# Layer specs.
gpt_layer_spec = (
get_gpt_layer_with_transformer_engine_spec()
if use_transformer_engine
else get_gpt_layer_local_spec()
)
get_retro_encoder_layer_spec = (
get_retro_encoder_layer_te_spec
if use_transformer_engine
else get_retro_encoder_layer_local_spec
)
retro_layer_spec = get_retro_encoder_layer_spec()
for spec in (gpt_layer_spec, retro_layer_spec):
spec.params["hidden_dropout"] = config.retro_encoder_hidden_dropout
spec.submodules.self_attention.params["attn_mask_type"] = AttnMaskType.padding
spec.submodules.self_attention.submodules.core_attention = ModuleSpec(
module=TEDotProductAttention if use_transformer_engine else DotProductAttention,
params={"attention_dropout": config.retro_encoder_attention_dropout,},
)
layer_specs = []
for layer_number in range(1, num_layers + 1):
if layer_number in retro_layer_numbers:
layer_specs.append(retro_layer_spec)
else:
layer_specs.append(gpt_layer_spec)
# Block spec.
block_spec = TransformerBlockSubmodules(layer_specs=layer_specs)
return block_spec
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Retro Model."""
from typing import Dict, Optional
from torch import Tensor
from megatron.core import InferenceParams
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.gpt import GPTModel
class RetroModel(GPTModel):
"""Retro Model.
A Retro model mostly re-uses the GPTModel interface, with the only difference
being the embedding of the 'context' this is used by Retro for processing
neighbor tokens. This embedded context is then forwarded to the Transformer
Block.
"""
def forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
context_input_ids: Tensor = None,
context_position_ids: Tensor = None,
context_mask: Tensor = None,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
) -> Tensor:
"""RetroModel forward method.
Foward input tokens & mask, along with neighbor tokens & mask, through
the Retro model..
Args:
input_ids (Tensor): Input token IDs.
position_ids (Tensor): Input position IDs.
attention_mask (Tensor): Input attention mask.
context_input_ids (Tensor): Context (i.e., neighbor) token IDs.
context_position_ids (Tensor): Context (i.e., neighbor) position IDs.
context_mask (Tensor): Context (i.e., neighbor) attention mask.
decoder_input (Tensor): When using pipeline parallelism, input_ids and position_ids will only be used on the first stage, and for all other stages decoder_input will be provided via communication from the previous stage.
labels (Tensor): The labels of dimension [batch size, seq length].
inference_params (InferenceParams): Parameters for inference.
Returns:
Output tensor of forward pass.
"""
# Argument shapes:
# Notation:
# ns : Sequence length.
# bs : Batch size.
# d : Hidden size.
# l : Number of chunks per sample (i.e., seq_length/chunk_length).
# k : Number of neighbors.
# r : Number of retrieved tokens (neighbors + continuation).
# - input_ids: [ bs, ns ]
# - context_ids: [ k*bs*l, r ]
# - context: [ r, k*bs*l, d ]
# - output: [ ns, bs, d ]
# Context embedding (e.g., for Retro neighbor tokens).
if context_input_ids is not None:
context = self.embedding(context_input_ids, context_position_ids)
else:
context = None
# Call GPTModel.forward, and pass in embedded context.
return super().forward(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
decoder_input=decoder_input,
labels=labels,
inference_params=inference_params,
extra_block_kwargs={"context": context, "context_mask": context_mask,},
)
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None
) -> ShardedStateDict:
"""Get sharded state dict.
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): Offsets of local shard within global tensor.
metadata (Optional[Dict]): Shard metadata.
Returns:
A <ShardedStateDict> ?
"""
metadata = metadata or {}
metadata['non_homogeneous_layers'] = True
return super().sharded_state_dict(prefix, sharded_offsets, metadata)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import os
import torch
def get_config_path(project_dir: str) -> str:
"""Config copy stored within retro project dir."""
return os.path.join(project_dir, "config.json")
def get_gpt_data_dir(project_dir: str) -> str:
"""Get project-relative directory of GPT bin/idx datasets."""
return os.path.join(project_dir, "data")
# ** Note ** : Retro's compatibility between cross attention and Flash/Fused
# Attention is currently a work in progress. We default to returning None for
# now.
# def get_all_true_mask(size, device):
# return torch.full(size=size, fill_value=True, dtype=torch.bool, device=device)
def get_all_true_mask(size, device):
return None
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Optional
import torch
from megatron.core import tensor_parallel
from megatron.core.models.common.vision_module.vision_module import VisionModule
from megatron.core.transformer.custom_layers.transformer_engine import TENorm
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
# Note: This is under development and is missing features like position embedding interpolation.
class CLIPViTModel(VisionModule):
"""CLIP ViT vision model.
Args:
transformer_config (TransformerConfig): Transformer config
transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers
patch_dim (int): Image patch size.
img_h (int): Input image height.
img_w (int): Input image width.
add_class_token (bool, optional): Include a class token. Defaults to True.
class_token_len (int): Class token length. Defaults to 1 but 8 may be faster.
"""
def __init__(
self,
transformer_config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
patch_dim: int = 14,
img_h: int = 336,
img_w: int = 336,
add_class_token: bool = True,
class_token_len: int = 1,
) -> None:
super().__init__(config=transformer_config)
self.visual_hidden_size = transformer_config.hidden_size
self.patch_dim = patch_dim
self.img_h = img_h
self.img_w = img_w
assert self.img_h % self.patch_dim == 0
assert self.img_w % self.patch_dim == 0
self.num_patches_per_dim_h = self.img_h // self.patch_dim
self.num_patches_per_dim_w = self.img_w // self.patch_dim
self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w
self.add_class_token = add_class_token
self.class_token_len = class_token_len
self.seq_length = self.num_patches + (self.class_token_len if self.add_class_token else 0)
self.conv1 = torch.nn.Conv2d(
in_channels=3,
out_channels=self.visual_hidden_size,
kernel_size=self.patch_dim,
stride=self.patch_dim,
bias=False,
)
self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
self.position_embeddings = torch.nn.Embedding(self.seq_length, self.visual_hidden_size)
self.add_class_token = add_class_token
if self.add_class_token:
self.class_token = torch.nn.Parameter(
torch.randn(1, self.class_token_len, self.visual_hidden_size)
)
self.ln_pre = TENorm(
config=self.config,
hidden_size=self.visual_hidden_size,
eps=self.config.layernorm_epsilon,
)
self.model_type = ModelType.encoder_or_decoder
# Transformer + final layer norm (via post_process)
# TODO: Follow-up changes will make pre and post_process configurable. They are needed for supporting pipeline parallelism.
self.transformer = TransformerBlock(
config=transformer_config,
spec=transformer_layer_spec,
pre_process=True,
post_process=True,
)
# Note: a final linear layer present in some implementations is omitted here. It can be added separately where needed.
def set_input_tensor(self, input_tensor: torch.Tensor) -> None:
"""Sets input tensor to the model.
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
self.transformer.set_input_tensor(input_tensor)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Forward function of the CLIP ViT Model. This function passes the input tensors
through the embedding layer and then the transformer.
Args:
x (torch.Tensor): input data of shape [batch, img_h, img_w]
attention_mask (torch.Tensor with dtype=bool): Attention mask to use. If none, all ones.
Returns:
x (torch.Tensor): output after final transformer block of shape [b, s, h].
"""
x = self.conv1(x) # shape = [batch, hidden_size, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # [batch, hidden_size, grid ** 2]
x = x.permute(0, 2, 1) # [batch, grid ** 2, hidden_size]
if self.add_class_token:
class_token = self.class_token.expand(
x.shape[0], -1, -1
) # [batch, class_token_len, hidden_size]
x = torch.cat(
[class_token, x], dim=1
) # [batch, grid ** 2 + class_token_len, hidden_size]
x = x + self.position_embeddings(self.position_ids)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # [b, s, h] -> [s, b, h]
if attention_mask is None:
attention_mask = torch.ones(1, 1, x.shape[0], x.shape[0]).cuda() # [1, 1, s, s]
attention_mask = attention_mask < 0.5 # to bool
x = self.transformer(x.contiguous(), attention_mask)
x = x.permute(1, 0, 2) # [s, b, h] -> [b, s, h]
x = x.contiguous()
return x
from megatron.core import tensor_parallel
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
class MultimodalProjector(MegatronModule):
"""
MultimodalProjector will take the encoded input with input_size hidden state and project
it into the hidden size of the language model for multimodal training. When projector is
type affine linear_fc1 from submodules is used.
Args:
transformer_config (TransformerConfig): Transformer config
submodules (MLPSubmodules): Specifies MLP submodules for mlp type projector
projector_type (str): Projector type
input_size (int): Input size from feature encoder
"""
def __init__(
self,
config: TransformerConfig,
submodules: MLPSubmodules,
projector_type: str,
input_size: int,
):
super().__init__(config=config)
self.projector_type = projector_type
assert submodules is not None, "MLPSubmodules must be provided"
if self.projector_type == "mlp":
self.encoder = MLP(config=config, submodules=submodules, input_size=input_size)
elif self.projector_type == "affine":
self.encoder = build_module(
submodules.linear_fc1,
input_size,
config.hidden_size,
config=config,
init_method=config.init_method,
gather_output=True,
bias=config.add_bias_linear,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name=None,
)
else:
raise Exception(f"Unsupported multimodal projection type {self.projector_type}")
def forward(self, hidden_states):
# Run encoder.
encoder_output, encoder_output_bias = self.encoder(hidden_states)
if encoder_output_bias is not None:
encoder_output = encoder_output + encoder_output_bias
return encoder_output
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.custom_layers.transformer_engine import (
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TERowParallelLinear,
)
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
# Use this spec to use lower level Transformer Engine modules (required for fp8 training)
def get_vit_layer_with_transformer_engine_spec() -> ModuleSpec:
mlp = _get_mlp_module_spec(use_te=True)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.no_mask},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
# Helper function to get module spec for MLP/MoE
def _get_mlp_module_spec(use_te: bool = True,) -> ModuleSpec:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from logging import getLogger
from typing import Callable, Dict, List, Optional
import torch
from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD
from megatron.core import mpu
from ..distributed import ParamAndGradBuffer
from ..transformer.module import MegatronModule
from .distrib_optimizer import DistributedOptimizer
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import (
ChainedOptimizer,
Float16OptimizerWithFloat16Params,
FP32Optimizer,
MegatronOptimizer,
)
from .optimizer_config import OptimizerConfig
logger = getLogger(__name__)
def _get_param_groups(
model_chunks: List[MegatronModule],
no_weight_decay_cond: Callable,
scale_lr_cond: Callable,
lr_mult: float,
use_decoupled_learning_rate: bool,
) -> List[Dict]:
"""Create parameter groups for optimizer.
Creates parameter groups based on weight decay condition (regularized vs
non regularized), learning rate scale condition (lr vs lr_mult * lr),
and whether it is expert parameters. scale_lr_cond is used during finetuning
where head of the network requires a scaled version of the base learning rate.
Args:
model_chunks (List[MegatronModule]): model chunks to create parameter
groups for.
no_weight_decay_cond (func): function to determine whether a parameter
should not perform weight decay.
scale_lr_cond (func): function to determine whether a parameter
should have a scaled learning rate.
lr_mult (float): learning rate multiplier for parameters that
satisfy scale_lr_cond.
use_decoupled_learning_rate (bool): true if using decoupled learning rate.
Returns:
List of parameter groups.
"""
# Map (wd_mult, lr_mult, is_expert_parallel, is_decoupled_lr) to params.
params_map = {}
for model_chunk in model_chunks:
for name, param in model_chunk.named_parameters():
if not param.requires_grad:
continue
is_expert_parallel = not getattr(param, 'allreduce', True)
if no_weight_decay_cond is not None:
no_wd = no_weight_decay_cond(name, param)
else:
# Do not regularize biases and norm parameters.
no_wd = name.endswith(".bias") or len(param.shape) == 1
if scale_lr_cond is not None:
scale_lr = scale_lr_cond(name, param)
else:
scale_lr = False
if not no_wd and not scale_lr:
wd_mult, lr_mult = 1.0, 1.0
elif not no_wd and scale_lr:
wd_mult, lr_mult = 1.0, lr_mult
elif no_wd and not scale_lr:
wd_mult, lr_mult = 0.0, 1.0
else:
wd_mult, lr_mult = 0.0, lr_mult
is_decoupled_lr = False
# For input/embedding and output layer: embedding.word_embeddings.weight / output_layer.weight.
if use_decoupled_learning_rate and getattr(
param, 'is_embedding_or_output_parameter', False
):
is_decoupled_lr = True
key = (wd_mult, lr_mult, is_expert_parallel, is_decoupled_lr)
if key not in params_map:
params_map[key] = []
params_map[key].append(param)
param_groups = []
for (wd_mult, lr_mult, is_expert_parallel, is_decoupled_lr), params in params_map.items():
assert len(params) > 0
param_groups.append(
{
'params': params,
'wd_mult': wd_mult,
'lr_mult': lr_mult,
'is_expert_parallel': is_expert_parallel,
'is_decoupled_lr': is_decoupled_lr,
}
)
return param_groups
def _update_min_and_max_lr_in_param_groups(
param_groups: List[Dict],
lr: float,
min_lr: float,
decoupled_lr: Optional[float],
decoupled_min_lr: Optional[float],
) -> List[Dict]:
"""
Updates `max_lr` and `min_lr` values in each parameter group, and returns new list.
By default, each group will use `lr` / `min_lr` as `max_lr` / `min_lr`.
If `decoupled_lr` is provided, then `decoupled_lr` / `decoupled_min_lr` will be used
as `max_lr` / `min_lr` for the input and output layer.
Args:
param_groups (List): parameter groups whose 'max_lr' and `min_lr` fields need to
be adjusted.
lr (float): learning rate.
min_lr (float): minimum learning rate.
decoupled_lr (Optional[float]): optional decoupled learning rate.
decoupled_min_lr (Optional[float]): optional decoupled minimum learning rate.
Returns:
List of adjusted parameter groups.
"""
if decoupled_min_lr is None:
decoupled_min_lr = min_lr
for param_group in param_groups:
if param_group['is_decoupled_lr']:
assert decoupled_lr is not None
param_group['max_lr'] = decoupled_lr
param_group['min_lr'] = decoupled_min_lr
else:
param_group['max_lr'] = lr
param_group['min_lr'] = min_lr
return param_groups
def _get_megatron_optimizer_based_on_param_groups(
config: OptimizerConfig,
param_groups: List,
per_model_buffers: Optional[Dict[int, List[ParamAndGradBuffer]]] = None,
data_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
data_parallel_group_gloo: Optional[torch.distributed.ProcessGroup] = None,
data_parallel_group_idx: Optional[int] = None,
) -> MegatronOptimizer:
"""Get Megatron optimizer based on parameter groups.
Args:
config (OptimizerConfig): optimizer configuration object.
param_groups (list): list of parameter groups.
per_model_buffers (dict, optional): buffers for distributed optimizer. Defaults to None.
data_parallel_group (torch.distributed.ProcessGroup, optional): data-parallel group for
distributed optimizer. Defaults to None.
data_parallel_group_gloo (torch.distributed.ProcessGroup, optional): gloo data-parallel
group for distributed optimizer. Defaults to None.
data_parallel_group_idx (int, optional): data-parallel group index for distributed
optimizer. Defaults to None.
Returns:
Instance of MegatronOptimizer.
"""
if config.optimizer == 'adam':
optimizer = Adam(
param_groups,
lr=config.lr,
weight_decay=config.weight_decay,
betas=(config.adam_beta1, config.adam_beta2),
eps=config.adam_eps,
)
def init_state_fn(opt):
for group in opt.param_groups:
for p in group['params']:
if len(opt.state[p]) == 0:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)
elif config.optimizer == 'sgd':
optimizer = SGD(
param_groups,
lr=config.lr,
weight_decay=config.weight_decay,
momentum=config.sgd_momentum,
)
init_state_fn = None
else:
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
# Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
# from the MixedPrecisionOptimizer, which manages any optimizer where
# the model params and main params are distinct.
if config.fp16 or config.bf16 or config.use_distributed_optimizer:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler = None
# Constant loss scale.
if config.loss_scale:
grad_scaler = ConstantGradScaler(config.loss_scale)
# Dynamic loss scale.
else:
if config.fp16:
grad_scaler = DynamicGradScaler(
initial_scale=config.initial_loss_scale,
min_scale=config.min_loss_scale,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=config.loss_scale_window,
hysteresis=config.hysteresis,
)
optimizer_args = [
optimizer,
config,
grad_scaler,
init_state_fn,
]
if config.use_distributed_optimizer:
optimizer = DistributedOptimizer(
*optimizer_args,
per_model_buffers=per_model_buffers,
data_parallel_group=data_parallel_group,
data_parallel_group_gloo=data_parallel_group_gloo,
data_parallel_group_idx=data_parallel_group_idx,
)
else:
optimizer = Float16OptimizerWithFloat16Params(*optimizer_args)
return optimizer
# FP32.
return FP32Optimizer(optimizer, config, init_state_fn,)
def get_megatron_optimizer(
config: OptimizerConfig,
model_chunks: List[MegatronModule],
no_weight_decay_cond: Optional[Callable] = None,
scale_lr_cond: Optional[Callable] = None,
lr_mult: float = 1.0,
) -> MegatronOptimizer:
"""Retrieve the Megatron optimizer for model chunks.
We use separate optimizers for expert parameters and non-expert parameters.
Args:
config (OptimizerConfig): optimizer configuration object.
model_chunks (List[MegatronModule]): model chunks to get optimizer for.
no_weight_decay_cond (func, optional): function to determine whether a parameter
should not perform weight decay. Defaults to None.
scale_lr_cond (func, optional): function to determine whether a parameter
should have a scaled learning rate. Defaults to None.
lr_mult (float, optional): learning rate multiplier for parameters that
satisfy scale_lr_cond. Defaults to 1.0.
Returns:
Instance of MegatronOptimizer.
"""
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
logger.info(f'Setting up optimizer with {type(config).__name__}: {config}')
# Collect param groups.
param_groups = _get_param_groups(
model_chunks,
no_weight_decay_cond,
scale_lr_cond,
lr_mult,
use_decoupled_learning_rate=config.decoupled_lr is not None,
)
param_groups = _update_min_and_max_lr_in_param_groups(
param_groups,
lr=config.lr,
min_lr=config.min_lr,
decoupled_lr=config.decoupled_lr,
decoupled_min_lr=config.decoupled_min_lr,
)
# Collect grad buffers for distributed optimizer.
per_model_buffers = {}
per_model_ep_buffers = {}
for model_idx, model_chunk in enumerate(model_chunks):
if hasattr(model_chunk, 'buffers'):
per_model_buffers[model_idx] = model_chunk.buffers
per_model_ep_buffers[model_idx] = model_chunk.expert_parallel_buffers
# Split param groups into dense and MoE params (since data-parallel groups for MoE
# parameters can be different with expert parallelism).
dense_param_groups = list(filter(lambda g: not g['is_expert_parallel'], param_groups))
moe_param_groups = list(filter(lambda g: g['is_expert_parallel'], param_groups))
# Create optimizers.
model_parallel_rank = torch.distributed.get_rank(mpu.get_model_parallel_group())
optimizers = [
_get_megatron_optimizer_based_on_param_groups(
config,
param_groups=dense_param_groups,
per_model_buffers=per_model_buffers,
data_parallel_group=mpu.get_data_parallel_group(with_context_parallel=True),
data_parallel_group_gloo=mpu.get_data_parallel_group_gloo(with_context_parallel=True),
data_parallel_group_idx=model_parallel_rank,
)
]
if len(moe_param_groups) > 0:
model_parallel_world_size = torch.distributed.get_world_size(mpu.get_model_parallel_group())
expert_parallel_rank = mpu.get_expert_model_parallel_rank()
optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config,
param_groups=moe_param_groups,
per_model_buffers=per_model_ep_buffers,
data_parallel_group=mpu.get_data_modulo_expert_parallel_group(),
data_parallel_group_gloo=mpu.get_data_modulo_expert_parallel_group_gloo(),
data_parallel_group_idx=expert_parallel_rank * model_parallel_world_size
+ model_parallel_rank,
)
)
if len(optimizers) == 1:
return optimizers[0]
return ChainedOptimizer(optimizers)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Gradient clipping."""
import os
from typing import List, Optional, Union
import amp_C
import torch
from apex.multi_tensor_apply import multi_tensor_applier
from torch import inf
from ..tensor_parallel import param_is_not_tensor_parallel_duplicate
from ..transformer.module import param_is_not_shared
def clip_grad_norm_fp32(
parameters: Union[List[torch.Tensor], torch.Tensor],
grads_for_norm: Union[List[torch.Tensor], torch.Tensor],
max_norm: Union[int, float],
norm_type: Union[int, float] = 2,
model_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) -> float:
"""Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized.
grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single
Tensor that will be used for calculating the grad norm.
max_norm (float or int): max norm of the gradients.
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
model_parallel_group (torch.distributed.ProcessGroup, optional): model-parallel
group over which grad norm needs to be aggregated.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
if isinstance(grads_for_norm, torch.Tensor):
grads_for_norm = [grads_for_norm]
# Grads.
grads = []
for param in parameters:
if param.grad is not None:
assert param.grad.type() == 'torch.cuda.FloatTensor'
grads.append(param.grad.detach())
# Norm parameters.
max_norm = float(max_norm)
norm_type = float(norm_type)
total_norm = 0.0
# Calculate norm.
if norm_type == inf:
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.tensor([float(total_norm)], dtype=torch.float, device='cuda')
# Take max across all model-parallel GPUs.
torch.distributed.all_reduce(
total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=model_parallel_group
)
total_norm = total_norm_cuda[0].item()
else:
if norm_type == 2.0:
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
# Use apex's multi-tensor applier for efficiency reasons.
# Multi-tensor applier takes a function and a list of list
# and performs the operation on that list all in one kernel.
if grads_for_norm:
grad_norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_for_norm],
False, # no per-parameter norm
)
else:
grad_norm = torch.tensor([0], dtype=torch.float, device='cuda')
# Since we will be summing across data parallel groups,
# we need the pow(norm-type).
total_norm = grad_norm ** norm_type
else:
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(
total_norm, op=torch.distributed.ReduceOp.SUM, group=model_parallel_group
)
total_norm = total_norm.item() ** (1.0 / norm_type)
# Scale.
clip_coeff = max_norm / (total_norm + 1.0e-6)
if clip_coeff < 1.0:
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
multi_tensor_applier(
amp_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff
)
return total_norm
def count_zeros_fp32(
parameters: Union[List[torch.Tensor], torch.Tensor],
model_parallel_group: torch.distributed.ProcessGroup,
) -> float:
"""Counts the number of zeros in gradients associated with the passed-in list of
parameters.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have the number of zeros in its corresponding
gradient counted.
model_parallel_group (torch.distributed.ProcessGroup, optional): model-parallel
group over which grad norm needs to be aggregated.
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
total_num_zeros = torch.tensor([0.0], dtype=torch.float, device='cuda')
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grad = param.grad.detach()
num_zeros = grad.numel() - torch.count_nonzero(grad)
total_num_zeros = num_zeros + total_num_zeros
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(
total_num_zeros, op=torch.distributed.ReduceOp.SUM, group=model_parallel_group
)
total_num_zeros = total_num_zeros.item()
return total_num_zeros
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Megatron distributed optimizer."""
import itertools
from logging import getLogger
from typing import Callable, Dict, List, Optional, Tuple
import torch
from apex.optimizers import FusedAdam as Adam
from .. import parallel_state, tensor_parallel
from ..dist_checkpointing import ShardedTensor
from ..dist_checkpointing.mapping import LocalNonpersitentObject, ShardedObject, ShardedStateDict
from ..distributed import ParamAndGradBuffer, shard_buffer
from .grad_scaler import MegatronGradScaler
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
from .optimizer_config import OptimizerConfig
logger = getLogger(__name__)
class Range:
"""
A range represents a start and end points for indexing a shard
from a full tensor.
"""
def __init__(self, start: int, end: int):
self.start = start
self.end = end
self.size = end - start
def normalize(self, start: int = 0):
return Range(start, start + self.size)
def __str__(self):
return "%d,%d [%d]" % (self.start, self.end, self.size)
def __len__(self):
return self.end - self.start
class DistributedOptimizer(MixedPrecisionOptimizer):
@classmethod
def _build_model_gbuf_param_range_map(
cls,
param_world_index_map: Dict[torch.nn.Parameter, Tuple],
gbuf_world_range: Range,
bucket_offset: int,
):
"""
Build mapping from param reference to grad buffer shard ranges.
This method builds a mapping from parameter references to grad
buffer shard ranges, specific to each data-parallel (DP) rank's
set of 'owned' parameters. Each grad buffer (padded to be an even
multiple of DP-world-size) is conceptually divided into DP-world-size
contiguous regions, where each DP rank 'owns' a contiguous regions.
Ownership in this sense means DP rank is responsible for reducing
the relevant subset of grads, and updating the relevant subset of
params.
This conceptual partitioning of the grad buffer does NOT respect
parameter boundaries, and as such it is assumed that each created
range references a shard (or subset) of the full parameter. It is
easiest to think of each DP rank as operating (i.e., reducing,
gathering) purely on views into the grad buffer, for all model-to-
main & main-to-model operations.
This method creates four ranges:
- The param's range within the entire grad buffer (i.e., world index).
- The param's range within the relevant grad bucket's buffer.
- The param's range within the DP rank's local view of the grad buffer.
- The param's range within itself (i.e., its shard).
"""
# Param range map.
param_range_map = {}
for param, param_world_indexes in param_world_index_map.items():
# Param range.
param_world_start, param_world_end, _ = param_world_indexes
param_local_start = max(0, param_world_start - gbuf_world_range.start)
param_local_end = min(gbuf_world_range.size, param_world_end - gbuf_world_range.start)
# Add param, if within local gbuf range.
if param_local_end > param_local_start:
param_local_range = Range(param_local_start, param_local_end)
param_world_range = param_local_range.normalize(
param_local_start + gbuf_world_range.start
)
param_world_range_in_bucket = Range(
param_world_range.start - bucket_offset, param_world_range.end - bucket_offset
)
sub_param_start = max(0, gbuf_world_range.start - param_world_start)
sub_param_range = param_local_range.normalize(sub_param_start)
param_range_map[param] = {
"gbuf_world": param_world_range,
"gbuf_world_in_bucket": param_world_range_in_bucket,
"gbuf_local": param_local_range,
"param": sub_param_range,
}
return param_range_map
@classmethod
def _build_model_gbuf_range(cls, param_and_grad_buffer: ParamAndGradBuffer, bucket_index: int):
"""
Build mapping between params and their grad buffers.
This method does the initial setup for the method above. This setup
includes determining the shard ranges into the param_and_grad_buffer
for each data-parallel (DP) rank. Each DP rank keeps range info for
all other DP ranks, for the purpose of creating args for
reduce-scatter and all-gather.
"""
data_parallel_rank = torch.distributed.get_rank(param_and_grad_buffer.data_parallel_group)
data_parallel_world_size = param_and_grad_buffer.data_parallel_group.size()
bucket = param_and_grad_buffer.buckets[bucket_index]
gbuf_size = bucket.grad_data.numel()
assert (
gbuf_size % data_parallel_world_size == 0
), f"Each bucket's buffer size should be divisible by {data_parallel_world_size}"
max_gbuf_range_size = gbuf_size // data_parallel_world_size
# All world ranges (i.e., across all data parallel ranks).
gbuf_world_all_ranges = []
for r in range(data_parallel_world_size):
# Compute start of chunk in this bucket.
gbuf_world_start = r * max_gbuf_range_size
gbuf_world_end = min(gbuf_size, gbuf_world_start + max_gbuf_range_size)
# Add bucket's offset in grad buffer.
gbuf_world_range = Range(
gbuf_world_start + bucket.offset, gbuf_world_end + bucket.offset
)
gbuf_world_all_ranges.append(gbuf_world_range)
# Local DP's ranges.
gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank]
# Get each param's ranges.
param_range_map = cls._build_model_gbuf_param_range_map(
param_and_grad_buffer.param_index_map, gbuf_world_range, bucket.offset
)
# Group into dict.
data = {
"param_map": param_range_map,
}
return data
@classmethod
def _build_gbuf_range_map(cls, param_and_grad_buffer: ParamAndGradBuffer):
"""
Build mapping between params and their grad buffers. These mappings are
partitioned according to data type.
Iterate through all buckets of grad buffer to construct param ranges
that this rank "owns" (the dp_rank'th shard of each bucket, where each
shard is 1/dp_world_size of the bucket).
Args:
param_and_grad_buffer (ParamAndGradBuffer): buffer to build mapping for.
"""
return {
(param_and_grad_buffer.param_dtype, param_and_grad_buffer.grad_dtype): [
cls._build_model_gbuf_range(param_and_grad_buffer, bucket_index)
for bucket_index in range(len(param_and_grad_buffer.buckets))
]
}
@classmethod
def _build_model_param_gbuf_map(
cls, gbuf_ranges: List[Dict]
) -> Dict[torch.nn.Parameter, Tuple]:
"""
Create a reverse of the gbuf_ranges, for referencing in opposite direction.
"""
param_gbuf_map = {}
for gbuf_index, gbuf_range_map in enumerate(gbuf_ranges):
for dtype, gbuf_range_map_for_all_buckets in gbuf_range_map.items():
for bucket_index, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
for param, _ in gbuf_range_map["param_map"].items():
assert (
param not in param_gbuf_map
), "Param should not be in param_gbuf_map; each param only belongs to a single bucket"
param_gbuf_map[param] = (gbuf_index, dtype, bucket_index)
return param_gbuf_map
@classmethod
def _build_optimizer_group_ranges(cls, param_groups: List[Dict], gbuf_ranges: List[Dict]):
"""
Create optimizer groups.
Given the set of parameter shard ranges that are owned by the current
data-parallel (DP) rank, gather the set of parameters that will be
used (in the method below) to create the current DP's optimizer
groups.
"""
# Param group map.
# World param group map.
# - Store a mapping of <model_parameter:group_index> for all parameters
# across all DP ranks. This is necessary because it is our first
# cross reference between the DDP mappings and the optimizer group
# parameters. This mapping only for use in the next step of building
# the local mapping over this DP rank's parameters.
world_param_group_map = {}
for group_index, group in enumerate(param_groups):
for param in group["params"]:
assert param.requires_grad
world_param_group_map[param] = group_index
# Optimizer group ranges & param-group mapping.
# - Build a mapping from groups to their contained parameters, and also
# from parameters to their containing group index and order within
# the group. The group index and order are particularly important for
# saving and loading checkpoints.
local_param_group_map = {}
group_ranges = [{"params": []} for _ in param_groups]
for gbuf_range_map in gbuf_ranges:
for dtype, gbuf_range_map_for_all_buckets in gbuf_range_map.items():
for gbuf_range_map in gbuf_range_map_for_all_buckets:
for param in gbuf_range_map["param_map"]:
group_index = world_param_group_map[param]
group_range = group_ranges[group_index]
group_range["params"].append(param)
local_param_group_map[param] = (group_index, len(group_range["params"]) - 1)
# Squeeze zero-size group ranges.
for group_index, group_range in enumerate(group_ranges):
group_range["orig_group"] = param_groups[group_index]
group_range["orig_group_idx"] = param_groups[group_index]
return local_param_group_map, group_ranges
@classmethod
def _build_model_and_main_param_groups(
cls,
gbuf_ranges: List[Dict],
param_gbuf_map: Dict[torch.nn.Parameter, Tuple],
opt_group_ranges: List,
):
"""
Create main parameter groups needed for the optimizer step.
These groups encompass both: 1) groups used by this class, for
reducing/gather, and 2) groups used by the inner optimizer for the
parameter update. Given that the conceptual grad buffer partitioning
(created in earlier method) doesn't respect parameter boundaries,
the optimizer operates on shards of the model parameters, rather than
the full parameters.
"""
# Parameter groups:
# model_float16_groups: original float16 parameters
# model_fp32_groups: original fp32 parameters
# shard_float16_groups: shards of original float16 parameters
# shard_fp32_groups: shards of original fp32 parameters
# shard_fp32_from_float16_groups: fp32 copy of float16 parameters
model_float16_groups = []
model_fp32_groups = []
shard_float16_groups = []
shard_fp32_groups = []
shard_fp32_from_float16_groups = []
# Allocate (or slice) each group's param shard.
for group_range in opt_group_ranges:
# Params of this group.
model_float16_params_this_group = []
model_fp32_params_this_group = []
shard_float16_params_this_group = []
shard_fp32_params_this_group = []
shard_fp32_from_float16_params_this_group = []
model_float16_groups.append(model_float16_params_this_group)
model_fp32_groups.append(model_fp32_params_this_group)
shard_float16_groups.append(shard_float16_params_this_group)
shard_fp32_groups.append(shard_fp32_params_this_group)
shard_fp32_from_float16_groups.append(shard_fp32_from_float16_params_this_group)
for model_param in group_range["params"]:
assert model_param.requires_grad
gbuf_index, dtype, bucket_index = param_gbuf_map[model_param]
gbuf_range = gbuf_ranges[gbuf_index][dtype][bucket_index]
param_range = gbuf_range["param_map"][model_param]["param"]
# fp16, bf16 params.
if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
# Clone model -> main.
shard_model_param = model_param.detach().view(-1)[
param_range.start : param_range.end
]
shard_main_param = shard_model_param.clone().float()
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_model_param, model_param
)
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_main_param, model_param
)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
shard_main_param.shared = model_param.shared
# Add to group.
model_float16_params_this_group.append(model_param)
shard_float16_params_this_group.append(shard_model_param)
shard_fp32_from_float16_params_this_group.append(shard_main_param)
# fp32 params.
elif model_param.type() == 'torch.cuda.FloatTensor':
shard_model_param = model_param.view(-1)[param_range.start : param_range.end]
model_fp32_params_this_group.append(model_param)
shard_fp32_params_this_group.append(shard_model_param)
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_model_param, model_param
)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
else:
raise TypeError(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'.format(model_param.type())
)
# Update optimizer's params.
group_range["orig_group"]["params"] = [
*shard_fp32_params_this_group,
*shard_fp32_from_float16_params_this_group,
]
return (
model_float16_groups,
model_fp32_groups,
shard_float16_groups,
shard_fp32_groups,
shard_fp32_from_float16_groups,
)
def __init__(
self,
optimizer: torch.optim.Optimizer,
config: OptimizerConfig,
grad_scaler: MegatronGradScaler,
init_state_fn: Optional[Callable],
per_model_buffers: Dict[int, List[ParamAndGradBuffer]],
data_parallel_group: torch.distributed.ProcessGroup,
data_parallel_group_gloo: torch.distributed.ProcessGroup,
data_parallel_group_idx: int,
):
"""
Distributed optimizer, for all data types (fp16, bf16, and fp32).
The steps in this method create the core mapping between param and grad buffers,
parameters, and parameter shard ranges, that is needed for converting between model
param indexes and main parameter shard indexes. This method also updates the optimizer
parameter groups with the newly created shards.
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
grad_scaler (MegatronGradScaler): used for scaling gradients. Note that
this can be None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constant gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
per_model_buffers (Dict[int, List[ParamAndGradBuffer]]): the implementation of the
distributed optimizer is centered on using a contiguous buffer for
communicating grads & params between the model state and the optimizer state.
You can find a more detailed description in
https://github.com/NVIDIA/Megatron-LM/blob/main/docs/source/distrib_optimizer.md.
data_parallel_group (torch.distributed.ProcessGroup): data-parallel group to use to
all-gather params after optimizer.step().
data_parallel_group_gloo (torch.distributed.ProcessGroup): gloo data-parallel group
(used in checkpoint loading and saving).
data_parallel_group_idx (int): index in data-parallel group (used by
distributed checkpointing logic).
"""
super().__init__(
optimizer, config, grad_scaler, init_state_fn,
)
assert isinstance(
optimizer, Adam
), "Only Adam currently supported, due to checkpointing requirements."
# Model grad buffer ranges.
assert per_model_buffers is not None, "per_model_buffers must be provided"
self.buffers = list(itertools.chain(*per_model_buffers.values()))
self.per_model_buffers = per_model_buffers
self.data_parallel_group = data_parallel_group
self.data_parallel_group_gloo = data_parallel_group_gloo
self.data_parallel_group_idx = data_parallel_group_idx
self.gbuf_idx_to_model_idx_map = {}
gbuf_idx = 0
for model_idx, buffers in self.per_model_buffers.items():
for _ in buffers:
self.gbuf_idx_to_model_idx_map[gbuf_idx] = model_idx
gbuf_idx += 1
self.gbuf_ranges = []
self.per_bucket_numel = []
self.per_bucket_numel_unpadded = []
for buffer in self.buffers:
self.per_bucket_numel.append(
{
(buffer.param_dtype, buffer.grad_dtype): [
bucket.grad_data.numel() for bucket in buffer.buckets
]
}
)
self.per_bucket_numel_unpadded.append(
{
(buffer.param_dtype, buffer.grad_dtype): [
bucket.numel_unpadded for bucket in buffer.buckets
]
}
)
self.gbuf_ranges.append(self._build_gbuf_range_map(buffer))
self.model_param_gbuf_map = self._build_model_param_gbuf_map(self.gbuf_ranges)
# Optimizer ranges.
(
self.model_param_group_index_map,
self.opt_group_ranges,
) = self._build_optimizer_group_ranges(self.optimizer.param_groups, self.gbuf_ranges)
# Allocate main param shards.
(
self.model_float16_groups,
self.model_fp32_groups,
self.shard_float16_groups,
self.shard_fp32_groups,
self.shard_fp32_from_float16_groups,
) = self._build_model_and_main_param_groups(
self.gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges
)
# Now construct data structures to manage all-gather handles.
self.all_gather_handles = []
self.all_gather_handle_index_to_bucket_index_map = []
self.model_index_to_all_gather_handle_index_map = {}
self.all_gather_handle_indices = []
self.param_to_all_gather_handle_index_map = {}
self.pbuf_view_items = self._get_model_param_buffer_dp_views()
for (gbuf_index, dtype, bucket_index, _, _) in self.pbuf_view_items:
self.all_gather_handle_index_to_bucket_index_map.append(
(gbuf_index, dtype, bucket_index)
)
all_gather_handle_index = len(self.all_gather_handle_index_to_bucket_index_map) - 1
self.all_gather_handles.append(None)
# Store all all_gather_handle_indices.
model_idx = self.gbuf_idx_to_model_idx_map[gbuf_index]
if model_idx not in self.model_index_to_all_gather_handle_index_map:
self.model_index_to_all_gather_handle_index_map[model_idx] = []
self.model_index_to_all_gather_handle_index_map[model_idx].append(
all_gather_handle_index
)
for param in self.buffers[gbuf_index].buckets[bucket_index].params_list:
self.param_to_all_gather_handle_index_map[param] = all_gather_handle_index
self.num_all_gather_handles = len(self.all_gather_handle_index_to_bucket_index_map)
self.overlap_param_gather = self.config.overlap_param_gather
self.remove_pre_hook_handle = None
if self.overlap_param_gather:
self.enable_pre_hook()
self.update_successful = False
# Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors.
self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges]
self.optimizer.load_state_dict(self.optimizer.state_dict())
def enable_pre_hook(self):
"""
Enable forward pre-hook needed for param all-gather overlap with forward compute.
"""
assert self.remove_pre_hook_handle is None
self.remove_pre_hook_handle = torch.nn.modules.module.register_module_forward_pre_hook(
self._make_forward_pre_hook()
)
def disable_pre_hook(self):
"""
Disable forward pre-hook needed for param all-gather overlap with forward compute.
"""
assert self.remove_pre_hook_handle is not None
self.remove_pre_hook_handle.remove()
self.remove_pre_hook_handle = None
# Make sure all-gathers are completed as needed.
self._reset_metadata_and_sync_gather_all_model_params(force_sync=True)
def _get_model_param_range_map(self, param: torch.nn.Parameter):
"""
Given a model param, get the index sub-range of the param that this
data-parallel rank owns.
"""
gbuf_index, dtype, bucket_index = self.model_param_gbuf_map[param]
gbuf_range_map = self.gbuf_ranges[gbuf_index][dtype][bucket_index]
param_range_map = gbuf_range_map["param_map"][param]
return param_range_map
def get_model_parallel_group(self) -> torch.distributed.ProcessGroup:
"""
With the distributed optimizer, the model parallel group is the
entire world.
"""
return None
def state_dict(self):
"""
The state dict contains all non-DP-rank-dependent (i.e., non-parameter-
related) optimizer variables. The returned state dict can be stored in
the standard model/RNG checkpoint file. The parameter and dependent
optimizer state (e.g., exp_avg, exp_avg_sq) are stored in a separate
checkpoint file by calling 'save_parameter_state()'.
"""
state_dict = {}
# Optimizer state (do not store parameter state here).
state_dict['optimizer'] = {
k: v for k, v in self.optimizer.state_dict().items() if k != "state"
}
for param_group in state_dict["optimizer"]["param_groups"]:
del param_group["params"]
# Grad scaler state.
if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
return state_dict
def load_state_dict(self, state_dict):
"""Load the state dict.
As detailed in state_dict(), the state dict contains all non-
parameter-related variables. This method is notably longer than
state_dict(), because the Torch optimizers state has yet to be
allocated at this point, and so we must do a cross referencing between
the optimizers state (and the ordering it expects for parameter state)
and this DP rank's shards. The optimizer at this point does not contain
any tensor dimension information, so we must get these dimensions from
the DP shards mapped during DistributedOptimizer.__init__().
The tensor parameter state is loaded via load_parameter_state(), and
so this method also must populate the loaded state dict with dummy
tensor data (i.e., via torch.empty() below). This will be overwritten
during load_parameter_state().
** Note: Torch optimizer's state structure. **
The Torch optimizer stores its state in two levels. The top level is a
list of groups, where each group contains a list of integer indexes
(corresponding to parameters) that index into a master parameter list
that is shared by all groups. As such, three values are necessary for
maintaining this ordering:
- group_index : The group to which a parameter belongs.
- group_order : The index of a parameter within its group.
- state_order : The index of a parameter within the shared parameter
list.
"""
# Get the Torch optimizer's state dict.
# - This 'inner' optimizer at this point is unallocated, and only
# contains an integer odering of parameters within each group, and
# the ordering of parameters within its flattened parameter state
# list.
inner_state_dict = self.optimizer.state_dict()
state_dict_param_groups = [
{**group, "params": list(inner_state_dict["param_groups"][idx]["params"]),}
for idx, group in enumerate(state_dict["optimizer"]["param_groups"])
]
# Allocate 'dummy' data for optimizer state (i.e., torch.empty() below)
# - Real data is overwritten during load_parameter_state().
state_dict_state = []
for gbuf_range_maps in self.gbuf_ranges:
for gbuf_range_map_for_all_buckets in gbuf_range_maps.values():
for gbuf_range_map in gbuf_range_map_for_all_buckets:
for model_param, param_range_map in gbuf_range_map["param_map"].items():
# Get parameter ordering information (see method docstring
# for details).
group_index, group_order = self.model_param_group_index_map[model_param]
state_order = inner_state_dict["param_groups"][group_index]["params"][
group_order
]
# Allocate dummy tensors.
numel = len(param_range_map["gbuf_world"])
init_shard = lambda: torch.empty(
(numel,), dtype=torch.float32, device=torch.cuda.current_device()
)
state_dict_state.append(
(state_order, {"exp_avg": init_shard(), "exp_avg_sq": init_shard(),})
)
# Sort by state order (see method docstring for details).
state_dict_state.sort(key=lambda s: s[0])
state_dict_state = {s[0]: s[1] for s in state_dict_state}
# Optimizer.
self.optimizer.load_state_dict(
{"state": state_dict_state, "param_groups": state_dict_param_groups,}
)
# Grad scaler.
if 'grad_scaler' not in state_dict:
if self.config.fp16:
logger.info(
'***WARNING*** found an old checkpoint, will not ' 'load grad scaler ...'
)
else:
if self.grad_scaler:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
else:
logger.info(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
if 'param_state' in state_dict:
assert 'param_state_sharding_type' in state_dict, state_dict.keys()
param_state = state_dict['param_state']
sharding_type = state_dict['param_state_sharding_type']
logger.info(f'Loading distributed optimizer sharded state of type {sharding_type}')
if sharding_type == 'dp_zero_gather_scatter':
self.load_parameter_state_from_dp_zero(param_state)
elif sharding_type == 'fully_sharded_bucket_space':
self.load_parameter_state_from_fs_bucket_space(param_state)
else:
raise NotImplementedError(f'Unknown sharding_type: {sharding_type}')
def get_parameter_state_fs_bucket_space(self):
"""Get internal representation of parameter state without any copies and modifications.
This is referred to as "fully sharded bucket space" because the optimizer state is
fully sharded (e.g. no gather involved) and bucket-centric (the state
follows the internal structure of the Distributed Optimizer buckets)
as opposed to model-centric (typical structure of PyT optimizers)
"""
state = {
"per_bucket_numel": self.per_bucket_numel,
"per_bucket_numel_unpadded": self.per_bucket_numel_unpadded,
}
for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges):
# Iterate grad buffers (by data type).
dtype_state = {}
assert len(gbuf_range_maps) == 1, "single dtype supported, for now."
for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
buckets_state = []
for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
bucket_state = []
for model_param, param_range_map in gbuf_range_map["param_map"].items():
# Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {
"param": main_param,
**optim_state,
"gbuf_local_start": param_range_map["gbuf_local"].start,
"gbuf_local_end": param_range_map["gbuf_local"].end,
}
bucket_state.append(tensors)
buckets_state.append(bucket_state)
dtype_state[dtype] = buckets_state
state[gbuf_idx] = dtype_state
return state
def get_parameter_state_dp_zero(self):
"""Get parameter state (i.e., parameter & optimizer tensors).
This method performs two steps:
- For each DP rank, copy param & optimizer shards to contiguous CPU
buffers (e.g., one buffer each for main_param, exp_avg, and
exp_avg_sq).
- Gather contiguous buffers on DP rank 0 and concatenate to world
buffers.
"""
# Data parallelism variables.
data_parallel_world_size = self.data_parallel_group_gloo.size()
data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo)
data_parallel_group_gloo = self.data_parallel_group_gloo
data_parallel_global_ranks = torch.distributed.get_process_group_ranks(
self.data_parallel_group_gloo
)
# Collect param states.
state = {
"per_bucket_numel": self.per_bucket_numel,
"per_bucket_numel_unpadded": self.per_bucket_numel_unpadded,
}
for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges):
# Iterate grad buffers (by data type).
dtype_state = {}
assert len(gbuf_range_maps) == 1, "single dtype supported, for now."
for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
world_tensors = {}
for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
# Compute local DP contiguous shard's size.
gbuf_world_numel = self.buffers[gbuf_idx].buckets[bucket_idx].grad_data.numel()
assert gbuf_world_numel % data_parallel_world_size == 0
gbuf_local_numel = gbuf_world_numel // data_parallel_world_size
local_shards = {
key: torch.empty((gbuf_local_numel,), dtype=torch.float32, device="cpu")
for key in ("param", "exp_avg", "exp_avg_sq")
}
# Build contiguous DP rank shards (for param + optim states).
for model_param, param_range_map in gbuf_range_map["param_map"].items():
# Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {
"param": main_param,
**optim_state,
}
# Copy states into contiguous shard.
gbuf_local_start = param_range_map["gbuf_local"].start
gbuf_local_end = param_range_map["gbuf_local"].end
for key in local_shards:
local_shards[key][gbuf_local_start:gbuf_local_end].data.copy_(
tensors[key].detach().cpu()
)
# Gather contiguous shards on DP rank 0.
for key, send_tensor in local_shards.items():
# Gather tensor list.
if data_parallel_rank == 0:
recv_tensors = [
torch.empty((gbuf_local_numel,), dtype=torch.float32, device="cpu")
for _ in range(data_parallel_world_size)
]
else:
recv_tensors = None
# Gather.
torch.distributed.gather(
send_tensor,
recv_tensors,
data_parallel_global_ranks[0],
data_parallel_group_gloo,
)
# Concatenate.
if data_parallel_rank == 0:
if key not in world_tensors:
world_tensors[key] = []
world_tensors[key].append(torch.cat(recv_tensors))
# Collect world state.
dtype_state[dtype] = world_tensors
state[gbuf_idx] = dtype_state
return state
def save_parameter_state(self, filename: str):
"""Save the distributed parameter state on DP rank 0.
Args:
filename (str): path to save parameter state to.
"""
state_dict = self.get_parameter_state_dp_zero()
if torch.distributed.get_rank(self.data_parallel_group) == 0:
torch.save(state_dict, filename)
def sharded_state_dict(
self,
model_sharded_state_dict: ShardedStateDict,
is_loading: bool = False,
sharding_type: str = 'fully_sharded_bucket_space',
):
"""
Chooses between 3 param state sharding implementations as requested by `sharding_type`.
Regular state dict parameters are saved on DP rank 0 and loaded on all ranks.
"""
state_dict = {
k: ShardedObject(
f'optimizer.distributed.dp_group_idx_{self.data_parallel_group_idx}.{k}',
v,
(1,),
(0,),
replica_id=torch.distributed.get_rank(self.data_parallel_group),
)
for k, v in self.state_dict().items()
}
if is_loading:
self.init_state_fn(self.optimizer)
if sharding_type == 'fully_sharded_bucket_space':
param_state = self.sharded_param_state_fs_bucket_space(
model_sharded_state_dict, is_loading
)
elif sharding_type == 'dp_zero_gather_scatter':
param_state = self.sharded_param_state_dp_zero(model_sharded_state_dict, is_loading)
elif sharding_type == 'fully_sharded_model_space':
# In this approach the tensors could be directly related to model parameters
# by linking them with metadata from `model_sharded_state_dict`.
# This would allow changing TP and PP while using DistOpt (as with other optimizers).
# This implementation is more involved and left out for now.
raise NotImplementedError(
f'The fully sharded model space version for'
f' {self.__class__.__name__}.sharded_state_dict'
f' not implemented.'
)
else:
raise NotImplementedError(f'Unknown sharding_type: {sharding_type}')
state_dict['param_state'] = param_state
state_dict['param_state_sharding_type'] = sharding_type
return state_dict
def sharded_param_state_dp_zero(
self, model_sharded_state_dict: ShardedStateDict, is_loading: bool = False
):
"""Naive implementation which reuses gather/scatter from the legacy ckpt format.
During saving, gathers the parameters state on DP rank 0 and saves a ShardedObject
with fixed TPxPP structure. During loading, loads the saved data on DP rank 0
(None on other ranks). Relies on the parameters scatter done in load_state_dict.
"""
if is_loading:
param_state_data = None
else:
# Gather on rank 0
param_state_data = self.get_parameter_state_dp_zero()
if torch.distributed.get_rank(self.data_parallel_group) == 0:
# Fixed TPxPP. Save on DP rank 0 only
param_state = ShardedObject(
f'optimizer.distributed.dp_group_idx_{self.data_parallel_group_idx}.param_state',
param_state_data,
(1,),
(0,),
)
else:
# DP ranks > 0 don't save. During loading, the param_state needs to be None.
param_state = LocalNonpersitentObject(None)
return param_state
def sharded_param_state_fs_bucket_space(
self, model_sharded_state_dict: ShardedStateDict, is_loading: bool = False
):
"""Sharded state dict where each noncontiguous buffer is a separate ShardedTensor.
Results in fully parallel save and load without any inter-process
communication or intermediate buffers/copies.
"""
data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group)
data_parallel_world_size = torch.distributed.get_world_size(self.data_parallel_group)
state = self.get_parameter_state_fs_bucket_space()
# per_bucket_numel metadata is saved separately for each TPxPP domain.
for per_bucket_key in ('per_bucket_numel', 'per_bucket_numel_unpadded'):
state[per_bucket_key] = ShardedObject(
f'optimizer.distributed.dp_group_idx_{self.data_parallel_group_idx}.{per_bucket_key}',
state[per_bucket_key],
(1,),
(0,),
replica_id=data_parallel_rank,
)
for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges):
for dtype, gbuf_range_map_for_all_buckets in state[gbuf_idx].items():
for bucket_idx, bucket_state in enumerate(gbuf_range_map_for_all_buckets):
# Compute local DP contiguous shard's size.
gbuf_world_numel = self.buffers[gbuf_idx].buckets[bucket_idx].grad_data.numel()
assert gbuf_world_numel % data_parallel_world_size == 0
gbuf_local_numel = gbuf_world_numel // data_parallel_world_size
sharded_bucket_key = f'optimizer.distributed.dp_group_idx_{self.data_parallel_group_idx}.gbuf_idx_{gbuf_idx}.dtype_{dtype}.bucket_idx_{bucket_idx}'
# The global ckpt tensors must be fully covered.
# We add extra empty padding if necessary
assert bucket_state, 'empty bucket encountered'
if bucket_state[-1]['gbuf_local_end'] != gbuf_local_numel:
assert (
data_parallel_rank == data_parallel_world_size - 1
), 'encountered padding on non-last DP rank'
pad_tensors = {
k: torch.empty(
gbuf_local_numel - bucket_state[-1]['gbuf_local_end'],
dtype=v.dtype,
device=v.device,
)
for k, v in bucket_state[-1].items()
if isinstance(v, torch.Tensor)
}
bucket_state.append(
{
**pad_tensors,
'gbuf_local_start': bucket_state[-1]['gbuf_local_end'],
'gbuf_local_end': gbuf_local_numel,
}
)
# Each tensor is mapped to a slice (`flattened_range`)
# of a DP-local shard of size `gbuf_local_numel`.
for bucket_params_idx in range(len(bucket_state)):
tensors = bucket_state[bucket_params_idx]
gbuf_local_start = tensors.pop('gbuf_local_start')
gbuf_local_end = tensors.pop('gbuf_local_end')
for key in tensors:
assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), (
tensors[key].shape,
gbuf_local_start,
gbuf_local_end,
)
tensors[key] = ShardedTensor(
f'{sharded_bucket_key}.{key}',
tensors[key],
tensors[key].dtype,
(gbuf_local_numel,),
(data_parallel_world_size * gbuf_local_numel,),
(data_parallel_rank * gbuf_local_numel,),
axis_fragmentations=(data_parallel_world_size,),
flattened_range=slice(gbuf_local_start, gbuf_local_end),
allow_shape_mismatch=True,
)
return state
def load_parameter_state_from_fs_bucket_space(self, state_dict):
""" Loads the parameter state from an internal representation.
Inverse of the `get_parameter_state_internal_repr` method.
"""
if state_dict is not None and "per_bucket_numel_unpadded" in state_dict:
per_bucket_numel_unpadded_in_checkpoint = state_dict["per_bucket_numel_unpadded"]
assert self.per_bucket_numel_unpadded == per_bucket_numel_unpadded_in_checkpoint, (
f"Number of unpadded elements in each bucket need to be the same in current run "
f"({self.per_bucket_numel_unpadded}) and checkpoint "
f"({per_bucket_numel_unpadded_in_checkpoint})"
)
for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges):
assert len(gbuf_range_maps) == 1, "single dtype supported, for now."
for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
bucket_state = state_dict[gbuf_idx][dtype][bucket_idx]
# State dict bucket state can be 1 entry longer in case of padding
assert len(bucket_state) in (
len(gbuf_range_map["param_map"]),
len(gbuf_range_map["param_map"]) + 1,
), (len(bucket_state), len(gbuf_range_map["param_map"]))
for src_tensors, (model_param, param_range_map) in zip(
bucket_state, gbuf_range_map["param_map"].items()
):
# Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
dst_tensors = {
"param": main_param,
**optim_state,
}
for key in dst_tensors:
dst_tensors[key].copy_(src_tensors[key])
def load_parameter_state_from_dp_zero(self, state_dict):
"""Load parameter state (i.e., parameter & optimizer tensors) from DP 0 rank.
This method performs the reverse of get_parameter_state_dp_zero():
- Scatter contiguous buffers from DP rank 0 to each DP rank (each DP
rank receives its relevant subset of the world buffers).
- For each DP rank, copy param & optimizer shards from contiguous CPU
buffers. (e.g., one buffer each for main_param, exp_avg, and
exp_avg_sq).
"""
if state_dict is not None and "per_bucket_numel_unpadded" in state_dict:
per_bucket_numel_unpadded_in_checkpoint = state_dict["per_bucket_numel_unpadded"]
assert self.per_bucket_numel_unpadded == per_bucket_numel_unpadded_in_checkpoint, (
f"Number of unpadded elements in each bucket need to be the same in current run "
f"({self.per_bucket_numel_unpadded}) and checkpoint "
f"({per_bucket_numel_unpadded_in_checkpoint})"
)
# Data parallelism variables.
data_parallel_world_size = self.data_parallel_group_gloo.size()
data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo)
data_parallel_group_gloo = self.data_parallel_group_gloo
data_parallel_global_ranks = torch.distributed.get_process_group_ranks(
self.data_parallel_group_gloo
)
# Scatter tensors to all DP ranks.
for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges):
for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
# Compute local DP contiguous shard's size.
gbuf_world_numel = self.buffers[gbuf_idx].buckets[bucket_idx].grad_data.numel()
assert gbuf_world_numel == self.per_bucket_numel[gbuf_idx][dtype][bucket_idx]
assert gbuf_world_numel % data_parallel_world_size == 0
gbuf_local_numel = gbuf_world_numel // data_parallel_world_size
# Contiguous local shards (received from DP rank 0).
local_shards = {
key: torch.empty((gbuf_local_numel,), dtype=torch.float32, device="cpu")
for key in ("param", "exp_avg", "exp_avg_sq")
}
# Scatter local shards from DP rank 0.
for key, recv_tensor in local_shards.items():
# Scatter tensor list.
if data_parallel_rank == 0:
world_tensor_for_all_buckets = state_dict[gbuf_idx][dtype][key]
if not isinstance(world_tensor_for_all_buckets, list):
world_tensor_for_all_buckets = [world_tensor_for_all_buckets]
assert bucket_idx < len(world_tensor_for_all_buckets), (
f"Trying to load state for bucket_id {bucket_idx} (out of "
f"{len(gbuf_range_map_for_all_buckets)} buckets) from checkpoint; "
f"checkpoint only has {len(world_tensor_for_all_buckets)} bucket(s)"
)
# This tensor might be bigger or smaller than expected (depending on
# relative sizes of per_bucket_numel_in_checkpoint and self.per_bucket_numel).
world_tensor = world_tensor_for_all_buckets[bucket_idx]
if "per_bucket_numel" in state_dict:
numel_in_checkpoint = state_dict["per_bucket_numel"][gbuf_idx][
dtype
][bucket_idx]
numel = self.per_bucket_numel[gbuf_idx][dtype][bucket_idx]
numel_unpadded = self.per_bucket_numel_unpadded[gbuf_idx][dtype][
bucket_idx
]
assert world_tensor.numel() == numel_in_checkpoint
assert numel_unpadded <= world_tensor.numel(), (
"True number of elements should be fewer than number of elements in "
"checkpoint tensor"
)
if world_tensor.numel() > numel:
# Truncate extra values, which are padding anyway.
logger.info(
f"Truncating extra values from checkpoint (numel_in_checkpoint={numel_in_checkpoint}, "
f"numel={numel}, numel_unpadded={numel_unpadded})"
)
world_tensor = world_tensor[:numel]
elif world_tensor.numel() < numel:
# In this case, numel > world_tensor.numel() (which is numel_in_checkpoint).
# Create new tensor with right number of values, then copy and use new tensor.
logger.info(
f"Expanding tensor from checkpoint (numel_in_checkpoint={numel_in_checkpoint}, "
f"numel={numel}, numel_unpadded={numel_unpadded})"
)
world_tensor_reshaped = torch.empty(
(numel,),
dtype=world_tensor.dtype,
device=world_tensor.device,
)
world_tensor_reshaped[:numel_in_checkpoint].copy_(world_tensor)
world_tensor = world_tensor_reshaped
else:
logger.info(
"***WARNING*** Using older checkpoint so skipping padding checks"
)
gbuf_start_idxs = list(range(0, gbuf_world_numel, gbuf_local_numel))
send_tensors = [
world_tensor[i : (i + gbuf_local_numel)] for i in gbuf_start_idxs
]
else:
send_tensors = None
# Scatter.
torch.distributed.scatter(
recv_tensor,
send_tensors,
data_parallel_global_ranks[0],
data_parallel_group_gloo,
)
# Copy local contiguous shards to param/optim shards.
for model_param, param_range_map in gbuf_range_map["param_map"].items():
# Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {
"param": main_param,
**optim_state,
}
# Copy states into contiguous shard.
gbuf_local_start = param_range_map["gbuf_local"].start
gbuf_local_end = param_range_map["gbuf_local"].end
for key in local_shards:
tensors[key].data.copy_(
local_shards[key][gbuf_local_start:gbuf_local_end]
)
def load_parameter_state(self, filename: str):
"""Load the distributed parameter state from disk.
Args:
filename (str): path to load parameter state from.
"""
state_dict = None
if torch.distributed.get_rank(self.data_parallel_group) == 0:
state_dict = torch.load(filename)
self.load_parameter_state_from_dp_zero(state_dict)
def zero_grad(self, set_to_none: bool = True):
"""
Zeroes grads for the model related parameters, i.e., model_float16_groups
and model_fp32_groups. We additionally zero the remaining groups as a
memory optimization to reduce fragmentation; in the case of
set_to_none==True, the space used by this field can be safely deallocated.
Args:
set_to_none (bool): if true, set grads to None.
"""
for groups in (
self.model_float16_groups,
self.model_fp32_groups,
self.shard_float16_groups, # grad empty/unused here?
self.shard_fp32_groups, # throws grad-access warning
self.shard_fp32_from_float16_groups,
):
for group in groups:
_zero_grad_group_helper(group, set_to_none)
# If overlapping param all-gather with forward compute, launch all-gather
# for first accessed bucket here before forward compute is initiated.
# The all-gather for the next bucket will be launched in the forward
# pre-hook when this all-gather finishes (to ensure that the communication
# kernels don't head-of-line block the compute kernels since we run with
# CUDA_DEVICE_MAX_CONNECTIONS=1 to support sequence parallelism).
if self.overlap_param_gather:
self._dispatch_gather_model_params(all_gather_handle_index=0)
def _get_model_param_buffer_dp_views(self):
"""
Get shard views of each of the param buffers.
In this nested list, the top level is grouped by the virtual model
index and the buffer's data type. The sub-level is a list of
shards of that buffer, where each shard in the list represents
a contiguous view of the buffer, that is owned by a data-parallel
rank. The shard boundary does not respect parameter boundaries, and
so the elements of some parameters are split across data parallel
ranks.
Additionally, return references to the entire buffers, for use
in _all_gather_base.
"""
# Buffer views.
# Add in reverse order in each model chunk since buckets start from the end of the model but we want
# all-gathers to run first for the start of the model (same order as forward pass).
# We keep the view_items in model chunk order since we want to still first run all_gather and
# all_gather_handle.wait() for the first model chunk.
# In all cases, we want all_gather and all_gather_handle.wait() to be called in the same order,
# and all_gather_handle.wait() needs to be called just before the corresponding forward pass.
view_items = []
for gbuf_index, buffer in enumerate(self.buffers):
view_items_per_model_chunk = []
dtype = self.buffers[gbuf_index].param_dtype
for bucket_index, bucket in enumerate(buffer.buckets):
data_parallel_world_size = torch.distributed.get_world_size(
self.data_parallel_group
)
buf_views = shard_buffer(bucket.param_data, data_parallel_world_size)
view_items_per_model_chunk.insert(
0, (gbuf_index, dtype, bucket_index, bucket.param_data, buf_views)
)
view_items.extend(view_items_per_model_chunk)
return view_items
def _dispatch_gather_model_params(self, all_gather_handle_index: int, force_sync: bool = False):
"""
All-gather updated model params.
When using the distributed optimizer, the params are already laid out in a contiguous
buffer (see mcore/distributed/param_and_grad_buffer.py for details), and so the
all-gather will put the results in the right region of memory.
"""
async_op = self.overlap_param_gather and not force_sync
if self.update_successful:
data_parallel_group = self.data_parallel_group
data_parallel_rank = torch.distributed.get_rank(data_parallel_group)
# All-gather updated main params.
# All param_buf views are guaranteed to have the same number of elements
# across all data-parallel ranks, due to padding done in
# param_and_grad_buffer.py). Thus, all sub-views will have consistent
# start / end indexes across data-parallel ranks.
(gbuf_index, dtype, bucket_index, pbuf, pbuf_views) = self.pbuf_view_items[
all_gather_handle_index
]
assert all_gather_handle_index < len(self.all_gather_handles)
all_gather_handle = torch.distributed._all_gather_base(
pbuf, pbuf_views[data_parallel_rank], group=data_parallel_group, async_op=async_op,
)
self.all_gather_handles[all_gather_handle_index] = all_gather_handle
assert self.all_gather_handle_index_to_bucket_index_map[all_gather_handle_index] == (
gbuf_index,
dtype,
bucket_index,
)
def _make_forward_pre_hook(self):
"""
Create a forward pre-hook to wait on all-gather handles when necessary (i.e.,
when a module uses a parameter in a bucket with a still incomplete all-gather)
and then copy the results from the param_buffer into model_params.
"""
def hook(module, *unused):
assert (
self.overlap_param_gather
), "Should use pre-hook only when overlap_param_gather is True"
# Make sure all parameters in this module have been all-gathered as necessary.
for param in module.parameters(recurse=False):
# Skip parameters that don't require grad.
if not param.requires_grad:
continue
# Some params might be handled in another DistributedOptimizer instance; for
# example, we use separate DistributedOptimizer instances for expert and
# non-expert params.
if param in self.param_to_all_gather_handle_index_map:
all_gather_handle_index = self.param_to_all_gather_handle_index_map[param]
self._finish_param_sync_helper(all_gather_handle_index)
return hook
def finish_param_sync(self, model_index: int, *unused):
"""
Finishes all necessary param syncs for the model_index'th model chunk.
Args:
model_index (int): index of model chunk to synchronize params.
"""
if model_index not in self.model_index_to_all_gather_handle_index_map:
return
all_gather_handle_indices = self.model_index_to_all_gather_handle_index_map[model_index]
for all_gather_handle_index in all_gather_handle_indices:
self._finish_param_sync_helper(all_gather_handle_index)
def _finish_param_sync_helper(self, all_gather_handle_index: int):
"""
Waits on all_gather_handle if necessary, then dispatches the next all-gather
as necessary.
"""
# First check if there is an outstanding all-gather handle for this param.
# If so, wait on the handle to ensure the communication is finished.
assert all_gather_handle_index < len(self.all_gather_handles)
all_gather_handle = self.all_gather_handles[all_gather_handle_index]
if all_gather_handle is not None:
all_gather_handle.wait()
self.all_gather_handles[all_gather_handle_index] = None
# Launch the all-gather for the next bucket now.
# We can't pre-launch all-gathers for all buckets at once since we don't
# want to head-of-line block the compute kernels with communication kernels
# (since we run with CUDA_DEVICE_MAX_CONNECTIONS=1 to support sequence
# parallelism).
next_all_gather_handle_index = all_gather_handle_index + 1
if next_all_gather_handle_index < self.num_all_gather_handles:
self._dispatch_gather_model_params(next_all_gather_handle_index)
def _collect_main_grad_data_for_unscaling(self):
"""
Note: this should be equivalent to the float-16 optimizer's method,
but written differently, so the two should be combined.
"""
return [
param.grad.data for group in self.optimizer.param_groups for param in group["params"]
]
def _get_model_and_main_params_data_float16(self):
"""
Get aligned list of model and main params.
"""
model_data = []
main_data = []
for model_group, main_group in zip(
self.shard_float16_groups, self.shard_fp32_from_float16_groups
):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
main_data.append(main_param.data)
return model_data, main_data
def _copy_model_grads_to_main_grads(self):
"""
Copy model grads to main grads.
Since this step follows a reduce-scatter through the DDP's grad
buffer, this method is responsible for copying the updated grads
from the grad buffer to the main shard's grad field.
"""
# Utility method for copying group grads.
def copy_group_grads(model_groups, shard_main_groups):
for model_group, shard_main_group in zip(model_groups, shard_main_groups):
for model_param, shard_main_param in zip(model_group, shard_main_group):
param_range_map = self._get_model_param_range_map(model_param)
param_range = param_range_map["param"]
assert param_range.size == shard_main_param.nelement()
model_grad = model_param.main_grad
shard_model_grad = model_grad.view(-1)[param_range.start : param_range.end]
shard_main_param.grad = shard_model_grad.float()
# Copy model groups to shard groups.
copy_group_grads(self.model_float16_groups, self.shard_fp32_from_float16_groups)
copy_group_grads(self.model_fp32_groups, self.shard_fp32_groups)
def _copy_main_params_to_model_params(self):
"""
Copy main params to model params.
Since this step is followed by an all-gather through the DDP's grad
buffer, this method is responsible for copying the updated params
from the main shards into the correct position in the grad buffer.
"""
# Utility method for copying group params.
def copy_group_params(shard_main_groups, model_groups):
for shard_main_group, model_group in zip(shard_main_groups, model_groups):
for shard_main_param, model_param in zip(shard_main_group, model_group):
param_range_map = self._get_model_param_range_map(model_param)
world_range = param_range_map["gbuf_world_in_bucket"]
assert world_range.size == shard_main_param.nelement()
gbuf_index, _, bucket_id = self.model_param_gbuf_map[model_param]
model_param_buffer = self.buffers[gbuf_index].buckets[bucket_id].param_data
shard_model_param = model_param_buffer.view(-1)[
world_range.start : world_range.end
]
shard_model_param.data.copy_(shard_main_param)
# Copy shard groups to model groups.
copy_group_params(self.shard_fp32_from_float16_groups, self.model_float16_groups)
copy_group_params(self.shard_fp32_groups, self.model_fp32_groups)
def _copy_model_params_to_main_params(self):
"""
Copy model params to main params.
During finetuning, this method is used to reload the main params from
the model params. This copy does not make use of the grad buffer as
an intermediary.
"""
# Utility method for copying group params.
def copy_group_params(model_groups, shard_main_groups):
for model_group, shard_main_group in zip(model_groups, shard_main_groups):
for model_param, shard_main_param in zip(model_group, shard_main_group):
param_range_map = self._get_model_param_range_map(model_param)
param_range = param_range_map["param"]
assert param_range.size == shard_main_param.nelement()
shard_model_param = model_param.view(-1)[param_range.start : param_range.end]
shard_main_param.data.copy_(shard_model_param)
# Copy model groups to shard groups.
copy_group_params(self.model_float16_groups, self.shard_fp32_from_float16_groups)
copy_group_params(self.model_fp32_groups, self.shard_fp32_groups)
def _reset_metadata_and_sync_gather_all_model_params(self, force_sync: bool):
"""
Reset metadata needed to track results of all-gathers.
"""
self.all_gather_handles = [None for _ in range(len(self.all_gather_handles))]
# Launch synchronous all-gather if --overlap-param-gather is turned on or if force_sync
# is explicitly set to True (e.g., if we are going to turn off all-gather overlapping for
# validation / test iterations).
if not self.overlap_param_gather or force_sync:
for all_gather_handle_index in range(self.num_all_gather_handles):
self._dispatch_gather_model_params(all_gather_handle_index, force_sync=force_sync)
@torch.no_grad()
def step(self):
"""
Step optimizer.
Under the hood, either launch synchronous param all-gathers or get ready to launch
asynchorous all-gathers that get overlapped with the next forward pass.
"""
self.update_successful, grad_norm, num_zeros_in_grad = super().step()
timers = self.config.timers
if timers is not None:
timers('params-all-gather', log_level=1).start(barrier=self.config.barrier_with_L1_time)
# If not overlapping all-gather for parameters, launch synchronous all-gather
# communication calls here. If overlapping all-gather for parameters, the following
# call to _gather_all_model_params is a no-op: the first all-gather is launched
# asynchronously in the next optimizer.zero_grad() call and subsequent all-gathers
# are launched in the forward pre-hook.
self._reset_metadata_and_sync_gather_all_model_params(force_sync=False)
if timers is not None:
timers('params-all-gather').stop()
return self.update_successful, grad_norm, num_zeros_in_grad
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Megatron grad scaler."""
from abc import ABC, abstractmethod
from typing import Dict
import torch
class MegatronGradScaler(ABC):
def __init__(self, initial_scale: float):
"""Initialize scale value with the input initial scale."""
assert initial_scale > 0.0
self._scale = torch.tensor([initial_scale], dtype=torch.float, device='cuda')
@property
def scale(self):
return self._scale
@property
def inv_scale(self):
return self._scale.double().reciprocal().float()
@abstractmethod
def update(self, found_inf: bool):
pass
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict: Dict):
pass
class ConstantGradScaler(MegatronGradScaler):
"""
Constant grad scaler (loss scale is never adjusted regardless of NaNs seen in gradients).
"""
def update(self, found_inf: bool):
pass
def state_dict(self):
return dict()
def load_state_dict(self, state_dict):
pass
class DynamicGradScaler(MegatronGradScaler):
"""
Grad scaler with dynamic scale that gets adjusted during training.
Reduces loss scale by `backoff_factor` if `hysteresis` number of NaNs are seen in a row. Increases
loss scale by `growth_factor` if NaNs are not seen for `growth_interval` iterations.
"""
def __init__(
self,
initial_scale: float,
min_scale: float,
growth_factor: float,
backoff_factor: float,
growth_interval: int,
hysteresis: int,
):
"""
Grad scaler with dynamic scale that gets adjusted during training.
Args:
initial_scale (float): Initial loss scale value.
min_scale (float): Minimum loss scale value.
growth_factor (float): Factor to grow loss scale by if NaNs are not seen in `growth_interval`
training iterations. Must be greater than 1.
backoff_factor (float): Factor to decrease loss scale by if NaNs are seen in `hysteresis`
consecutive training iterations. Must be between 0 and 1.
growth_interval (int): Number of training iterations of no NaNs before loss scale is increased.
hysteresis (int): Number of training iterations of consecutive NaNs before loss scale is decreased.
"""
super(DynamicGradScaler, self).__init__(initial_scale)
# Lower bound on the scale.
assert min_scale > 0.0
assert min_scale <= initial_scale
self.min_scale = torch.tensor([min_scale], dtype=torch.float, device='cuda')
# Growth and backoff factors for the scale.
assert growth_factor > 1.0
self.growth_factor = torch.tensor([growth_factor], dtype=torch.float, device='cuda')
assert backoff_factor < 1.0
assert backoff_factor > 0.0
self.backoff_factor = torch.tensor([backoff_factor], dtype=torch.float, device='cuda')
# Interval over which if we don't see any inf/nan,
# we will scale the grad scale by the growth factor.
assert growth_interval > 0
self.growth_interval = growth_interval
# Number of inf/nans we should see before scaling down
# the grad scale by the backoff factor.
assert hysteresis > 0
self.hysteresis = hysteresis
# Trackers.
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
def update(self, found_inf: bool):
"""
Updates internal state in grad scaler based on whether NaNs are seen in grads or not.
"""
# If we have an inf/nan, growth tracker is set to 0
# and hysterisis tracker is reduced by 1.
if found_inf:
self._growth_tracker = 0
self._hysteresis_tracker -= 1
# Now if we are out of hysteresis count, scale down the loss.
if self._hysteresis_tracker <= 0:
self._scale = torch.max(self._scale * self.backoff_factor, self.min_scale)
else:
# If there is no nan/inf, increment the growth tracker.
self._growth_tracker += 1
# If we have had enough consequitive intervals with no nan/inf:
if self._growth_tracker == self.growth_interval:
# Reset the tracker and hysteresis trackers,
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
# and scale up the loss scale.
self._scale = self._scale * self.growth_factor
def state_dict(self):
state_dict = {}
state_dict['scale'] = self._scale
state_dict['growth_tracker'] = self._growth_tracker
state_dict['hysteresis_tracker'] = self._hysteresis_tracker
return state_dict
def load_state_dict(self, state_dict: Dict):
self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
self._growth_tracker = state_dict['growth_tracker']
self._hysteresis_tracker = state_dict['hysteresis_tracker']
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Megatron optimizer."""
import math
from abc import ABC, abstractmethod
from itertools import chain
from logging import getLogger
from typing import Callable, List, Optional
import amp_C
import torch
from apex.multi_tensor_apply import multi_tensor_applier
from .. import parallel_state, tensor_parallel
from ..dist_checkpointing.mapping import ShardedStateDict
from ..dist_checkpointing.optimizer import (
get_param_id_to_sharded_param_map,
make_sharded_optimizer_tensor,
optim_state_to_sharding_state,
)
from ..dist_checkpointing.utils import add_prefix_for_sharding
from ..transformer.module import param_is_not_shared
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
from .grad_scaler import MegatronGradScaler
from .optimizer_config import OptimizerConfig
logger = getLogger(__name__)
def _zero_grad_group_helper(group: List[torch.nn.Parameter], set_to_none: bool):
"""
Zero out the gradient for a group of parameters.
Note: copied from torch.optim.optimizer.
"""
for param in group:
if param.grad is not None:
if set_to_none:
param.grad = None
else:
if param.grad.grad_fn is not None:
param.grad.detach_()
else:
param.grad.requires_grad_(False)
param.grad.zero_()
def _multi_tensor_copy_this_to_that(
this: List[torch.Tensor], that: List[torch.Tensor], overflow_buf: Optional[torch.Tensor] = None
):
"""
Use multi-tensor-applier to copy values from one list to another.
We don't have a bfloat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible
with bfloat16.
"""
if overflow_buf:
overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale, overflow_buf, [this, that], 1.0)
else:
for this_, that_ in zip(this, that):
that_.copy_(this_)
class MegatronOptimizer(ABC):
"""
Base class for all Megatron optimizers.
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
config: OptimizerConfig,
init_state_fn: Callable = lambda x: None,
):
"""Input optimizer is the base optimizer (e.g., Adam)."""
self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.'
self.config = config
self.init_state_fn = init_state_fn
def get_parameters(self) -> List[torch.nn.Parameter]:
"""
Get list of parameters wrapped in optimizer.
"""
params = []
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
params.append(param)
return params
def get_main_grads_for_grad_norm(self) -> List[torch.Tensor]:
"""
Get main_grads that should be taken into account to compute the grad norm.
Filter parameters based on:
- grad should not be None.
- parameter should not be shared (i.e., grads shouldn't be double counted while
computing norms).
- should not be a replica due to tensor model parallelism.
"""
params = self.get_parameters()
grads_for_norm = []
for param in params:
grad = param.grad
grad_not_none = grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = tensor_parallel.param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad)
return grads_for_norm
def get_model_parallel_group(self) -> torch.distributed.ProcessGroup:
"""Default returned here, but the distributed optimizer overrides this."""
return parallel_state.get_model_parallel_group()
def clip_grad_norm(self, clip_grad: float) -> float:
"""Compute grad norm."""
params = self.get_parameters()
grads_for_norm = self.get_main_grads_for_grad_norm()
return clip_grad_norm_fp32(
params, grads_for_norm, clip_grad, model_parallel_group=self.get_model_parallel_group(),
)
def count_zeros(self) -> float:
"""Count number of zeros in model's gradients."""
params = self.get_parameters()
return count_zeros_fp32(params, model_parallel_group=self.get_model_parallel_group())
@abstractmethod
def zero_grad(self, set_to_none: bool = True):
pass
@abstractmethod
def get_loss_scale(self) -> torch.Tensor:
"""
Get current loss scale factor.
NOTE: The output should be a CUDA tensor of size 1.
"""
pass
def scale_loss(self, loss: torch.Tensor) -> torch.Tensor:
"""Simple scaling."""
return self.get_loss_scale() * loss
def finish_param_sync(self, model_index: int):
"""
Finish parameter synchronization for all optimizers.
This is a no-op for all non-distributed optimizers.
"""
pass
@abstractmethod
def reload_model_params(self):
"""Refreshes any internal state from the current model parameters.
Call whenever the parameters are changed outside of the optimizer.
For example, when we load a model from a checkpoint without loading
the optimizer, the model parameters are updated but for fp16 optimizer
with main parameters, the main parameters need to also be updated."""
pass
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict):
pass
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
@abstractmethod
def step(self):
"""Step the optimizer."""
pass
@abstractmethod
def sharded_state_dict(
self, model_sharded_state_dict: ShardedStateDict, is_loading: bool = False
) -> ShardedStateDict:
""" Builds sharded state dict for the optimizer, based on model's sharded state dict.
Args:
model_sharded_state_dict (ShardedStateDict): sharded state dict of the model
is_loading (bool, optional): flag indicating whether the state dict will be used to save or load the optimizer state.
Defaults to False.
Returns: optimizer sharded state dict
"""
class MixedPrecisionOptimizer(MegatronOptimizer):
"""Base class for both the float-16 and the distributed optimizer.
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
grad_scaler (MegatronGradScaler): used for scaling gradients. Note that
this can be None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constant gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
config: OptimizerConfig,
grad_scaler: Optional[MegatronGradScaler],
init_state_fn: Callable,
):
super().__init__(
optimizer, config, init_state_fn,
)
self.grad_scaler = grad_scaler
# None grad scaler is only supported for bf16.
if self.grad_scaler is None:
assert not self.config.fp16, 'fp16 expects a grad scaler.'
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
# Note that we keep this for the cases that grad scaler is none.
# We still record nan/inf if we have a bfloat16 with a grad scaler.
if self.grad_scaler:
self.found_inf = torch.tensor([0.0], dtype=torch.float, device='cuda')
# Dummy tensor needed for apex multi-apply tensor.
# For bfloat, we don't have multi-tensor apply and for now
# we set it to none so the multi-tensor apply gets ignored.
if self.config.bf16:
self._dummy_overflow_buf = None
else:
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
# In case grad scaler is not passed, define the unity scale.
if self.grad_scaler is None:
self._scale_one = torch.tensor([1.0], dtype=torch.float, device='cuda')
def get_loss_scale(self):
if self.grad_scaler is None:
return self._scale_one
return self.grad_scaler.scale
def reload_model_params(self):
self._copy_model_params_to_main_params()
def _unscale_main_grads_and_check_for_nan(self):
# Collect main grads.
main_grads = self._collect_main_grad_data_for_unscaling()
# Reset found inf.
self.found_inf.fill_(0.0)
# Unscale and set found inf/nan
torch._amp_foreach_non_finite_check_and_unscale_(
main_grads, self.found_inf, self.grad_scaler.inv_scale
)
# Update across all model parallel instances.
torch.distributed.all_reduce(
self.found_inf, op=torch.distributed.ReduceOp.MAX, group=self.get_model_parallel_group()
)
# Check for nan.
found_inf_flag = self.found_inf.item() > 0
return found_inf_flag
@torch.no_grad()
def step(self):
timers = self.config.timers
# Copy gradients from model params to main params.
if timers is not None:
timers('optimizer-copy-to-main-grad', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
self._copy_model_grads_to_main_grads()
if timers is not None:
timers('optimizer-copy-to-main-grad').stop()
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if self.grad_scaler:
# Unscale and check for inf/nan.
if timers is not None:
timers('optimizer-unscale-and-check-inf', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
if timers is not None:
timers('optimizer-unscale-and-check-inf').stop()
# We are done with scaling gradients
# so we can update the loss scale.
self.grad_scaler.update(found_inf_flag)
# If we found inf/nan, skip the update.
if found_inf_flag:
return False, None, None
# Clip the main gradients.
if timers is not None:
timers('optimizer-clip-main-grad', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
grad_norm = None
if self.config.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.config.clip_grad)
if timers is not None:
timers('optimizer-clip-main-grad').stop()
# Count the zeros in the grads.
if timers is not None:
timers('optimizer-count-zeros', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
num_zeros_in_grad = self.count_zeros() if self.config.log_num_zeros_in_grad else None
if timers is not None:
timers('optimizer-count-zeros').stop()
# Step the optimizer.
if timers is not None:
timers('optimizer-inner-step', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
self.optimizer.step()
if timers is not None:
timers('optimizer-inner-step').stop()
# Update params from main params.
if timers is not None:
timers('optimizer-copy-main-to-model-params', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
self._copy_main_params_to_model_params()
if timers is not None:
timers('optimizer-copy-main-to-model-params').stop()
# Successful update.
return True, grad_norm, num_zeros_in_grad
class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
"""Float16 optimizer for fp16 and bf16 data types.
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
grad_scaler (MegatronGradScaler): used for scaling gradients. Note that
this can be None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constant gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
config: OptimizerConfig,
grad_scaler: MegatronGradScaler,
init_state_fn: Callable,
):
super().__init__(
optimizer, config, grad_scaler, init_state_fn,
)
# Handle main parameters.
# Three groups of parameters:
# float16_groups: original float16 parameters
# fp32_from_float16_groups: fp32 copy of float16 parameters
# fp32_from_fp32_groups: original fp32 parameters
self.float16_groups = []
self.fp32_from_float16_groups = []
self.fp32_from_fp32_groups = []
# For all the groups in the original optimizer:
for param_group in self.optimizer.param_groups:
float16_params_this_group = []
fp32_params_this_group = []
fp32_from_float16_params_this_group = []
# For all the parameters in this group:
for i, param in enumerate(param_group['params']):
if param.requires_grad:
# float16 params:
if param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
float16_params_this_group.append(param)
# Create a copy
main_param = param.detach().clone().float()
# Copy tensor model parallel attributes.
tensor_parallel.copy_tensor_model_parallel_attributes(main_param, param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param.
if param in self.optimizer.state:
self.optimizer.state[main_param] = self.optimizer.state.pop(param)
# fp32 params.
elif param.type() == 'torch.cuda.FloatTensor':
fp32_params_this_group.append(param)
param_group['params'][i] = param
else:
raise TypeError(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'.format(param.type())
)
self.float16_groups.append(float16_params_this_group)
self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups. We additionally zero
fp32_from_float16_groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point."""
for group in self.float16_groups:
_zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_float16_groups:
_zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_fp32_groups:
_zero_grad_group_helper(group, set_to_none)
def _collect_main_grad_data_for_unscaling(self):
main_grads = []
# fp32 params from float16 ones.
for main_group in self.fp32_from_float16_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
# Append fp32 parameters.
for main_group in self.fp32_from_fp32_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
return main_grads
def _get_model_and_main_params_data_float16(self):
model_data = []
main_data = []
for model_group, main_group in zip(self.float16_groups, self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
main_data.append(main_param.data)
return model_data, main_data
def _copy_model_grads_to_main_grads(self):
# This only needs to be done for the float16 group.
for model_group, main_group in zip(self.float16_groups, self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
if hasattr(model_param, 'main_grad'):
main_param.grad = model_param.main_grad.float()
else:
if model_param.grad is not None:
main_param.grad = model_param.grad.float()
# Safe to deallocate model's grad/main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
model_param.grad = None
# For fp32 grads, we need to reset the grads to main grad.
for model_group in self.fp32_from_fp32_groups:
for model_param in model_group:
model_param.grad = model_param.main_grad
def _copy_main_params_to_model_params(self):
# Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(
this=main_data, that=model_data, overflow_buf=self._dummy_overflow_buf
)
def _copy_model_params_to_main_params(self):
# Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(
this=model_data, that=main_data, overflow_buf=self._dummy_overflow_buf
)
def state_dict(self):
state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict()
if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
return state_dict
def sharded_state_dict(
self, model_sharded_state_dict: ShardedStateDict, is_loading: bool = False
):
if is_loading:
self.init_state_fn(self.optimizer)
state_dict = self.state_dict()
id_to_sharded_param_map = get_param_id_to_sharded_param_map(
model_sharded_state_dict, chain.from_iterable(g for g in self.float16_groups)
)
# Convert fp32_from_fp16_params
assert len(state_dict['fp32_from_fp16_params']) == len(
state_dict['optimizer']['param_groups']
)
state_dict['fp32_from_fp16_params'] = [
[
make_sharded_optimizer_tensor(
id_to_sharded_param_map[param_id],
fp32_param,
prefix=f'optimizer.state.fp32_param',
)
for param_id, fp32_param in zip(state_group['params'], fp32_group)
]
for fp32_group, state_group in zip(
state_dict['fp32_from_fp16_params'], state_dict['optimizer']['param_groups']
)
]
# Convert regular optimizer state
optim_state_to_sharding_state(state_dict['optimizer'], id_to_sharded_param_map)
return state_dict
def load_state_dict(self, state_dict):
# Optimizer.
optimizer_key = 'optimizer'
if optimizer_key not in state_dict:
optimizer_key = 'optimizer_state_dict'
logger.info('***WARNING*** loading optimizer from ' 'an old checkpoint ...')
self.optimizer.load_state_dict(state_dict[optimizer_key])
# Grad scaler.
if 'grad_scaler' not in state_dict:
if self.config.fp16:
logger.info(
'***WARNING*** found an old checkpoint, will not ' 'load grad scaler ...'
)
else:
if self.grad_scaler:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
else:
logger.info(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
# Copy data for the main params.
fp32_from_float16_params_key = 'fp32_from_fp16_params'
if fp32_from_float16_params_key not in state_dict:
fp32_from_float16_params_key = 'fp32_from_fp16'
for current_group, saved_group in zip(
self.fp32_from_float16_groups, state_dict[fp32_from_float16_params_key]
):
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)
class FP32Optimizer(MegatronOptimizer):
"""Float32 optimizer.
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
"""
def __init__(
self, optimizer: torch.optim.Optimizer, config: OptimizerConfig, init_state_fn: Callable,
):
super(FP32Optimizer, self).__init__(
optimizer, config, init_state_fn,
)
self._scale = torch.tensor([1.0], dtype=torch.float, device='cuda')
def zero_grad(self, set_to_none=True):
"""Copied from torch.optim.optimizer"""
for group in self.optimizer.param_groups:
_zero_grad_group_helper(group['params'], set_to_none)
def get_loss_scale(self):
"""FP32 optimizer does not do any scaling."""
return self._scale
@torch.no_grad()
def step(self):
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
timers = self.config.timers
# Copy main_grads to grads.
if timers is not None:
timers('optimizer-copy-to-main-grad', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
param.grad = param.main_grad
if timers is not None:
timers('optimizer-copy-to-main-grad').stop()
# Clip gradients.
if timers is not None:
timers('optimizer-clip-main-grad', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
grad_norm = None
if self.config.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.config.clip_grad)
if timers is not None:
timers('optimizer-clip-main-grad').stop()
# Count the zeros in the grads.
if timers is not None:
timers('optimizer-count-zeros', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
num_zeros_in_grad = self.count_zeros() if self.config.log_num_zeros_in_grad else None
if timers is not None:
timers('optimizer-count-zeros').stop()
# Update parameters.
if timers is not None:
timers('optimizer-inner-step', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
self.optimizer.step()
if timers is not None:
timers('optimizer-inner-step').stop()
# No overflow for FP32 optimizer.
return True, grad_norm, num_zeros_in_grad
def reload_model_params(self):
pass
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict)
class ChainedOptimizer(MegatronOptimizer):
"""ChainedOptimizer is designed for a collection of optimizers.
These optimizers are responsible for different parts of multiple models for
a training task and will be executed one-by-one when the model is updated.
Args:
chained_optimizers: a list of optimizers.
"""
# Remove these attributes which inherits from MegatronOptimizer.
state = None
param_groups = None
def __init__(self, chained_optimizers: List[MegatronOptimizer]):
self.chained_optimizers = chained_optimizers
self.param_groups = []
for optimizer in self.chained_optimizers:
self.param_groups += optimizer.param_groups
def zero_grad(self, set_to_none=True):
for optimizer in self.chained_optimizers:
optimizer.zero_grad(set_to_none)
def get_loss_scale(self):
return self.chained_optimizers[0].get_loss_scale()
def reload_model_params(self):
for optimizer in self.chained_optimizers:
optimizer.reload_model_params()
def state_dict(self):
return [optimizer.state_dict() for optimizer in self.chained_optimizers]
def sharded_state_dict(
self, model_sharded_state_dict: ShardedStateDict, is_loading: bool = False, **kwargs
):
sharded_state_dict = {}
for optimizer_idx, optimizer in enumerate(self.chained_optimizers):
optim_state_dict = optimizer.sharded_state_dict(
model_sharded_state_dict, is_loading, **kwargs
)
add_prefix_for_sharding(optim_state_dict, f'chained_{optimizer_idx}.')
sharded_state_dict[optimizer_idx] = optim_state_dict
return sharded_state_dict
def load_state_dict(self, state_dict):
if len(self.chained_optimizers) != len(state_dict):
raise RuntimeError(
f'Expected {len(self.chained_optimizers)} entries'
f' in state dict, but got {len(state_dict)}.'
)
if isinstance(state_dict, dict):
state_dict = (v for k, v in sorted(state_dict.items()))
for optimizer, state in zip(self.chained_optimizers, state_dict):
optimizer.load_state_dict(state)
# Reset param_groups as load_state_dict reset chained optimizers's attribute.
self.param_groups = []
for optimizer in self.chained_optimizers:
self.param_groups += optimizer.param_groups
def disable_pre_hook(self):
if not self.config.use_distributed_optimizer or not self.config.overlap_param_gather:
raise ValueError(
"disable_pre_hook should only be called with 'use_distributed_optimizer' "
"and 'overlap_param_gather' are both enabled."
)
for optimizer in self.chained_optimizers:
optimizer.disable_pre_hook()
def enable_pre_hook(self):
if not self.config.use_distributed_optimizer or not self.config.overlap_param_gather:
raise ValueError(
"enable_pre_hook should only be called with 'use_distributed_optimizer' "
"and 'overlap_param_gather' are both enabled."
)
for optimizer in self.chained_optimizers:
optimizer.enable_pre_hook()
def step(self):
"""ChainedOptimizer will step all optimizers one by one.
"""
update_successful, grad_norm, num_zeros_in_grad = True, 0, 0
grad_norms = []
for optimizer in self.chained_optimizers:
_update_successful, _grad_norm, _num_zeros_in_grad = optimizer.step()
update_successful &= _update_successful
grad_norms += [_grad_norm if _grad_norm else 0.0]
num_zeros_in_grad += _num_zeros_in_grad if _num_zeros_in_grad else 0
grad_norm = math.sqrt(sum([x ** 2 for x in grad_norms]))
return update_successful, grad_norm, num_zeros_in_grad
def save_parameter_state(self, filename: str):
"""Save the distributed parameter states of all optimizers to a file.
Args:
filename (str): path to save parameter state to.
"""
save_states = False
states = []
for optimizer in self.chained_optimizers:
if hasattr(optimizer, 'get_parameter_state_dp_zero'):
state_dict = optimizer.get_parameter_state_dp_zero()
# Save checkpoint economically, only when DP rank = 0, state dict
# needs to be saved.
if torch.distributed.get_rank(optimizer.data_parallel_group) == 0:
states.append(state_dict)
save_states = True
else:
states.append(None)
else:
states.append(None)
if save_states:
torch.save(states, filename)
def load_parameter_state(self, filename: str):
"""Load the distributed parameter states of all optimizers from a file.
Args:
filename (str): path to load parameter state from.
"""
states = None
for idx, optimizer in enumerate(self.chained_optimizers):
if not hasattr(optimizer, 'load_parameter_state_from_dp_zero'):
continue
# Lazy loading checkpoint, state dict is needed only when DP rank = 0.
if torch.distributed.get_rank(optimizer.data_parallel_group) == 0 and states is None:
states = torch.load(filename)
state_dict = states[idx] if states else None
optimizer.load_parameter_state_from_dp_zero(state_dict)
def finish_param_sync(self, model_index: int):
"""Finish parameter synchronization for all optimizers.
"""
for optimizer in self.chained_optimizers:
optimizer.finish_param_sync(model_index)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Callable, Optional
import torch
@dataclass
class OptimizerConfig:
"""Configuration for optimizer."""
##############
# General
##############
optimizer: str = 'adam'
"""Optimizer to use (one of Adam or SGD)."""
lr: Optional[float] = None
"""Initial learning rate. Depending on decay style and initial warmup, the learning rate at each
iteration would be different.
"""
min_lr: Optional[float] = None
"""Minumum value for learning rate. The scheduler clip values below this threshold."""
decoupled_lr: Optional[float] = None
"""Separate learning rate for the input and output layer."""
decoupled_min_lr: Optional[float] = None
"""Minimum value for learning rate for the input and output layer. The scheduler clip values
below this threshold.
"""
weight_decay: float = 0.01
"""Weight decay coefficient for L2 regularization."""
##############
# Precision
##############
fp16: bool = False
"""If true, train with fp16 mixed precision training. Defaults to False."""
bf16: bool = False
"""If true, train with bf16 mixed precision training. Defaults to False."""
params_dtype: torch.dtype = torch.float32
"""dtype used when intializing the weights. Defaults to torch.float32."""
###############
# Loss scaling
###############
loss_scale: Optional[float] = None
"""Static loss scaling, positive power of 2 values can improve fp16 convergence. If None,
dynamic loss scaling is used.
"""
initial_loss_scale: float = 2 ** 32
"""Initial loss-scale for dynamic loss scaling."""
min_loss_scale: float = 1.0
"""Minimum loss scale for dynamic loss scaling."""
loss_scale_window: float = 1000
"""Window over which to raise/lower dynamic scale."""
hysteresis: int = 2
"""Hysteresis for dynamic loss scaling."""
##############
# Optimizer
##############
# Adam
adam_beta1: float = 0.9
"""First coefficient for computing running averages of gradient and its square in Adam
optimizer.
"""
adam_beta2: float = 0.999
"""Second coefficient for computing running averages of gradient and its square in Adam
optimizer.
"""
adam_eps: float = 1e-08
"""Term added to the denominator to improve numerical stability in Adam optimizer."""
# SGD.
sgd_momentum: float = 0.9
"""Momentum factor for SGD optimizer."""
#######################
# Distributed optimizer
#######################
use_distributed_optimizer: bool = False
"""Distribute optimizer state over data-parallel replicas."""
overlap_grad_reduce: bool = False
"""If true, overlap grad reduce-scatter with backward compute in distributed optimizer."""
overlap_param_gather: bool = False
"""If true, overlap param all-gather with forward compute in distributed optimizer."""
################
# Miscellaneous
################
clip_grad: float = 1.0
"""Gradient clipping based on global L2 norm."""
log_num_zeros_in_grad: bool = False
"""If true, calculate and log the number of zeros in gradient."""
barrier_with_L1_time: bool = False
"""If true, use barrier with level 1 time measurements."""
timers: Callable = None
"""Function to get timers."""
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
MAJOR = 0
MINOR = 7
PATCH = 0
PRE_RELEASE = 'b0'
# Use the following formatting: (major, minor, patch, pre-release)
VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE)
__shortversion__ = '.'.join(map(str, VERSION[:3]))
__version__ = '.'.join(map(str, VERSION[:3])) + ''.join(VERSION[3:])
__package_name__ = 'megatron_core'
__contact_names__ = 'NVIDIA'
__contact_emails__ = 'nemo-toolkit@nvidia.com' # use NeMo Email
__homepage__ = (
'https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/' # use NeMo homepage
)
__repository_url__ = 'https://github.com/NVIDIA/Megatron-LM/megatron/core'
__download_url__ = 'https://github.com/NVIDIA/Megatron-LM/releases'
__description__ = (
'Megatron Core - a library for efficient and scalable training of transformer based models'
)
__license__ = 'BSD-3'
__keywords__ = (
'deep learning, machine learning, gpu, NLP, NLU, language, transformer, nvidia, pytorch, torch'
)
from dataclasses import dataclass
from torch import Tensor
@dataclass
class PackedSeqParams:
# parameters to TEDotProductAttention and fused rope kernels for the `thd` (packed) sequence format,
qkv_format: str = None
cu_seqlens_q: Tensor = None
cu_seqlens_kv: Tensor = None
max_seqlen_q: Tensor = None
max_seqlen_kv: Tensor = None
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model and data parallel groups."""
import os
import warnings
from datetime import timedelta
from typing import List, Optional
import torch
from .utils import GlobalMemoryBuffer
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
# Position embedding group.
_POSITION_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP_GLOO = None
# tensor model parallel group and data parallel group combined
# used for fp8 and moe training
_TENSOR_AND_DATA_PARALLEL_GROUP = None
# Expert parallel group that the current rank belongs to.
_EXPERT_MODEL_PARALLEL_GROUP = None
_TENSOR_AND_EXPERT_PARALLEL_GROUP = None
_DATA_MODULO_EXPERT_PARALLEL_GROUP = None
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
_MPU_EXPERT_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS = None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None
# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None
# A list of global ranks for each tensor model parallel group to ease calculation of
# the first local rank in the tensor model parallel group
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = None
# Context parallel group that the current rank belongs to
_CONTEXT_PARALLEL_GROUP = None
# A list of global ranks for each context parallel group to ease calculation of the
# destination rank when exchanging KV/dKV between context parallel_ranks
_CONTEXT_PARALLEL_GLOBAL_RANKS = None
# Data parallel group information with context parallel combined.
_DATA_PARALLEL_GROUP_WITH_CP = None
_DATA_PARALLEL_GROUP_WITH_CP_GLOO = None
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None
# combined parallel group of TP, DP, and CP used for fp8
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None
# MOE logging
_MOE_AUX_LOSSES_LOGGING_TRACKER = {}
def get_nccl_options(pg_name, nccl_comm_cfgs):
"""Set the NCCL process group options.
Args:
pg_name (str): process group name
nccl_comm_cfgs (dict): nccl communicator configurations
When an option (e.g., max_ctas) is not found in the config, use the NCCL default setting.
"""
if pg_name in nccl_comm_cfgs:
nccl_options = torch.distributed.ProcessGroupNCCL.Options()
nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get('cga_cluster_size', 4)
nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get('max_ctas', 32)
nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get('min_ctas', 1)
return nccl_options
else:
return None
def generate_masked_orthogonal_rank_groups(
world_size: int, parallel_size: List[int], mask: List[bool],
) -> List[List[int]]:
"""Generate orthogonal parallel groups based on the parallel size and mask.
Arguments:
world_size (int): world size
parallel_size (List[int]):
The parallel size of each orthogonal parallel type. For example, if
tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4,
and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4].
mask (List[bool]):
The mask controls which parallel methods the generated groups represent. If mask[i] is
True, it means the generated group contains the i-th parallelism method. For example,
if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then
the generated group is the `tp-dp` group, if the mask = [False, True, False], then the
generated group is the `pp` group.
Algorithm:
For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and
local_rank satisfy the following equation:
global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1)
tp_rank \in [0, tp_size)
dp_rank \in [0, dp_size)
pp_rank \in [0, pp_size)
If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each.
For example, if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the
dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].)
The tp_rank and pp_rank will be combined to form the `dp_group_index`.
dp_group_index = tp_rank + pp_rank * tp_size (2)
So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in
range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the
equation (1).
This function solve this math problem.
For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4],
and the mask = [False, True, False]. Then,
dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2
dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2
...
dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2
dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4]
dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5]
...
dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23]
"""
def prefix_product(a: List[int], init=1) -> List[int]:
r = [init]
for v in a:
init = init * v
r.append(init)
return r
def inner_product(a: List[int], b: List[int]) -> int:
return sum([x * y for x, y in zip(a, b)])
def decompose(index, shape, stride=None):
'''
This function solve the math problem below:
There is an equation:
index = sum(idx[i] * stride[i])
And given the value of index, stride.
Return the idx.
This function will used to get the pp/dp/pp_rank
from group_index and rank_in_group.
'''
if stride is None:
stride = prefix_product(shape)
idx = [(index // d) % s for s, d in zip(shape, stride)]
# stride is a prefix_product result. And the value of stride[-1]
# is not used.
assert (
sum([x * y for x, y in zip(idx, stride[:-1])]) == index
), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx)
return idx
masked_shape = [s for s, m in zip(parallel_size, mask) if m]
unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m]
global_stride = prefix_product(parallel_size)
masked_stride = [d for d, m in zip(global_stride, mask) if m]
unmasked_stride = [d for d, m in zip(global_stride, mask) if not m]
group_size = prefix_product(masked_shape)[-1]
num_of_group = world_size // group_size
ranks = []
for group_index in range(num_of_group):
# get indices from unmaksed for group_index.
decomposed_group_idx = decompose(group_index, unmasked_shape)
rank = []
for rank_in_group in range(group_size):
# get indices from masked for rank_in_group.
decomposed_rank_idx = decompose(rank_in_group, masked_shape)
rank.append(
inner_product(decomposed_rank_idx, masked_stride)
+ inner_product(decomposed_group_idx, unmasked_stride)
)
ranks.append(rank)
return ranks
class RankGenerator(object):
def __init__(self, tp: int, ep: int, dp: int, pp: int, cp: int, order: str) -> None:
self.tp = tp
self.ep = ep
self.dp = dp
self.pp = pp
self.cp = cp
self.world_size = tp * dp * pp * cp
self.name_to_size = {
"tp": self.tp,
"pp": self.pp,
"dp": self.dp,
"ep": self.ep,
"cp": self.cp,
}
self.order = order
order = order.lower()
if 'ep' in order:
if 'ep-dp' not in order and 'dp-ep' not in order:
raise RuntimeError(f"The ep and dp must be adjacent in order ({self.order}).")
for name in self.name_to_size.keys():
if name not in order and self.name_to_size[name] != 1:
raise RuntimeError(
f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})."
)
elif name not in order:
order = order + '-' + name
self.order_w_ep = order
self.order_wo_ep = '-'.join([token for token in order.split('-') if token != 'ep'])
self.ordered_size_wo_ep = []
self.ordered_size_w_ep = []
for token in order.split('-'):
if token == 'dp':
self.ordered_size_w_ep.append(self.dp // self.ep)
self.ordered_size_wo_ep.append(self.dp)
elif token == 'ep':
self.ordered_size_w_ep.append(self.ep)
else:
self.ordered_size_w_ep.append(self.name_to_size[token])
self.ordered_size_wo_ep.append(self.name_to_size[token])
def get_mask(self, order: str, token: str):
ordered_token = order.split('-')
token = token.split('-')
mask = [False] * len(ordered_token)
for t in token:
mask[ordered_token.index(t)] = True
return mask
def get_ranks(self, token, independent_ep=False):
'''Get rank group by input token.
Arguments:
token (str):
Specify the ranks type that want to get. If we want
to obtain multiple parallel types, we can use a hyphen
'-' to separate them. For example, if we want to obtain
the TP_DP group, the token should be 'tp-dp'.
independent_ep (bool: True):
This flag controls whether we treat EP and DP independently.
EP shares ranks with DP, if we want to get ranks related to
EP, we should set the flag. For example, get_ranks('dp', True)
will get DP modulo EP group, and get_ranks('dp', False) will
get full DP group.
'''
if independent_ep:
parallel_size = self.ordered_size_w_ep
order = self.order_w_ep
else:
parallel_size = self.ordered_size_wo_ep
order = self.order_wo_ep
mask = self.get_mask(order, token)
ranks = generate_masked_orthogonal_rank_groups(self.world_size, parallel_size, mask)
return ranks
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_split_rank: Optional[int] = None,
use_sharp: bool = False,
context_parallel_size: int = 1,
expert_model_parallel_size: int = 1,
nccl_communicator_config_path: Optional[str] = None,
distributed_timeout_minutes: int = 30,
order: str = "tp-cp-ep-dp-pp",
) -> None:
"""Initialize model data parallel groups.
Args:
tensor_model_parallel_size (int, default = 1):
The number of GPUs to split individual tensors across.
pipeline_model_parallel_size (int, default = 1):
The number of tensor parallel GPU groups to split the
Transformer layers across. For example, if
tensor_model_parallel_size is 4 and
pipeline_model_parallel_size is 2, the model will be split
into 2 groups of 4 GPUs.
virtual_pipeline_model_parallel_size (int, optional):
The number of stages that each pipeline group will have,
interleaving as necessary. If None, no interleaving is
performed. For example, if tensor_model_parallel_size is 1,
pipeline_model_parallel_size is 4,
virtual_pipeline_model_parallel_size is 2, and there are
16 transformer layers in the model, the model will be
split into 8 stages with two layers each and each GPU
would get 2 stages as such (layer number starting with 1):
GPU 0: [1, 2] [9, 10]
GPU 1: [3, 4] [11, 12]
GPU 2: [5, 6] [13, 14]
GPU 3: [7, 8] [15, 16]
pipeline_model_parallel_split_rank (int, optional):
For models with both an encoder and decoder, the rank in
pipeline to switch between encoder and decoder (i.e. the
first rank of the decoder). This allows the user to set
the pipeline parallel size of the encoder and decoder
independently. For example, if
pipeline_model_parallel_size is 8 and
pipeline_model_parallel_split_rank is 3, then ranks 0-2
will be the encoder and ranks 3-7 will be the decoder.
use_sharp (bool, default = False):
Set the use of SHARP for the collective communications of
data-parallel process groups. When `True`, run barrier
within each data-parallel process group, which specifies
the SHARP application target groups.
context_parallel_size (int, default = 1):
The number of tensor parallel GPU groups to split the
network input sequence length across. Compute of attention
module requires tokens of full sequence length, so GPUs
in a context parallel group need to communicate with each
other to exchange information of other sequence chunks.
Each GPU and its counterparts in other tensor parallel
groups compose a context parallel group.
For example, assume we have 8 GPUs, if tensor model parallel
size is 4 and context parallel size is 2, the network input
will be split into two sequence chunks, which are processed
by 2 different groups of 4 GPUs. One chunk is processed by
GPU0-3, the other chunk is processed by GPU4-7. Four groups
are build to do context parallel communications: [GPU0, GPU4],
[GPU1, GPU5], [GPU2, GPU6], and [GPU3, GPU7].
Context parallelism partitions sequence length, so it has no
impact on weights, which means weights are duplicated among
GPUs in a context parallel group. Hence, weight gradients
all-reduce is required in backward. For simplicity, we piggyback
GPUs of context parallelism on data parallel group for
weight gradient all-reduce.
expert_model_parallel_size (int, default = 1):
The number of Mixture of Experts parallel GPUs in each expert
parallel group.
nccl_communicator_config_path (str, default = None):
Path to the yaml file of NCCL communicator configurations.
`min_ctas`, `max_ctas`, and `cga_cluster_size` can be set
for each communicator.
distributed_timeout_minutes (int, default = 30): Timeout, in
minutes,for operations executed against distributed
process groups. See PyTorch documentation at
https://pytorch.org/docs/stable/distributed.html for
caveats.
order (str, default=tp-dp-pp):
The rank initialization order of parallelism. Now we support
tp-dp-pp and tp-pp-dp orders.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
and 8 data-parallel groups as:
8 data_parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4 pipeline model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
if (
world_size
% (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size)
!= 0
):
raise RuntimeError(
f"world_size ({world_size}) is not divisible by tensor_model_parallel_size "
f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size}) "
f"x context_parallel_size ({context_parallel_size})"
)
data_parallel_size: int = world_size // (
tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
)
if data_parallel_size % expert_model_parallel_size != 0:
raise RuntimeError(
f"data_parallel_size ({data_parallel_size}) is not divisible by expert_model_parallel_size "
)
if expert_model_parallel_size > 1 and context_parallel_size > 1:
raise RuntimeError(
f"combination of expert model prallellism and context parallelism is not supported"
)
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
if virtual_pipeline_model_parallel_size is not None:
if not pipeline_model_parallel_size > 2:
raise RuntimeError(
"pipeline-model-parallel size should be greater than 2 with interleaved schedule"
)
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
if pipeline_model_parallel_split_rank is not None:
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
rank = torch.distributed.get_rank()
nccl_comm_cfgs = {}
if nccl_communicator_config_path is not None:
try:
import yaml
except ImportError:
raise RuntimeError(
"Cannot import `yaml`. Setting custom nccl communicator configs "
"requires the yaml package."
)
with open(nccl_communicator_config_path, "r") as stream:
nccl_comm_cfgs = yaml.safe_load(stream)
rank_generator = RankGenerator(
tp=tensor_model_parallel_size,
ep=expert_model_parallel_size,
dp=data_parallel_size,
pp=pipeline_model_parallel_size,
cp=context_parallel_size,
order=order,
)
timeout = timedelta(minutes=distributed_timeout_minutes)
# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GROUP_GLOO
global _DATA_PARALLEL_GLOBAL_RANKS
global _DATA_PARALLEL_GROUP_WITH_CP
global _DATA_PARALLEL_GROUP_WITH_CP_GLOO
global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
for ranks in rank_generator.get_ranks('dp'):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('dp', nccl_comm_cfgs)
)
group_gloo = torch.distributed.new_group(ranks, timeout=timeout, backend="gloo")
if rank in ranks:
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GROUP_GLOO = group_gloo
_DATA_PARALLEL_GLOBAL_RANKS = ranks
for ranks_with_cp in rank_generator.get_ranks('dp-cp'):
group_with_cp = torch.distributed.new_group(
ranks_with_cp, timeout=timeout, pg_options=get_nccl_options('dp_cp', nccl_comm_cfgs)
)
group_with_cp_gloo = torch.distributed.new_group(
ranks_with_cp, timeout=timeout, backend="gloo"
)
if rank in ranks_with_cp:
_DATA_PARALLEL_GROUP_WITH_CP = group_with_cp
_DATA_PARALLEL_GROUP_WITH_CP_GLOO = group_with_cp_gloo
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks_with_cp
# Apply SHARP to DP process groups
if use_sharp:
if rank == 0:
print(
"The number of process groups to use SHARP with depends on the type "
"of the network switch. Nvidia QM1 switch supports SAHRP up to 8 "
"process groups and QM2 supports up to 256 process groups. We apply "
"SHARP to the communications of the data-parallel domain. If the "
"number of data-parallel process groups is larger than the max "
"process groups that the network switch supports, the communication "
"will fall back to non-SHARP operators. To enable SHARP, "
"`#SBATCH_NETWORK=sharp` should be set in the sbatch script."
)
torch.distributed.barrier(
group=get_data_parallel_group(with_context_parallel=True),
device_ids=[torch.cuda.current_device()],
)
# Set `NCCL_COLLNET_ENABLE=0` to restrict SHARP application to DP process groups
os.environ["NCCL_COLLNET_ENABLE"] = "0"
# Build the context-parallel groups.
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_GLOBAL_RANKS
assert _CONTEXT_PARALLEL_GROUP is None, 'context parallel group is already initialized'
for ranks in rank_generator.get_ranks('cp'):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('cp', nccl_comm_cfgs)
)
if rank in ranks:
_CONTEXT_PARALLEL_GROUP = group
_CONTEXT_PARALLEL_GLOBAL_RANKS = ranks
# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
for ranks in rank_generator.get_ranks('tp-pp'):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('mp', nccl_comm_cfgs)
)
if rank in ranks:
_MODEL_PARALLEL_GROUP = group
# Build the tensor model-parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP
global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
assert (
_TENSOR_MODEL_PARALLEL_GROUP is None
), 'tensor model parallel group is already initialized'
for ranks in rank_generator.get_ranks('tp'):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('tp', nccl_comm_cfgs)
)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = ranks
# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
assert (
_PIPELINE_MODEL_PARALLEL_GROUP is None
), 'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP
global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'
global _POSITION_EMBEDDING_GROUP
global _POSITION_EMBEDDING_GLOBAL_RANKS
assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'
for ranks in rank_generator.get_ranks('pp'):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('pp', nccl_comm_cfgs)
)
if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
# Setup embedding group (to exchange gradients between
# first and last stages).
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
position_embedding_ranks = [ranks[0]]
if pipeline_model_parallel_split_rank is not None:
if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
embedding_ranks = [
ranks[0],
ranks[pipeline_model_parallel_split_rank],
ranks[-1],
]
if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]
else:
embedding_ranks = ranks
position_embedding_ranks = ranks
group = torch.distributed.new_group(
embedding_ranks, timeout=timeout, pg_options=get_nccl_options('embd', nccl_comm_cfgs)
)
if rank in embedding_ranks:
_EMBEDDING_GROUP = group
if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
group = torch.distributed.new_group(
position_embedding_ranks,
timeout=timeout,
pg_options=get_nccl_options('embd', nccl_comm_cfgs),
)
if rank in position_embedding_ranks:
_POSITION_EMBEDDING_GROUP = group
if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
# Build the tensor + data parallel groups.
global _TENSOR_AND_DATA_PARALLEL_GROUP
global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
assert (
_TENSOR_AND_DATA_PARALLEL_GROUP is None
), 'Tensor + data parallel group is already initialized'
for ranks in rank_generator.get_ranks('tp-dp-cp'):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('tp_dp_cp', nccl_comm_cfgs)
)
if rank in ranks:
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = group
for ranks in rank_generator.get_ranks('tp-dp'):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('tp_dp', nccl_comm_cfgs)
)
if rank in ranks:
_TENSOR_AND_DATA_PARALLEL_GROUP = group
# Build the tensor + expert parallel groups
global _EXPERT_MODEL_PARALLEL_GROUP
assert _EXPERT_MODEL_PARALLEL_GROUP is None, 'Expert parallel group is already initialized'
global _TENSOR_AND_EXPERT_PARALLEL_GROUP
assert (
_TENSOR_AND_EXPERT_PARALLEL_GROUP is None
), 'Tensor + expert parallel group is already initialized'
global _DATA_MODULO_EXPERT_PARALLEL_GROUP
assert (
_DATA_MODULO_EXPERT_PARALLEL_GROUP is None
), 'Data modulo expert group is already initialized'
global _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
for ranks in rank_generator.get_ranks('tp-ep', independent_ep=True):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('tp_exp', nccl_comm_cfgs)
)
if rank in ranks:
_TENSOR_AND_EXPERT_PARALLEL_GROUP = group
for ranks in rank_generator.get_ranks('ep', independent_ep=True):
group = torch.distributed.new_group(
ranks, pg_options=get_nccl_options('exp', nccl_comm_cfgs)
)
if rank in ranks:
_EXPERT_MODEL_PARALLEL_GROUP = group
for ranks in rank_generator.get_ranks('dp', independent_ep=True):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('dp_modulo_exp', nccl_comm_cfgs)
)
group_gloo = torch.distributed.new_group(ranks, backend="gloo")
if rank in ranks:
_DATA_MODULO_EXPERT_PARALLEL_GROUP = group
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = group_gloo
# Initialize global memory buffer
# This isn't really "parallel state" but there isn't another good place to
# put this. If we end up with a more generic initialization of megatron-core
# we could stick it there
_set_global_memory_buffer()
def is_initialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is not None
def is_unitialized() -> bool:
"""Check if parallel state has been initialized
Deprecated. Use is_initialized instead.
"""
warnings.warn(
"is_unitialized is deprecated, use is_initialized instead", DeprecationWarning,
)
return not is_initialized()
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if (
_TENSOR_MODEL_PARALLEL_GROUP is None
or _PIPELINE_MODEL_PARALLEL_GROUP is None
or _DATA_PARALLEL_GROUP is None
):
return False
return True
def get_model_parallel_group():
"""Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, 'model parallel group is not initialized'
return _MODEL_PARALLEL_GROUP
def get_tensor_model_parallel_group(check_initialized=True):
"""Get the tensor model parallel group the caller rank belongs to."""
if check_initialized:
assert (
_TENSOR_MODEL_PARALLEL_GROUP is not None
), 'tensor model parallel group is not initialized'
return _TENSOR_MODEL_PARALLEL_GROUP
def get_pipeline_model_parallel_group():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert (
_PIPELINE_MODEL_PARALLEL_GROUP is not None
), 'pipeline_model parallel group is not initialized'
return _PIPELINE_MODEL_PARALLEL_GROUP
def get_data_parallel_group(with_context_parallel=False):
"""Get the data parallel group the caller rank belongs to."""
if with_context_parallel:
assert (
_DATA_PARALLEL_GROUP_WITH_CP is not None
), 'data parallel group with context parallel combined is not initialized'
return _DATA_PARALLEL_GROUP_WITH_CP
else:
assert _DATA_PARALLEL_GROUP is not None, 'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP
def get_data_parallel_group_gloo(with_context_parallel=False):
"""Get the data parallel group-gloo the caller rank belongs to."""
if with_context_parallel:
assert (
_DATA_PARALLEL_GROUP_WITH_CP_GLOO is not None
), 'data parallel group-gloo with context parallel combined is not initialized'
return _DATA_PARALLEL_GROUP_WITH_CP_GLOO
else:
assert _DATA_PARALLEL_GROUP_GLOO is not None, 'data parallel group-gloo is not initialized'
return _DATA_PARALLEL_GROUP_GLOO
def get_context_parallel_group(check_initialized=True):
"""Get the context parallel group the caller rank belongs to."""
if check_initialized:
assert _CONTEXT_PARALLEL_GROUP is not None, 'context parallel group is not initialized'
return _CONTEXT_PARALLEL_GROUP
def get_context_parallel_global_ranks(check_initialized=True):
"""Get all global ranks of the context parallel group that the caller rank belongs to."""
if check_initialized:
assert (
_CONTEXT_PARALLEL_GLOBAL_RANKS is not None
), 'context parallel group is not initialized'
return _CONTEXT_PARALLEL_GLOBAL_RANKS
def get_embedding_group():
"""Get the embedding group the caller rank belongs to."""
assert _EMBEDDING_GROUP is not None, 'embedding group is not initialized'
return _EMBEDDING_GROUP
def get_position_embedding_group():
"""Get the position embedding group the caller rank belongs to."""
assert _POSITION_EMBEDDING_GROUP is not None, 'position embedding group is not initialized'
return _POSITION_EMBEDDING_GROUP
def get_amax_reduction_group(with_context_parallel=False):
"""Get the FP8 amax reduction group the caller rank belongs to."""
if with_context_parallel:
assert (
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP is not None
), 'FP8 amax reduction group is not initialized'
return _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
else:
assert (
_TENSOR_AND_DATA_PARALLEL_GROUP is not None
), 'FP8 amax reduction group is not initialized'
return _TENSOR_AND_DATA_PARALLEL_GROUP
def get_tensor_and_data_parallel_group(with_context_parallel=False):
"""Get the tensor and data parallel group the caller rank belongs to."""
if with_context_parallel:
assert (
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP is not None
), 'tensor and data parallel group is not initialized'
return _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
else:
assert (
_TENSOR_AND_DATA_PARALLEL_GROUP is not None
), 'tensor and data parallel group is not initialized'
return _TENSOR_AND_DATA_PARALLEL_GROUP
def get_expert_model_parallel_group():
assert (
_EXPERT_MODEL_PARALLEL_GROUP is not None
), 'expert model parallel group is not initialized'
return _EXPERT_MODEL_PARALLEL_GROUP
def get_tensor_and_expert_parallel_group():
assert (
_TENSOR_AND_EXPERT_PARALLEL_GROUP is not None
), 'tensor and expert parallel group is not initialized'
return _TENSOR_AND_EXPERT_PARALLEL_GROUP
def get_data_modulo_expert_parallel_group():
assert (
_DATA_MODULO_EXPERT_PARALLEL_GROUP is not None
), 'data modulo expert parallel group is not initialized'
return _DATA_MODULO_EXPERT_PARALLEL_GROUP
def get_data_modulo_expert_parallel_group_gloo():
assert (
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO is not None
), 'data modulo expert parallel group-gloo is not initialized'
return _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
def set_expert_model_parallel_world_size(world_size):
global _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = world_size
def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
def set_pipeline_model_parallel_world_size(world_size):
"""Set the pipeline model parallel size"""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
def set_virtual_pipeline_model_parallel_world_size(world_size):
"""Set the pipeline model parallel size"""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
def get_pipeline_model_parallel_world_size():
"""Return world size for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
def set_expert_model_parallel_rank(rank):
"""Set expert model parallel rank."""
global _MPU_EXPERT_MODEL_PARALLEL_RANK
_MPU_EXPERT_MODEL_PARALLEL_RANK = rank
def set_tensor_model_parallel_rank(rank):
"""Set tensor model parallel rank."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = rank
def set_pipeline_model_parallel_rank(rank):
"""Set pipeline model parallel rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
def set_pipeline_model_parallel_split_rank(rank):
"""Set pipeline model parallel split rank."""
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
return _MPU_TENSOR_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
def get_pipeline_model_parallel_rank():
"""Return my rank for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def get_pipeline_model_parallel_split_rank():
"""Return pipeline model parallel split rank."""
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
if (
get_virtual_pipeline_model_parallel_world_size() is not None
and get_virtual_pipeline_model_parallel_rank() != 0
):
return False
return get_pipeline_model_parallel_rank() == 0
def is_pipeline_last_stage(ignore_virtual=False):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
virtual_pipeline_model_parallel_world_size = (
get_virtual_pipeline_model_parallel_world_size()
)
if virtual_pipeline_model_parallel_world_size is not None and get_virtual_pipeline_model_parallel_rank() != (
virtual_pipeline_model_parallel_world_size - 1
):
return False
return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1)
def is_rank_in_embedding_group(ignore_virtual=False):
"""Return true if current rank is in embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _EMBEDDING_GLOBAL_RANKS
if ignore_virtual:
return rank in _EMBEDDING_GLOBAL_RANKS
if rank in _EMBEDDING_GLOBAL_RANKS:
if rank == _EMBEDDING_GLOBAL_RANKS[0]:
return is_pipeline_first_stage(ignore_virtual=False)
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
return is_pipeline_last_stage(ignore_virtual=False)
else:
return True
return False
def is_rank_in_position_embedding_group():
"""Return true if current rank is in position embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _POSITION_EMBEDDING_GLOBAL_RANKS
return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_after_split(rank=None):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_at_split():
"""Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
decoder."""
rank = get_pipeline_model_parallel_rank()
return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(rank + 1)
def get_virtual_pipeline_model_parallel_rank():
"""Return the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
def set_virtual_pipeline_model_parallel_rank(rank):
"""Set the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_virtual_pipeline_model_parallel_world_size():
"""Return the virtual pipeline-parallel world size."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
assert (
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS is not None
), "Tensor model parallel group is not initialized"
return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS[0]
def get_data_parallel_src_rank(with_context_parallel=False):
"""Calculate the global rank corresponding to the first local rank
in the data parallel group."""
if with_context_parallel:
assert (
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP is not None
), "Data parallel group with context parallel combined is not initialized"
return _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP[0]
else:
assert _DATA_PARALLEL_GLOBAL_RANKS is not None, "Data parallel group is not initialized"
return _DATA_PARALLEL_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_first_rank():
"""Return the global rank of the first process in the pipeline for the
current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank():
"""Return the global rank of the last process in the pipeline for the
current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank():
"""Return the global rank that follows the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
def get_pipeline_model_parallel_prev_rank():
"""Return the global rank that preceeds the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
def get_data_parallel_world_size(with_context_parallel=False):
"""Return world size for the data parallel group."""
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_world_size(
group=get_data_parallel_group(with_context_parallel=with_context_parallel)
)
else:
return 0
def get_data_parallel_rank(with_context_parallel=False):
"""Return my rank for the data parallel group."""
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_rank(
group=get_data_parallel_group(with_context_parallel=with_context_parallel)
)
else:
return 0
def get_context_parallel_world_size():
"""Return world size for the context parallel group."""
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_world_size(group=get_context_parallel_group())
else:
return 0
def get_context_parallel_rank():
"""Return my rank for the context parallel group."""
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_rank(group=get_context_parallel_group())
else:
return 0
def get_expert_model_parallel_world_size():
"""Return world size for the expert model parallel group"""
if _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE:
return _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
if torch.distributed.is_available() and torch.distributed.is_initialized():
tensor_and_expert_parallel_world_size = torch.distributed.get_world_size(
group=get_tensor_and_expert_parallel_group()
)
return tensor_and_expert_parallel_world_size // get_tensor_model_parallel_world_size()
else:
return 0
def get_tensor_and_expert_parallel_world_size():
"""Return world size for the expert model parallel group times model parallel group.
Currently, each expert will also be distributed across TP group by default.
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
tensor_and_expert_parallel_world_size = torch.distributed.get_world_size(
group=get_tensor_and_expert_parallel_group()
)
return tensor_and_expert_parallel_world_size
else:
return 0
def get_expert_model_parallel_rank():
"""Return my rank for the expert parallel group"""
if _MPU_EXPERT_MODEL_PARALLEL_RANK:
return _MPU_EXPERT_MODEL_PARALLEL_RANK
if torch.distributed.is_available() and torch.distributed.is_initialized():
tensor_and_expert_parallel_rank = torch.distributed.get_rank(
group=get_tensor_and_expert_parallel_group()
)
return tensor_and_expert_parallel_rank // get_tensor_model_parallel_world_size()
else:
return 0
def get_data_modulo_expert_parallel_rank():
"""Return my rank for the context parallel group."""
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_rank(group=get_data_modulo_expert_parallel_group())
else:
return 0
def _set_global_memory_buffer():
"""Initialize global buffer"""
global _GLOBAL_MEMORY_BUFFER
assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized'
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
def get_global_memory_buffer():
"""Return the global GlobalMemoryBuffer object"""
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
return _GLOBAL_MEMORY_BUFFER
def destroy_global_memory_buffer():
"""Sets the global memory buffer to None"""
global _GLOBAL_MEMORY_BUFFER
_GLOBAL_MEMORY_BUFFER = None
def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None
global _TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP_WITH_CP
_DATA_PARALLEL_GROUP_WITH_CP = None
global _CONTEXT_PARALLEL_GROUP
_CONTEXT_PARALLEL_GROUP = None
global _CONTEXT_PARALLEL_GLOBAL_RANKS
_CONTEXT_PARALLEL_GLOBAL_RANKS = None
global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None
global _POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP = None
global _TENSOR_AND_DATA_PARALLEL_GROUP
_TENSOR_AND_DATA_PARALLEL_GROUP = None
global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None
global _EXPERT_MODEL_PARALLEL_GROUP
_EXPERT_MODEL_PARALLEL_GROUP = None
global _TENSOR_AND_EXPERT_PARALLEL_GROUP
_TENSOR_AND_EXPERT_PARALLEL_GROUP = None
global _DATA_MODULO_EXPERT_PARALLEL_GROUP
_DATA_MODULO_EXPERT_PARALLEL_GROUP = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
global _GLOBAL_MEMORY_BUFFER
_GLOBAL_MEMORY_BUFFER = None
global _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_EXPERT_MODEL_PARALLEL_RANK
_MPU_EXPERT_MODEL_PARALLEL_RANK = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment