Commit 2c63b5cd authored by wangxj's avatar wangxj
Browse files

升级0.12版本

parent c271aaae
Pipeline #2451 passed with stage
......@@ -11,8 +11,9 @@ from megatron.core.config_logger import has_config_logger_enabled, log_config_to
from megatron.core.models.gpt import GPTModel
from megatron.core.models.vision.clip_vit_model import CLIPViTModel, get_num_image_embeddings
from megatron.core.models.vision.multimodal_projector import MultimodalProjector
from megatron.core.models.vision.radio import RADIOViTModel
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.parallel_state import get_context_parallel_group, get_context_parallel_world_size
from megatron.core.parallel_state import get_context_parallel_rank, get_context_parallel_world_size
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
......@@ -20,12 +21,17 @@ from megatron.core.utils import log_single_rank
try:
import transformer_engine # pylint: disable=unused-import
from transformer_engine.pytorch.distributed import gather_along_first_dim
from megatron.core.extensions.transformer_engine import TEDotProductAttention
from megatron.core.utils import is_te_min_version
HAVE_TE = True
try:
import transformer_engine_torch as tex
HAVE_TEX = True
except:
HAVE_TEX = False
except:
HAVE_TE = False
if get_context_parallel_world_size() > 1:
......@@ -36,6 +42,53 @@ IGNORE_INDEX = -100 # ID for labels that should be ignored.
# Image token index can be tokenizer dependent so the default value does not work in all cases.
DEFAULT_IMAGE_TOKEN_INDEX = -200
IMAGE_TOKEN = "<image>"
VIDEO_TOKEN = "<video>"
class _get_data_on_this_cp_rank(torch.autograd.Function):
"""Performs sharding for Context Parallelism in THD format
In the forward pass, indices are selected for each CP rank and remaining tokens are dropped.
In the backward pass, this class takes care of managing gradients for dropped tokens on each
CP rank.
"""
@staticmethod
def forward(ctx, batch, packed_seq_params):
"""Context Parallelism forward support for THD format"""
cp_size = get_context_parallel_world_size()
cp_rank = get_context_parallel_rank()
for key, data in batch.items():
index = tex.thd_get_partitioned_indices(
packed_seq_params.cu_seqlens_q_padded, data.size(1), cp_size, cp_rank
)
if key == "combined_embeddings":
ctx.decoder_emb_index = index
ctx.decoder_emb_seqlen = data.size(1)
batch[key] = data.index_select(1, index)
batch[key].requires_grad = data.requires_grad
return batch
@staticmethod
def backward(ctx, grad_out, grad_label, grad_loss):
"""Context Parallelism backward support for THD format"""
seqlen = ctx.decoder_emb_seqlen
index = ctx.decoder_emb_index
assert grad_out.size(1) == index.size(
0
), f"Shape mismatch in incoming gradient {grad_out.shape} and \
index from THD CP sharding {index.shape}"
grad_in = torch.zeros(
grad_out.size(0),
seqlen,
*grad_out.size()[2:],
dtype=grad_out.dtype,
device=grad_out.device,
)
grad_in[:, ctx.decoder_emb_index, :] = grad_out
return (grad_in, None, None, None)
# Note: This is under development and may be missing features.
......@@ -57,6 +110,7 @@ class LLaVAModel(MegatronModule):
missing when loading a checkpoint. Default False.
parallel_output (bool): Keep outputs split across tensor parallel ranks.
This is typically True for training and False for inference.
share_embeddings_and_output_weights (bool): Input embedding and output layer share weights.
language_position_embedding_type (str): Language model position embedding type.
language_rotary_percent (float): RoPE percent. Defaults to 1.0.
pre_process (bool): Include embedding layer in the decoder (used with pipeline parallel).
......@@ -70,6 +124,7 @@ class LLaVAModel(MegatronModule):
patch_dim (int): The size of each image patch side.
language_rotary_base (int): RoPE base.
language_rope_scaling (bool): Toggle RoPE scaling.
language_rope_scaling_factor (float): RoPE scaling factor. Defaults to 8.
image_token_index (int): Token ID for image token such as <image>.
pixel_shuffle (bool): Enable pixel shuffle.
tile_tags (list): Optional tile tags.
......@@ -89,6 +144,7 @@ class LLaVAModel(MegatronModule):
vision_projection_type: str = "mlp",
allow_missing_vision_projection_checkpoint: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
language_position_embedding_type: str = 'learned_absolute',
language_rotary_percent: float = 1.0,
pre_process: bool = True,
......@@ -100,6 +156,7 @@ class LLaVAModel(MegatronModule):
patch_dim: int = 14,
language_rotary_base: int = 10000,
language_rope_scaling: bool = False,
language_rope_scaling_factor: float = 8.0,
image_token_index: int = DEFAULT_IMAGE_TOKEN_INDEX,
pixel_shuffle: bool = False,
tile_tags: Optional[list] = None,
......@@ -142,51 +199,102 @@ class LLaVAModel(MegatronModule):
# This attribute is needed to check if an all-reduce is required
# on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`.
self.share_embeddings_and_output_weights = False
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
if self.add_decoder:
self.language_model = GPTModel(
config=language_transformer_config,
transformer_layer_spec=language_transformer_layer_spec,
vocab_size=language_vocab_size,
max_sequence_length=language_max_sequence_length,
parallel_output=parallel_output,
position_embedding_type=language_position_embedding_type,
rotary_percent=language_rotary_percent,
pre_process=self.pre_process,
post_process=self.post_process,
rotary_base=language_rotary_base,
rope_scaling=language_rope_scaling,
scatter_embedding_sequence_parallel=False,
)
self.share_embeddings_and_output_weights = (
self.language_model.share_embeddings_and_output_weights
)
if hasattr(
language_transformer_config, "language_model_type"
) and language_transformer_config.language_model_type.startswith("huggingface"):
from megatron.core.models.huggingface.module import build_hf_model
self.language_model = build_hf_model(language_transformer_config)
else:
self.language_model = GPTModel(
config=language_transformer_config,
transformer_layer_spec=language_transformer_layer_spec,
vocab_size=language_vocab_size,
max_sequence_length=language_max_sequence_length,
parallel_output=parallel_output,
position_embedding_type=language_position_embedding_type,
rotary_percent=language_rotary_percent,
pre_process=self.pre_process,
post_process=self.post_process,
rotary_base=language_rotary_base,
rope_scaling=language_rope_scaling,
scatter_embedding_sequence_parallel=False,
)
self.share_embeddings_and_output_weights = (
self.language_model.share_embeddings_and_output_weights
)
self._language_max_sequence_length = language_max_sequence_length
self._language_is_pipeline_parallel = (
language_transformer_config.pipeline_model_parallel_size > 1
)
# Newer Transformer Engine versions add _extra_state keys in state_dict when using FP8.
# Older models may not have _extra_state and can be ignored.
self.language_model.register_load_state_dict_post_hook(
_load_state_dict_hook_ignore_extra_state
)
class_token_len = 1
if self.add_encoder:
self._drop_vision_class_token = drop_vision_class_token
add_class_token = True
if vision_transformer_config.vision_model_type == "siglip":
class_token_len = 0
add_class_token = False
error_msg = (
"Siglip does not support vision class token, "
"set disable-vision-class-token to False."
if vision_transformer_config.vision_model_type.startswith(
("clip", "siglip", "internvit")
):
if vision_transformer_config.vision_model_type == "siglip":
class_token_len = 0
add_class_token = False
error_msg = (
"Siglip does not support vision class token, "
"set disable-vision-class-token to False."
)
assert not self._drop_vision_class_token, error_msg
self.vision_model = CLIPViTModel(
vision_transformer_config,
vision_transformer_layer_spec,
img_h=img_h,
img_w=img_w,
class_token_len=class_token_len,
patch_dim=patch_dim,
model_subtype=vision_transformer_config.vision_model_type,
add_class_token=add_class_token,
)
elif vision_transformer_config.vision_model_type in ("radio"):
# TODO: should refactor into model code itself?
class_token_len = 8
max_img_h = 2048
max_img_w = 2048
embedder_bias = False
use_mask_token = False
self.vision_model = RADIOViTModel(
vision_transformer_config,
vision_transformer_layer_spec,
img_h=img_h,
img_w=img_w,
max_img_h=max_img_h,
max_img_w=max_img_w,
class_token_len=class_token_len,
patch_dim=patch_dim,
add_class_token=add_class_token,
embedder_bias=embedder_bias,
use_mask_token=use_mask_token,
)
elif vision_transformer_config.vision_model_type.startswith("huggingface"):
from megatron.core.models.huggingface.module import build_hf_model
self.vision_model = build_hf_model(vision_transformer_config)
else:
raise ValueError(
"Vision model "
f"{vision_transformer_config.vision_model_type} is not "
"supported."
)
assert not self._drop_vision_class_token, error_msg
self.vision_model = CLIPViTModel(
vision_transformer_config,
vision_transformer_layer_spec,
img_h=img_h,
img_w=img_w,
class_token_len=class_token_len,
patch_dim=patch_dim,
model_subtype=vision_transformer_config.vision_model_type,
add_class_token=add_class_token,
self.vision_model.register_load_state_dict_post_hook(
_load_state_dict_hook_ignore_extra_state
)
vision_projection_input_size = vision_transformer_config.hidden_size
......@@ -212,7 +320,7 @@ class LLaVAModel(MegatronModule):
partial(_load_state_dict_hook_ignore_param_names, vision_projection_param_names)
)
self._img_seq_len = get_num_image_embeddings(
self.img_seq_len = get_num_image_embeddings(
img_h,
img_w,
patch_dim,
......@@ -286,7 +394,6 @@ class LLaVAModel(MegatronModule):
inference_params,
image_token_index,
num_image_tiles,
image_token_mask=None,
):
"""Preprocess input data before input to language model.
......@@ -334,11 +441,8 @@ class LLaVAModel(MegatronModule):
if use_inference_kv_cache:
return language_embeddings, loss_mask, labels
img_seq_len = self._img_seq_len
img_seq_len = self.img_seq_len
batch_size, text_seq_len = input_ids.shape
# input_ids seq len is expected to be sharded by CP size
if self.context_parallel_lm:
text_seq_len *= self.context_parallel_lm
has_labels = labels is not None
if has_labels:
......@@ -348,12 +452,7 @@ class LLaVAModel(MegatronModule):
# Create indices for new text and label positions.
with torch.no_grad():
if image_token_mask is None:
assert (
self.context_parallel_lm <= 1
), "image_token_mask cannot be inferred from input_ids if using \
Context Parallelism. Please provide in forward_step"
image_token_mask = input_ids == image_token_index
image_token_mask = input_ids == image_token_index
num_images_per_sample = torch.sum(image_token_mask, dim=-1)
# Number of tiles per sample.
......@@ -387,8 +486,10 @@ class LLaVAModel(MegatronModule):
new_position_ids = torch.cumsum((image_token_mask_lens + 1), dim=-1) - 1
text_position_ids = new_position_ids[batch_indices, non_image_indices]
label_batch_indices = None # dummy value to pass formatting
# Labels are shifted to left by one.
# So, shift text position ids and non-image indices to left by one.
label_batch_indices = None
if has_labels:
label_text_position_ids = text_position_ids - 1
valid_label_text_position_ids = label_text_position_ids >= 0
......@@ -433,9 +534,17 @@ class LLaVAModel(MegatronModule):
]
# Put image embeddings to image positions.
final_embedding[images_mask] = (
image_embeddings.permute(1, 0, 2).reshape(-1, embed_dim).contiguous()
)
# NOTE: FSDP can hang with text-only samples so we use a workaround to run a dummy image
# through the vision model and then zero-out the impact of the output here.
if num_image_tiles.shape[0] == 0 and image_embeddings.shape[0] > 0:
assert images_mask.sum() == 0 and getattr(
self.vision_model, "_is_fsdp_managed_module", False
), "expected FSDP and dummy image"
final_embedding[:1, :1, :1] += 0 * image_embeddings[:1, :1, :1]
else:
final_embedding[images_mask] = (
image_embeddings.permute(1, 0, 2).reshape(-1, embed_dim).contiguous()
)
# Create the final labels and loss mask (if this is the last language model stage).
final_labels, final_loss_mask = None, None
......@@ -488,7 +597,7 @@ class LLaVAModel(MegatronModule):
# Truncate if exceeding the language model's max sequence length.
if final_embedding.shape[1] > self._language_max_sequence_length:
final_embedding = final_embedding[:, : self._language_max_sequence_length]
# Transpose to [s,b,h] if not using CP because CP Sharding expects seq in dim=1
# Transpose to [s,b,h] only if not using CP because CP Sharding expects seq in dim=1
if self.context_parallel_lm == 1:
final_embedding = final_embedding.transpose(1, 0).contiguous()
......@@ -507,18 +616,12 @@ class LLaVAModel(MegatronModule):
"""Processes the input data for model parallelism support.
When using sequence parallelism (SP) or context parallelism (CP), the sequence is sharded
across different GPUs. This function helps ensure that the sharding is done correctly by
1. Calculates `padding_factor` which determines based on how many chunks we expect to shard
the sequence
2. Calculates and pads the inputs to necessary length to ensure equal sized chunks
3. Creates/Modifies PackedSeqParams which helps mask padded tokens during calculations
4. Performs any layout changes if necessary
5. Distributes the sequence across GPUs for SP and CP
across different GPUs. This function performs the sharding and distributes the sequence
across GPUs for SP and CP
Context Parallelism is a feature that helps improve memory efficiency for
long sequence training by distributing sequence across CP ranks.
It requires token length to be divisible by (CP size *2) to ensure proper load balance.
Please refer to `get_batch_on_this_cp_rank` function for more details.
Sequence Parallelism is a feature that helps improve memory efficiency for
long sequence training by distributing sequence across TP ranks.
......@@ -531,143 +634,62 @@ class LLaVAModel(MegatronModule):
packed_seq_params (PackedSeqParams): Dict with padded token information.
"""
# combined_embeddings - `s,b,h` if not using CP, `b,s,h` if using CP
batch_size = (
combined_embeddings.shape[0]
if self.context_parallel_lm > 1
else combined_embeddings.shape[1]
)
seq_dim = 1 if self.context_parallel_lm > 1 else 0
padding_mask_type = 'padding' in str(
self.language_model.transformer_layer_spec.submodules.self_attention.params.get(
'attn_mask_type', ''
)
)
if self.sequence_parallel_lm and self.tp_comm_overlap_lm:
assert (
combined_embeddings.shape[seq_dim] == self._language_max_sequence_length
) or padding_mask_type, f"TP Comm overlap either requires Vision+Text token length \
== language_max_sequence_length or mask type to be set to padding/padding_causal"
if padding_mask_type:
# Calculate the padded sequence length needed to support SP and CP
# SP and CP are used to distributed the sequence across GPUs to improve
# memory efficiency and enable very long context training.
# To distribute workload equally, we need to ensure that the sequence is
# divisible by the appropriate padding factor calculated below.
padding_factor = None
padded_seq_len = None
mp_padding_needed = 0
# No pre or post processing needed with PP middle chunks.
if not self.pre_process and not self.post_process:
return combined_embeddings, new_labels, new_loss_mask, packed_seq_params
shard_factor = seq_dim = None
if self.pre_process:
if self.context_parallel_lm > 1 and self.sequence_parallel_lm:
padding_factor = self.tensor_model_parallel_size_lm * self.context_parallel_lm * 2
shard_factor = self.tensor_model_parallel_size_lm * self.context_parallel_lm * 2
seq_dim = 1
elif self.context_parallel_lm > 1:
padding_factor = self.context_parallel_lm * 2
shard_factor = self.context_parallel_lm * 2
seq_dim = 1
elif self.sequence_parallel_lm:
padding_factor = self.tensor_model_parallel_size_lm
padded_seq_len = int(
(combined_embeddings.shape[seq_dim] + (padding_factor - 1))
// padding_factor
* padding_factor
)
shard_factor = self.tensor_model_parallel_size_lm
seq_dim = 0
assert (
padded_seq_len <= self._language_max_sequence_length
), f"Sequence length after padding {padded_seq_len} for SP/CP has exceeded \
language_max_sequence_length. Ensure language_max_sequence_length is \
divisible by SP/CP factor: {padding_factor}"
combined_embeddings.shape[seq_dim] % shard_factor == 0
), f"Sequence length should be divisible by {shard_factor} for \
Sequence/Context parallelism"
if self.sequence_parallel_lm and self.tp_comm_overlap_lm:
# TP Comm overlap initializes the user buffer shape used for communication
# at the beginning of training run and the same shape is expected to be
# used throughout the training.
# Pad to language_max_sequence_length to use TP Comm overlap.
assert (
self._language_max_sequence_length % padding_factor == 0
), f"TP Comm overlap uses language_max_sequence_length \
which needs to be divisible by SP/CP factor {padding_factor}"
padded_seq_len = self._language_max_sequence_length
assert (
packed_seq_params is not None
), "Please provide PackedSeqParams dict when using SP or CP with padding"
valid_seqlens = packed_seq_params.cu_seqlens_q[1:] - packed_seq_params.cu_seqlens_q[:-1]
valid_seq_len = max(valid_seqlens)
assert (
padded_seq_len >= valid_seq_len
), f"Padded Seq Len calculated for model parallelism: {padded_seq_len} \
is shorter than expected valid token len {valid_seq_len} provided."
mp_padding_needed = padded_seq_len - combined_embeddings.shape[seq_dim]
if mp_padding_needed > 0:
new_labels = torch.nn.functional.pad(
new_labels, (0, mp_padding_needed), value=IGNORE_INDEX
)
new_loss_mask = torch.nn.functional.pad(new_loss_mask, (0, mp_padding_needed))
if self.context_parallel_lm > 1:
combined_embeddings = torch.nn.functional.pad(
combined_embeddings, (0, 0, 0, mp_padding_needed)
)
else:
combined_embeddings = torch.nn.functional.pad(
combined_embeddings, (0, 0, 0, 0, 0, mp_padding_needed)
)
# Update PackedSeqParams if padding needed beyond user provided PackedSeqParams
packed_seq_params.max_seqlen_q = padded_seq_len
packed_seq_params.max_seqlen_kv = padded_seq_len
cu_seqlens_padded = None
# We need cu_seqlens_q_padded/cu_seqlens_kv_padded when doing
# CP+Padding to support accurate Attention with THD format.
if self.context_parallel_lm > 1:
cu_seqlens_padded = torch.arange(
0,
(batch_size + 1) * (padded_seq_len),
step=(padded_seq_len),
dtype=torch.int32,
device=combined_embeddings.device,
)
packed_seq_params.cu_seqlens_q_padded = cu_seqlens_padded
packed_seq_params.cu_seqlens_kv_padded = cu_seqlens_padded
packed_seq_params.qkv_format = 'thd'
else:
packed_seq_params.qkv_format = 'sbhd'
combined_embeddings.shape[seq_dim] == self._language_max_sequence_length
), f"TP Comm overlap either requires Vision+Text token length \
== language_max_sequence_length"
if self.context_parallel_lm > 1:
batch = dict()
if self.pre_process:
batch["combined_embeddings"] = combined_embeddings
if self.post_process:
batch["new_labels"] = new_labels
batch["new_loss_mask"] = new_loss_mask
# Distribute sequence across CP ranks
from megatron.training.utils import get_batch_on_this_cp_rank
batch = get_batch_on_this_cp_rank(
{
"combined_embeddings": combined_embeddings,
"new_labels": new_labels,
"new_loss_mask": new_loss_mask,
}
)
if packed_seq_params is None or packed_seq_params.qkv_format == 'sbhd':
from megatron.training.utils import get_batch_on_this_cp_rank
combined_embeddings = batch["combined_embeddings"] # [B, S/CP, H]
new_labels = batch["new_labels"]
new_loss_mask = batch["new_loss_mask"]
if getattr(packed_seq_params, 'qkv_format', None) == 'thd':
# If PackedSeqParams requires THD format,
# reshape embedding from [B,S,H] to [T,1,H] where T=B*S
combined_embeddings = (
combined_embeddings.contiguous()
.view(combined_embeddings.shape[0] * combined_embeddings.shape[1], -1)
.unsqueeze(1)
)
new_labels = new_labels.view(new_labels.shape[0] * new_labels.shape[1]).unsqueeze(0)
new_loss_mask = new_loss_mask.view(
new_loss_mask.shape[0] * new_loss_mask.shape[1]
).unsqueeze(0)
batch = get_batch_on_this_cp_rank(batch)
else:
assert HAVE_TEX and is_te_min_version(
"1.10.0"
), "Please update Transformer Engine to >= 1.10 to use \
Context Parallel with THD format data"
batch = _get_data_on_this_cp_rank.apply(batch, packed_seq_params)
if self.pre_process:
combined_embeddings = batch["combined_embeddings"] # [B, S/CP, H]
combined_embeddings = combined_embeddings.transpose(
1, 0
).contiguous() # [B,S/CP,H] -> [S/CP,B,H]
if self.post_process:
new_labels = batch["new_labels"]
new_loss_mask = batch["new_loss_mask"]
if self.sequence_parallel_lm:
if self.sequence_parallel_lm and self.pre_process:
combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(
combined_embeddings
) # [S/(CP*TP),B,H]
......@@ -722,7 +744,6 @@ class LLaVAModel(MegatronModule):
num_image_tiles: Optional[List[int]] = None,
image_token_index: Optional[int] = None,
runtime_gather_output: Optional[bool] = None,
image_token_mask: Optional[torch.Tensor] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
) -> torch.Tensor:
"""Forward function of the LLaVA model.
......@@ -744,8 +765,6 @@ class LLaVAModel(MegatronModule):
arg in the constructor will be used.
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
image_token_mask (torch.Tensor): Tensor indicating the location of
image token index in input_ids.
packed_seq_params (PackedSeqParams): 1) If using sequence packing, must contain
subsample length information. 2) If using SP/CP with padding mask type,
must contain padded token information.
......@@ -817,18 +836,13 @@ class LLaVAModel(MegatronModule):
language_embeddings = self.language_model.embedding(
input_ids=input_ids_text, position_ids=position_ids
) # [text_seq_len, b, h_language]
# Gather the language embeddings back. We need the full embedding to insert
# image embeddings and then scatter again to avoid load imbalance.
if self.context_parallel_lm > 1:
cp_group = get_context_parallel_group()
language_embeddings, _ = gather_along_first_dim(language_embeddings, cp_group)
language_embeddings = language_embeddings.transpose(
1, 0
).contiguous() # [b, text_seq_len, h_language]
# Assume 1 tile per image if the number of tiles is not provided.
if num_image_tiles is None:
if num_image_tiles is None and images is not None:
num_image_tiles = torch.ones(images.shape[0], dtype=torch.int, device=input_ids.device)
combined_embeddings, new_labels, new_loss_mask = self._preprocess_data(
......@@ -841,7 +855,6 @@ class LLaVAModel(MegatronModule):
inference_params,
image_token_index if image_token_index is not None else self.image_token_index,
num_image_tiles,
image_token_mask,
) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len]
if self.context_parallel_lm > 1 or self.sequence_parallel_lm:
......@@ -889,6 +902,28 @@ def _load_state_dict_hook_ignore_param_names(
incompatible_keys.missing_keys.remove(param_name)
def _load_state_dict_hook_ignore_extra_state(
module: torch.nn.Module, incompatible_keys: namedtuple
):
"""Hook to ignore Transformer Engine _extra_state used for FP8.
This is for backwards-compatibility. Newer TE versions add _extra_state keys to the state dict,
while older models might not have those keys. Those keys can be ignored when not using FP8.
Args:
module (torch.nn.Module): The torch module this hook applies to. Required by the torch API.
incompatible_keys (namedtuple): Namedtuple with fields missing_keys and unexpected_keys,
which collect the missing and unexpected keys, respectively.
"""
for name, keys in incompatible_keys._asdict().items():
for key in keys[::-1]:
if "extra_state" in key:
logging.getLogger(__name__).warning(
f"_extra_state key {key} being removed from {name}"
)
keys.remove(key)
# pylint: disable-next=line-too-long
# Based on https://github.com/OpenGVLab/InternVL/blob/c7c5af1a8930b4862afe8ed14672307082ef61fa/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py#L218
# Copyright (c) 2023 OpenGVLab.
......
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Optional
from megatron.core.extensions.transformer_engine import (
TEDotProductAttention,
TELayerNormColumnParallelLinear,
......@@ -6,7 +8,7 @@ from megatron.core.extensions.transformer_engine import (
TERowParallelLinear,
)
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec
from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
......@@ -27,15 +29,15 @@ except ImportError:
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn(f'Apex is not installed. Falling back to Torch Norm')
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
def decoder_model_with_transformer_engine_default_spec(
num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
num_experts: Optional[int] = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
) -> ModuleSpec:
"""LLava decoder TE spec (uses Transformer Engine components)."""
mlp = _get_mlp_module_spec(
mlp = get_mlp_module_spec(
use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
)
return ModuleSpec(
......@@ -60,10 +62,10 @@ def decoder_model_with_transformer_engine_default_spec(
def decoder_model_with_local_default_spec(
num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
num_experts: Optional[int] = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
) -> ModuleSpec:
"""LLava decoder local spec."""
mlp = _get_mlp_module_spec(
mlp = get_mlp_module_spec(
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
)
return ModuleSpec(
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -201,6 +201,11 @@ def get_num_image_embeddings(
keep_class_token = False
elif vision_model_type in ("clip", "internvit"):
keep_class_token = not disable_vision_class_token
elif vision_model_type.startswith("radio"):
keep_class_token = not disable_vision_class_token
elif vision_model_type.startswith("huggingface"):
# TODO: Temp, what do we do in this sitaution?
keep_class_token = True
else:
raise ValueError(f"unsupported vision model: {vision_model_type}")
......
File mode changed from 100755 to 100644
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.models.common.vision_module.vision_module import VisionModule
from megatron.core.tensor_parallel.layers import ColumnParallelLinear
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
# RADIO reference code: https://github.com/NVlabs/RADIO
class RADIOViTModel(VisionModule):
"""RADIO ViT vision model.
Args:
transformer_config (TransformerConfig): Transformer config.
transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers.
ln_pre_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_pre.
ln_post_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_post.
use_mask_token (bool, optional): Whether to use RADIO mask token. Default to False.
add_class_token (bool, optional): Include a class token. Defaults to True.
class_token_len (int): Class token length. Defaults to 1 but 8 may be faster.
patch_dim (int): Image patch size.
img_h (int): Input image height.
img_w (int): Input image width.
max_img_h (int): Max input image height.
max_img_w (int): Max input image width.
pos_dropout (int): Positional encoding dropout value. Defaults to 0.
has_cpe: (bool): Whether to use conditional positional encoding. Defaults to True.
embedder_bias: (bool): Bias in embedder linear. Defaults to False.
"""
def __init__(
self,
transformer_config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
ln_pre_impl: Union[ModuleSpec, type] = None,
ln_post_impl: Union[ModuleSpec, type] = None,
use_mask_token: bool = False,
add_class_token: bool = True,
class_token_len: int = 8,
patch_dim: int = 16,
img_h: int = 224,
img_w: int = 224,
max_img_h: int = 2048,
max_img_w: int = 2048,
pos_dropout: int = 0,
has_cpe: bool = True,
embedder_bias: bool = False,
) -> None:
super().__init__(config=transformer_config)
if has_config_logger_enabled(transformer_config):
log_config_to_disk(transformer_config, locals(), prefix=type(self).__name__)
self.class_token_len = class_token_len
self.visual_hidden_size = transformer_config.hidden_size
self.patch_dim = patch_dim
self.img_h = img_h
self.img_w = img_w
assert self.img_h % self.patch_dim == 0
assert self.img_w % self.patch_dim == 0
self.input_dims = (img_h // patch_dim, img_w // patch_dim)
# used for positional embedding
self.max_img_h = max_img_h
self.max_img_w = max_img_w
self.max_num_rows = max_img_h // patch_dim
self.max_num_cols = max_img_w // patch_dim
self.max_num_patches = self.max_num_rows * self.max_num_cols
# TODO: are we actually going to use this anywhere?
self.use_mask_token = use_mask_token
if self.use_mask_token:
self.mask_token = nn.Parameter(torch.zeros(1, self.visual_hidden_size))
self.add_class_token = add_class_token
self.class_token_len = class_token_len
if self.add_class_token:
self.class_token = nn.Parameter(
torch.randn(self.class_token_len, self.visual_hidden_size)
)
self.seq_length = (img_h // self.patch_dim) * (img_w // self.patch_dim) + (
self.class_token_len if self.add_class_token else 0
)
pos_scale = self.visual_hidden_size**-0.5
self.position_embeddings = nn.Parameter(
torch.randn(1, self.max_num_patches, self.visual_hidden_size) * pos_scale
)
self.pos_dropout = pos_dropout
self.has_cpe = has_cpe
# Using non-TE version so we can force gather_output
self.embedder = ColumnParallelLinear(
input_size=3 * self.patch_dim * self.patch_dim,
output_size=self.visual_hidden_size,
bias=embedder_bias,
config=transformer_config,
gather_output=True,
init_method=lambda tensor: torch.nn.init.normal_(tensor, mean=0.0, std=1.0),
)
self.model_type = ModelType.encoder_or_decoder
self.ln_pre = None
self.ln_post = None
if ln_pre_impl is not None:
self.ln_pre = build_module(
ln_pre_impl,
config=transformer_config,
hidden_size=self.visual_hidden_size,
eps=transformer_config.layernorm_epsilon,
)
if ln_post_impl is not None:
self.ln_post = build_module(
ln_post_impl,
config=transformer_config,
hidden_size=self.visual_hidden_size,
eps=transformer_config.layernorm_epsilon,
)
self.decoder = TransformerBlock(
config=transformer_config,
spec=transformer_layer_spec,
pre_process=True,
post_process=False,
)
def set_input_tensor(self, input_tensor: torch.Tensor) -> None:
"""Sets input tensor to the model.
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
self.decoder.set_input_tensor(input_tensor)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Forward function of the RADIO ViT Model. This function passes the input tensors
through the embedding layer and then the transformer.
Args:
x (torch.Tensor): input data of shape [batch, img_h, img_w]
attention_mask (torch.Tensor with dtype=bool): Attention mask to use.
Returns:
x (torch.Tensor): output after final transformer block of shape [b, s, h].
"""
input_size = x.shape[2:]
py = x.shape[-2] // self.patch_dim
px = x.shape[-1] // self.patch_dim
x = rearrange(
x,
'b c (py yy) (px xx) -> b (py px) (c yy xx)',
py=py,
yy=self.patch_dim,
px=px,
xx=self.patch_dim,
)
x, _ = self.embedder(x) # [batch, seq_length, hidden_size]
x, _ = self.apply_pos_enc(x, input_size=input_size)
if self.add_class_token:
class_token = self.class_token.expand(
x.shape[0], -1, -1
) # [batch, class_token_len, hidden_size]
x = torch.cat(
[class_token, x], dim=1
) # [batch, seq_length + class_token_len, hidden_size]
assert x.shape[1] == self.seq_length, f"{x.shape[1]} != {self.seq_length}"
if self.ln_pre:
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # [b, s, h] -> [s, b, h]
x = x.contiguous()
x = self.decoder(x, attention_mask=attention_mask)
x = x.permute(1, 0, 2) # [s, b, h] -> [b, s, h]
x = x.contiguous()
if self.ln_post:
x = self.ln_post(x)
return x
def apply_pos_enc(
self,
patches: torch.Tensor,
patch_idxs: Optional[torch.Tensor] = None,
input_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
"""Apply positional encoding to patches"""
pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size)
if self.training and self.pos_dropout > 0:
keeps = (
torch.rand(patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device)
> self.pos_dropout
)
pos_enc_drop = torch.where(keeps, pos_enc, 0)
else:
pos_enc_drop = pos_enc
return patches + pos_enc_drop, pos_enc
def get_pos_enc(
self,
batch_size: int,
patch_idxs: Optional[torch.Tensor] = None,
input_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
"""Get positional encoding for certain input size"""
if input_size is None:
input_dims = self.input_dims
else:
input_dims = tuple(d // self.patch_dim for d in input_size)
pos_embed = self._get_pos_embeddings(batch_size, input_dims)
if patch_idxs is None:
return pos_embed
exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])
pos_embed = torch.gather(
pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs
)
return pos_embed
def _get_pos_embeddings(self, batch_size: int, input_dims: Tuple[int, int]):
"""Get RADIO absolute positional embeddings"""
if (self.max_num_rows, self.max_num_cols) == input_dims:
return self.position_embeddings
pos_embed = self.position_embeddings.reshape(
1, self.max_num_rows, self.max_num_cols, -1
).permute(0, 3, 1, 2)
def window_select(pos_embed):
if input_dims[0] < pos_embed.shape[-2]:
pos_embed = pos_embed[..., : input_dims[0], :]
if input_dims[1] < pos_embed.shape[-1]:
pos_embed = pos_embed[..., :, : input_dims[1]]
return pos_embed
if self.has_cpe:
if self.training:
min_scale = math.sqrt(0.1)
scale = (
torch.rand(batch_size, 1, 1, device=pos_embed.device) * (1 - min_scale)
+ min_scale
)
aspect_min = math.log(3 / 4)
aspect_max = -aspect_min
aspect = torch.exp(
torch.rand(batch_size, 1, 1, device=pos_embed.device)
* (aspect_max - aspect_min)
+ aspect_min
)
scale_x = scale * aspect
scale_y = scale * (1 / aspect)
scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)
pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (1 - scale_xy)
lin_x = torch.linspace(0, 1, steps=input_dims[1], device=pos_embed.device)[
None, None
].expand(batch_size, input_dims[0], -1)
lin_y = torch.linspace(0, 1, steps=input_dims[0], device=pos_embed.device)[
None, :, None
].expand(batch_size, -1, input_dims[1])
lin_xy = torch.stack([lin_x, lin_y], dim=-1)
grid_xy = lin_xy * scale_xy + pos_xy
# Convert to [-1, 1] range
grid_xy.mul_(2).sub_(1)
pos_embed = F.grid_sample(
pos_embed.float().expand(batch_size, -1, -1, -1),
grid=grid_xy,
mode='bilinear',
padding_mode='zeros',
align_corners=True,
).to(pos_embed.dtype)
else:
max_dim = max(input_dims)
pos_embed = F.interpolate(
pos_embed.float(), size=(max_dim, max_dim), align_corners=True, mode='bilinear'
).to(pos_embed.dtype)
pos_embed = window_select(pos_embed)
else:
pos_embed = window_select(pos_embed)
if pos_embed.shape[-2:] != input_dims:
pos_embed = F.interpolate(
pos_embed.float(), size=input_dims, align_corners=True, mode='bilinear'
).to(pos_embed.dtype)
pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
return pos_embed
......@@ -27,7 +27,7 @@ except ImportError:
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn(f'Apex is not installed. Falling back to Torch Norm')
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
......
File mode changed from 100755 to 100644
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
import warnings
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torch.optim import SGD as CPUSGD
from torch.optim import AdamW as CPUAdam
try:
from transformer_engine.pytorch.optimizers import FusedAdam as Adam
......@@ -12,8 +15,6 @@ except ImportError:
from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD
except ImportError:
import warnings
warnings.warn(
f'Transformer Engine and Apex are not installed. Falling back to Torch optimizers.'
)
......@@ -24,10 +25,11 @@ except ImportError:
from torch.optim import AdamW as Adam, SGD
from megatron.core import mpu
from megatron.core.optimizer.cpu_offloading.hybrid_optimizer import HybridDeviceOptimizer
from ..distributed.param_and_grad_buffer import _ParamAndGradBuffer
from ..transformer.module import MegatronModule
from ..utils import log_single_rank
from ..utils import is_te_min_version, log_single_rank
from .distrib_optimizer import DistributedOptimizer
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import (
......@@ -81,7 +83,12 @@ def _get_param_groups(
# Map (wd_mult, lr_mult, is_expert_parallel, is_decoupled_lr) to params.
params_map = {}
for model_chunk in model_chunks:
for name, param in model_chunk.named_parameters():
if model_chunk.ddp_config.use_custom_fsdp:
named_parameters = model_chunk.optimizer_named_parameters()
else:
named_parameters = model_chunk.named_parameters()
for name, param in named_parameters:
if not param.requires_grad:
continue
......@@ -262,32 +269,97 @@ def _get_megatron_optimizer_based_on_param_groups(
Returns:
Instance of MegatronOptimizer.
"""
if config.optimizer == 'adam':
optimizer = Adam(
param_groups,
lr=config.lr,
weight_decay=config.weight_decay,
betas=(config.adam_beta1, config.adam_beta2),
eps=config.adam_eps,
)
# when freezing sub-models we may have no trainable parameters on a rank and
# hence an empty param_groups. However, we still need to create an optimizer
# for the purposes of grad stats reductions
if param_groups:
if config.optimizer_cpu_offload:
if torch.__version__ < '2.3.0':
warnings.warn(
"CPU offload is recommended for PyTorch >= 2.3.0, "
"untested versions below this may have convergence issues."
)
gpu_optimizer_cls = Adam if config.optimizer == 'adam' else SGD
cpu_optimizer_cls = CPUAdam if config.optimizer == 'adam' else CPUSGD
if config.use_torch_optimizer_for_cpu_offload:
gpu_optimizer_cls = cpu_optimizer_cls
if config.optimizer == 'adam':
gpu_optimizer_cls = Adam
cpu_optimizer_cls = CPUAdam
optimizer_defaults = dict(
lr=config.lr,
weight_decay=config.weight_decay,
betas=(config.adam_beta1, config.adam_beta2),
eps=config.adam_eps,
bias_correction=True,
fused=True, # this flag is used to improve the performance of the cpu optimizer
)
else:
gpu_optimizer_cls = SGD
cpu_optimizer_cls = CPUSGD
optimizer_defaults = dict(
lr=config.lr, weight_decay=config.weight_decay, momentum=config.sgd_momentum
)
optimizer = HybridDeviceOptimizer(
param_groups,
offload_fraction=config.optimizer_offload_fraction,
cpu_optimizer_cls=cpu_optimizer_cls,
gpu_optimizer_cls=gpu_optimizer_cls,
overlap_cpu_optimizer_d2h_h2d=config.overlap_cpu_optimizer_d2h_h2d,
pin_cpu_grads=config.pin_cpu_grads,
pin_cpu_params=config.pin_cpu_params,
param_update_in_fp32=True,
**optimizer_defaults,
)
init_state_fn = None
elif config.optimizer == 'adam':
kwargs = {
"params": param_groups,
"lr": config.lr,
"weight_decay": config.weight_decay,
"betas": (config.adam_beta1, config.adam_beta2),
"eps": config.adam_eps,
}
if config.use_precision_aware_optimizer:
kwargs.update(
{
"master_weights": True,
"use_decoupled_grad": True,
"master_weight_dtype": config.main_params_dtype,
"exp_avg_dtype": config.exp_avg_dtype,
"exp_avg_sq_dtype": config.exp_avg_sq_dtype,
}
)
def init_state_fn(opt):
for group in opt.param_groups:
for p in group['params']:
if len(opt.state[p]) == 0:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)
elif config.optimizer == 'sgd':
optimizer = SGD(
param_groups,
lr=config.lr,
weight_decay=config.weight_decay,
momentum=config.sgd_momentum,
)
init_state_fn = None
if is_te_min_version("2.1.0.dev0"):
kwargs.update({"store_param_remainders": True})
optimizer = Adam(**kwargs)
def init_state_fn(opt, config=None):
for group in opt.param_groups:
for p in group['params']:
if len(opt.state[p]) == 0:
if config is None or not config.use_precision_aware_optimizer:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)
else:
opt.initialize_state(p)
elif config.optimizer == 'sgd':
optimizer = SGD(
param_groups,
lr=config.lr,
weight_decay=config.weight_decay,
momentum=config.sgd_momentum,
)
init_state_fn = None
else:
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
else:
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
optimizer = None
init_state_fn = None
# Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
......@@ -347,6 +419,7 @@ def get_megatron_optimizer(
no_weight_decay_cond: Optional[Callable] = None,
scale_lr_cond: Optional[Callable] = None,
lr_mult: float = 1.0,
use_gloo_process_groups: bool = True,
) -> MegatronOptimizer:
"""Retrieve the Megatron optimizer for model chunks.
......@@ -361,6 +434,8 @@ def get_megatron_optimizer(
should have a scaled learning rate. Defaults to None.
lr_mult (float, optional): learning rate multiplier for parameters that
satisfy scale_lr_cond. Defaults to 1.0.
use_gloo_process_groups (bool): if false, disable use of Gloo process groups
in underlying Megatron optimizers.
Returns:
Instance of MegatronOptimizer.
......@@ -390,6 +465,42 @@ def get_megatron_optimizer(
optimizers = []
model_chunk_offset = 0
ddp_config = model_chunks[0].ddp_config # Use the first model chunk's DDP config
if ddp_config.use_custom_fsdp:
for model_chunk, overlap_param_gather_with_optimizer_step in zip(
all_dense_model_chunks, overlap_param_gather_with_optimizer_step_flags
):
param_groups, buffers = _get_param_groups_and_buffers(
model_chunk,
model_chunk_offset=model_chunk_offset,
config=config,
no_weight_decay_cond=no_weight_decay_cond,
scale_lr_cond=scale_lr_cond,
lr_mult=lr_mult,
filter_fn=lambda g: True,
buffer_name='buffers',
)
optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config,
model_chunks=model_chunk,
param_groups=param_groups,
per_model_buffers=buffers,
model_parallel_group=mpu.get_model_parallel_group(),
data_parallel_group=mpu.get_data_parallel_group(with_context_parallel=True),
data_parallel_group_gloo=mpu.get_data_parallel_group_gloo(
with_context_parallel=True
),
data_parallel_group_idx=model_parallel_rank,
)
)
model_chunk_offset += 1
if len(optimizers) == 1:
return optimizers[0]
return ChainedOptimizer(optimizers)
for dense_model_chunks, overlap_param_gather_with_optimizer_step in zip(
all_dense_model_chunks, overlap_param_gather_with_optimizer_step_flags
):
......@@ -407,6 +518,14 @@ def get_megatron_optimizer(
model_chunk.overlap_param_gather_with_optimizer_step = (
overlap_param_gather_with_optimizer_step
)
# Pass Gloo process groups into optimizer only if needed.
if use_gloo_process_groups:
data_parallel_group_gloo = mpu.get_data_parallel_group_gloo(
with_context_parallel=True, partial_data_parallel=True
)
else:
data_parallel_group_gloo = None
optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config,
......@@ -417,9 +536,7 @@ def get_megatron_optimizer(
data_parallel_group=mpu.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=True
),
data_parallel_group_gloo=mpu.get_data_parallel_group_gloo(
with_context_parallel=True, partial_data_parallel=True
),
data_parallel_group_gloo=data_parallel_group_gloo,
data_parallel_group_idx=model_parallel_rank,
distributed_optimizer_instance_id=distributed_optimizer_instance_id,
)
......@@ -440,6 +557,11 @@ def get_megatron_optimizer(
model_parallel_rank = torch.distributed.get_rank(
mpu.get_expert_tensor_model_pipeline_parallel_group()
)
# Pass Gloo process groups into optimizer only if needed.
if use_gloo_process_groups:
data_parallel_group_gloo = mpu.get_expert_data_parallel_group_gloo()
else:
data_parallel_group_gloo = None
optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config,
......@@ -448,7 +570,7 @@ def get_megatron_optimizer(
per_model_buffers=moe_buffers,
model_parallel_group=mpu.get_expert_tensor_model_pipeline_parallel_group(),
data_parallel_group=mpu.get_expert_data_parallel_group(),
data_parallel_group_gloo=mpu.get_expert_data_parallel_group_gloo(),
data_parallel_group_gloo=data_parallel_group_gloo,
data_parallel_group_idx=model_parallel_rank,
)
)
......
......@@ -139,6 +139,7 @@ def clip_grad_by_total_norm_fp32(
parameters: Union[List[torch.Tensor], torch.Tensor],
max_norm: Union[int, float],
total_norm: float,
use_decoupled_grad: bool = False,
):
"""Clips gradient of an iterable of parameters in fp32 by total norm.
......@@ -149,15 +150,23 @@ def clip_grad_by_total_norm_fp32(
single Tensor that will have gradients normalized.
max_norm (float or int): max norm of the gradients.
total_norm (float): total norm of the gradients.
use_decoupled_grad (bool, optional): whether to read grad from ".grad" or ".decoupled_grad",
default value is False.
"""
# Grads.
params = []
grads = []
for param in parameters:
if param.grad is not None:
assert param.grad.type() == 'torch.cuda.FloatTensor'
params.append(param)
grads.append(to_local_if_dtensor(param.grad).detach())
if use_decoupled_grad:
if hasattr(param, "decoupled_grad") and param.decoupled_grad is not None:
assert param.decoupled_grad.dtype in [torch.float32, torch.bfloat16]
params.append(param)
grads.append(to_local_if_dtensor(param.decoupled_grad).detach())
else:
if param.grad is not None:
assert param.grad.type() == 'torch.cuda.FloatTensor'
params.append(param)
grads.append(to_local_if_dtensor(param.grad).detach())
# Scale.
clip_coeff = max_norm / (total_norm + 1.0e-6)
......@@ -171,6 +180,7 @@ def clip_grad_by_total_norm_fp32(
def count_zeros_fp32(
parameters: Union[List[torch.Tensor], torch.Tensor],
grad_stats_parallel_group: torch.distributed.ProcessGroup,
use_decoupled_grad: bool = False,
) -> float:
"""Counts the number of zeros in gradients associated with the passed-in list of
parameters.
......@@ -182,6 +192,8 @@ def count_zeros_fp32(
grad_stats_parallel_group (group): Process group for reducing the num_zeros count. This is
generally the model-parallel group for non-distributed optimizers, and the entire
world for the distributed optimizer.
use_decoupled_grad (bool, optional) whether to read grad from ".grad" or ".decoupled_grad",
default value is False.
"""
if isinstance(parameters, torch.Tensor):
......@@ -194,14 +206,14 @@ def count_zeros_fp32(
total_num_zeros = torch.tensor([0.0], dtype=torch.float, device='cuda')
data_parallel_group = None
for param in parameters:
grad_not_none = param.grad is not None
grad_attr = "decoupled_grad" if use_decoupled_grad else "grad"
grad_not_none = hasattr(param, grad_attr) and getattr(param, grad_attr) is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
data_parallel_group = get_data_parallel_group_if_dtensor(
param.grad, data_parallel_group
)
grad = to_local_if_dtensor(param.grad).detach()
grad_obj = getattr(param, grad_attr)
data_parallel_group = get_data_parallel_group_if_dtensor(grad_obj, data_parallel_group)
grad = to_local_if_dtensor(grad_obj).detach()
num_zeros = grad.numel() - torch.count_nonzero(grad)
total_num_zeros = num_zeros + total_num_zeros
......
## How to use ?
Add these flags to enable optimizer cpu offload in MCore.
```bash
--optimizer-cpu-offload
--optimizer-offload-fraction 1.0
--use-precision-aware-optimizer
```
## Configuration Recommendataions
Gradient copy from GPU to CPU, CPU optimizer step, and subsequent parameter copy from CPU to GPU can be time-consuming operations, and it is recommended to use the flag `--overlap-cpu-optimizer-d2h-h2d` to execute them concurrently.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment