Commit 688448db authored by silencealiang's avatar silencealiang
Browse files

更新代码

parent a02a5490
Pipeline #2503 passed with stage
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from .module import HuggingFaceModule, build_hf_model
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from transformers import AutoModel
from megatron.core.models.huggingface import HuggingFaceModule
class ClipHuggingFaceModel(HuggingFaceModule):
"""
Wrapper for CLIP HuggingFace models
"""
def __init__(self, config):
super().__init__(config)
self.model = AutoModel.from_pretrained(config.huggingface_model_name_or_path)
def forward(self, *args, **kwargs):
"""Forward function"""
x = self.model(*args, **kwargs)
x = x['last_hidden_state']
return x
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from transformers import AutoConfig, AutoModel
from megatron.core.transformer.module import MegatronModule
class HuggingFaceModule(MegatronModule):
"""
Basic module for huggingface
"""
def __init__(self, config):
super().__init__(config=config)
def set_input_tensor(self, input_tensor):
"""Dummy function for set_input_tensor"""
self.input_tensor = input_tensor
class AutoHuggingFaceModel(HuggingFaceModule):
"""
Wrapper for HuggingFace AutoModel
"""
def __init__(self, config):
super().__init__(config)
self.model = AutoModel.from_pretrained(config.huggingface_model_name_or_path)
def forward(self, *args, **kwargs):
"""Forward function"""
return self.model(*args, **kwargs)
def build_hf_model(config):
"""Builds huggingface wrapper model given config"""
hf_config = AutoConfig.from_pretrained(config.huggingface_model_name_or_path)
if "qwen" in hf_config.model_type:
from megatron.core.models.huggingface.qwen_model import QwenHuggingFaceModel
model = QwenHuggingFaceModel(config)
elif "vit" in hf_config.model_type:
from megatron.core.models.huggingface.clip_model import ClipHuggingFaceModel
model = ClipHuggingFaceModel(config)
else:
raise NotImplementedError(f"Huggingface model type {hf_config.model_type} is not supported")
return model
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from transformers.models.qwen2 import Qwen2ForCausalLM
from megatron.core.models.huggingface import HuggingFaceModule
class QwenHuggingFaceModel(HuggingFaceModule):
"""
Wrapper for Qwen LM HuggingFace models
"""
def __init__(self, config):
super().__init__(config)
self.model = Qwen2ForCausalLM.from_pretrained(config.huggingface_model_name_or_path)
def forward(self, *args, **kwargs):
"""Forward function"""
combined_embeddings = kwargs['decoder_input'].permute(1, 0, 2)
x = self.model(
position_ids=None, # TODO: I guess we're just assuming no custom pos ids
attention_mask=kwargs['attention_mask'],
inputs_embeds=combined_embeddings,
labels=kwargs['labels'],
)
if kwargs['labels'] is not None:
x = x["loss"]
else:
x = x["logits"]
return x
def embedding(self, input_ids, position_ids=None):
"""Function to run process tokens with input embeddings"""
return self.model.get_input_embeddings()(input_ids).transpose(1, 0).contiguous()
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
"""Multimodal Sequence Parallel (SP) and Context Parallel (CP) functionality."""
import torch
from megatron.core.packed_seq_params import PackedSeqParams
def get_padding(
seq_len, cp_size, tp_size, has_sp, decoder_tp_comm_overlap=False, decoder_seq_len=None
):
"""Calculate padding needed for SP and/or CP.
Args:
seq_len (int): Model sequence length.
cp_size (int): Context parallel size.
tp_size (int): Tensor parallel size.
has_sp (bool): Model uses sequence parallelism.
decoder_tp_comm_overlap (bool): Decoder (LLM) uses tensor parallel communication overlap.
decoder_seq_len (int): Decoder (LLM) maximum sequence length.
Returns:
padding (int): Padding needed given model configuration.
"""
padding = 0
# TP Comm overlap is performed with combined text+image embeddings.
if has_sp and decoder_tp_comm_overlap:
# If TP Comm Overlap is enabled for combined text+image embedding in LM backbone,
# user needs to provide decoder_seq_len with any potential padding needed for SP+CP
assert (
decoder_seq_len is not None
), "Please provide decoder seq length when using TP comm overlap for LM backbone"
padding = decoder_seq_len - seq_len
elif has_sp or cp_size > 1:
padding_factor = 1
if has_sp and cp_size > 1:
# Padding to multiple of tp_size * cp_size * 2 when using CP + SP.
padding_factor = tp_size * cp_size * 2
elif cp_size > 1:
padding_factor = cp_size * 2
elif has_sp:
padding_factor = tp_size
padding = int((seq_len + padding_factor - 1) // padding_factor * padding_factor) - seq_len
return padding
def get_packed_seq_params(tokens, img_seq_len, padding_needed, cp_size, use_packed_sequence=False):
"""Get PackedSeqParams for CP.
Args:
tokens (torch.Tensor): [batch, seq_len] input tokens.
img_seq_len (int): Image sequence length.
padding_needed (int): Padding to add.
cp_size (int): Context parallel size.
use_packed_sequence (bool): Uses sequence packing.
Returns:
packed_seq_params (PackedSeqParams): Parameters to be sent to Transformer Engine.
"""
batch_size = tokens.shape[0]
# Calculate the valid token seq len that LM backbone should compute on
combined_valid_seqlen = tokens.shape[1] + img_seq_len - padding_needed
cu_seqlens = torch.arange(
0,
(batch_size + 1) * (combined_valid_seqlen),
step=(combined_valid_seqlen),
dtype=torch.int32,
device=tokens.device,
)
# Calculate the total padded token seq len
combined_padded_seqlen = tokens.shape[1] + img_seq_len
cu_seqlens_padded = None
qkv_format = 'sbhd'
if cp_size > 1 and (padding_needed > 0 or use_packed_sequence):
# Provide cu_seqlens_<q/kv>_padded for CP support
cu_seqlens_padded = torch.arange(
0,
(batch_size + 1) * (combined_padded_seqlen),
step=(combined_padded_seqlen),
dtype=torch.int32,
device=tokens.device,
)
# CP with padding mask type requires THD format
qkv_format = 'thd'
packed_seq_params = PackedSeqParams(
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
cu_seqlens_q_padded=cu_seqlens_padded,
cu_seqlens_kv_padded=cu_seqlens_padded,
max_seqlen_q=combined_padded_seqlen,
max_seqlen_kv=combined_padded_seqlen,
qkv_format=qkv_format,
)
return packed_seq_params
......@@ -11,8 +11,9 @@ from megatron.core.config_logger import has_config_logger_enabled, log_config_to
from megatron.core.models.gpt import GPTModel
from megatron.core.models.vision.clip_vit_model import CLIPViTModel, get_num_image_embeddings
from megatron.core.models.vision.multimodal_projector import MultimodalProjector
from megatron.core.models.vision.radio import RADIOViTModel
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.parallel_state import get_context_parallel_group, get_context_parallel_world_size
from megatron.core.parallel_state import get_context_parallel_rank, get_context_parallel_world_size
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
......@@ -20,12 +21,17 @@ from megatron.core.utils import log_single_rank
try:
import transformer_engine # pylint: disable=unused-import
from transformer_engine.pytorch.distributed import gather_along_first_dim
from megatron.core.extensions.transformer_engine import TEDotProductAttention
from megatron.core.utils import is_te_min_version
HAVE_TE = True
try:
import transformer_engine_torch as tex
HAVE_TEX = True
except:
HAVE_TEX = False
except:
HAVE_TE = False
if get_context_parallel_world_size() > 1:
......@@ -39,6 +45,52 @@ IMAGE_TOKEN = "<image>"
VIDEO_TOKEN = "<video>"
class _get_data_on_this_cp_rank(torch.autograd.Function):
"""Performs sharding for Context Parallelism in THD format
In the forward pass, indices are selected for each CP rank and remaining tokens are dropped.
In the backward pass, this class takes care of managing gradients for dropped tokens on each
CP rank.
"""
@staticmethod
def forward(ctx, batch, packed_seq_params):
"""Context Parallelism forward support for THD format"""
cp_size = get_context_parallel_world_size()
cp_rank = get_context_parallel_rank()
for key, data in batch.items():
index = tex.thd_get_partitioned_indices(
packed_seq_params.cu_seqlens_q_padded, data.size(1), cp_size, cp_rank
)
if key == "combined_embeddings":
ctx.decoder_emb_index = index
ctx.decoder_emb_seqlen = data.size(1)
batch[key] = data.index_select(1, index)
batch[key].requires_grad = data.requires_grad
return batch
@staticmethod
def backward(ctx, grad_out, grad_label, grad_loss):
"""Context Parallelism backward support for THD format"""
seqlen = ctx.decoder_emb_seqlen
index = ctx.decoder_emb_index
assert grad_out.size(1) == index.size(
0
), f"Shape mismatch in incoming gradient {grad_out.shape} and \
index from THD CP sharding {index.shape}"
grad_in = torch.zeros(
grad_out.size(0),
seqlen,
*grad_out.size()[2:],
dtype=grad_out.dtype,
device=grad_out.device,
)
grad_in[:, ctx.decoder_emb_index, :] = grad_out
return (grad_in, None, None, None)
# Note: This is under development and may be missing features.
class LLaVAModel(MegatronModule):
"""LLaVA multi-modal model.
......@@ -58,6 +110,7 @@ class LLaVAModel(MegatronModule):
missing when loading a checkpoint. Default False.
parallel_output (bool): Keep outputs split across tensor parallel ranks.
This is typically True for training and False for inference.
share_embeddings_and_output_weights (bool): Input embedding and output layer share weights.
language_position_embedding_type (str): Language model position embedding type.
language_rotary_percent (float): RoPE percent. Defaults to 1.0.
pre_process (bool): Include embedding layer in the decoder (used with pipeline parallel).
......@@ -71,6 +124,7 @@ class LLaVAModel(MegatronModule):
patch_dim (int): The size of each image patch side.
language_rotary_base (int): RoPE base.
language_rope_scaling (bool): Toggle RoPE scaling.
language_rope_scaling_factor (float): RoPE scaling factor. Defaults to 8.
image_token_index (int): Token ID for image token such as <image>.
pixel_shuffle (bool): Enable pixel shuffle.
tile_tags (list): Optional tile tags.
......@@ -90,6 +144,7 @@ class LLaVAModel(MegatronModule):
vision_projection_type: str = "mlp",
allow_missing_vision_projection_checkpoint: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
language_position_embedding_type: str = 'learned_absolute',
language_rotary_percent: float = 1.0,
pre_process: bool = True,
......@@ -101,6 +156,7 @@ class LLaVAModel(MegatronModule):
patch_dim: int = 14,
language_rotary_base: int = 10000,
language_rope_scaling: bool = False,
language_rope_scaling_factor: float = 8.0,
image_token_index: int = DEFAULT_IMAGE_TOKEN_INDEX,
pixel_shuffle: bool = False,
tile_tags: Optional[list] = None,
......@@ -143,8 +199,15 @@ class LLaVAModel(MegatronModule):
# This attribute is needed to check if an all-reduce is required
# on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`.
self.share_embeddings_and_output_weights = False
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
if self.add_decoder:
if hasattr(
language_transformer_config, "language_model_type"
) and language_transformer_config.language_model_type.startswith("huggingface"):
from megatron.core.models.huggingface.module import build_hf_model
self.language_model = build_hf_model(language_transformer_config)
else:
self.language_model = GPTModel(
config=language_transformer_config,
transformer_layer_spec=language_transformer_layer_spec,
......@@ -159,6 +222,7 @@ class LLaVAModel(MegatronModule):
rope_scaling=language_rope_scaling,
scatter_embedding_sequence_parallel=False,
)
self.share_embeddings_and_output_weights = (
self.language_model.share_embeddings_and_output_weights
)
......@@ -167,10 +231,19 @@ class LLaVAModel(MegatronModule):
language_transformer_config.pipeline_model_parallel_size > 1
)
# Newer Transformer Engine versions add _extra_state keys in state_dict when using FP8.
# Older models may not have _extra_state and can be ignored.
self.language_model.register_load_state_dict_post_hook(
_load_state_dict_hook_ignore_extra_state
)
class_token_len = 1
if self.add_encoder:
self._drop_vision_class_token = drop_vision_class_token
add_class_token = True
if vision_transformer_config.vision_model_type.startswith(
("clip", "siglip", "internvit")
):
if vision_transformer_config.vision_model_type == "siglip":
class_token_len = 0
add_class_token = False
......@@ -189,6 +262,40 @@ class LLaVAModel(MegatronModule):
model_subtype=vision_transformer_config.vision_model_type,
add_class_token=add_class_token,
)
elif vision_transformer_config.vision_model_type in ("radio"):
# TODO: should refactor into model code itself?
class_token_len = 8
max_img_h = 2048
max_img_w = 2048
embedder_bias = False
use_mask_token = False
self.vision_model = RADIOViTModel(
vision_transformer_config,
vision_transformer_layer_spec,
img_h=img_h,
img_w=img_w,
max_img_h=max_img_h,
max_img_w=max_img_w,
class_token_len=class_token_len,
patch_dim=patch_dim,
add_class_token=add_class_token,
embedder_bias=embedder_bias,
use_mask_token=use_mask_token,
)
elif vision_transformer_config.vision_model_type.startswith("huggingface"):
from megatron.core.models.huggingface.module import build_hf_model
self.vision_model = build_hf_model(vision_transformer_config)
else:
raise ValueError(
"Vision model "
f"{vision_transformer_config.vision_model_type} is not "
"supported."
)
self.vision_model.register_load_state_dict_post_hook(
_load_state_dict_hook_ignore_extra_state
)
vision_projection_input_size = vision_transformer_config.hidden_size
vision_projection_input_size *= 4 if pixel_shuffle else 1
......@@ -213,7 +320,7 @@ class LLaVAModel(MegatronModule):
partial(_load_state_dict_hook_ignore_param_names, vision_projection_param_names)
)
self._img_seq_len = get_num_image_embeddings(
self.img_seq_len = get_num_image_embeddings(
img_h,
img_w,
patch_dim,
......@@ -287,7 +394,6 @@ class LLaVAModel(MegatronModule):
inference_params,
image_token_index,
num_image_tiles,
image_token_mask=None,
):
"""Preprocess input data before input to language model.
......@@ -335,11 +441,8 @@ class LLaVAModel(MegatronModule):
if use_inference_kv_cache:
return language_embeddings, loss_mask, labels
img_seq_len = self._img_seq_len
img_seq_len = self.img_seq_len
batch_size, text_seq_len = input_ids.shape
# input_ids seq len is expected to be sharded by CP size
if self.context_parallel_lm:
text_seq_len *= self.context_parallel_lm
has_labels = labels is not None
if has_labels:
......@@ -349,11 +452,6 @@ class LLaVAModel(MegatronModule):
# Create indices for new text and label positions.
with torch.no_grad():
if image_token_mask is None:
assert (
self.context_parallel_lm <= 1
), "image_token_mask cannot be inferred from input_ids if using \
Context Parallelism. Please provide in forward_step"
image_token_mask = input_ids == image_token_index
num_images_per_sample = torch.sum(image_token_mask, dim=-1)
......@@ -388,8 +486,10 @@ class LLaVAModel(MegatronModule):
new_position_ids = torch.cumsum((image_token_mask_lens + 1), dim=-1) - 1
text_position_ids = new_position_ids[batch_indices, non_image_indices]
label_batch_indices = None # dummy value to pass formatting
# Labels are shifted to left by one.
# So, shift text position ids and non-image indices to left by one.
label_batch_indices = None
if has_labels:
label_text_position_ids = text_position_ids - 1
valid_label_text_position_ids = label_text_position_ids >= 0
......@@ -434,6 +534,14 @@ class LLaVAModel(MegatronModule):
]
# Put image embeddings to image positions.
# NOTE: FSDP can hang with text-only samples so we use a workaround to run a dummy image
# through the vision model and then zero-out the impact of the output here.
if num_image_tiles.shape[0] == 0 and image_embeddings.shape[0] > 0:
assert images_mask.sum() == 0 and getattr(
self.vision_model, "_is_fsdp_managed_module", False
), "expected FSDP and dummy image"
final_embedding[:1, :1, :1] += 0 * image_embeddings[:1, :1, :1]
else:
final_embedding[images_mask] = (
image_embeddings.permute(1, 0, 2).reshape(-1, embed_dim).contiguous()
)
......@@ -489,7 +597,7 @@ class LLaVAModel(MegatronModule):
# Truncate if exceeding the language model's max sequence length.
if final_embedding.shape[1] > self._language_max_sequence_length:
final_embedding = final_embedding[:, : self._language_max_sequence_length]
# Transpose to [s,b,h] if not using CP because CP Sharding expects seq in dim=1
# Transpose to [s,b,h] only if not using CP because CP Sharding expects seq in dim=1
if self.context_parallel_lm == 1:
final_embedding = final_embedding.transpose(1, 0).contiguous()
......@@ -508,18 +616,12 @@ class LLaVAModel(MegatronModule):
"""Processes the input data for model parallelism support.
When using sequence parallelism (SP) or context parallelism (CP), the sequence is sharded
across different GPUs. This function helps ensure that the sharding is done correctly by
1. Calculates `padding_factor` which determines based on how many chunks we expect to shard
the sequence
2. Calculates and pads the inputs to necessary length to ensure equal sized chunks
3. Creates/Modifies PackedSeqParams which helps mask padded tokens during calculations
4. Performs any layout changes if necessary
5. Distributes the sequence across GPUs for SP and CP
across different GPUs. This function performs the sharding and distributes the sequence
across GPUs for SP and CP
Context Parallelism is a feature that helps improve memory efficiency for
long sequence training by distributing sequence across CP ranks.
It requires token length to be divisible by (CP size *2) to ensure proper load balance.
Please refer to `get_batch_on_this_cp_rank` function for more details.
Sequence Parallelism is a feature that helps improve memory efficiency for
long sequence training by distributing sequence across TP ranks.
......@@ -532,143 +634,62 @@ class LLaVAModel(MegatronModule):
packed_seq_params (PackedSeqParams): Dict with padded token information.
"""
# combined_embeddings - `s,b,h` if not using CP, `b,s,h` if using CP
batch_size = (
combined_embeddings.shape[0]
if self.context_parallel_lm > 1
else combined_embeddings.shape[1]
)
seq_dim = 1 if self.context_parallel_lm > 1 else 0
padding_mask_type = 'padding' in str(
self.language_model.transformer_layer_spec.submodules.self_attention.params.get(
'attn_mask_type', ''
)
)
if self.sequence_parallel_lm and self.tp_comm_overlap_lm:
assert (
combined_embeddings.shape[seq_dim] == self._language_max_sequence_length
) or padding_mask_type, f"TP Comm overlap either requires Vision+Text token length \
== language_max_sequence_length or mask type to be set to padding/padding_causal"
if padding_mask_type:
# Calculate the padded sequence length needed to support SP and CP
# SP and CP are used to distributed the sequence across GPUs to improve
# memory efficiency and enable very long context training.
# To distribute workload equally, we need to ensure that the sequence is
# divisible by the appropriate padding factor calculated below.
padding_factor = None
padded_seq_len = None
mp_padding_needed = 0
# No pre or post processing needed with PP middle chunks.
if not self.pre_process and not self.post_process:
return combined_embeddings, new_labels, new_loss_mask, packed_seq_params
shard_factor = seq_dim = None
if self.pre_process:
if self.context_parallel_lm > 1 and self.sequence_parallel_lm:
padding_factor = self.tensor_model_parallel_size_lm * self.context_parallel_lm * 2
shard_factor = self.tensor_model_parallel_size_lm * self.context_parallel_lm * 2
seq_dim = 1
elif self.context_parallel_lm > 1:
padding_factor = self.context_parallel_lm * 2
shard_factor = self.context_parallel_lm * 2
seq_dim = 1
elif self.sequence_parallel_lm:
padding_factor = self.tensor_model_parallel_size_lm
padded_seq_len = int(
(combined_embeddings.shape[seq_dim] + (padding_factor - 1))
// padding_factor
* padding_factor
)
shard_factor = self.tensor_model_parallel_size_lm
seq_dim = 0
assert (
padded_seq_len <= self._language_max_sequence_length
), f"Sequence length after padding {padded_seq_len} for SP/CP has exceeded \
language_max_sequence_length. Ensure language_max_sequence_length is \
divisible by SP/CP factor: {padding_factor}"
combined_embeddings.shape[seq_dim] % shard_factor == 0
), f"Sequence length should be divisible by {shard_factor} for \
Sequence/Context parallelism"
if self.sequence_parallel_lm and self.tp_comm_overlap_lm:
# TP Comm overlap initializes the user buffer shape used for communication
# at the beginning of training run and the same shape is expected to be
# used throughout the training.
# Pad to language_max_sequence_length to use TP Comm overlap.
assert (
self._language_max_sequence_length % padding_factor == 0
), f"TP Comm overlap uses language_max_sequence_length \
which needs to be divisible by SP/CP factor {padding_factor}"
padded_seq_len = self._language_max_sequence_length
assert (
packed_seq_params is not None
), "Please provide PackedSeqParams dict when using SP or CP with padding"
valid_seqlens = packed_seq_params.cu_seqlens_q[1:] - packed_seq_params.cu_seqlens_q[:-1]
valid_seq_len = max(valid_seqlens)
assert (
padded_seq_len >= valid_seq_len
), f"Padded Seq Len calculated for model parallelism: {padded_seq_len} \
is shorter than expected valid token len {valid_seq_len} provided."
mp_padding_needed = padded_seq_len - combined_embeddings.shape[seq_dim]
if mp_padding_needed > 0:
new_labels = torch.nn.functional.pad(
new_labels, (0, mp_padding_needed), value=IGNORE_INDEX
)
new_loss_mask = torch.nn.functional.pad(new_loss_mask, (0, mp_padding_needed))
if self.context_parallel_lm > 1:
combined_embeddings = torch.nn.functional.pad(
combined_embeddings, (0, 0, 0, mp_padding_needed)
)
else:
combined_embeddings = torch.nn.functional.pad(
combined_embeddings, (0, 0, 0, 0, 0, mp_padding_needed)
)
# Update PackedSeqParams if padding needed beyond user provided PackedSeqParams
packed_seq_params.max_seqlen_q = padded_seq_len
packed_seq_params.max_seqlen_kv = padded_seq_len
cu_seqlens_padded = None
# We need cu_seqlens_q_padded/cu_seqlens_kv_padded when doing
# CP+Padding to support accurate Attention with THD format.
if self.context_parallel_lm > 1:
cu_seqlens_padded = torch.arange(
0,
(batch_size + 1) * (padded_seq_len),
step=(padded_seq_len),
dtype=torch.int32,
device=combined_embeddings.device,
)
packed_seq_params.cu_seqlens_q_padded = cu_seqlens_padded
packed_seq_params.cu_seqlens_kv_padded = cu_seqlens_padded
packed_seq_params.qkv_format = 'thd'
else:
packed_seq_params.qkv_format = 'sbhd'
combined_embeddings.shape[seq_dim] == self._language_max_sequence_length
), f"TP Comm overlap either requires Vision+Text token length \
== language_max_sequence_length"
if self.context_parallel_lm > 1:
batch = dict()
if self.pre_process:
batch["combined_embeddings"] = combined_embeddings
if self.post_process:
batch["new_labels"] = new_labels
batch["new_loss_mask"] = new_loss_mask
# Distribute sequence across CP ranks
if packed_seq_params is None or packed_seq_params.qkv_format == 'sbhd':
from megatron.training.utils import get_batch_on_this_cp_rank
batch = get_batch_on_this_cp_rank(
{
"combined_embeddings": combined_embeddings,
"new_labels": new_labels,
"new_loss_mask": new_loss_mask,
}
)
batch = get_batch_on_this_cp_rank(batch)
else:
assert HAVE_TEX and is_te_min_version(
"1.10.0"
), "Please update Transformer Engine to >= 1.10 to use \
Context Parallel with THD format data"
batch = _get_data_on_this_cp_rank.apply(batch, packed_seq_params)
if self.pre_process:
combined_embeddings = batch["combined_embeddings"] # [B, S/CP, H]
new_labels = batch["new_labels"]
new_loss_mask = batch["new_loss_mask"]
if getattr(packed_seq_params, 'qkv_format', None) == 'thd':
# If PackedSeqParams requires THD format,
# reshape embedding from [B,S,H] to [T,1,H] where T=B*S
combined_embeddings = (
combined_embeddings.contiguous()
.view(combined_embeddings.shape[0] * combined_embeddings.shape[1], -1)
.unsqueeze(1)
)
new_labels = new_labels.view(new_labels.shape[0] * new_labels.shape[1]).unsqueeze(0)
new_loss_mask = new_loss_mask.view(
new_loss_mask.shape[0] * new_loss_mask.shape[1]
).unsqueeze(0)
else:
combined_embeddings = combined_embeddings.transpose(
1, 0
).contiguous() # [B,S/CP,H] -> [S/CP,B,H]
if self.post_process:
new_labels = batch["new_labels"]
new_loss_mask = batch["new_loss_mask"]
if self.sequence_parallel_lm:
if self.sequence_parallel_lm and self.pre_process:
combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(
combined_embeddings
) # [S/(CP*TP),B,H]
......@@ -723,7 +744,6 @@ class LLaVAModel(MegatronModule):
num_image_tiles: Optional[List[int]] = None,
image_token_index: Optional[int] = None,
runtime_gather_output: Optional[bool] = None,
image_token_mask: Optional[torch.Tensor] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
) -> torch.Tensor:
"""Forward function of the LLaVA model.
......@@ -745,8 +765,6 @@ class LLaVAModel(MegatronModule):
arg in the constructor will be used.
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
image_token_mask (torch.Tensor): Tensor indicating the location of
image token index in input_ids.
packed_seq_params (PackedSeqParams): 1) If using sequence packing, must contain
subsample length information. 2) If using SP/CP with padding mask type,
must contain padded token information.
......@@ -818,11 +836,6 @@ class LLaVAModel(MegatronModule):
language_embeddings = self.language_model.embedding(
input_ids=input_ids_text, position_ids=position_ids
) # [text_seq_len, b, h_language]
# Gather the language embeddings back. We need the full embedding to insert
# image embeddings and then scatter again to avoid load imbalance.
if self.context_parallel_lm > 1:
cp_group = get_context_parallel_group()
language_embeddings, _ = gather_along_first_dim(language_embeddings, cp_group)
language_embeddings = language_embeddings.transpose(
1, 0
......@@ -842,7 +855,6 @@ class LLaVAModel(MegatronModule):
inference_params,
image_token_index if image_token_index is not None else self.image_token_index,
num_image_tiles,
image_token_mask,
) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len]
if self.context_parallel_lm > 1 or self.sequence_parallel_lm:
......@@ -890,6 +902,28 @@ def _load_state_dict_hook_ignore_param_names(
incompatible_keys.missing_keys.remove(param_name)
def _load_state_dict_hook_ignore_extra_state(
module: torch.nn.Module, incompatible_keys: namedtuple
):
"""Hook to ignore Transformer Engine _extra_state used for FP8.
This is for backwards-compatibility. Newer TE versions add _extra_state keys to the state dict,
while older models might not have those keys. Those keys can be ignored when not using FP8.
Args:
module (torch.nn.Module): The torch module this hook applies to. Required by the torch API.
incompatible_keys (namedtuple): Namedtuple with fields missing_keys and unexpected_keys,
which collect the missing and unexpected keys, respectively.
"""
for name, keys in incompatible_keys._asdict().items():
for key in keys[::-1]:
if "extra_state" in key:
logging.getLogger(__name__).warning(
f"_extra_state key {key} being removed from {name}"
)
keys.remove(key)
# pylint: disable-next=line-too-long
# Based on https://github.com/OpenGVLab/InternVL/blob/c7c5af1a8930b4862afe8ed14672307082ef61fa/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py#L218
# Copyright (c) 2023 OpenGVLab.
......
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Optional
from megatron.core.extensions.transformer_engine import (
TEDotProductAttention,
TELayerNormColumnParallelLinear,
......@@ -6,7 +8,7 @@ from megatron.core.extensions.transformer_engine import (
TERowParallelLinear,
)
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec
from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
......@@ -27,15 +29,15 @@ except ImportError:
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn(f'Apex is not installed. Falling back to Torch Norm')
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
def decoder_model_with_transformer_engine_default_spec(
num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
num_experts: Optional[int] = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
) -> ModuleSpec:
"""LLava decoder TE spec (uses Transformer Engine components)."""
mlp = _get_mlp_module_spec(
mlp = get_mlp_module_spec(
use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
)
return ModuleSpec(
......@@ -60,10 +62,10 @@ def decoder_model_with_transformer_engine_default_spec(
def decoder_model_with_local_default_spec(
num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
num_experts: Optional[int] = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
) -> ModuleSpec:
"""LLava decoder local spec."""
mlp = _get_mlp_module_spec(
mlp = get_mlp_module_spec(
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
)
return ModuleSpec(
......
......@@ -201,6 +201,11 @@ def get_num_image_embeddings(
keep_class_token = False
elif vision_model_type in ("clip", "internvit"):
keep_class_token = not disable_vision_class_token
elif vision_model_type.startswith("radio"):
keep_class_token = not disable_vision_class_token
elif vision_model_type.startswith("huggingface"):
# TODO: Temp, what do we do in this sitaution?
keep_class_token = True
else:
raise ValueError(f"unsupported vision model: {vision_model_type}")
......
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.models.common.vision_module.vision_module import VisionModule
from megatron.core.tensor_parallel.layers import ColumnParallelLinear
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
# RADIO reference code: https://github.com/NVlabs/RADIO
class RADIOViTModel(VisionModule):
"""RADIO ViT vision model.
Args:
transformer_config (TransformerConfig): Transformer config.
transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers.
ln_pre_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_pre.
ln_post_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_post.
use_mask_token (bool, optional): Whether to use RADIO mask token. Default to False.
add_class_token (bool, optional): Include a class token. Defaults to True.
class_token_len (int): Class token length. Defaults to 1 but 8 may be faster.
patch_dim (int): Image patch size.
img_h (int): Input image height.
img_w (int): Input image width.
max_img_h (int): Max input image height.
max_img_w (int): Max input image width.
pos_dropout (int): Positional encoding dropout value. Defaults to 0.
has_cpe: (bool): Whether to use conditional positional encoding. Defaults to True.
embedder_bias: (bool): Bias in embedder linear. Defaults to False.
"""
def __init__(
self,
transformer_config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
ln_pre_impl: Union[ModuleSpec, type] = None,
ln_post_impl: Union[ModuleSpec, type] = None,
use_mask_token: bool = False,
add_class_token: bool = True,
class_token_len: int = 8,
patch_dim: int = 16,
img_h: int = 224,
img_w: int = 224,
max_img_h: int = 2048,
max_img_w: int = 2048,
pos_dropout: int = 0,
has_cpe: bool = True,
embedder_bias: bool = False,
) -> None:
super().__init__(config=transformer_config)
if has_config_logger_enabled(transformer_config):
log_config_to_disk(transformer_config, locals(), prefix=type(self).__name__)
self.class_token_len = class_token_len
self.visual_hidden_size = transformer_config.hidden_size
self.patch_dim = patch_dim
self.img_h = img_h
self.img_w = img_w
assert self.img_h % self.patch_dim == 0
assert self.img_w % self.patch_dim == 0
self.input_dims = (img_h // patch_dim, img_w // patch_dim)
# used for positional embedding
self.max_img_h = max_img_h
self.max_img_w = max_img_w
self.max_num_rows = max_img_h // patch_dim
self.max_num_cols = max_img_w // patch_dim
self.max_num_patches = self.max_num_rows * self.max_num_cols
# TODO: are we actually going to use this anywhere?
self.use_mask_token = use_mask_token
if self.use_mask_token:
self.mask_token = nn.Parameter(torch.zeros(1, self.visual_hidden_size))
self.add_class_token = add_class_token
self.class_token_len = class_token_len
if self.add_class_token:
self.class_token = nn.Parameter(
torch.randn(self.class_token_len, self.visual_hidden_size)
)
self.seq_length = (img_h // self.patch_dim) * (img_w // self.patch_dim) + (
self.class_token_len if self.add_class_token else 0
)
pos_scale = self.visual_hidden_size**-0.5
self.position_embeddings = nn.Parameter(
torch.randn(1, self.max_num_patches, self.visual_hidden_size) * pos_scale
)
self.pos_dropout = pos_dropout
self.has_cpe = has_cpe
# Using non-TE version so we can force gather_output
self.embedder = ColumnParallelLinear(
input_size=3 * self.patch_dim * self.patch_dim,
output_size=self.visual_hidden_size,
bias=embedder_bias,
config=transformer_config,
gather_output=True,
init_method=lambda tensor: torch.nn.init.normal_(tensor, mean=0.0, std=1.0),
)
self.model_type = ModelType.encoder_or_decoder
self.ln_pre = None
self.ln_post = None
if ln_pre_impl is not None:
self.ln_pre = build_module(
ln_pre_impl,
config=transformer_config,
hidden_size=self.visual_hidden_size,
eps=transformer_config.layernorm_epsilon,
)
if ln_post_impl is not None:
self.ln_post = build_module(
ln_post_impl,
config=transformer_config,
hidden_size=self.visual_hidden_size,
eps=transformer_config.layernorm_epsilon,
)
self.decoder = TransformerBlock(
config=transformer_config,
spec=transformer_layer_spec,
pre_process=True,
post_process=False,
)
def set_input_tensor(self, input_tensor: torch.Tensor) -> None:
"""Sets input tensor to the model.
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
self.decoder.set_input_tensor(input_tensor)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Forward function of the RADIO ViT Model. This function passes the input tensors
through the embedding layer and then the transformer.
Args:
x (torch.Tensor): input data of shape [batch, img_h, img_w]
attention_mask (torch.Tensor with dtype=bool): Attention mask to use.
Returns:
x (torch.Tensor): output after final transformer block of shape [b, s, h].
"""
input_size = x.shape[2:]
py = x.shape[-2] // self.patch_dim
px = x.shape[-1] // self.patch_dim
x = rearrange(
x,
'b c (py yy) (px xx) -> b (py px) (c yy xx)',
py=py,
yy=self.patch_dim,
px=px,
xx=self.patch_dim,
)
x, _ = self.embedder(x) # [batch, seq_length, hidden_size]
x, _ = self.apply_pos_enc(x, input_size=input_size)
if self.add_class_token:
class_token = self.class_token.expand(
x.shape[0], -1, -1
) # [batch, class_token_len, hidden_size]
x = torch.cat(
[class_token, x], dim=1
) # [batch, seq_length + class_token_len, hidden_size]
assert x.shape[1] == self.seq_length, f"{x.shape[1]} != {self.seq_length}"
if self.ln_pre:
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # [b, s, h] -> [s, b, h]
x = x.contiguous()
x = self.decoder(x, attention_mask=attention_mask)
x = x.permute(1, 0, 2) # [s, b, h] -> [b, s, h]
x = x.contiguous()
if self.ln_post:
x = self.ln_post(x)
return x
def apply_pos_enc(
self,
patches: torch.Tensor,
patch_idxs: Optional[torch.Tensor] = None,
input_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
"""Apply positional encoding to patches"""
pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size)
if self.training and self.pos_dropout > 0:
keeps = (
torch.rand(patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device)
> self.pos_dropout
)
pos_enc_drop = torch.where(keeps, pos_enc, 0)
else:
pos_enc_drop = pos_enc
return patches + pos_enc_drop, pos_enc
def get_pos_enc(
self,
batch_size: int,
patch_idxs: Optional[torch.Tensor] = None,
input_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
"""Get positional encoding for certain input size"""
if input_size is None:
input_dims = self.input_dims
else:
input_dims = tuple(d // self.patch_dim for d in input_size)
pos_embed = self._get_pos_embeddings(batch_size, input_dims)
if patch_idxs is None:
return pos_embed
exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])
pos_embed = torch.gather(
pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs
)
return pos_embed
def _get_pos_embeddings(self, batch_size: int, input_dims: Tuple[int, int]):
"""Get RADIO absolute positional embeddings"""
if (self.max_num_rows, self.max_num_cols) == input_dims:
return self.position_embeddings
pos_embed = self.position_embeddings.reshape(
1, self.max_num_rows, self.max_num_cols, -1
).permute(0, 3, 1, 2)
def window_select(pos_embed):
if input_dims[0] < pos_embed.shape[-2]:
pos_embed = pos_embed[..., : input_dims[0], :]
if input_dims[1] < pos_embed.shape[-1]:
pos_embed = pos_embed[..., :, : input_dims[1]]
return pos_embed
if self.has_cpe:
if self.training:
min_scale = math.sqrt(0.1)
scale = (
torch.rand(batch_size, 1, 1, device=pos_embed.device) * (1 - min_scale)
+ min_scale
)
aspect_min = math.log(3 / 4)
aspect_max = -aspect_min
aspect = torch.exp(
torch.rand(batch_size, 1, 1, device=pos_embed.device)
* (aspect_max - aspect_min)
+ aspect_min
)
scale_x = scale * aspect
scale_y = scale * (1 / aspect)
scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)
pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (1 - scale_xy)
lin_x = torch.linspace(0, 1, steps=input_dims[1], device=pos_embed.device)[
None, None
].expand(batch_size, input_dims[0], -1)
lin_y = torch.linspace(0, 1, steps=input_dims[0], device=pos_embed.device)[
None, :, None
].expand(batch_size, -1, input_dims[1])
lin_xy = torch.stack([lin_x, lin_y], dim=-1)
grid_xy = lin_xy * scale_xy + pos_xy
# Convert to [-1, 1] range
grid_xy.mul_(2).sub_(1)
pos_embed = F.grid_sample(
pos_embed.float().expand(batch_size, -1, -1, -1),
grid=grid_xy,
mode='bilinear',
padding_mode='zeros',
align_corners=True,
).to(pos_embed.dtype)
else:
max_dim = max(input_dims)
pos_embed = F.interpolate(
pos_embed.float(), size=(max_dim, max_dim), align_corners=True, mode='bilinear'
).to(pos_embed.dtype)
pos_embed = window_select(pos_embed)
else:
pos_embed = window_select(pos_embed)
if pos_embed.shape[-2:] != input_dims:
pos_embed = F.interpolate(
pos_embed.float(), size=input_dims, align_corners=True, mode='bilinear'
).to(pos_embed.dtype)
pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
return pos_embed
......@@ -27,7 +27,7 @@ except ImportError:
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn(f'Apex is not installed. Falling back to Torch Norm')
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
......
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
import warnings
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torch.optim import SGD as CPUSGD
from torch.optim import AdamW as CPUAdam
try:
from transformer_engine.pytorch.optimizers import FusedAdam as Adam
......@@ -12,8 +15,6 @@ except ImportError:
from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD
except ImportError:
import warnings
warnings.warn(
f'Transformer Engine and Apex are not installed. Falling back to Torch optimizers.'
)
......@@ -24,10 +25,11 @@ except ImportError:
from torch.optim import AdamW as Adam, SGD
from megatron.core import mpu
from megatron.core.optimizer.cpu_offloading.hybrid_optimizer import HybridDeviceOptimizer
from ..distributed.param_and_grad_buffer import _ParamAndGradBuffer
from ..transformer.module import MegatronModule
from ..utils import log_single_rank
from ..utils import is_te_min_version, log_single_rank
from .distrib_optimizer import DistributedOptimizer
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import (
......@@ -81,7 +83,12 @@ def _get_param_groups(
# Map (wd_mult, lr_mult, is_expert_parallel, is_decoupled_lr) to params.
params_map = {}
for model_chunk in model_chunks:
for name, param in model_chunk.named_parameters():
if model_chunk.ddp_config.use_custom_fsdp:
named_parameters = model_chunk.optimizer_named_parameters()
else:
named_parameters = model_chunk.named_parameters()
for name, param in named_parameters:
if not param.requires_grad:
continue
......@@ -262,12 +269,50 @@ def _get_megatron_optimizer_based_on_param_groups(
Returns:
Instance of MegatronOptimizer.
"""
# when freezing sub-models we may have no trainable parameters on a rank and
# hence an empty param_groups. However, we still need to create an optimizer
# for the purposes of grad stats reductions
if param_groups:
if config.optimizer_cpu_offload:
if torch.__version__ < '2.3.0':
warnings.warn(
"CPU offload is recommended for PyTorch >= 2.3.0, "
"untested versions below this may have convergence issues."
)
gpu_optimizer_cls = Adam if config.optimizer == 'adam' else SGD
cpu_optimizer_cls = CPUAdam if config.optimizer == 'adam' else CPUSGD
if config.use_torch_optimizer_for_cpu_offload:
gpu_optimizer_cls = cpu_optimizer_cls
if config.optimizer == 'adam':
gpu_optimizer_cls = Adam
cpu_optimizer_cls = CPUAdam
optimizer_defaults = dict(
lr=config.lr,
weight_decay=config.weight_decay,
betas=(config.adam_beta1, config.adam_beta2),
eps=config.adam_eps,
bias_correction=True,
fused=True, # this flag is used to improve the performance of the cpu optimizer
)
else:
gpu_optimizer_cls = SGD
cpu_optimizer_cls = CPUSGD
optimizer_defaults = dict(
lr=config.lr, weight_decay=config.weight_decay, momentum=config.sgd_momentum
)
optimizer = HybridDeviceOptimizer(
param_groups,
offload_fraction=config.optimizer_offload_fraction,
cpu_optimizer_cls=cpu_optimizer_cls,
gpu_optimizer_cls=gpu_optimizer_cls,
overlap_cpu_optimizer_d2h_h2d=config.overlap_cpu_optimizer_d2h_h2d,
pin_cpu_grads=config.pin_cpu_grads,
pin_cpu_params=config.pin_cpu_params,
param_update_in_fp32=True,
**optimizer_defaults,
)
init_state_fn = None
elif config.optimizer == 'adam':
kwargs = {
"params": param_groups,
"lr": config.lr,
......@@ -287,6 +332,9 @@ def _get_megatron_optimizer_based_on_param_groups(
}
)
if is_te_min_version("2.1.0.dev0"):
kwargs.update({"store_param_remainders": True})
optimizer = Adam(**kwargs)
def init_state_fn(opt, config=None):
......@@ -371,6 +419,7 @@ def get_megatron_optimizer(
no_weight_decay_cond: Optional[Callable] = None,
scale_lr_cond: Optional[Callable] = None,
lr_mult: float = 1.0,
use_gloo_process_groups: bool = True,
) -> MegatronOptimizer:
"""Retrieve the Megatron optimizer for model chunks.
......@@ -385,6 +434,8 @@ def get_megatron_optimizer(
should have a scaled learning rate. Defaults to None.
lr_mult (float, optional): learning rate multiplier for parameters that
satisfy scale_lr_cond. Defaults to 1.0.
use_gloo_process_groups (bool): if false, disable use of Gloo process groups
in underlying Megatron optimizers.
Returns:
Instance of MegatronOptimizer.
......@@ -414,6 +465,42 @@ def get_megatron_optimizer(
optimizers = []
model_chunk_offset = 0
ddp_config = model_chunks[0].ddp_config # Use the first model chunk's DDP config
if ddp_config.use_custom_fsdp:
for model_chunk, overlap_param_gather_with_optimizer_step in zip(
all_dense_model_chunks, overlap_param_gather_with_optimizer_step_flags
):
param_groups, buffers = _get_param_groups_and_buffers(
model_chunk,
model_chunk_offset=model_chunk_offset,
config=config,
no_weight_decay_cond=no_weight_decay_cond,
scale_lr_cond=scale_lr_cond,
lr_mult=lr_mult,
filter_fn=lambda g: True,
buffer_name='buffers',
)
optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config,
model_chunks=model_chunk,
param_groups=param_groups,
per_model_buffers=buffers,
model_parallel_group=mpu.get_model_parallel_group(),
data_parallel_group=mpu.get_data_parallel_group(with_context_parallel=True),
data_parallel_group_gloo=mpu.get_data_parallel_group_gloo(
with_context_parallel=True
),
data_parallel_group_idx=model_parallel_rank,
)
)
model_chunk_offset += 1
if len(optimizers) == 1:
return optimizers[0]
return ChainedOptimizer(optimizers)
for dense_model_chunks, overlap_param_gather_with_optimizer_step in zip(
all_dense_model_chunks, overlap_param_gather_with_optimizer_step_flags
):
......@@ -432,6 +519,13 @@ def get_megatron_optimizer(
overlap_param_gather_with_optimizer_step
)
# Pass Gloo process groups into optimizer only if needed.
if use_gloo_process_groups:
data_parallel_group_gloo = mpu.get_data_parallel_group_gloo(
with_context_parallel=True, partial_data_parallel=True
)
else:
data_parallel_group_gloo = None
optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config,
......@@ -442,9 +536,7 @@ def get_megatron_optimizer(
data_parallel_group=mpu.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=True
),
data_parallel_group_gloo=mpu.get_data_parallel_group_gloo(
with_context_parallel=True, partial_data_parallel=True
),
data_parallel_group_gloo=data_parallel_group_gloo,
data_parallel_group_idx=model_parallel_rank,
distributed_optimizer_instance_id=distributed_optimizer_instance_id,
)
......@@ -465,6 +557,11 @@ def get_megatron_optimizer(
model_parallel_rank = torch.distributed.get_rank(
mpu.get_expert_tensor_model_pipeline_parallel_group()
)
# Pass Gloo process groups into optimizer only if needed.
if use_gloo_process_groups:
data_parallel_group_gloo = mpu.get_expert_data_parallel_group_gloo()
else:
data_parallel_group_gloo = None
optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config,
......@@ -473,7 +570,7 @@ def get_megatron_optimizer(
per_model_buffers=moe_buffers,
model_parallel_group=mpu.get_expert_tensor_model_pipeline_parallel_group(),
data_parallel_group=mpu.get_expert_data_parallel_group(),
data_parallel_group_gloo=mpu.get_expert_data_parallel_group_gloo(),
data_parallel_group_gloo=data_parallel_group_gloo,
data_parallel_group_idx=model_parallel_rank,
)
)
......
## How to use ?
Add these flags to enable optimizer cpu offload in MCore.
```bash
--optimizer-cpu-offload
--optimizer-offload-fraction 1.0
--use-precision-aware-optimizer
```
## Configuration Recommendataions
Gradient copy from GPU to CPU, CPU optimizer step, and subsequent parameter copy from CPU to GPU can be time-consuming operations, and it is recommended to use the flag `--overlap-cpu-optimizer-d2h-h2d` to execute them concurrently.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from .hybrid_optimizer import HybridDeviceOptimizer
# Copyright (c) 2025, NVIDIA CORPORATION and Alibaba PAI. All rights reserved.
from collections import defaultdict
from typing import Dict
import torch
def _param_generator(cpu_optimizer):
for group in cpu_optimizer.param_groups:
for param in group["params"]:
yield param
class HybridDeviceOptimizer(torch.optim.Optimizer):
"""
HybridDeviceOptimizer is a custom optimizer designed to facilitate
hybrid parameter updates across GPU and CPU. This optimizer allows
users to adjust the fraction of parameters updated on the CPU and
GPU through the `offload_fraction` parameter.
It supports bf16 mixed-precision training. Additionally, the optimizer
implements overlapping operations for improved performance, including
gradient transfer from device to host (D2H) and parameter transfer
from host to device (H2D).
Example:
from transformer_engine.pytorch.optimizers import FusedAdam as GPUAdam
from torch.optim import AdamW as CPUAdam
optimizer = HybridDeviceOptimizer(
param_groups,
cpu_optimizer_cls=CPUAdam,
gpu_optimizer_cls=GPUAdam,
offload_fraction=0.5,
param_update_in_fp32=True,
overlap_cpu_optimizer_d2h_h2d=True,
)
optimizer.step()
Note:
This optimizer is particularly useful in scenarios where memory
constraints are present or when leveraging both CPU and GPU resources
can lead to performance improvements.
"""
def __init__(
self,
params,
offload_fraction=0.5,
cpu_optimizer_cls=None,
gpu_optimizer_cls=None,
param_update_in_fp32: bool = False,
pin_cpu_grads: bool = True,
pin_cpu_params: bool = True,
overlap_cpu_optimizer_d2h_h2d: bool = True,
**kwargs
):
super(HybridDeviceOptimizer, self).__init__(
params,
defaults={
"offload_fraction": offload_fraction,
"cpu_optimizer_cls": cpu_optimizer_cls,
"gpu_optimizer_cls": gpu_optimizer_cls,
"param_update_in_fp32": param_update_in_fp32,
"pin_cpu_grads": pin_cpu_grads,
"pin_cpu_params": pin_cpu_params,
"overlap_cpu_optimizer_d2h_h2d": overlap_cpu_optimizer_d2h_h2d,
**kwargs,
},
)
self.offload_fraction = offload_fraction
self.cpu_optimizer_cls = cpu_optimizer_cls
self.gpu_optimizer_cls = gpu_optimizer_cls
self.pin_cpu_grads = pin_cpu_grads
self.pin_cpu_params = pin_cpu_params
self.overlap_cpu_optimizer_d2h_h2d = overlap_cpu_optimizer_d2h_h2d
self.param_update_in_fp32 = param_update_in_fp32
self.sub_optimizer_kwargs = kwargs
self._init_sub_optimizers()
self._register_load_state_dict_hooks()
def _set_sub_optimizer_grads(self):
if self.param_update_in_fp32:
for param in self.param_to_fp32_param:
if param in self.gpu_params_map_cpu_copy:
# Skip if the param is offloaded to CPU, it should be handled
# in the following part.
continue
fp32_param = self.param_to_fp32_param[param]
grad = getattr(param, "decoupled_grad", param.grad)
if grad is not None:
fp32_param.grad = grad.to(fp32_param.dtype)
fp32_param.requires_grad = True
else:
fp32_param.requires_grad = False
# Sync the grads from GPU to CPU.
for optimizer in self.cpu_optimizers:
for param in _param_generator(optimizer):
gpu_param = self.cpu_copys_map_gpu_param[param]
grad = getattr(gpu_param, "decoupled_grad", gpu_param.grad)
if grad is None:
param.requires_grad = False
continue
param.requires_grad = False
if param not in self.cpu_copy_map_grad:
self.cpu_copy_map_grad[param] = torch.empty(
param.shape, dtype=param.dtype, pin_memory=self.pin_cpu_grads, device="cpu"
)
param.grad = self.cpu_copy_map_grad[param]
self.cpu_copy_map_grad[param].data.copy_(grad, non_blocking=True)
self._cpu_optimizer_map_data_event[optimizer] = self._d2h_stream.record_event()
def _register_param_copy_back_gpu_hook(self):
def param_copy_back_gpu_hook_closure():
def param_copy_back_gpu_hook(optimizer, args, kwargs):
self._h2d_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._h2d_stream):
for param in _param_generator(optimizer):
gpu_param = self.cpu_copys_map_gpu_param[param]
gpu_param.data.copy_(param.data, non_blocking=True)
self._d2h_stream.record_event().wait(torch.cuda.current_stream())
return param_copy_back_gpu_hook
def fp32_param_copy_back_gpu_hook_closure():
def fp32_param_copy_back_gpu_hook(optimizer, args, kwargs):
for group in self.param_groups:
for param in group["params"]:
if param in self.gpu_params_map_cpu_copy:
# Skip if the param is offloaded to GPU, it has been
# copied back in the previous hook.
continue
if param in self.param_to_fp32_param:
fp32_param = self.param_to_fp32_param[param]
param.data.copy_(fp32_param.data)
return fp32_param_copy_back_gpu_hook
for optimizer in self.sub_optimizers:
if optimizer is not self.gpu_optimizer:
optimizer.register_step_post_hook(param_copy_back_gpu_hook_closure())
elif self.param_update_in_fp32:
optimizer.register_step_post_hook(fp32_param_copy_back_gpu_hook_closure())
def step(self, closure=None):
"""
Override the step method to perform the following operations:
1. Sync the HDO param_groups to sub-optimizers.
2. Sync the grads from GPU to CPU.
3. Step the sub-optimizers.
4. Sync the sub-optimizers state to HDO.
"""
# Sync param_groups to sub-optimizers before each step to make sure
# the lr, wd, etc. are up-to-date.
self._sync_hdo_param_groups_to_sub_optimizers()
self._d2h_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._d2h_stream):
self._set_sub_optimizer_grads()
# Step the sub-optimizers.
if self.gpu_optimizer:
self.gpu_optimizer.step(closure)
for cpu_optimizer in self.cpu_optimizers:
d2h_event = self._cpu_optimizer_map_data_event.pop(cpu_optimizer, None)
if d2h_event is not None:
d2h_event.synchronize()
cpu_optimizer.step(closure)
# Sync state and param_groups to HDO after each step.
# NOTE: It is possible for the optimizer to change the properties
# in param_groups.
self._sync_sub_optimizers_state_to_hdo()
def _init_sub_optimizers(self):
(
self.cpu_param_groups,
self.gpu_param_groups,
self.gpu_params_map_cpu_copy,
self.cpu_copys_map_gpu_param,
self.param_to_fp32_param,
) = self._get_sub_optimizer_param_groups(self.offload_fraction)
self.param_to_inner_param = {}
self.inner_param_to_orig_param = {}
for group in self.param_groups:
for param in group["params"]:
if param in self.param_to_fp32_param:
inner_param = self.param_to_fp32_param[param]
elif param in self.gpu_params_map_cpu_copy:
inner_param = self.gpu_params_map_cpu_copy[param]
else:
inner_param = param
self.param_to_inner_param[param] = inner_param
self.inner_param_to_orig_param[inner_param] = param
self.fp32_param_to_orig_param = {v: k for k, v in self.param_to_fp32_param.items()}
self.cpu_optimizers = []
if self.overlap_cpu_optimizer_d2h_h2d:
self.cpu_optimizers = self.build_cpu_optimizer_list(
self.cpu_optimizer_cls, self.cpu_param_groups
)
elif len(self.cpu_param_groups) > 0:
self.cpu_optimizers = [self.cpu_optimizer_cls(self.cpu_param_groups)]
if len(self.gpu_param_groups) > 0:
self.gpu_optimizer = self.gpu_optimizer_cls(self.gpu_param_groups)
else:
self.gpu_optimizer = None
self.cpu_copy_map_grad: Dict[torch.Tensor, torch.Tensor] = defaultdict(torch.Tensor)
self._d2h_stream = torch.cuda.current_stream()
self._h2d_stream = torch.cuda.current_stream()
if self.overlap_cpu_optimizer_d2h_h2d:
self._d2h_stream = torch.cuda.Stream()
self._h2d_stream = torch.cuda.Stream()
self._cpu_optimizer_map_data_event = dict()
self._register_param_copy_back_gpu_hook()
@staticmethod
def build_cpu_optimizer_list(cpu_optimizer_cls, cpu_param_groups):
"""Build several cpu optimizers to enable overlap. Currently we naively
assign each parameter to an individual optimizer.
Args:
cpu_optimizer_cls (Type[torch.optim.Optimizer]): A torch optimizer class
cpu_param_groups (List[Dict[str, Any]]): The CPU parameter groups
"""
cpu_optimizers = []
if len(cpu_param_groups) == 0:
return cpu_optimizers
for group in cpu_param_groups:
group_defaults = group.copy()
params = group_defaults.pop("params")
if isinstance(params, torch.Tensor):
params = [params]
for param in params:
_cpu_param_group = group_defaults.copy()
_cpu_param_group["params"] = [param]
cpu_optimizers.append(cpu_optimizer_cls([_cpu_param_group]))
return cpu_optimizers
def _get_sub_optimizer_param_groups(self, offload_fraction: float):
params = []
for group in self.param_groups:
params.extend(group["params"])
params_total_numel = sum([param.numel() for param in params])
gpu_params_total_numel = sum([param.numel() for param in params if param.is_cuda])
cpu_params_total_numel = params_total_numel - gpu_params_total_numel
offload_threshold = gpu_params_total_numel * offload_fraction
offload_params_numel = 0
cpu_param_groups = []
gpu_param_groups = []
gpu_params_map_cpu_copy = {}
cpu_copys_map_gpu_param = {}
param_to_fp32_param = {}
for group in self.param_groups:
gpu_group = group.copy()
cpu_group = group.copy()
gpu_group["params"] = []
cpu_group["params"] = []
for param in group["params"]:
orig_param = param
cpu_copy = False
if offload_params_numel < offload_threshold and param.is_cuda:
param = param.detach().clone().cpu().pin_memory()
offload_params_numel += param.numel()
cpu_copy = True
if self.param_update_in_fp32 and param.dtype != torch.float32:
param = param.detach().clone().float()
param_to_fp32_param[orig_param] = param
if cpu_copy:
gpu_params_map_cpu_copy[orig_param] = param
cpu_copys_map_gpu_param[param] = orig_param
if param.is_cuda:
gpu_group["params"].append(param)
else:
cpu_group["params"].append(param)
if len(gpu_group["params"]) != 0:
gpu_param_groups.append(gpu_group)
if len(cpu_group["params"]) != 0:
cpu_param_groups.append(cpu_group)
return (
cpu_param_groups,
gpu_param_groups,
gpu_params_map_cpu_copy,
cpu_copys_map_gpu_param,
param_to_fp32_param,
)
def _sync_sub_optimizers_state_to_hdo(self):
"""
Update HDO state attribute to sub-optimizers.
"""
# optimizer.state:
# {
# torch.nn.Parameter: {
# str: Any,
# },
# ...
# }
new_state = defaultdict(dict)
for optimizer in self.sub_optimizers:
for param in optimizer.state:
orig_param = self.inner_param_to_orig_param[param]
new_state[orig_param] = optimizer.state[param]
if self.param_update_in_fp32:
new_state[orig_param]["master_param"] = param
self.state = new_state
def _sync_hdo_state_to_sub_optimizers(self):
for optimizer in self.sub_optimizers:
new_state = defaultdict(dict)
for group in optimizer.param_groups:
for param in group["params"]:
orig_param = self.inner_param_to_orig_param[param]
new_state[param] = self.state[orig_param]
optimizer.state = new_state
self._update_fp32_params_by_new_state()
self._move_new_state_to_right_device()
def _sync_hdo_param_groups_to_sub_optimizers(self):
"""Sync HDO new param_groups attribute (e.g. lr, wd, etc.) to sub-optimizers."""
param_in_param_group_index = {}
for i, group in enumerate(self.param_groups):
for p_id, param in enumerate(group["params"]):
inner_param = self.param_to_inner_param[param]
param_in_param_group_index[inner_param] = (i, p_id)
for optimizer in self.sub_optimizers:
new_param_groups = []
for group in optimizer.param_groups:
new_group = group.copy()
# After sync-up the sub-optimizer last update, we need to sync-up the
# HDO new param_groups attributes to the sub-optimizer.
assert len(group["params"]) > 0, "param_groups should not be empty"
group_id, _ = param_in_param_group_index[group["params"][0]]
update_group_attrs = self.param_groups[group_id].copy()
del update_group_attrs["params"]
new_group.update(update_group_attrs)
new_param_groups.append(new_group)
optimizer.param_groups = new_param_groups
def _move_new_state_to_right_device(self):
for optimizer in self.sub_optimizers:
for param, state in optimizer.state.items():
for k, v in state.items():
if not isinstance(v, torch.Tensor):
continue
orig_param = self.inner_param_to_orig_param.get(param, param)
if isinstance(optimizer, self.defaults["cpu_optimizer_cls"]):
self.state[orig_param][k] = state[k] = v.to("cpu")
else:
self.state[orig_param][k] = state[k] = v.to("cuda")
def _update_fp32_params_by_new_state(self):
if not self.param_update_in_fp32:
return
for param, v in self.state.items():
fp32_param = self.param_to_fp32_param[param]
fp32_param.data.copy_(v["master_param"])
def _register_load_state_dict_hooks(self):
def pre_load_state_dict_hook(self, state_dict):
"""
Pre-load state dictionary hook to prevent loss of precision in
mixed-precision training.
When loading a state dictionary with `torch.load_state_dict`,
optimizer states are reset and cast from `float32` to `bfloat16`/`float16`,
potentially losing precision. This hook replaces parameters with
their `float32` copies to mitigate this issue.
Args:
state_dict (dict): The state dictionary to be loaded.
Returns:
dict: The modified state dictionary with `float32` parameters.
"""
if not self.param_update_in_fp32:
return state_dict
new_state = {}
for param, v in self.state.items():
param = self.param_to_fp32_param.get(param, param)
new_state[param] = v
self.state = new_state
for group in self.param_groups:
for i, param in enumerate(group["params"]):
group["params"][i] = self.param_to_fp32_param.get(param, param)
return state_dict
self.register_load_state_dict_pre_hook(pre_load_state_dict_hook)
def post_load_state_dict_hook(self):
# 1. Replace the temporarily replaced fp32 parameters back. Please
# refer to the documentation in `pre_load_state_dict_hook`.
if self.param_update_in_fp32:
new_state = {}
for param, v in self.state.items():
orig_param = self.fp32_param_to_orig_param.get(param, param)
new_state[orig_param] = v
self.state = new_state
for group in self.param_groups:
for i, param in enumerate(group["params"]):
group["params"][i] = self.fp32_param_to_orig_param.get(param, param)
# 2. After loading state_dict, the parameters may change, and we need to
# reinitialize the sub-optimizers to regenerate the new parameters and
# cpu copy pairs.
self._init_sub_optimizers()
self._sync_hdo_param_groups_to_sub_optimizers()
self._sync_hdo_state_to_sub_optimizers()
self.register_load_state_dict_post_hook(post_load_state_dict_hook)
def zero_grad(self, set_to_none: bool = True):
"""
Zero or zero to none the gradients of all the parameters in the model.
"""
super(HybridDeviceOptimizer, self).zero_grad(set_to_none)
for group in self.param_groups:
for param in group["params"]:
if hasattr(param, "decoupled_grad"):
if set_to_none:
param.decoupled_grad = None
else:
param.decoupled_grad.zero_()
def dummy_step(self):
"""
The dummy step can be used to initialize the potential optimizer.state,
which can solve the problem of checkpoint loading for an inplace operation
such as loading a torch distributed checkpoint, for example.
"""
for group in self.param_groups:
for param in group["params"]:
param.grad = torch.randn_like(param)
self.step()
self.zero_grad()
@property
def sub_optimizers(self):
"""
Return the list of sub-optimizers.
"""
if self.gpu_optimizer is not None:
return self.cpu_optimizers + [self.gpu_optimizer]
return self.cpu_optimizers
......@@ -21,6 +21,8 @@ except ImportError:
HAVE_APEX_OR_TE = False
from megatron.core.optimizer.cpu_offloading import HybridDeviceOptimizer
from .. import tensor_parallel
from ..config_logger import has_config_logger_enabled, log_config_to_disk
from ..dist_checkpointing import ShardedTensor
......@@ -384,6 +386,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# When using precision-aware optimizer, main params are held by FusedAdam.
shard_main_param = None
# Store handle to main_param.
model_param.main_param = shard_main_param
model_param.main_param_sharded = True
# Add to group.
model_float16_params_this_group.append(model_param)
shard_float16_params_this_group.append(shard_model_param)
......@@ -438,7 +444,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
model_chunks: List[MegatronModule],
per_model_buffers: Dict[int, List[_ParamAndGradBuffer]],
data_parallel_group: torch.distributed.ProcessGroup,
data_parallel_group_gloo: torch.distributed.ProcessGroup,
data_parallel_group_gloo: Optional[torch.distributed.ProcessGroup],
data_parallel_group_idx: int,
distributed_optimizer_instance_id: int,
):
......@@ -482,10 +488,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.ddp_config = self.model_chunks[0].ddp_config
for model_chunk in self.model_chunks:
assert self.ddp_config == model_chunk.ddp_config
self.distributed_optimizer_instance_id = distributed_optimizer_instance_id
assert (
isinstance(optimizer, Adam) or optimizer is None
), "Only Adam currently supported, due to checkpointing requirements."
assert isinstance(optimizer, (Adam, HybridDeviceOptimizer)) or optimizer is None, (
"Only Adam and HybridDeviceOptimizer currently supported, "
"due to checkpointing requirements."
)
# when freezing sub-models we have no real optimizer
# but still need a stub DistributedOptimizer class
......@@ -493,6 +501,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.is_stub_optimizer = True
return
self.is_stub_optimizer = False
if self.ddp_config.use_custom_fsdp:
return
# Model grad buffer ranges.
assert per_model_buffers is not None, "per_model_buffers must be provided"
self.buffers = list(itertools.chain(*per_model_buffers.values()))
......@@ -500,7 +512,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.data_parallel_group = data_parallel_group
self.data_parallel_group_gloo = data_parallel_group_gloo
self.data_parallel_group_idx = data_parallel_group_idx
self.distributed_optimizer_instance_id = distributed_optimizer_instance_id
self.gbuf_idx_to_model_idx_map = {}
gbuf_idx = 0
......@@ -535,6 +546,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.gbuf_ranges.append(self._build_gbuf_range_map(buffer))
self.model_param_gbuf_map = self._build_model_param_gbuf_map(self.gbuf_ranges)
# Add main_param field to each parameter. We will use this fp32 copy to compute
# the param norm.
# For parameters with optimizer state on this rank, None will be overwritten by
# the corresponding sharded main_param tensor.
for param_group in self.optimizer.param_groups:
# For all the parameters in this group.
for param in param_group['params']:
if param.requires_grad:
# fp32 copy only needed for 16-bit parameters.
if param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
param.main_param = None
param.main_param_sharded = True
# Optimizer ranges.
(self.model_param_group_index_map, self.opt_group_ranges) = (
self._build_optimizer_group_ranges(self.optimizer.param_groups, self.gbuf_ranges)
......@@ -551,14 +575,14 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges, config
)
# Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors.
if isinstance(self.optimizer, HybridDeviceOptimizer):
self.optimizer = HybridDeviceOptimizer(
params=[g["orig_group"] for g in self.opt_group_ranges], **self.optimizer.defaults
)
else:
self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges]
self.optimizer.load_state_dict(self.optimizer.state_dict())
self.is_stub_optimizer = False
def _get_model_param_range_map(self, param: torch.nn.Parameter):
"""
Given a model param, get the index sub-range of the param that this
......@@ -593,6 +617,16 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
steps = list(set([s["step"].item() for s in inner_state_dict["state"].values()]))
assert len(steps) == 1
step = steps[0]
elif isinstance(self.optimizer, HybridDeviceOptimizer):
step = None
for optimizer in self.optimizer.sub_optimizers:
if isinstance(optimizer, torch.optim.AdamW):
if len(optimizer.state) == 0:
continue
steps = list(set([s["step"].item() for s in optimizer.state.values()]))
assert len(steps) == 1, f"steps: {optimizer.state}"
step = steps[0]
break
# Optimizer state (do not store parameter state here).
state_dict['optimizer'] = {k: v for k, v in inner_state_dict.items() if k != "state"}
......@@ -601,6 +635,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
if not HAVE_APEX_OR_TE:
# Native PyTorch param group requires step (i.e., iteration).
param_group["step"] = step
elif isinstance(self.optimizer, HybridDeviceOptimizer) and step is not None:
param_group["step"] = int(step)
# Grad scaler state.
if self.grad_scaler:
......@@ -637,6 +673,23 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
- state_order : The index of a parameter within the shared parameter
list.
"""
if len(self.optimizer.state) == 0:
if isinstance(self.optimizer, HybridDeviceOptimizer):
self.optimizer.dummy_step()
elif self.ddp_config.use_custom_fsdp:
# Initializes optimizer states with dummy values.
# This step is necessary to ensure that the optimizer's states are
# initialized correctly. These dummy states will be replaced in-place
# during the loading of distributed checkpoints.
for group in self.optimizer.param_groups:
for param in group["params"]:
if param.numel() == 0:
# Avoid FusedAdam errors on empty tensor input.
continue
param.grad = torch.randn_like(param)
self.optimizer.step()
self.optimizer.zero_grad()
# Get the Torch optimizer's state dict.
# - This 'inner' optimizer at this point is unallocated, and only
......@@ -699,6 +752,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for s in state_dict_state.values():
# Native PyTorch state dict requires step (i.e., iteration).
s["step"] = step
elif isinstance(self.optimizer, HybridDeviceOptimizer):
# Handle Torch AdamW special case, which, unlike FusedAdam, Torch AdamW
# has an extra optimizer state “step”.
steps = list(
set([g["step"] for g in state_dict["optimizer"]["param_groups"] if "step" in g])
)
if len(steps) != 0:
assert len(steps) == 1, f"steps: {steps}"
step = torch.tensor(steps[0], dtype=torch.float32, device="cpu")
for v in self.optimizer.state.values():
v["step"] = step.detach().clone()
# Optimizer.
self.optimizer.load_state_dict(
......@@ -725,6 +789,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
assert 'param_state_sharding_type' in state_dict, state_dict.keys()
param_state = state_dict['param_state']
sharding_type = state_dict['param_state_sharding_type']
if self.ddp_config.use_custom_fsdp:
assert (
sharding_type == "fully_sharded_model_space"
), "Only fully sharded model space is supported"
logger.info(f'Loading distributed optimizer sharded state of type {sharding_type}')
if sharding_type == 'dp_zero_gather_scatter':
self.load_parameter_state_from_dp_zero(param_state)
......@@ -746,11 +814,26 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
"exp_avg_sq": torch.Tensor
}
"""
if self.ddp_config.use_custom_fsdp:
for model_chunk in self.model_chunks:
pg_buffer = model_chunk.param_and_grad_buffer
if model_param not in pg_buffer.param_to_name:
continue
param_name = pg_buffer.param_to_name[model_param]
main_param = dict(pg_buffer.optimizer_named_parameters)[param_name]
assert param_name is not None, f"Not found main_param"
return {"param": main_param, **self.optimizer.state[main_param]}
group_index, group_order = self.model_param_group_index_map[model_param]
if self.config.use_precision_aware_optimizer:
sharded_model_param = self.optimizer.param_groups[group_index]["params"][group_order]
tensors = {}
for k in self.optimizer.state[sharded_model_param]:
if isinstance(self.optimizer, HybridDeviceOptimizer):
tensors[k] = self.optimizer.state[sharded_model_param][k]
continue
tensors[k] = self.optimizer.get_unscaled_state(sharded_model_param, k)
tensors["param"] = tensors.pop("master_param")
else:
......@@ -769,10 +852,32 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
"exp_avg_sq": torch.Tensor
}
"""
if self.ddp_config.use_custom_fsdp:
for model_chunk in self.model_chunks:
pg_buffer = model_chunk.param_and_grad_buffer
if model_param not in pg_buffer.param_to_name:
continue
param_name = pg_buffer.param_to_name[model_param]
main_param = dict(pg_buffer.optimizer_named_parameters)[param_name]
assert param_name is not None, f"Not found parameter"
for key in tensors:
if key == "param":
main_param.copy_(tensors[key])
else:
self.optimizer.state[main_param][key] = tensors[key]
return
group_index, group_order = self.model_param_group_index_map[model_param]
if self.config.use_precision_aware_optimizer:
sharded_model_param = self.optimizer.param_groups[group_index]["params"][group_order]
for k, v in tensors.items():
if isinstance(self.optimizer, HybridDeviceOptimizer):
if k == "param":
k = "master_param"
self.optimizer.state[sharded_model_param][k] = v
continue
if k == "param":
self.optimizer.set_scaled_state(sharded_model_param, "master_param", v)
else:
......@@ -829,8 +934,35 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
- Gather contiguous buffers on DP rank 0 and concatenate to world
buffers.
"""
if self.ddp_config.use_custom_fsdp:
state = {"buckets_coalesced": True}
for model_chunk in self.model_chunks:
pg_buffer = model_chunk.param_and_grad_buffer
for group_id, group in enumerate(pg_buffer.parameter_groups):
this_group_state = {}
mbuf = group.master_weight_buffer
for item_id, _ in enumerate(group.params):
main_param = mbuf.get_item(item_id)
optim_state = self.optimizer.state[main_param]
object_list = [None] * mbuf.dp_world_size
torch.distributed.all_gather_object(
object_list, optim_state, group=mbuf.data_parallel_group
)
for rank, obj in enumerate(object_list):
for name, value in obj.items():
assert torch.is_tensor(value), f"Expected tensor, got {type(value)}"
this_group_state.setdefault(name, []).append(value)
for name, values in this_group_state.items():
this_group_state[name] = torch.cat(values).cpu()
state[f"group_{group_id}"] = this_group_state
return state
# Data parallelism variables.
assert self.data_parallel_group_gloo is not None
data_parallel_world_size = self.data_parallel_group_gloo.size()
data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo)
data_parallel_group_gloo = self.data_parallel_group_gloo
......@@ -956,6 +1088,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
' Please switch to `full_sharded_model_space`.'
)
if self.ddp_config.use_custom_fsdp:
assert sharding_type == 'fully_sharded_model_space', (
f'For FSDP, only `fully_sharded_model_space` is supported. ' f'Got: {sharding_type}'
)
state_dict = self.state_dict()
if sharding_type != 'fully_sharded_model_space':
# State dict differs between different model parallel groups
......@@ -975,7 +1112,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# which conditionally skips re-allocating the optimizer's state if
# already initialized, which in turn reduces memory fragmentation.
self.load_state_dict(self.state_dict())
if sharding_type == 'fully_sharded_bucket_space':
param_state = self.sharded_param_state_fs_bucket_space(
model_sharded_state_dict, is_loading
......@@ -1017,7 +1153,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Fixed TPxPP. Save on DP rank 0 only
param_state = ShardedObject(
f'optimizer.distributed.dp_group_idx_{self.data_parallel_group_idx}.param_state',
param_state_data,
param_state_data, # pylint: disable=E0606
(1,),
(0,),
)
......@@ -1164,17 +1300,15 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Not stored in the checkpoint, used only to identify params in
# `sharded_param_state_fs_model_space`.
param_idx = 0
for gbuf_range_maps in self.gbuf_ranges:
for gbuf_range_map_for_all_buckets in gbuf_range_maps.values():
for gbuf_range_map in gbuf_range_map_for_all_buckets:
for model_param, param_range_map in gbuf_range_map["param_map"].items():
param_range = param_range_map['param']
def _get_param_state_sharded_tensors(model_param, item_slice):
# Main param & optimizer states.
tensors = self._get_main_param_and_optimizer_states(model_param)
tensors["fp32_param"] = tensors.pop("param")
# Match optimizer parameter with model ShardedTensor (or
# ShardedTensorFactory).
if self.ddp_config.use_custom_fsdp:
model_param = getattr(model_param, "fully_shard_param_local_shard", model_param)
try:
sharded_metadata = param_to_sharded_metadata[model_param]
except KeyError as e:
......@@ -1186,25 +1320,61 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
assert (
len(sharded_metadata.replica_id) == 3
), f'Expected replica_id format (PP, TP, DP), got: {sharded_metadata}'
replica_id = (
*sharded_metadata.replica_id[:2],
self.distributed_optimizer_instance_id,
)
replica_id = (*sharded_metadata.replica_id[:2], self.distributed_optimizer_instance_id)
# Instantiate ShardedTensor (or ShardedTensorFactory) for optimizer
# params.
for state_key, state_ten in tensors.items():
if state_key == 'step':
# Note that step is a 0-dim tensor, unlike other
# states have the same size as the parameter.
# The optimizer state of STEP is handled
# specifically and is read from param_groups.
continue
replace_kwargs = dict(
key=f'{prefix}.{state_key}.{sharded_metadata.key}',
data=state_ten,
dtype=state_ten.dtype,
flattened_range=slice(param_range.start, param_range.end),
flattened_range=item_slice,
replica_id=replica_id,
)
if isinstance(sharded_metadata, ShardedTensorFactory):
replace_kwargs.pop('dtype')
tensors[state_key] = replace(sharded_metadata, **replace_kwargs)
tensors[state_key].validate_metadata_integrity()
return tensors
if self.ddp_config.use_custom_fsdp:
for model_chunk in self.model_chunks:
pg_buffer = model_chunk.param_and_grad_buffer
for pg in pg_buffer.parameter_groups:
gbuf = pg.main_grad_buffer
if gbuf is None:
continue
for model_param in gbuf.params:
item_id = gbuf.param_idx[model_param]
param_name = pg_buffer.param_to_name[model_param]
item_slice = gbuf._get_item_slice_in_shard(item_id)
if item_slice[0] == item_slice[1]:
# This param is not in this shard.
continue
state[param_name] = _get_param_state_sharded_tensors(
model_param, slice(*item_slice)
)
return state
# Not stored in the checkpoint, used only to identify params in
# `sharded_param_state_fs_model_space`.
param_idx = 0
for gbuf_range_maps in self.gbuf_ranges:
for gbuf_range_map_for_all_buckets in gbuf_range_maps.values():
for gbuf_range_map in gbuf_range_map_for_all_buckets:
for model_param, param_range_map in gbuf_range_map["param_map"].items():
param_range = param_range_map['param']
tensors = _get_param_state_sharded_tensors(
model_param, slice(param_range.start, param_range.end)
)
state[param_idx] = tensors
param_idx += 1
return state
......@@ -1254,6 +1424,23 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
Inverse of the `sharded_param_state_fs_model_space` method.
"""
if self.ddp_config.use_custom_fsdp:
for model_chunk in self.model_chunks:
pg_buffer = model_chunk.param_and_grad_buffer
for model_param in pg_buffer.params:
param_name = pg_buffer.param_to_name[model_param]
if param_name not in state_dict:
continue
src_tensors = {}
for k, v in state_dict[param_name].items():
if k == "fp32_param":
src_tensors["param"] = v
else:
src_tensors[k] = v
self._set_main_param_and_optimizer_states(model_param, src_tensors)
return
param_idx = 0 # matching order with `sharded_param_state_fs_model_space`
for gbuf_range_maps in self.gbuf_ranges:
for gbuf_range_map_for_all_buckets in gbuf_range_maps.values():
......@@ -1261,12 +1448,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for model_param, param_range_map in gbuf_range_map["param_map"].items():
src_tensors = {}
for k, v in state_dict[param_idx].items():
if k == "step":
# Handle torch Adam "step" state separately.
continue
if k == "fp32_param":
src_tensors["param"] = v
else:
src_tensors[k] = v
self._set_main_param_and_optimizer_states(model_param, src_tensors)
param_idx += 1
if isinstance(self.optimizer, HybridDeviceOptimizer):
self.optimizer._sync_hdo_state_to_sub_optimizers()
@classmethod
def _update_legacy_world_tensors(cls, old_tensors, new_numels):
......@@ -1304,6 +1496,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
"""
# Data parallelism variables.
assert self.data_parallel_group_gloo is not None
data_parallel_world_size = self.data_parallel_group_gloo.size()
data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo)
data_parallel_group_gloo = self.data_parallel_group_gloo
......@@ -1419,6 +1612,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
return self.load_parameter_state_from_dp_zero_legacy(state_dict)
# Data parallelism variables.
assert self.data_parallel_group_gloo is not None
data_parallel_world_size = self.data_parallel_group_gloo.size()
data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo)
data_parallel_group_gloo = self.data_parallel_group_gloo
......@@ -1663,6 +1857,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
Args:
set_to_none (bool): if true, set grads to None.
"""
if self.ddp_config.use_custom_fsdp:
for model_chunk in self.model_chunks:
model_chunk.zero_grad_buffer()
return
if self.is_stub_optimizer:
return
total_groups = [
......@@ -1725,6 +1924,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
if self.is_stub_optimizer:
return
if self.ddp_config.use_custom_fsdp:
return
# Utility method for copying group grads.
def copy_group_grads(model_groups, shard_main_groups):
for model_group, shard_main_group in zip(model_groups, shard_main_groups):
......@@ -1765,6 +1967,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
if self.is_stub_optimizer:
return
if self.ddp_config.use_custom_fsdp:
for model_chunk in self.model_chunks:
model_chunk.param_and_grad_buffer.copy_main_weights_to_model_weights()
return
# Utility method for copying group params.
def copy_group_params(shard_main_groups, model_groups):
for shard_main_group, model_group in zip(shard_main_groups, model_groups):
......@@ -1820,6 +2027,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
the model params. This copy does not make use of the grad buffer as
an intermediary.
"""
if isinstance(self.optimizer, HybridDeviceOptimizer):
return
if self.ddp_config.use_custom_fsdp:
for model_chunk in self.model_chunks:
model_chunk.param_and_grad_buffer.copy_model_weights_to_main_weights()
return
# Utility method for copying group params.
def copy_group_params(model_groups, shard_main_groups):
......@@ -1849,13 +2063,22 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
"""
if self.is_stub_optimizer:
return
if self.ddp_config.use_custom_fsdp:
buffers = []
for m in self.model_chunks:
for group in m.param_and_grad_buffer.parameter_groups:
mbuf = group.model_weight_buffer
buffers.append(mbuf)
else:
buffers = self.buffers
# Iterate over all parameters inside this optimizer to find FP8 parameters.
for buffer in buffers:
amaxes = []
scales = []
scale_invs = []
# Iterate over all parameters inside this optimizer to find FP8 parameters.
for buffer in self.buffers:
for bucket in buffer.buckets:
for param in bucket.params_list:
for param in buffer.params:
if is_float8tensor(param):
fp8_meta = param._fp8_meta['scaling_fwd']
fp8_meta_index = param._fp8_meta_index
......@@ -1870,7 +2093,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
# Update scaling factors.
packed_scales = torch.empty(len(scales), dtype=torch.float32, device=scales[0].device)
packed_scales = torch.empty(
len(scales), dtype=torch.float32, device=scales[0].device
)
packed_scale_views = [packed_scales[i].view(1) for i in range(len(scales))]
_multi_tensor_copy_this_to_that(scales, packed_scale_views, dummy_overflow_buf)
torch.reciprocal(packed_scales, out=packed_scales)
......@@ -1878,11 +2103,15 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Reduce amaxes.
# Note: Assume each param has a separate amax.
packed_amaxes = torch.empty(len(amaxes), dtype=torch.float32, device=amaxes[0].device)
packed_amaxes = torch.empty(
len(amaxes), dtype=torch.float32, device=amaxes[0].device
)
packed_amax_views = [packed_amaxes[i].view(1) for i in range(len(amaxes))]
_multi_tensor_copy_this_to_that(amaxes, packed_amax_views, dummy_overflow_buf)
torch.distributed.all_reduce(
packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=self.data_parallel_group
packed_amaxes,
op=torch.distributed.ReduceOp.MAX,
group=buffer.data_parallel_group,
)
_multi_tensor_copy_this_to_that(packed_amax_views, amaxes, dummy_overflow_buf)
......@@ -1900,6 +2129,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
timers = self.config.timers
if timers is not None:
timers('params-all-gather', log_level=1).start(barrier=self.config.barrier_with_L1_time)
if self.ddp_config.use_custom_fsdp:
for model_chunk in self.model_chunks:
model_chunk.start_param_sync()
else:
# If not overlapping all-gather for parameters, launch synchronous all-gather
# communication calls here. If overlapping all-gather for parameters, the following
# the first all-gather is launched asynchronously in the next optimizer.zero_grad()
......
......@@ -298,7 +298,7 @@ class MegatronOptimizer(ABC):
"""
@staticmethod
def _extract_common_per_param_step(state_dict) -> Union[int, torch.Tensor]:
def _extract_common_per_param_step(state_dict) -> Union[int, torch.Tensor, None]:
common_step = None
for param_idx, param_state in state_dict['state'].items():
param_step = param_state.get('step', None)
......@@ -374,6 +374,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
return self.grad_scaler.scale
def reload_model_params(self):
if self.param_groups:
self._copy_model_params_to_main_params()
def _unscale_main_grads_and_check_for_nan(self):
......@@ -555,6 +556,9 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
# Store handle to main_param.
param.main_param = main_param
fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param.
if param in self.optimizer.state:
......@@ -708,6 +712,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
)
# save step as a shared step among all parameters. Separate per-parameter
# steps are not supported
if step:
state_dict['optimizer']['state']['common_step'] = step
return state_dict
......@@ -884,6 +889,7 @@ class FP32Optimizer(MegatronOptimizer):
optim_state_to_sharding_state(state_dict, id_to_sharded_param_map, exclude_keys="step")
# save step as a shared step among all parameters. Separate per-parameter
# steps are not supported
if step:
state_dict['state']['common_step'] = step
return state_dict
......
......@@ -114,6 +114,34 @@ class OptimizerConfig:
overlap_param_gather_with_optimizer_step: bool = False
"""If true, overlap param all-gather of first bucket with optimizer step."""
#######################
# Optimizer Offload
#######################
optimizer_cpu_offload: bool = False
"""If True, offload optimizer states tensor and compute to CPU."""
optimizer_offload_fraction: float = 0.0
"""Specifies the fraction of optimizer states to offload from GPU memory to CPU."""
use_torch_optimizer_for_cpu_offload: bool = False
"""If True, use torch.optim.Optimizer for CPU offload."""
overlap_cpu_optimizer_d2h_h2d: bool = False
"""
When set to `True`, this flag enables overlapping of the CPU optimizer
update process with the data transfer operations. This can help improve
overall training efficiency by reducing idle time during data movement,
allowing the optimizer to perform updates while gradients and parameters
are being transferred between devices.
"""
pin_cpu_grads: bool = True
"""If True, pin the optimizer gradients to CPU memory."""
pin_cpu_params: bool = True
"""If True, pin the optimizer parameters to CPU memory."""
################
# Miscellaneous
################
......@@ -142,8 +170,11 @@ class OptimizerConfig:
self.use_distributed_optimizer
), '--use-precision-aware-optimizer only supported with distributed optimizer'
# Only the FusedAdam in TE supports --use-precision-aware-optimizer.
# Only the FusedAdam in TE and HybridDeviceOptimizer supports
# --use-precision-aware-optimizer.
# TODO: Remove this check when apex's FusedAdam is no longer used.
if self.optimizer_cpu_offload:
return
try:
import inspect
......
......@@ -2,7 +2,7 @@
MAJOR = 0
MINOR = 10
MINOR = 12
PATCH = 0
PRE_RELEASE = 'rc0'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment