Commit f3ef5e1b authored by dongcl's avatar dongcl
Browse files

patch for megatron commit 0595ef2b0c93f8d61f473c9f99f9ff73803ff919

parent bb6ab0fb
......@@ -49,6 +49,9 @@ class PipelineFeature(AbstractFeature):
_allreduce_embedding_grads_wrapper
)
from dcu_megatron.training.training import evaluate
from dcu_megatron.core.transformer.transformer_layer import get_transformer_layer_offset
from dcu_megatron.training.utils import get_batch_on_this_tp_rank
from dcu_megatron.training.training import build_train_valid_test_data_iterators_wrapper
patch_manager.register_patch(
'megatron.training.training.get_model', get_model)
......@@ -69,6 +72,20 @@ class PipelineFeature(AbstractFeature):
patch_manager.register_patch(
'megatron.training.training.evaluate', evaluate)
patch_manager.register_patch(
'megatron.core.transformer.transformer_layer.get_transformer_layer_offset', get_transformer_layer_offset)
# support dualpipev, two data iterators
patch_manager.register_patch(
'megatron.training.training.build_train_valid_test_data_iterators',
build_train_valid_test_data_iterators_wrapper,
apply_wrapper=True)
# support dualpipev, broadcast loss_mask and labels
patch_manager.register_patch(
'megatron.training.utils.get_batch_on_this_tp_rank',
get_batch_on_this_tp_rank)
if args.combined_1f1b:
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear
......@@ -86,10 +103,10 @@ class PipelineFeature(AbstractFeature):
from dcu_megatron.core.transformer.moe.moe_layer import MoELayer
patch_manager.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher',
MoEAlltoAllTokenDispatcher)
MoEAlltoAllTokenDispatcher)
patch_manager.register_patch('megatron.core.transformer.transformer_layer.TransformerLayer',
TransformerLayer)
TransformerLayer)
patch_manager.register_patch('megatron.core.models.gpt.gpt_model.GPTModel.build_schedule_plan',
GPTModel.build_schedule_plan,
......
......@@ -163,7 +163,6 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper
from ..core.transformer.transformer_layer import get_transformer_layer_offset
from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch
# Transformer block. If mtp_num_layers > 0, move final_layernorm outside
......@@ -190,10 +189,6 @@ class CoreAdaptation(MegatronAdaptationABC):
torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True)
# support dualpipev
MegatronAdaptation.register('megatron.core.transformer.transformer_layer.get_transformer_layer_offset',
get_transformer_layer_offset)
def patch_core_extentions(self):
import transformer_engine as te
......@@ -257,10 +252,10 @@ class CoreAdaptation(MegatronAdaptationABC):
from ..training.tokenizer import build_tokenizer
from ..training.initialize import _initialize_distributed
from ..training.initialize import _compile_dependencies
from ..training.training import train, build_train_valid_test_data_iterators_wrapper
from ..training.training import train
from ..training.initialize import _set_random_seed
from ..training.utils import get_batch_on_this_tp_rank
# add Llama3Tokenizer, QwenTokenizer, DeepSeekV2Tokenizer
MegatronAdaptation.register('megatron.training.tokenizer.tokenizer.build_tokenizer',
build_tokenizer)
# specify init_method
......@@ -278,15 +273,6 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.training.training.train',
train)
# support dualpipev, two data iterators
MegatronAdaptation.register('megatron.training.training.build_train_valid_test_data_iterators',
build_train_valid_test_data_iterators_wrapper,
apply_wrapper=True)
# support dualpipev, broadcast loss_mask and labels
MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank',
get_batch_on_this_tp_rank)
def patch_miscellaneous(self):
from ..training.arguments import parse_args
......
......@@ -160,6 +160,7 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
k_channels: Optional[int] = None,
v_channels: Optional[int] = None,
cp_comm_type: str = "p2p",
model_comm_pgs: ModelCommProcessGroups = None,
):
self.config = config
self.te_forward_mask_type = False
......@@ -186,6 +187,26 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
f"num_attention_heads ({self.config.num_attention_heads}))"
)
if model_comm_pgs is None:
# For backward compatibility, remove in v0.14 and raise error
# raise ValueError("TEDotProductAttention was called without ModelCommProcessGroups")
model_comm_pgs = ModelCommProcessGroups(
tp=get_tensor_model_parallel_group(check_initialized=False),
cp=get_context_parallel_group(check_initialized=False),
hcp=get_hierarchical_context_parallel_groups(check_initialized=False),
)
else:
assert hasattr(
model_comm_pgs, 'tp'
), "TEDotProductAttention model_comm_pgs must have tp pg"
assert hasattr(
model_comm_pgs, 'cp'
), "TEDotProductAttention model_comm_pgs must have cp pg"
if cp_comm_type == "a2a+p2p":
assert hasattr(
model_comm_pgs, 'hcp'
), "TEDotProductAttention model_comm_pgs must have hierarchical cp pg"
if is_te_min_version("0.10.0"):
extra_kwargs["attention_type"] = attention_type
# older version don't need attention_type
......@@ -201,9 +222,9 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if getattr(TEDotProductAttention, "cp_stream") is None:
TEDotProductAttention.cp_stream = torch.cuda.Stream()
extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(
check_initialized=False
extra_kwargs["cp_group"] = model_comm_pgs.cp
extra_kwargs["cp_global_ranks"] = torch.distributed.get_process_group_ranks(
model_comm_pgs.cp
)
extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
if is_te_min_version("1.10.0"):
......@@ -277,7 +298,7 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
tp_group=get_tensor_model_parallel_group(check_initialized=False),
tp_group=model_comm_pgs.tp,
layer_number=layer_number,
**extra_kwargs,
)
......@@ -294,7 +315,6 @@ if is_te_min_version("1.9.0.dev0"):
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def __init__(
self,
num_gemms: int,
......@@ -308,6 +328,7 @@ if is_te_min_version("1.9.0.dev0"):
skip_bias_add: bool,
is_expert: bool = False,
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
args = get_args()
self.split_bw = args.split_bw if hasattr(args, "split_bw") else False
......@@ -329,6 +350,7 @@ if is_te_min_version("1.9.0.dev0"):
skip_bias_add=skip_bias_add,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
tp_group=tp_group,
)
def backward_dw(self):
......
......@@ -288,7 +288,7 @@ class MoeAttnNode(TransformerLayerNode):
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
probs,
permuted_probs,
) = self.layer._submodule_attention_router_compound_forward(
hidden_states,
attention_mask=attention_mask,
......@@ -304,11 +304,10 @@ class MoeAttnNode(TransformerLayerNode):
self.common_state.tokens_per_expert = tokens_per_expert
# detached here
self.common_state.probs = self.detach(probs)
self.common_state.residual = self.detach(hidden_states)
self.common_state.pre_mlp_layernorm_output = self.detach(pre_mlp_layernorm_output)
return permutated_local_input_tokens
return permutated_local_input_tokens, permuted_probs
def dw(self):
with torch.cuda.nvtx.range(f"{self.name} wgrad"):
......@@ -317,26 +316,26 @@ class MoeAttnNode(TransformerLayerNode):
class MoeDispatchNode(TransformerLayerNode):
def forward_impl(self, permutated_local_input_tokens):
def forward_impl(self, permutated_local_input_tokens, permuted_probs):
token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state):
tokens_per_expert, global_input_tokens = token_dispatcher.dispatch_all_to_all(
self.common_state.tokens_per_expert, permutated_local_input_tokens
tokens_per_expert, global_input_tokens, global_probs = token_dispatcher.dispatch_all_to_all(
self.common_state.tokens_per_expert, permutated_local_input_tokens, permuted_probs
)
# release tensor not used by backward
# inputs.untyped_storage().resize_(0)
self.common_state.tokens_per_expert = tokens_per_expert
return global_input_tokens
return global_input_tokens, global_probs
class MoeMlPNode(TransformerLayerNode):
def forward_impl(self, global_input_tokens):
def forward_impl(self, global_input_tokens, global_probs):
pre_mlp_layernorm_output = self.common_state.pre_mlp_layernorm_output
token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state):
expert_output, shared_expert_output, mlp_bias = self.layer._submodule_moe_forward(
self.common_state.tokens_per_expert, global_input_tokens, pre_mlp_layernorm_output
self.common_state.tokens_per_expert, global_input_tokens, global_probs, pre_mlp_layernorm_output
)
assert mlp_bias is None
......@@ -363,9 +362,7 @@ class MoeCombineNode(TransformerLayerNode):
)
cur_stream = torch.cuda.current_stream()
self.common_state.residual.record_stream(cur_stream)
self.common_state.probs.record_stream(cur_stream)
self.common_state.residual = None
self.common_state.probs = None
return output
......
......@@ -125,8 +125,9 @@ def gpt_model_forward(
and inference_context.is_static_batching()
and not self.training
):
current_batch_size = input_ids.shape[0]
sequence_len_offset = torch.tensor(
[inference_context.sequence_len_offset] * inference_context.current_batch_size,
[inference_context.sequence_len_offset] * current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
......@@ -156,12 +157,6 @@ def gpt_model_forward(
**(extra_block_kwargs or {}),
)
# Process inference output.
if inference_context and not inference_context.is_static_batching():
hidden_states = inference_context.last_token_logits(
hidden_states.squeeze(1).unsqueeze(0)
).unsqueeze(1)
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
......@@ -202,10 +197,17 @@ def gpt_model_forward(
if (
not self.training
and inference_context is not None
and inference_context.is_static_batching()
and inference_context.materialize_only_last_token_logits
):
hidden_states = hidden_states[-1:, :, :]
if inference_context.is_static_batching():
hidden_states = hidden_states[-1:, :, :]
else:
# Reshape [B, 1, H] to [1, B, H] → extract each sample’s true last‐token hidden
# state ([B, H]) → unsqueeze back to [1, B, H]
# (so that the output layer, which expects S×B×H, receives only the final token)
hidden_states = inference_context.last_token_logits(
hidden_states.squeeze(1).unsqueeze(0)
).unsqueeze(1)
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
)
......
......@@ -12,8 +12,11 @@ from megatron.core.transformer.moe.moe_utils import (
permute,
sort_chunks_by_idxs,
unpermute,
pad_routing_map,
)
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher as MegatronCoreMoEAlltoAllTokenDispatcher
from megatron.core.fp8_utils import get_fp8_align_size
from megatron.core.fusions.fused_pad_routing_map import fused_pad_routing_map
from dcu_megatron.core.tensor_parallel import all_to_all
......@@ -101,6 +104,12 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
assert routing_map.dim() == 2, "Expected 2D tensor for token2expert mask"
assert routing_map.dtype == torch.bool, "Expected bool tensor for mask"
if self.config.moe_router_padding_for_fp8:
pad_multiple = get_fp8_align_size(self.config.fp8_recipe)
if experimental_config.ENABLE_EXPERIMENTAL and self.config.moe_permute_fusion:
self.routing_map = fused_pad_routing_map(self.routing_map, pad_multiple)
else:
self.routing_map = pad_routing_map(self.routing_map, pad_multiple)
tokens_per_expert = self.preprocess(self.routing_map)
return tokens_per_expert
......@@ -117,18 +126,20 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
self.hidden_shape_before_permute = hidden_states.shape
(
permutated_local_input_tokens,
permuted_probs,
self.reversed_local_input_permutation_mapping,
) = permute(
hidden_states,
routing_map,
self.probs,
num_out_tokens=self.num_out_tokens,
fused=self.config.moe_permute_fusion,
drop_and_pad=self.drop_and_pad,
)
return tokens_per_expert, permutated_local_input_tokens
return tokens_per_expert, permutated_local_input_tokens, permuted_probs
def dispatch_all_to_all(self, tokens_per_expert, permutated_local_input_tokens):
def dispatch_all_to_all(self, tokens_per_expert, permutated_local_input_tokens, permuted_probs):
# Perform expert parallel AlltoAll communication
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_ep_alltoall", tokens_per_expert
......@@ -136,10 +147,13 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
global_input_tokens = all_to_all(
self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits, use_qcomm=self.use_qcomm
)
global_probs = all_to_all(
self.ep_group, permuted_probs, self.output_splits, self.input_splits, use_qcomm=self.use_qcomm
)
return tokens_per_expert, global_input_tokens
return tokens_per_expert, global_input_tokens, global_probs
def dispatch_postprocess(self, tokens_per_expert, global_input_tokens):
def dispatch_postprocess(self, tokens_per_expert, global_input_tokens, global_probs):
if self.shared_experts is not None:
self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)
......@@ -152,6 +166,9 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
global_input_tokens = gather_from_sequence_parallel_region(
global_input_tokens, group=self.tp_group, output_split_sizes=output_split_sizes
)
global_probs = gather_from_sequence_parallel_region(
global_probs, group=self.tp_group, output_split_sizes=output_split_sizes
)
# Permutation 2: Sort tokens by local expert.
tokens_per_expert = self._maybe_dtoh_and_synchronize(
......@@ -170,16 +187,28 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
.contiguous()
.flatten(start_dim=0, end_dim=2)
)
global_probs = (
global_probs.view(
self.tp_size * self.ep_size,
self.num_local_experts,
self.capacity,
*global_probs.size()[1:],
)
.transpose(0, 1)
.contiguous()
.flatten(start_dim=0, end_dim=2)
)
else:
global_input_tokens = sort_chunks_by_idxs(
global_input_tokens, global_probs = sort_chunks_by_idxs(
global_input_tokens,
self.num_global_tokens_per_local_expert.ravel(),
self.sort_input_by_local_experts,
probs=global_probs,
fused=self.config.moe_permute_fusion,
)
tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert)
return global_input_tokens, tokens_per_expert
return global_input_tokens, tokens_per_expert, global_probs
def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
......@@ -207,15 +236,15 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
# Preprocess: Get the metadata for communication, permutation and computation operations.
# Permutation 1: input to AlltoAll input
tokens_per_expert = self.meta_prepare(hidden_states, probs, routing_map)
tokens_per_expert, permutated_local_input_tokens = self.dispatch_preprocess(hidden_states, routing_map, tokens_per_expert)
tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.dispatch_preprocess(hidden_states, routing_map, tokens_per_expert)
# Perform expert parallel AlltoAll communication
tokens_per_expert, global_input_tokens = self.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens)
tokens_per_expert, global_input_tokens, global_probs = self.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens, permuted_probs)
# Permutation 2: Sort tokens by local expert.
global_input_tokens, tokens_per_expert = self.dispatch_postprocess(tokens_per_expert, global_input_tokens)
global_input_tokens, tokens_per_expert, global_probs = self.dispatch_postprocess(tokens_per_expert, global_input_tokens, global_probs)
return global_input_tokens, tokens_per_expert
return global_input_tokens, tokens_per_expert, global_probs
def combine_preprocess(self, hidden_states):
# Unpermutation 2: Unsort tokens by local expert.
......@@ -272,7 +301,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping,
restore_shape=self.hidden_shape_before_permute,
probs=self.probs,
routing_map=self.routing_map,
fused=self.config.moe_permute_fusion,
drop_and_pad=self.drop_and_pad,
......
......@@ -8,6 +8,8 @@ from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import (
deprecate_inference_params,
make_viewless_tensor,
nvtx_range_pop,
nvtx_range_push,
)
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer
......@@ -15,7 +17,7 @@ from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispa
from megatron.core.transformer.transformer_config import TransformerConfig
def get_transformer_layer_offset(config: TransformerConfig):
def get_transformer_layer_offset(config: TransformerConfig, vp_stage: Optional[int] = None):
"""Get the index offset of current pipeline stage, given the level of pipelining."""
args = get_args()
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
......@@ -67,9 +69,10 @@ def get_transformer_layer_offset(config: TransformerConfig):
- num_layers_in_last_pipeline_stage
)
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
if (vp_size := config.virtual_pipeline_model_parallel_size) is not None:
assert (
vp_stage is not None
), "vp_stage must be provided if virtual pipeline model parallel size is set"
# Calculate number of layers in each virtual model chunk
# If the num_layers_in_first_pipeline_stage and
......@@ -100,10 +103,10 @@ def get_transformer_layer_offset(config: TransformerConfig):
# Calculate the layer offset with interleaved uneven pipeline parallelism
if pipeline_rank == 0:
offset = vp_rank * total_virtual_chunks
offset = vp_stage * total_virtual_chunks
else:
offset = (
vp_rank * total_virtual_chunks
vp_stage * total_virtual_chunks
+ num_layers_per_virtual_model_chunk_in_first_pipeline_stage
+ (pipeline_rank - 1)
* (
......@@ -151,20 +154,23 @@ def get_transformer_layer_offset(config: TransformerConfig):
if args.schedule_method == 'dualpipev':
num_layers_per_pipeline_rank = num_layers_per_pipeline_rank // 2
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
if (vp_size := config.virtual_pipeline_model_parallel_size) is not None:
assert (
vp_stage is not None
), "vp_stage must be provided if virtual pipeline model parallel size is set"
num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
total_virtual_chunks = num_layers // vp_size
offset = vp_rank * total_virtual_chunks + (
offset = vp_stage * total_virtual_chunks + (
pipeline_rank * num_layers_per_virtual_rank
)
# Reduce the offset of embedding layer from the total layer number
if (
config.account_for_embedding_in_pipeline_split
and not parallel_state.is_pipeline_first_stage()
and not parallel_state.is_pipeline_first_stage(
ignore_virtual=False, vp_stage=vp_stage
)
):
offset -= 1
else:
......@@ -176,7 +182,9 @@ def get_transformer_layer_offset(config: TransformerConfig):
# Reduce the offset of embedding layer from the total layer number
if (
config.account_for_embedding_in_pipeline_split
and not parallel_state.is_pipeline_first_stage()
and not parallel_state.is_pipeline_first_stage(
ignore_virtual=False, vp_stage=vp_stage
)
):
offset -= 1
else:
......@@ -188,9 +196,9 @@ class TransformerLayer(MegatronCoreTransformerLayer):
def forward(
self,
hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
context: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
......@@ -208,9 +216,9 @@ class TransformerLayer(MegatronCoreTransformerLayer):
):
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
......@@ -226,7 +234,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
_,
permuted_probs,
) = self._submodule_attention_router_compound_forward(
hidden_states,
attention_mask,
......@@ -240,14 +248,16 @@ class TransformerLayer(MegatronCoreTransformerLayer):
inference_params=inference_params,
)
(tokens_per_expert, global_input_tokens) = self._submodule_dispatch_forward(
(tokens_per_expert, global_input_tokens, global_probs) = self._submodule_dispatch_forward(
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs,
)
(expert_output, shared_expert_output, mlp_bias) = self._submodule_moe_forward(
tokens_per_expert,
global_input_tokens,
global_probs,
pre_mlp_layernorm_output
)
......@@ -292,6 +302,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
input_layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
nvtx_range_push(suffix="self_attention")
attention_output_with_bias = self.self_attention(
input_layernorm_output,
attention_mask=attention_mask,
......@@ -303,6 +314,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
nvtx_range_pop(suffix="self_attention")
if self.recompute_input_layernorm:
# discard the output of the input layernorm and register the recompute
......@@ -313,10 +325,12 @@ class TransformerLayer(MegatronCoreTransformerLayer):
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push(suffix="self_attn_bda")
with self.bias_dropout_add_exec_handler():
hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
nvtx_range_pop(suffix="self_attn_bda")
return hidden_states
......@@ -363,7 +377,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
tokens_per_expert = self.mlp.token_dispatcher.meta_prepare(
pre_mlp_layernorm_output, probs, routing_map
)
tokens_per_expert, permutated_local_input_tokens = self.mlp.token_dispatcher.dispatch_preprocess(
tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.mlp.token_dispatcher.dispatch_preprocess(
pre_mlp_layernorm_output, routing_map, tokens_per_expert
)
......@@ -372,18 +386,18 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
probs,
permuted_probs,
]
return tuple(outputs)
def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens):
def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens, permuted_probs):
"""
Dispatches tokens to the appropriate experts based on the router output.
"""
tokens_per_expert, global_input_tokens = self.mlp.token_dispatcher.dispatch_all_to_all(
tokens_per_expert, permutated_local_input_tokens
tokens_per_expert, global_input_tokens, global_probs = self.mlp.token_dispatcher.dispatch_all_to_all(
tokens_per_expert, permutated_local_input_tokens, permuted_probs
)
return [tokens_per_expert, global_input_tokens]
return [tokens_per_expert, global_input_tokens, global_probs]
def _submodule_dense_forward(self, hidden_states):
residual = hidden_states
......@@ -399,18 +413,20 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return output
def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, pre_mlp_layernorm_output):
def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, global_probs, pre_mlp_layernorm_output):
"""
Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
"""
shared_expert_output = None
(dispatched_input, tokens_per_expert) = (
self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens)
(dispatched_input, tokens_per_expert, permuted_probs) = (
self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens, global_probs)
)
expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert)
expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert, permuted_probs)
expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output)
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
shared_expert_output = self.mlp.shared_experts(pre_mlp_layernorm_output)
return expert_output, shared_expert_output, mlp_bias
......@@ -438,10 +454,19 @@ class TransformerLayer(MegatronCoreTransformerLayer):
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push(suffix="mlp_bda")
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout
)
nvtx_range_pop(suffix="mlp_bda")
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
......
......@@ -6,6 +6,7 @@ from functools import wraps
from megatron.training import get_args
from megatron.core import tensor_parallel
from megatron.legacy.model.enums import AttnType
from megatron.core.utils import deprecate_inference_params
from megatron.core.models.common.embeddings import apply_rotary_pos_emb
from megatron.legacy.model.module import MegatronModule
......@@ -86,26 +87,21 @@ def parallel_attention_init_wrapper(fn):
return wrapper
class ParallelAttentionPatch(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None,
rotary_pos_emb=None):
encoder_output=None, inference_context=None,
rotary_pos_emb=None, *, inference_params=None):
# hidden_states: [sq, b, h]
inference_context = deprecate_inference_params(inference_context, inference_params)
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step = False
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
if inference_context:
if self.layer_number not in inference_context.key_value_memory_dict:
inf_max_seq_len = inference_context.max_sequence_length
inf_max_batch_size = inference_context.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size,
self.num_query_groups_per_partition)
......@@ -113,12 +109,12 @@ class ParallelAttentionPatch(MegatronModule):
inf_max_seq_len, inf_max_batch_size,
self.num_query_groups_per_partition)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_context.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory)
is_first_step = True
else:
inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number]
inference_context.key_value_memory_dict[self.layer_number]
# =====================
# Query, Key, and Value
......@@ -188,13 +184,14 @@ class ParallelAttentionPatch(MegatronModule):
else:
rotary_pos_emb = ((rotary_pos_emb,) * 2)
if inference_params:
batch_start = inference_params.batch_size_offset
if inference_context:
batch_start = inference_context.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_start = inference_context.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
assert sequence_end <= inference_key_memory.size(0), ("Current sequence length is "
"longer than expected maximum sequence length! Increase inference_max_seq_length.")
# Copy key and values.
inference_key_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = key_layer
......
......@@ -3,6 +3,7 @@ import argparse
from typing import Union
from megatron.training.arguments import add_megatron_arguments
from megatron.core.msc_utils import MultiStorageClientFeature
from dcu_megatron.adaptor.features_manager import ADAPTOR_FEATURES
......@@ -65,6 +66,12 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
# args.rank = int(os.getenv('RANK', '0'))
# args.world_size = int(os.getenv("WORLD_SIZE", '1'))
# Args to disable MSC
if not args.enable_msc:
MultiStorageClientFeature.disable()
assert MultiStorageClientFeature.is_enabled() is False
print('WARNING: The MSC feature is disabled.')
return args
......
......@@ -8,6 +8,7 @@ from datetime import timedelta
from megatron.training import get_args
from megatron.core import mpu, tensor_parallel
from megatron.training import inprocess_restart
def _compile_dependencies():
......@@ -76,7 +77,7 @@ def _compile_dependencies():
)
def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, store):
"""Initialize torch.distributed and core model parallel."""
args = get_args()
......@@ -109,6 +110,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
# Call the init process
init_process_group_kwargs = {
'backend': args.distributed_backend,
'store': store,
'world_size': args.world_size,
'rank': args.rank,
'init_method': args.dist_url,
......@@ -116,6 +118,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
}
torch.distributed.init_process_group(**init_process_group_kwargs)
inprocess_restart.maybe_force_nccl_backend_init(device_id)
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
......@@ -129,6 +132,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank,
pipeline_model_parallel_comm_backend=args.pipeline_model_parallel_comm_backend,
use_sharp=args.use_sharp,
context_parallel_size=args.context_parallel_size,
hierarchical_context_parallel_sizes=args.hierarchical_context_parallel_sizes,
expert_model_parallel_size=args.expert_model_parallel_size,
......@@ -142,6 +146,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
get_embedding_ranks=get_embedding_ranks,
get_position_embedding_ranks=get_position_embedding_ranks,
create_gloo_process_groups=args.enable_gloo_process_groups,
high_priority_stream_groups=args.high_priority_stream_groups,
)
if args.rank == 0:
print(
......
......@@ -19,7 +19,11 @@ def get_batch_on_this_tp_rank(data_iterator):
def _broadcast(item):
if item is not None:
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())
torch.distributed.broadcast(
item,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group(),
)
if mpu.get_tensor_model_parallel_rank() == 0:
......@@ -29,11 +33,15 @@ def get_batch_on_this_tp_rank(data_iterator):
data = None
batch = {
'tokens': data["tokens"].cuda(non_blocking = True),
'labels': data["labels"].cuda(non_blocking = True),
'loss_mask': data["loss_mask"].cuda(non_blocking = True),
'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True),
'position_ids': data["position_ids"].cuda(non_blocking = True)
'tokens': data["tokens"].cuda(non_blocking=True),
'labels': data["labels"].cuda(non_blocking=True),
'loss_mask': data["loss_mask"].cuda(non_blocking=True),
'attention_mask': (
None
if "attention_mask" not in data
else data["attention_mask"].cuda(non_blocking=True)
),
'position_ids': data["position_ids"].cuda(non_blocking=True),
}
if args.pipeline_model_parallel_size == 1:
......@@ -64,16 +72,34 @@ def get_batch_on_this_tp_rank(data_iterator):
else:
tokens=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
labels=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
loss_mask=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.float32 , device = torch.cuda.current_device())
tokens = torch.empty(
(args.micro_batch_size, args.seq_length),
dtype=torch.int64,
device=torch.cuda.current_device(),
)
labels = torch.empty(
(args.micro_batch_size, args.seq_length),
dtype=torch.int64,
device=torch.cuda.current_device(),
)
loss_mask = torch.empty(
(args.micro_batch_size, args.seq_length),
dtype=torch.float32,
device=torch.cuda.current_device(),
)
if args.create_attention_mask_in_dataloader:
attention_mask=torch.empty(
(args.micro_batch_size,1,args.seq_length,args.seq_length), dtype = torch.bool , device = torch.cuda.current_device()
attention_mask = torch.empty(
(args.micro_batch_size, 1, args.seq_length, args.seq_length),
dtype=torch.bool,
device=torch.cuda.current_device(),
)
else:
attention_mask=None
position_ids=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
attention_mask = None
position_ids = torch.empty(
(args.micro_batch_size, args.seq_length),
dtype=torch.int64,
device=torch.cuda.current_device(),
)
if args.pipeline_model_parallel_size == 1:
_broadcast(tokens)
......@@ -117,4 +143,4 @@ def get_batch_on_this_tp_rank(data_iterator):
'position_ids': position_ids
}
return batch
\ No newline at end of file
return batch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment