Commit a8a2bbea authored by dongcl's avatar dongcl
Browse files

patch for megatron 4429e8ebe

parent 2ddbd4be
......@@ -21,7 +21,6 @@ class MegatronAdaptation:
for adaptation in [CoreAdaptation(), LegacyAdaptation()]:
adaptation.execute()
MegatronAdaptation.apply()
# MegatronAdaptation.post_execute()
@classmethod
def register(cls, orig_func_name, new_func=None, force_patch=False, create_dummy=False, apply_wrapper=False, remove_origin_wrappers=False):
......@@ -87,47 +86,23 @@ class CoreAdaptation(MegatronAdaptationABC):
self.patch_miscellaneous()
def patch_core_distributed(self):
# Mtp share embedding
from ..core.distributed.finalize_model_grads import _allreduce_word_embedding_grads
MegatronAdaptation.register('megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads',
_allreduce_word_embedding_grads)
pass
def patch_core_models(self):
from ..core.models.common.embeddings.language_model_embedding import (
language_model_embedding_forward,
language_model_embedding_init_func
)
from ..core.models.gpt.gpt_model import (
gpt_model_forward,
gpt_model_init_wrapper,
shared_embedding_or_mtp_embedding_weight
)
from ..training.utils import get_batch_on_this_tp_rank
# Embedding
MegatronAdaptation.register(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.__init__',
language_model_embedding_init_func)
MegatronAdaptation.register(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward',
language_model_embedding_forward)
MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank)
from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, gpt_model_forward
# GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper,
apply_wrapper=True)
from megatron.core.models.gpt.gpt_model import GPTModel
setattr(GPTModel, 'shared_embedding_or_mtp_embedding_weight', shared_embedding_or_mtp_embedding_weight)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward',
gpt_model_forward)
def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper
from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch
# Transformer block
# Transformer block. If mtp_num_layers > 0, move final_layernorm outside
MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__',
transformer_block_init_wrapper)
......@@ -141,9 +116,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}),
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
apply_wrapper=True)
# MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
# torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
# apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute',
torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True)
......@@ -157,6 +132,7 @@ class CoreAdaptation(MegatronAdaptationABC):
from ..core.extensions.transformer_engine import TEDotProductAttentionPatch
from megatron.core.extensions.transformer_engine import TEGroupedLinear
# kv channels, te_min_version 1.10.0 -> 1.9.0
MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__',
TEDotProductAttentionPatch.__init__)
......@@ -165,13 +141,10 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init_wrapper
# VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
vocab_parallel_embedding_forward)
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__',
vocab_parallel_embedding_init_wrapper,
torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True)
# VocabParallelCrossEntropy
......@@ -202,6 +175,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec",
get_gpt_layer_with_flux_spec)
def patch_pipeline_parallel(self):
pass
def patch_training(self):
from ..training.tokenizer import build_tokenizer
from ..training.initialize import _initialize_distributed
......@@ -210,12 +186,14 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.training.tokenizer.tokenizer.build_tokenizer',
build_tokenizer)
# specify init_method
MegatronAdaptation.register('megatron.training.initialize._initialize_distributed',
_initialize_distributed)
# remove fused_kernels
MegatronAdaptation.register('megatron.training.initialize._compile_dependencies',
_compile_dependencies)
# traing.train
# add trace_handler
MegatronAdaptation.register('megatron.training.training.train',
train)
......@@ -245,6 +223,8 @@ class LegacyAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelMLP.__init__',
parallel_mlp_init_wrapper,
apply_wrapper=True)
# ParallelAttention
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.__init__',
parallel_attention_init_wrapper,
apply_wrapper=True)
......
......@@ -148,11 +148,29 @@ class MegatronPatchesManager:
patches_info = {}
@staticmethod
def register_patch(orig_func_or_cls_name, new_func_or_cls=None, force_patch=False, create_dummy=False):
def register_patch(
orig_func_or_cls_name,
new_func_or_cls=None,
force_patch=False,
create_dummy=False,
apply_wrapper=False,
remove_origin_wrappers=False
):
if orig_func_or_cls_name not in MegatronPatchesManager.patches_info:
MegatronPatchesManager.patches_info[orig_func_or_cls_name] = Patch(orig_func_or_cls_name, new_func_or_cls, create_dummy)
MegatronPatchesManager.patches_info[orig_func_or_cls_name] = Patch(
orig_func_or_cls_name,
new_func_or_cls,
create_dummy,
apply_wrapper=apply_wrapper,
remove_origin_wrappers=remove_origin_wrappers
)
else:
MegatronPatchesManager.patches_info.get(orig_func_or_cls_name).set_patch_func(new_func_or_cls, force_patch)
MegatronPatchesManager.patches_info.get(orig_func_or_cls_name).set_patch_func(
new_func_or_cls,
force_patch,
apply_wrapper=apply_wrapper,
remove_origin_wrappers=remove_origin_wrappers
)
@staticmethod
def apply_patches():
......
from typing import List
import torch
from megatron.core import parallel_state
from megatron.core.distributed.finalize_model_grads import _unshard_if_dtensor, _reshard_if_dtensor
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import get_attr_wrapped_model
def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings parameters stay in
sync.
"""
if (
parallel_state.is_rank_in_embedding_group(ignore_virtual=True)
and torch.distributed.get_world_size(parallel_state.get_embedding_group()) > 1
):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
model_module = model[0]
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
model_module = model[-1]
else: # We do not support an interleaved schedule for models with encoders yet.
model_module = model[0]
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights:
weight = model_module.shared_embedding_or_output_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
if (
hasattr(model_module, "share_mtp_embedding_and_output_weight")
and model_module.share_mtp_embedding_and_output_weight
and config.num_nextn_predict_layers > 0
):
weight = model_module.shared_embedding_or_mtp_embedding_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
import os
import torch
import dataclasses
import transformer_engine as te
......@@ -11,6 +12,7 @@ from megatron.core.utils import get_te_version, is_te_min_version
from megatron.core.extensions.transformer_engine import TEDotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.process_groups_config import ModelCommProcessGroups
from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
......@@ -32,6 +34,7 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
k_channels: Optional[int] = None,
v_channels: Optional[int] = None,
cp_comm_type: str = "p2p",
model_comm_pgs: ModelCommProcessGroups = None,
):
self.config = config
self.te_forward_mask_type = False
......@@ -58,6 +61,26 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
f"num_attention_heads ({self.config.num_attention_heads}))"
)
if model_comm_pgs is None:
# For backward compatibility, remove in v0.14 and raise error
# raise ValueError("TEDotProductAttention was called without ModelCommProcessGroups")
model_comm_pgs = ModelCommProcessGroups(
tp=get_tensor_model_parallel_group(check_initialized=False),
cp=get_context_parallel_group(check_initialized=False),
hcp=get_hierarchical_context_parallel_groups(check_initialized=False),
)
else:
assert hasattr(
model_comm_pgs, 'tp'
), "TEDotProductAttention model_comm_pgs must have tp pg"
assert hasattr(
model_comm_pgs, 'cp'
), "TEDotProductAttention model_comm_pgs must have cp pg"
if cp_comm_type == "a2a+p2p":
assert hasattr(
model_comm_pgs, 'hcp'
), "TEDotProductAttention model_comm_pgs must have hierarchical cp pg"
if is_te_min_version("0.10.0"):
extra_kwargs["attention_type"] = attention_type
# older version don't need attention_type
......@@ -73,9 +96,9 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if getattr(TEDotProductAttention, "cp_stream") is None:
TEDotProductAttention.cp_stream = torch.cuda.Stream()
extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(
check_initialized=False
extra_kwargs["cp_group"] = model_comm_pgs.cp
extra_kwargs["cp_global_ranks"] = torch.distributed.get_process_group_ranks(
model_comm_pgs.cp
)
extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
if is_te_min_version("1.10.0"):
......@@ -149,7 +172,7 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
tp_group=get_tensor_model_parallel_group(check_initialized=False),
tp_group=model_comm_pgs.tp,
layer_number=layer_number,
**extra_kwargs,
)
from typing import Literal
import torch
from torch import Tensor
from megatron.core import tensor_parallel
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
def language_model_embedding_init_func(
self,
config: TransformerConfig,
vocab_size: int,
max_sequence_length: int,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
num_tokentypes: int = 0,
scatter_to_sequence_parallel: bool = True,
skip_weight_param_allocation: bool = False
):
"""Patch language model embeddings init."""
super(LanguageModelEmbedding, self).__init__(config=config)
self.config: TransformerConfig = config
self.vocab_size: int = vocab_size
self.max_sequence_length: int = max_sequence_length
self.add_position_embedding: bool = position_embedding_type == 'learned_absolute'
self.num_tokentypes = num_tokentypes
self.scatter_to_sequence_parallel = scatter_to_sequence_parallel
self.reduce_scatter_embeddings = (
(not self.add_position_embedding)
and self.num_tokentypes <= 0
and self.config.sequence_parallel
and self.scatter_to_sequence_parallel
)
# Word embeddings (parallel).
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
num_embeddings=self.vocab_size,
embedding_dim=self.config.hidden_size,
init_method=self.config.init_method,
reduce_scatter_embeddings=self.reduce_scatter_embeddings,
config=self.config,
skip_weight_param_allocation=skip_weight_param_allocation
)
# Position embedding (serial).
if self.add_position_embedding:
self.position_embeddings = torch.nn.Embedding(
self.max_sequence_length, self.config.hidden_size
)
# Initialize the position embeddings.
if self.config.perform_initialization:
self.config.init_method(self.position_embeddings.weight)
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(
self.num_tokentypes, self.config.hidden_size
)
# Initialize the token-type embeddings.
if self.config.perform_initialization:
self.config.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(self.config.hidden_dropout)
def language_model_embedding_forward(self,
input_ids: Tensor,
position_ids: Tensor,
tokentype_ids: int = None,
weight: Tensor = None) -> Tensor:
"""Pacth forward pass of the embedding module.
Args:
input_ids (Tensor): The input tokens
position_ids (Tensor): The position id's used to calculate position embeddings
tokentype_ids (int): The token type ids. Used when args.bert_binary_head is
set to True. Defaults to None
weight (Tensor): embedding weight
Returns:
Tensor: The output embeddings
"""
if weight is None:
if self.word_embeddings.weight is None:
raise RuntimeError(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.word_embeddings.weight
word_embeddings = self.word_embeddings(input_ids, weight)
if self.add_position_embedding:
position_embeddings = self.position_embeddings(position_ids)
embeddings = word_embeddings + position_embeddings
else:
embeddings = word_embeddings
if not self.reduce_scatter_embeddings:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
# [b s h] -> [s b h] (So that it can be added with embeddings)
tokentype_embedding = self.tokentype_embeddings(tokentype_ids).permute(1, 0, 2)
embeddings = embeddings + tokentype_embedding
else:
assert self.tokentype_embeddings is None
# If the input flag for fp32 residual connection is set, convert for float.
if self.config.fp32_residual_connection:
embeddings = embeddings.float()
# Dropout.
if self.config.sequence_parallel:
if not self.reduce_scatter_embeddings and self.scatter_to_sequence_parallel:
embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
# `scatter_to_sequence_parallel_region` returns a view, which prevents
# the original tensor from being garbage collected. Clone to facilitate GC.
# Has a small runtime cost (~0.5%).
if self.config.clone_scatter_output_in_embedding and self.scatter_to_sequence_parallel:
embeddings = embeddings.clone()
with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings)
return embeddings
import warnings
from typing import Optional
from typing import Optional, Union
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
......@@ -12,13 +12,13 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules,
)
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import (
TransformerLayer,
TransformerLayerSubmodules,
)
from dcu_megatron.core.tensor_parallel.layers import FluxColumnParallelLinear, FluxRowParallelLinear
from megatron.core.utils import is_te_min_version
try:
......@@ -36,6 +36,17 @@ try:
except ImportError:
warnings.warn('Apex is not installed.')
from dcu_megatron.core.tensor_parallel.layers import (
FluxColumnParallelLinear,
FluxRowParallelLinear
)
from dcu_megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlockSubmodules,
get_mtp_layer_offset,
get_mtp_layer_spec,
get_mtp_num_layers_to_build,
)
def get_gpt_layer_with_flux_spec(
num_experts: Optional[int] = None,
......
import os
import logging
from typing import Literal, Optional
from functools import wraps
from collections import OrderedDict
from typing import Optional
import torch
from torch import Tensor
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear
from dcu_megatron.core.utils import tensor_slide
from dcu_megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPredictor
from dcu_megatron.core.transformer.transformer_config import TransformerConfig
from dcu_megatron.core.tensor_parallel import FluxColumnParallelLinear
......@@ -30,11 +18,13 @@ def gpt_model_init_wrapper(fn):
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
if (
self.post_process
and int(os.getenv("USE_FLUX_OVERLAP", "0"))
):
self.output_layer = FluxColumnParallelLinear(
# Output
if self.post_process or self.mtp_process:
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
parallel_linear_impl = FluxColumnParallelLinear
else:
parallel_linear_impl = tensor_parallel.ColumnParallelLinear
self.output_layer = parallel_linear_impl(
self.config.hidden_size,
self.vocab_size,
config=self.config,
......@@ -48,174 +38,12 @@ def gpt_model_init_wrapper(fn):
grad_output_buffer=self.grad_output_buffer,
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
# add mtp
self.num_nextn_predict_layers = self.config.num_nextn_predict_layers
if self.num_nextn_predict_layers:
assert hasattr(self.config, "mtp_spec")
self.mtp_spec: ModuleSpec = self.config.mtp_spec
self.share_mtp_embedding_and_output_weight = self.config.share_mtp_embedding_and_output_weight
self.recompute_mtp_norm = self.config.recompute_mtp_norm
self.recompute_mtp_layer = self.config.recompute_mtp_layer
self.mtp_loss_scale = self.config.mtp_loss_scale
if self.post_process and self.training:
self.mtp_layers = torch.nn.ModuleList(
[
MultiTokenPredictor(
self.config,
self.mtp_spec.submodules,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
layer_number=i,
pre_process=self.pre_process,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
parallel_output=self.parallel_output,
position_embedding_type=self.position_embedding_type,
rotary_percent=self.rotary_percent,
seq_len_interpolation_factor=kwargs.get("seq_len_interpolation_factor", None),
share_mtp_embedding_and_output_weight=self.share_mtp_embedding_and_output_weight,
recompute_mtp_norm=self.recompute_mtp_norm,
recompute_mtp_layer=self.recompute_mtp_layer,
add_output_layer_bias=False
)
for i in range(self.num_nextn_predict_layers)
]
)
if self.pre_process or self.post_process:
setup_mtp_embeddings(self)
return wrapper
def shared_embedding_or_mtp_embedding_weight(self) -> Tensor:
"""Gets the embedding weight when share embedding and mtp embedding weights set to True.
Returns:
Tensor: During pre processing it returns the input embeddings weight while during post processing it returns
mtp embedding layers weight
"""
assert self.num_nextn_predict_layers > 0
if self.pre_process:
return self.embedding.word_embeddings.weight
elif self.post_process:
return self.mtp_layers[0].embedding.word_embeddings.weight
return None
def setup_mtp_embeddings(self):
"""
Share embedding layer in mtp layer.
"""
if self.pre_process:
self.embedding.word_embeddings.weight.is_embedding_or_output_parameter = True
# Set `is_embedding_or_output_parameter` attribute.
for i in range(self.num_nextn_predict_layers):
if self.post_process and self.mtp_layers[i].embedding.word_embeddings.weight is not None:
self.mtp_layers[i].embedding.word_embeddings.weight.is_embedding_or_output_parameter = True
if not self.share_mtp_embedding_and_output_weight:
return
if self.pre_process and self.post_process:
# Zero out wgrad if sharing embeddings between two layers on same
# pipeline stage to make sure grad accumulation into main_grad is
# correct and does not include garbage values (e.g., from torch.empty).
self.shared_embedding_or_mtp_embedding_weight().zero_out_wgrad = True
return
if self.pre_process and not self.post_process:
assert parallel_state.is_pipeline_first_stage()
self.shared_embedding_or_mtp_embedding_weight().shared_embedding = True
if self.post_process and not self.pre_process:
assert not parallel_state.is_pipeline_first_stage()
for i in range(self.num_nextn_predict_layers):
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.mtp_layers[i].embedding.word_embeddings.weight.data.fill_(0)
self.mtp_layers[i].embedding.word_embeddings.weight.shared = True
self.mtp_layers[i].embedding.word_embeddings.weight.shared_embedding = True
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if parallel_state.is_rank_in_embedding_group():
weight = self.shared_embedding_or_mtp_embedding_weight()
weight.data = weight.data.cuda()
torch.distributed.all_reduce(
weight.data, group=parallel_state.get_embedding_group()
)
elif not getattr(LanguageModule, "embedding_warning_printed", False):
logging.getLogger(__name__).warning(
"Distributed processes aren't initialized, so the output layer "
"is not initialized with weights from the word embeddings. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
LanguageModule.embedding_warning_printed = True
def slice_inputs(self, input_ids, labels, position_ids, attention_mask):
if self.num_nextn_predict_layers == 0:
return (
[input_ids],
[labels],
[position_ids],
[attention_mask],
)
return (
tensor_slide(input_ids, self.num_nextn_predict_layers),
tensor_slide(labels, self.num_nextn_predict_layers),
generate_nextn_position_ids(position_ids, self.num_nextn_predict_layers),
# not compatible with ppo attn_mask
tensor_slide(attention_mask, self.num_nextn_predict_layers, dims=[-2, -1]),
)
def generate_nextn_position_ids(tensor, slice_num):
slides = tensor_slide(tensor, slice_num)
if slides[0] is None:
return slides
for idx in range(1, len(slides)):
slides[idx] = regenerate_position_ids(slides[idx], idx)
return slides
def regenerate_position_ids(tensor, offset):
if tensor is None:
return None
tensor = tensor.clone()
for i in range(tensor.size(0)):
row = tensor[i]
zero_mask = (row == 0) # 两句拼接情形
if zero_mask.any():
first_zero_idx = torch.argmax(zero_mask.int()).item()
tensor[i, :first_zero_idx] = torch.arange(first_zero_idx)
else:
tensor[i] = tensor[i] - offset
return tensor
def gpt_model_forward(
self,
input_ids: Tensor,
......@@ -227,9 +55,10 @@ def gpt_model_forward(
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
loss_mask: Optional[Tensor] = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
through the embedding layer, and then the decoder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
......@@ -241,20 +70,11 @@ def gpt_model_forward(
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# generate inputs for main and mtps
input_ids, labels, position_ids, attention_mask = slice_inputs(
self,
input_ids,
labels,
position_ids,
attention_mask
)
# Decoder embedding.
if decoder_input is not None:
pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids[0], position_ids=position_ids[0])
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
......@@ -296,7 +116,7 @@ def gpt_model_forward(
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask[0],
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
......@@ -306,46 +126,43 @@ def gpt_model_forward(
**(extra_block_kwargs or {}),
)
if not self.post_process:
return hidden_states
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
loss = 0
# Multi token prediction module
if self.num_nextn_predict_layers and self.training:
if not self.share_embeddings_and_output_weights and self.share_mtp_embedding_and_output_weight:
output_weight = self.output_layer.weight
output_weight.zero_out_wgrad = True
embedding_weight = self.shared_embedding_or_mtp_embedding_weight() if self.share_mtp_embedding_and_output_weight else None
mtp_hidden_states = hidden_states
for i in range(self.num_nextn_predict_layers):
mtp_hidden_states, mtp_loss = self.mtp_layers[i](
mtp_hidden_states, # [s,b,h]
input_ids[i + 1],
position_ids[i + 1] if position_ids[0] is not None else None,
attention_mask[i + 1] if attention_mask[0] is not None else None,
labels[i + 1] if labels[0] is not None else None,
inference_params,
packed_seq_params,
extra_block_kwargs,
embeding_weight=embedding_weight,
output_weight=output_weight,
)
if self.mtp_process:
hidden_states = self.mtp(
input_ids=input_ids,
position_ids=position_ids,
labels=labels,
loss_mask=loss_mask,
hidden_states=hidden_states,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
embedding=self.embedding,
output_layer=self.output_layer,
output_weight=output_weight,
runtime_gather_output=runtime_gather_output,
compute_language_model_loss=self.compute_language_model_loss,
**(extra_block_kwargs or {}),
)
loss += self.mtp_loss_scale / self.num_nextn_predict_layers * mtp_loss
if (
self.num_nextn_predict_layers
self.mtp_process is not None
and getattr(self.decoder, "main_final_layernorm", None) is not None
):
# move block main model final norms here
hidden_states = self.decoder.main_final_layernorm(hidden_states)
if not self.post_process:
return hidden_states
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
)
......@@ -353,19 +170,19 @@ def gpt_model_forward(
if has_config_logger_enabled(self.config):
payload = OrderedDict(
{
'input_ids': input_ids[0],
'position_ids': position_ids[0],
'attention_mask': attention_mask[0],
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attention_mask,
'decoder_input': decoder_input,
'logits': logits,
}
)
log_config_to_disk(self.config, payload, prefix='input_and_logits')
if labels[0] is None:
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss += self.compute_language_model_loss(labels[0], logits)
loss = self.compute_language_model_loss(labels, logits)
return loss
from .layers import (
FluxColumnParallelLinear,
FluxRowParallelLinear,
vocab_parallel_embedding_forward,
vocab_parallel_embedding_init_wrapper,
)
\ No newline at end of file
import os
import copy
import socket
import warnings
from functools import wraps
from typing import Callable, List, Optional
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
try:
import flux
from dcu_megatron.core.utils import is_flux_min_version
except ImportError:
raise ImportError("flux is NOT installed")
try:
import flux
except ImportError:
raise ImportError("flux is NOT installed")
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from megatron.training import print_rank_0
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.parallel_state import (
get_global_memory_buffer,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from megatron.core.utils import (
is_torch_min_version,
prepare_input_tensors_for_wgrad_compute
)
from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu,
_initialize_affine_weight_gpu,
VocabParallelEmbedding,
)
from megatron.core.utils import prepare_input_tensors_for_wgrad_compute
from megatron.core.tensor_parallel.mappings import (
_reduce,
copy_to_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
_reduce_scatter_along_first_dim,
_gather_along_first_dim,
)
from megatron.core.tensor_parallel.utils import VocabUtility
from megatron.core.tensor_parallel.mappings import _reduce
from megatron.core.tensor_parallel import (
ColumnParallelLinear,
RowParallelLinear,
......@@ -50,9 +30,9 @@ from megatron.core.tensor_parallel.layers import (
custom_fwd,
custom_bwd,
dist_all_gather_func,
linear_with_frozen_weight,
linear_with_grad_accumulation_and_async_allreduce
)
from dcu_megatron.core.utils import is_flux_min_version
_grad_accum_fusion_available = True
try:
......@@ -61,74 +41,6 @@ except ImportError:
_grad_accum_fusion_available = False
def vocab_parallel_embedding_init_wrapper(fn):
@wraps(fn)
def wrapper(self,
*args,
skip_weight_param_allocation: bool = False,
**kwargs
):
if (
skip_weight_param_allocation
and "config" in kwargs
and hasattr(kwargs["config"], "perform_initialization")
):
config = copy.deepcopy(kwargs["config"])
config.perform_initialization = False
kwargs["config"] = config
fn(self, *args, **kwargs)
if skip_weight_param_allocation:
self.weight = None
return wrapper
@torch.compile(mode='max-autotune-no-cudagraphs')
def vocab_parallel_embedding_forward(self, input_, weight=None):
"""Forward.
Args:
input_ (torch.Tensor): Input tensor.
"""
if weight is None:
if self.weight is None:
raise RuntimeError(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.weight
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
if self.deterministic_mode:
output_parallel = weight[masked_input]
else:
# F.embedding currently has a non-deterministic backward function
output_parallel = F.embedding(masked_input, weight)
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
if self.reduce_scatter_embeddings:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
output_parallel = output_parallel.transpose(0, 1).contiguous()
output = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
# Reduce across all the model parallel GPUs.
output = reduce_from_tensor_model_parallel_region(output_parallel)
return output
def get_tensor_model_parallel_node_size(group=None):
""" 获取节点数
"""
......
import torch
from megatron.core.tensor_parallel.random import (
get_cuda_rng_tracker,
_set_cuda_rng_state
)
class CheckpointFunctionWithoutOutput(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, checkpoint, *args):
with torch.no_grad():
outputs = run_function(*args)
# Store everything
ctx.save_for_backward(*detach_variable(args))
checkpoint.ctx = ctx
return outputs
@staticmethod
def backward(ctx, *args):
inputs = ctx.saved_tensors
outputs = ctx.outputs
torch.autograd.backward(outputs, args)
ctx.outputs = None
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in inputs)
return (None, None) + grads
class CheckpointWithoutOutput:
def __init__(self):
self.run_function = None
self.fwd_cpu_rng_state = None
self.fwd_cuda_rng_state = None
self.fwd_cuda_rng_state_tracker = None
self.outputs = None
def checkpoint(self, run_function, distribute_saved_activations, *args):
self.run_function = run_function
if distribute_saved_activations:
raise RuntimeError(
"CheckpointFunctionWithoutOutput does not support "
"distribute_saved_activations"
)
#Copy the rng states.
self.fwd_cpu_rng_state = torch.get_rng_state()
self.fwd_cuda_rng_state = torch.cuda.get_rng_state()
self.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
outputs = CheckpointFunctionWithoutOutput.apply(run_function, self, *args)
self.outputs = outputs
if isinstance(self.outputs, torch.Tensor):
self.outputs = (self.outputs,)
return outputs
def discard_output(self):
for output in self.outputs:
output.untyped_storage().resize_(0)
def recompute(self, _):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
# Store the current states.
cur_cpu_rng_state = torch.get_rng_state()
cur_cuda_rng_state = torch.cuda.get_rng_state()
cur_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(self.fwd_cpu_rng_state)
_set_cuda_rng_state(self.fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(self.fwd_cuda_rng_state_tracker)
with torch.enable_grad():
outputs = self.run_function(*self.ctx.saved_tensors)
self.run_function = None
self.fwd_cpu_rng_state = None
self.fwd_cuda_rng_state = None
self.fwd_cuda_rng_state_tracker = None
# Set the states back to what it was at the start of this function.
torch.set_rng_state(cur_cpu_rng_state)
_set_cuda_rng_state(cur_cuda_rng_state)
get_cuda_rng_tracker().set_states(cur_cuda_rng_state_tracker)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
for output, recomputation_output in zip(self.outputs, outputs):
output_size = recomputation_output.untyped_storage().size()
output.untyped_storage().resize_(output_size)
with torch.no_grad():
output.untyped_storage().copy_(recomputation_output.untyped_storage())
self.ctx.outputs = outputs
self.outputs = None
self.ctx = None
import warnings
from megatron.core.tensor_parallel import ColumnParallelLinear
from megatron.core.transformer import ModuleSpec
from .multi_token_predictor import (
MultiTokenPredicationSubmodules,
MultiTokenPredictor
)
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TENorm
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
import apex
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
LNImpl = FusedLayerNorm
except ImportError:
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
def get_mtp_spec(transformer_layer, use_te=False):
"""
Multi Token Predication Layer Specification.
"""
use_te = use_te & HAVE_TE
mtp_spec = ModuleSpec(
module=MultiTokenPredictor,
submodules=MultiTokenPredicationSubmodules(
embedding=None,
enorm=TENorm if use_te else LNImpl,
hnorm=TENorm if use_te else LNImpl,
eh_proj=TEColumnParallelLinear if use_te else ColumnParallelLinear,
transformer_layer=transformer_layer,
final_layernorm=TENorm if use_te else LNImpl,
output_layer=None,
)
)
return mtp_spec
import os
import logging
from dataclasses import dataclass
from typing import Union, Optional, Literal
import torch
from torch import Tensor
from megatron.core import tensor_parallel, InferenceParams
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.module import MegatronModule
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
from megatron.core.transformer import ModuleSpec, TransformerConfig, build_module
from ...tensor_parallel.random import CheckpointWithoutOutput
from ...tensor_parallel import FluxColumnParallelLinear
@dataclass
class MultiTokenPredicationSubmodules:
embedding: Union[ModuleSpec, type] = None
output_layer: Union[ModuleSpec, type] = None
eh_proj: Union[ModuleSpec, type] = None
enorm: Union[ModuleSpec, type] = None
hnorm: Union[ModuleSpec, type] = None
transformer_layer: Union[ModuleSpec, type] = None
final_layernorm: Union[ModuleSpec, type] = None
class MultiTokenPredictor(MegatronModule):
def __init__(
self,
config: TransformerConfig,
submodules: MultiTokenPredicationSubmodules,
vocab_size: int,
max_sequence_length: int,
layer_number: int = 1,
hidden_dropout: float = None,
pre_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
seq_len_interpolation_factor: Optional[float] = None,
share_mtp_embedding_and_output_weight=True,
recompute_mtp_norm=False,
recompute_mtp_layer=False,
add_output_layer_bias=False
):
super().__init__(config=config)
self.config = config
self.submodules = submodules
self.layer_number = layer_number
self.hidden_dropout = hidden_dropout
self.hidden_size = self.config.hidden_size
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.position_embedding_type = position_embedding_type
# share with main model
self.share_mtp_embedding_and_output_weight = share_mtp_embedding_and_output_weight
self.recompute_layer_norm = recompute_mtp_norm
self.recompute_mtp_layer = recompute_mtp_layer
self.add_output_layer_bias = add_output_layer_bias
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=self.position_embedding_type,
skip_weight_param_allocation=self.pre_process and self.share_mtp_embedding_and_output_weight
)
if self.position_embedding_type == 'rope':
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
use_cpu_initialization=self.config.use_cpu_initialization,
)
self.enorm = build_module(
self.submodules.enorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
self.hnorm = build_module(
self.submodules.hnorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
self.eh_proj = build_module(
self.submodules.eh_proj,
self.hidden_size + self.hidden_size,
self.hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name='eh',
)
self.transformer_layer = build_module(
self.submodules.transformer_layer,
config=self.config,
)
if self.submodules.final_layernorm:
self.final_layernorm = build_module(
self.submodules.final_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
else:
self.final_layernorm = None
if self.config.defer_embedding_wgrad_compute:
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
column_parallel_linear_impl = FluxColumnParallelLinear
else:
column_parallel_linear_impl = tensor_parallel.ColumnParallelLinear
self.output_layer = column_parallel_linear_impl(
self.config.hidden_size,
self.vocab_size,
config=self.config,
init_method=self.config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.share_mtp_embedding_and_output_weight,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
def forward(
self,
hidden_input_ids: Tensor,
embed_input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
embeding_weight: Optional[torch.Tensor] = None,
output_weight: Optional[torch.Tensor] = None,
):
"""Forward function of the MTP module"""
# Decoder embedding.
decoder_input = self.embedding(
input_ids=embed_input_ids,
position_ids=position_ids,
weight=embeding_weight,
)
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if inference_params is not None:
rotary_seq_len = inference_params.max_sequence_length
else:
rotary_seq_len = decoder_input.size(0)
if self.config.sequence_parallel:
rotary_seq_len *= self.config.tensor_model_parallel_size
rotary_seq_len *= self.config.context_parallel_size
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
if self.recompute_layer_norm:
self.enorm_ckpt = CheckpointWithoutOutput()
enorm_output = self.enorm_ckpt.checkpoint(self.enorm, False, decoder_input)
self.hnorm_ckpt = CheckpointWithoutOutput()
hnorm_output = self.hnorm_ckpt.checkpoint(self.hnorm, False, hidden_input_ids)
else:
enorm_output = self.enorm(decoder_input)
hnorm_output = self.hnorm(hidden_input_ids)
# [s, b, h] -> [s, b, 2h]
hidden_states = torch.concat(
[hnorm_output,
enorm_output],
dim=-1
)
if self.recompute_layer_norm:
self.enorm_ckpt.discard_output()
self.hnorm_ckpt.discard_output()
hidden_states.register_hook(self.enorm_ckpt.recompute)
hidden_states.register_hook(self.hnorm_ckpt.recompute)
# hidden_states -> [s, b, h]
hidden_states, _ = self.eh_proj(hidden_states)
if self.config.tensor_model_parallel_size > 1:
hidden_states = tensor_parallel.gather_from_tensor_model_parallel_region(hidden_states)
if self.config.sequence_parallel:
hidden_states = tensor_parallel.scatter_to_sequence_parallel_region(hidden_states)
if self.recompute_mtp_layer:
hidden_states, context = tensor_parallel.checkpoint(
self.transformer_layer,
self.config.distribute_saved_activations,
hidden_states,
attention_mask,
None,
None,
rotary_pos_emb,
inference_params,
packed_seq_params,
)
else:
hidden_states, _ = self.transformer_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
**(extra_block_kwargs or {}),
)
# Final layer norm.
if self.final_layernorm is not None:
if self.recompute_layer_norm:
self.finalnorm_ckpt = CheckpointWithoutOutput()
finalnorm_output = self.finalnorm_ckpt.checkpoint(self.final_layernorm, False, hidden_states)
else:
finalnorm_output = self.final_layernorm(hidden_states)
else:
finalnorm_output = hidden_states
logits, _ = self.output_layer(finalnorm_output, weight=output_weight)
if self.recompute_layer_norm:
self.finalnorm_ckpt.discard_output()
logits.register_hook(self.finalnorm_ckpt.recompute)
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
return hidden_states, loss
def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor:
"""Computes the language model loss (Cross entropy across vocabulary)
Args:
labels (Tensor): The labels of dimension [batch size, seq length]
logits (Tensor): The final logits returned by the output layer of the transformer model
Returns:
Tensor: Loss tensor of dimensions [batch size, sequence_length]
"""
# [b s] => [s b]
labels = labels.transpose(0, 1).contiguous()
if self.config.cross_entropy_loss_fusion:
loss = fused_vocab_parallel_cross_entropy(logits, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(logits, labels)
# [s b] => [b, s]
loss = loss.transpose(0, 1).contiguous()
return loss
\ No newline at end of file
......@@ -8,7 +8,7 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config = args[0] if len(args) > 1 else kwargs['config']
if getattr(config, "num_nextn_predict_layers", 0) > 0:
if getattr(config, "mtp_num_layers", 0) > 0:
self.main_final_layernorm = self.final_layernorm
self.final_layernorm = None
......
from typing import Optional
from functools import wraps
from dataclasses import dataclass
from megatron.training import get_args
from megatron.core.transformer.transformer_config import TransformerConfig, MLATransformerConfig
@dataclass
class ExtraTransformerConfig:
##################
# multi-token prediction
##################
num_nextn_predict_layers: int = 0
"""The number of multi-token prediction layers"""
def transformer_config_post_init_wrapper(fn):
@wraps(fn)
def wrapper(self):
fn(self)
args = get_args()
mtp_loss_scale: float = 0.3
"""Multi-token prediction loss scale"""
"""Number of Multi-Token Prediction (MTP) Layers."""
self.mtp_num_layers = args.mtp_num_layers
recompute_mtp_norm: bool = False
"""Whether to recompute mtp normalization"""
"""Weighting factor of Multi-Token Prediction (MTP) loss."""
self.mtp_loss_scaling_factor = args.mtp_loss_scaling_factor
recompute_mtp_layer: bool = False
"""Whether to recompute mtp layer"""
##################
# flux
##################
self.flux_transpose_weight = args.flux_transpose_weight
share_mtp_embedding_and_output_weight: bool = False
"""share embedding and output weight with mtp layer."""
return wrapper
@dataclass
class ExtraTransformerConfig:
##################
# flux
##################
......
......@@ -30,33 +30,3 @@ def is_flux_min_version(version, check_equality=True):
if check_equality:
return get_flux_version() >= PkgVersion(version)
return get_flux_version() > PkgVersion(version)
def tensor_slide(
tensor: Optional[torch.Tensor],
num_slice: int,
dims: Union[int, List[int]] = -1,
step: int = 1,
return_first=False,
) -> List[Union[torch.Tensor, None]]:
"""通用滑动窗口函数,支持任意维度"""
if tensor is None:
# return `List[None]` to avoid NoneType Error
return [None] * (num_slice + 1)
if num_slice == 0:
return [tensor]
window_size = tensor.shape[-1] - num_slice
dims = [dims] if isinstance(dims, int) else sorted(dims, reverse=True)
# 连续多维度滑动
slices = []
for i in range(0, tensor.size(dims[-1]) - window_size + 1, step):
slice_obj = [slice(None)] * tensor.dim()
for dim in dims:
slice_obj[dim] = slice(i, i + window_size)
slices.append(tensor[tuple(slice_obj)])
if return_first:
return slices
return slices
......@@ -170,14 +170,16 @@ def _add_extra_tokenizer_args(parser):
def _add_mtp_args(parser):
group = parser.add_argument_group(title='multi token prediction')
group.add_argument('--num-nextn-predict-layers', type=int, default=0, help='Multi-Token prediction layer num')
group.add_argument('--mtp-loss-scale', type=float, default=0.3, help='Multi-Token prediction loss scale')
group.add_argument('--recompute-mtp-norm', action='store_true', default=False,
help='Multi-Token prediction recompute norm')
group.add_argument('--recompute-mtp-layer', action='store_true', default=False,
help='Multi-Token prediction recompute layer')
group.add_argument('--share-mtp-embedding-and-output-weight', action='store_true', default=False,
help='Main model share embedding and output weight with mtp layer.')
group.add_argument('--mtp-num-layers', type=int, default=None,
help='Number of Multi-Token Prediction (MTP) Layers.'
'MTP extends the prediction scope to multiple future tokens at each position.'
'This MTP implementation sequentially predict additional tokens '
'by using D sequential modules to predict D additional tokens.')
group.add_argument('--mtp-loss-scaling-factor', type=float, default=0.3,
help='Scaling factor of Multi-Token Prediction (MTP) loss. '
'We compute the average of the MTP losses across all depths, '
'and multiply it the scaling factor to obtain the overall MTP loss, '
'which serves as an additional training objective.')
return parser
......
......@@ -7,6 +7,72 @@ from megatron.training import get_args
from megatron.core import mpu
def _compile_dependencies():
args = get_args()
# =========================
# Compile dataset C++ code.
# =========================
# TODO: move this to ninja
if torch.distributed.get_rank() == 0:
start_time = time.time()
print("> compiling dataset index builder ...")
from megatron.core.datasets.utils import compile_helpers
compile_helpers()
print(
">>> done with dataset index builder. Compilation time: {:.3f} "
"seconds".format(time.time() - start_time),
flush=True,
)
# ==================
# Load fused kernels
# ==================
# Custom kernel constraints check.
seq_len = args.seq_length
attn_batch_size = (
args.num_attention_heads / args.tensor_model_parallel_size
) * args.micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = (
seq_len > 16 and seq_len <= 16384 and seq_len % 4 == 0 and attn_batch_size % 4 == 0
)
# Print a warning.
if not ((args.fp16 or args.bf16) and custom_kernel_constraint and args.masked_softmax_fusion):
if args.rank == 0:
print(
"WARNING: constraints for invoking optimized"
" fused softmax kernel are not met. We default"
" back to unfused kernel invocations.",
flush=True,
)
# Always build on rank zero first.
if torch.distributed.get_rank() == 0:
start_time = time.time()
print("> compiling and loading fused kernels ...", flush=True)
#fused_kernels.load(args)
torch.distributed.barrier()
else:
torch.distributed.barrier()
#fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(
">>> done with compiling and loading fused kernels. "
"Compilation time: {:.3f} seconds".format(time.time() - start_time),
flush=True,
)
def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
"""Initialize torch.distributed and core model parallel."""
args = get_args()
......@@ -16,8 +82,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
if args.rank == 0:
print(
"torch distributed is already initialized, "
"skipping initialization ...",
"torch distributed is already initialized, " "skipping initialization ...",
flush=True,
)
args.rank = torch.distributed.get_rank()
......@@ -34,6 +99,10 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
else:
device_id = None
# Set to non-default stream for cudagraph capturing.
if args.external_cuda_graph:
torch.cuda.set_stream(torch.cuda.Stream())
# Call the init process
init_process_group_kwargs = {
'backend' : args.distributed_backend,
......@@ -56,6 +125,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank,
pipeline_model_parallel_comm_backend=args.pipeline_model_parallel_comm_backend,
context_parallel_size=args.context_parallel_size,
hierarchical_context_parallel_sizes=args.hierarchical_context_parallel_sizes,
expert_model_parallel_size=args.expert_model_parallel_size,
......@@ -68,6 +138,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
encoder_pipeline_model_parallel_size=args.encoder_pipeline_model_parallel_size,
get_embedding_ranks=get_embedding_ranks,
get_position_embedding_ranks=get_position_embedding_ranks,
create_gloo_process_groups=args.enable_gloo_process_groups,
)
if args.rank == 0:
print(
......@@ -78,76 +149,3 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
f"> initialized pipeline model parallel with size "
f"{mpu.get_pipeline_model_parallel_world_size()}"
)
def _compile_dependencies():
args = get_args()
# =========================
# Compile dataset C++ code.
# =========================
# TODO: move this to ninja
if torch.distributed.get_rank() == 0:
start_time = time.time()
print("> compiling dataset index builder ...")
from megatron.core.datasets.utils import compile_helpers
compile_helpers()
print(
">>> done with dataset index builder. Compilation time: {:.3f} "
"seconds".format(time.time() - start_time),
flush=True,
)
# ==================
# Load fused kernels
# ==================
# Custom kernel constraints check.
seq_len = args.seq_length
attn_batch_size = (
args.num_attention_heads / args.tensor_model_parallel_size
) * args.micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = (
seq_len > 16
and seq_len <= 16384
and seq_len % 4 == 0
and attn_batch_size % 4 == 0
)
# Print a warning.
if not (
(args.fp16 or args.bf16)
and custom_kernel_constraint
and args.masked_softmax_fusion
):
if args.rank == 0:
print(
"WARNING: constraints for invoking optimized"
" fused softmax kernel are not met. We default"
" back to unfused kernel invocations.",
flush=True,
)
# Always build on rank zero first.
if torch.distributed.get_rank() == 0:
start_time = time.time()
print("> compiling and loading fused kernels ...", flush=True)
#fused_kernels.load(args)
torch.distributed.barrier()
else:
torch.distributed.barrier()
#fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(
">>> done with compiling and loading fused kernels. "
"Compilation time: {:.3f} seconds".format(time.time() - start_time),
flush=True,
)
......@@ -50,14 +50,34 @@ from megatron.training.training import (
stimer = StragglerDetector()
def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func, config, checkpointing_context, non_loss_data_func):
def train(
forward_step_func,
model,
optimizer,
opt_param_scheduler,
train_data_iterator,
valid_data_iterator,
process_non_loss_data_func,
config,
checkpointing_context,
non_loss_data_func,
):
"""Training function: run train_step desired number of times, run validation, checkpoint."""
args = get_args()
timers = get_timers()
one_logger = get_one_logger()
if args.run_workload_inspector_server:
try:
from workload_inspector.utils.webserver import run_server
import threading
threading.Thread(
target=run_server, daemon=True, args=(torch.distributed.get_rank(),)
).start()
except ModuleNotFoundError:
print_rank_0("workload inspector module not found.")
# Write args to tensorboard
write_args_to_tensorboard()
......@@ -70,23 +90,35 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Iterations.
iteration = args.iteration
# Make sure rerun_state_machine has the right iteration loaded from checkpoint.
rerun_state_machine = get_rerun_state_machine()
if rerun_state_machine.current_iteration != iteration:
print_rank_0(f"Setting rerun_state_machine.current_iteration to {iteration}...")
rerun_state_machine.current_iteration = iteration
# Track E2E metrics at the start of training.
one_logger_utils.on_train_start(iteration=iteration, consumed_train_samples=args.consumed_train_samples,
train_samples=args.train_samples, seq_length=args.seq_length,
train_iters=args.train_iters, save=args.save, async_save=args.async_save,
log_throughput=args.log_throughput,
num_floating_point_operations_so_far=args.num_floating_point_operations_so_far)
one_logger_utils.on_train_start(
iteration=iteration,
consumed_train_samples=args.consumed_train_samples,
train_samples=args.train_samples,
seq_length=args.seq_length,
train_iters=args.train_iters,
save=args.save,
async_save=args.async_save,
log_throughput=args.log_throughput,
num_floating_point_operations_so_far=args.num_floating_point_operations_so_far,
)
num_floating_point_operations_so_far = args.num_floating_point_operations_so_far
# Setup some training config params.
config.grad_scale_func = optimizer.scale_loss
config.timers = timers
if isinstance(model[0], DDP) and args.overlap_grad_reduce:
assert config.no_sync_func is None, \
('When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce')
if isinstance(model[0], (custom_FSDP, DDP)) and args.overlap_grad_reduce:
assert config.no_sync_func is None, (
'When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce'
)
config.no_sync_func = [model_chunk.no_sync for model_chunk in model]
if len(model) == 1:
config.no_sync_func = config.no_sync_func[0]
......@@ -110,8 +142,9 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if args.manual_gc:
# Disable the default garbage collector and perform the collection manually.
# This is to align the timing of garbage collection across ranks.
assert args.manual_gc_interval >= 0, \
'Manual garbage collection interval should be larger than or equal to 0'
assert (
args.manual_gc_interval >= 0
), 'Manual garbage collection interval should be larger than or equal to 0'
gc.disable()
gc.collect()
......@@ -121,10 +154,13 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
world = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
mmcnt = args.straggler_minmax_count
stimer.configure(world, rank,
mmcnt = mmcnt,
enabled = not args.disable_straggler_on_startup,
port = args.straggler_ctrlr_port)
stimer.configure(
world,
rank,
mmcnt=mmcnt,
enabled=not args.disable_straggler_on_startup,
port=args.straggler_ctrlr_port,
)
num_floating_point_operations_since_last_log_event = 0.0
num_microbatches = get_num_microbatches()
......@@ -132,10 +168,10 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
eval_iterations = 0
def get_e2e_base_metrics():
"""Get base metrics values for one-logger to calculate E2E tracking metrics.
"""
num_floating_point_operations_since_current_train_start = \
"""Get base metrics values for one-logger to calculate E2E tracking metrics."""
num_floating_point_operations_since_current_train_start = (
num_floating_point_operations_so_far - args.num_floating_point_operations_so_far
)
return {
'iteration': iteration,
'train_duration': timers('interval-time').active_time(),
......@@ -145,7 +181,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
'num_floating_point_operations_so_far': num_floating_point_operations_so_far,
'consumed_train_samples': args.consumed_train_samples,
'world_size': args.world_size,
'seq_length': args.seq_length
'seq_length': args.seq_length,
}
# Cache into one-logger for callback.
if one_logger:
......@@ -153,7 +189,11 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
one_logger.store_set('get_e2e_base_metrics', get_e2e_base_metrics)
prof = None
if args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_pytorch_profiler:
if (
args.profile
and torch.distributed.get_rank() in args.profile_ranks
and args.use_pytorch_profiler
):
def trace_handler(p):
from pathlib import Path
Path(f"{args.profile_dir}").mkdir(parents=True, exist_ok=True)
......@@ -178,9 +218,9 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
warmup=1 if args.profile_step_start > 0 else 0,
active=args.profile_step_end-args.profile_step_start,
repeat=1),
on_trace_ready=trace_handler,
record_shapes=True,
#on_trace_ready=torch.profiler.tensorboard_trace_handler('./torch_prof_data'))
on_trace_ready=trace_handler)
)
prof.start()
elif args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_hip_profiler:
import ctypes
......@@ -190,7 +230,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Disable forward pre-hook to start training to ensure that errors in checkpoint loading
# or random initialization don't propagate to all ranks in first all-gather (which is a
# no-op if things work correctly).
if args.use_distributed_optimizer and args.overlap_param_gather:
if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model, param_sync=False)
# Also remove param_sync_func temporarily so that sync calls made in
# `forward_backward_func` are no-ops.
......@@ -199,8 +239,9 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
pre_hook_enabled = False
# Also, check weight hash across DP replicas to be very pedantic.
if args.check_weight_hash_across_dp_replicas_interval is not None:
assert check_param_hashes_across_dp_replicas(model, cross_check=True), \
"Parameter hashes not matching across DP replicas"
assert check_param_hashes_across_dp_replicas(
model, cross_check=True
), "Parameter hashes not matching across DP replicas"
torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...")
......@@ -226,33 +267,60 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# to make sure training configuration is still valid.
update_num_microbatches(args.consumed_train_samples, consistency_check=False, verbose=True)
if get_num_microbatches() != num_microbatches and iteration != 0:
assert get_num_microbatches() > num_microbatches, \
(f"Number of microbatches should be increasing due to batch size rampup; "
f"instead going from {num_microbatches} to {get_num_microbatches()}")
assert get_num_microbatches() > num_microbatches, (
f"Number of microbatches should be increasing due to batch size rampup; "
f"instead going from {num_microbatches} to {get_num_microbatches()}"
)
if args.save is not None:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
save_checkpoint_and_time(
iteration,
model,
optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator=train_data_iterator,
)
num_microbatches = get_num_microbatches()
update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True)
# Completely skip iteration if needed.
if iteration in args.iterations_to_skip:
# Dummy train_step to fast forward train_data_iterator.
dummy_train_step(train_data_iterator)
iteration += 1
batch_size = (
mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
)
args.consumed_train_samples += batch_size
args.skipped_train_samples += batch_size
continue
# Run training step.
args.curr_iteration = iteration
ft_integration.on_training_step_start()
loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
opt_param_scheduler,
config)
(
loss_dict,
skipped_iter,
should_checkpoint,
should_exit,
exit_code,
grad_norm,
num_zeros_in_grad,
) = train_step(
forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config
)
ft_integration.on_training_step_end()
if should_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
save_checkpoint_and_time(
iteration,
model,
optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator=train_data_iterator,
)
if should_exit:
break
......@@ -269,18 +337,19 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Enable forward pre-hook after training step has successfully run. All subsequent
# forward passes will use the forward pre-hook / `param_sync_func` in
# `forward_backward_func`.
if args.use_distributed_optimizer and args.overlap_param_gather:
if should_disable_forward_pre_hook(args):
enable_forward_pre_hook(model)
config.param_sync_func = param_sync_func
pre_hook_enabled = True
iteration += 1
batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
batch_size = (
mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
)
args.consumed_train_samples += batch_size
num_skipped_samples_in_batch = (get_current_global_batch_size() -
get_current_running_global_batch_size())
num_skipped_samples_in_batch = (
get_current_global_batch_size() - get_current_running_global_batch_size()
)
if args.decrease_batch_size_if_needed:
assert num_skipped_samples_in_batch >= 0
else:
......@@ -306,18 +375,24 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
decoupled_learning_rate = param_group['lr']
else:
learning_rate = param_group['lr']
report_memory_flag = training_log(loss_dict, total_loss_dict,
learning_rate,
decoupled_learning_rate,
iteration, loss_scale,
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
report_memory_flag = training_log(
loss_dict,
total_loss_dict,
learning_rate,
decoupled_learning_rate,
iteration,
loss_scale,
report_memory_flag,
skipped_iter,
grad_norm,
params_norm,
num_zeros_in_grad,
)
# Evaluation.
if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid:
if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
timers('interval-time').stop()
if args.use_distributed_optimizer and args.overlap_param_gather:
if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model)
pre_hook_enabled = False
if args.manual_gc and args.manual_gc_eval:
......@@ -325,11 +400,18 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
gc.collect()
prefix = f'iteration {iteration}'
timers('eval-time', log_level=0).start(barrier=True)
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func,
config, verbose=False, write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func)
evaluate_and_print_results(
prefix,
forward_step_func,
valid_data_iterator,
model,
iteration,
process_non_loss_data_func,
config,
verbose=False,
write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func,
)
eval_duration += timers('eval-time').elapsed()
eval_iterations += args.eval_iters
timers('eval-time').stop()
......@@ -338,20 +420,32 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if args.manual_gc and args.manual_gc_eval:
# Collect only the objects created and used in evaluation.
gc.collect(generation=0)
if args.use_distributed_optimizer and args.overlap_param_gather:
if should_disable_forward_pre_hook(args):
enable_forward_pre_hook(model)
pre_hook_enabled = True
timers('interval-time', log_level=0).start(barrier=True)
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations.
post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteration, prof,
num_floating_point_operations_since_last_log_event)
post_training_step_callbacks(
model,
optimizer,
opt_param_scheduler,
iteration,
prof,
num_floating_point_operations_since_last_log_event,
)
# Checkpoint and decide whether to exit.
should_exit = checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator)
should_exit = checkpoint_and_decide_exit(
model,
optimizer,
opt_param_scheduler,
iteration,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator,
)
if should_exit:
break
......@@ -367,8 +461,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
disable_forward_pre_hook(model)
ft_integration.on_checkpointing_start()
maybe_finalize_async_save(blocking=True)
# This will finalize all unfinalized async request and terminate
# a persistent async worker if persistent ckpt worker is enabled
maybe_finalize_async_save(blocking=True, terminate=True)
ft_integration.on_checkpointing_end(is_async_finalization=True)
if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
ft_integration.get_rank_monitor_client().shutdown_workload_monitoring()
# If any exit conditions (signal handler, duration, iterations) have been reached, exit.
if should_exit:
......@@ -376,6 +474,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if wandb_writer:
wandb_writer.finish()
ft_integration.shutdown()
one_logger_utils.finish()
sys.exit(exit_code)
return iteration, num_floating_point_operations_so_far
import torch
from megatron.core import mpu
from megatron.training import get_args
def get_batch_on_this_tp_rank(data_iterator):
args = get_args()
def _broadcast(item):
if item is not None:
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())
if mpu.get_tensor_model_parallel_rank() == 0:
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
batch = {
'tokens': data["tokens"].cuda(non_blocking = True),
'labels': data["labels"].cuda(non_blocking = True),
'loss_mask': data["loss_mask"].cuda(non_blocking = True),
'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True),
'position_ids': data["position_ids"].cuda(non_blocking = True)
}
if args.pipeline_model_parallel_size == 1:
_broadcast(batch['tokens'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])
elif mpu.is_pipeline_first_stage():
_broadcast(batch['tokens'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])
elif mpu.is_pipeline_last_stage():
if args.num_nextn_predict_layers:
_broadcast(batch['tokens'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
if args.reset_position_ids or args.num_nextn_predict_layers:
_broadcast(batch['position_ids'])
else:
tokens=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.int64,
device = torch.cuda.current_device())
labels=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.int64,
device = torch.cuda.current_device())
loss_mask=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.float32,
device = torch.cuda.current_device())
if args.create_attention_mask_in_dataloader:
attention_mask=torch.empty(
(args.micro_batch_size, 1, args.seq_length + args.num_nextn_predict_layers,
args.seq_length + args.num_nextn_predict_layers), dtype = torch.bool,
device = torch.cuda.current_device()
)
else:
attention_mask=None
position_ids=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.int64,
device = torch.cuda.current_device())
if args.pipeline_model_parallel_size == 1:
_broadcast(tokens)
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
_broadcast(position_ids)
elif mpu.is_pipeline_first_stage():
labels=None
loss_mask=None
_broadcast(tokens)
_broadcast(attention_mask)
_broadcast(position_ids)
elif mpu.is_pipeline_last_stage():
if args.num_nextn_predict_layers:
_broadcast(tokens)
else:
tokens = None
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
if args.reset_position_ids or args.num_nextn_predict_layers:
_broadcast(position_ids)
else:
position_ids = None
batch = {
'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'attention_mask': attention_mask,
'position_ids': position_ids
}
return batch
for para in $*
do
if [[ $para == --profiling* ]];then
profiling=${para#*=}
export GPU_FLUSH_ON_EXECUTION=1
export HIP_DIRECT_DISPATCH=0
fi
done
mpirun -np 8 --allow-run-as-root \
train_deepseek_v3_1node.sh localhost --profiling=$profiling > output.log 2>&1
wait
rm -rf CKPT
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment