Commit f3ef5e1b authored by dongcl's avatar dongcl
Browse files

patch for megatron commit 0595ef2b0c93f8d61f473c9f99f9ff73803ff919

parent bb6ab0fb
...@@ -49,6 +49,9 @@ class PipelineFeature(AbstractFeature): ...@@ -49,6 +49,9 @@ class PipelineFeature(AbstractFeature):
_allreduce_embedding_grads_wrapper _allreduce_embedding_grads_wrapper
) )
from dcu_megatron.training.training import evaluate from dcu_megatron.training.training import evaluate
from dcu_megatron.core.transformer.transformer_layer import get_transformer_layer_offset
from dcu_megatron.training.utils import get_batch_on_this_tp_rank
from dcu_megatron.training.training import build_train_valid_test_data_iterators_wrapper
patch_manager.register_patch( patch_manager.register_patch(
'megatron.training.training.get_model', get_model) 'megatron.training.training.get_model', get_model)
...@@ -69,6 +72,20 @@ class PipelineFeature(AbstractFeature): ...@@ -69,6 +72,20 @@ class PipelineFeature(AbstractFeature):
patch_manager.register_patch( patch_manager.register_patch(
'megatron.training.training.evaluate', evaluate) 'megatron.training.training.evaluate', evaluate)
patch_manager.register_patch(
'megatron.core.transformer.transformer_layer.get_transformer_layer_offset', get_transformer_layer_offset)
# support dualpipev, two data iterators
patch_manager.register_patch(
'megatron.training.training.build_train_valid_test_data_iterators',
build_train_valid_test_data_iterators_wrapper,
apply_wrapper=True)
# support dualpipev, broadcast loss_mask and labels
patch_manager.register_patch(
'megatron.training.utils.get_batch_on_this_tp_rank',
get_batch_on_this_tp_rank)
if args.combined_1f1b: if args.combined_1f1b:
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear
...@@ -86,10 +103,10 @@ class PipelineFeature(AbstractFeature): ...@@ -86,10 +103,10 @@ class PipelineFeature(AbstractFeature):
from dcu_megatron.core.transformer.moe.moe_layer import MoELayer from dcu_megatron.core.transformer.moe.moe_layer import MoELayer
patch_manager.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher', patch_manager.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher',
MoEAlltoAllTokenDispatcher) MoEAlltoAllTokenDispatcher)
patch_manager.register_patch('megatron.core.transformer.transformer_layer.TransformerLayer', patch_manager.register_patch('megatron.core.transformer.transformer_layer.TransformerLayer',
TransformerLayer) TransformerLayer)
patch_manager.register_patch('megatron.core.models.gpt.gpt_model.GPTModel.build_schedule_plan', patch_manager.register_patch('megatron.core.models.gpt.gpt_model.GPTModel.build_schedule_plan',
GPTModel.build_schedule_plan, GPTModel.build_schedule_plan,
......
...@@ -163,7 +163,6 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -163,7 +163,6 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_core_transformers(self): def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper from ..core import transformer_block_init_wrapper
from ..core.transformer.transformer_layer import get_transformer_layer_offset
from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch
# Transformer block. If mtp_num_layers > 0, move final_layernorm outside # Transformer block. If mtp_num_layers > 0, move final_layernorm outside
...@@ -190,10 +189,6 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -190,10 +189,6 @@ class CoreAdaptation(MegatronAdaptationABC):
torch.compile(mode='max-autotune-no-cudagraphs'), torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True) apply_wrapper=True)
# support dualpipev
MegatronAdaptation.register('megatron.core.transformer.transformer_layer.get_transformer_layer_offset',
get_transformer_layer_offset)
def patch_core_extentions(self): def patch_core_extentions(self):
import transformer_engine as te import transformer_engine as te
...@@ -257,10 +252,10 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -257,10 +252,10 @@ class CoreAdaptation(MegatronAdaptationABC):
from ..training.tokenizer import build_tokenizer from ..training.tokenizer import build_tokenizer
from ..training.initialize import _initialize_distributed from ..training.initialize import _initialize_distributed
from ..training.initialize import _compile_dependencies from ..training.initialize import _compile_dependencies
from ..training.training import train, build_train_valid_test_data_iterators_wrapper from ..training.training import train
from ..training.initialize import _set_random_seed from ..training.initialize import _set_random_seed
from ..training.utils import get_batch_on_this_tp_rank
# add Llama3Tokenizer, QwenTokenizer, DeepSeekV2Tokenizer
MegatronAdaptation.register('megatron.training.tokenizer.tokenizer.build_tokenizer', MegatronAdaptation.register('megatron.training.tokenizer.tokenizer.build_tokenizer',
build_tokenizer) build_tokenizer)
# specify init_method # specify init_method
...@@ -278,15 +273,6 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -278,15 +273,6 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.training.training.train', MegatronAdaptation.register('megatron.training.training.train',
train) train)
# support dualpipev, two data iterators
MegatronAdaptation.register('megatron.training.training.build_train_valid_test_data_iterators',
build_train_valid_test_data_iterators_wrapper,
apply_wrapper=True)
# support dualpipev, broadcast loss_mask and labels
MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank',
get_batch_on_this_tp_rank)
def patch_miscellaneous(self): def patch_miscellaneous(self):
from ..training.arguments import parse_args from ..training.arguments import parse_args
......
...@@ -160,6 +160,7 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention): ...@@ -160,6 +160,7 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
k_channels: Optional[int] = None, k_channels: Optional[int] = None,
v_channels: Optional[int] = None, v_channels: Optional[int] = None,
cp_comm_type: str = "p2p", cp_comm_type: str = "p2p",
model_comm_pgs: ModelCommProcessGroups = None,
): ):
self.config = config self.config = config
self.te_forward_mask_type = False self.te_forward_mask_type = False
...@@ -186,6 +187,26 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention): ...@@ -186,6 +187,26 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
f"num_attention_heads ({self.config.num_attention_heads}))" f"num_attention_heads ({self.config.num_attention_heads}))"
) )
if model_comm_pgs is None:
# For backward compatibility, remove in v0.14 and raise error
# raise ValueError("TEDotProductAttention was called without ModelCommProcessGroups")
model_comm_pgs = ModelCommProcessGroups(
tp=get_tensor_model_parallel_group(check_initialized=False),
cp=get_context_parallel_group(check_initialized=False),
hcp=get_hierarchical_context_parallel_groups(check_initialized=False),
)
else:
assert hasattr(
model_comm_pgs, 'tp'
), "TEDotProductAttention model_comm_pgs must have tp pg"
assert hasattr(
model_comm_pgs, 'cp'
), "TEDotProductAttention model_comm_pgs must have cp pg"
if cp_comm_type == "a2a+p2p":
assert hasattr(
model_comm_pgs, 'hcp'
), "TEDotProductAttention model_comm_pgs must have hierarchical cp pg"
if is_te_min_version("0.10.0"): if is_te_min_version("0.10.0"):
extra_kwargs["attention_type"] = attention_type extra_kwargs["attention_type"] = attention_type
# older version don't need attention_type # older version don't need attention_type
...@@ -201,9 +222,9 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention): ...@@ -201,9 +222,9 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!" ), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if getattr(TEDotProductAttention, "cp_stream") is None: if getattr(TEDotProductAttention, "cp_stream") is None:
TEDotProductAttention.cp_stream = torch.cuda.Stream() TEDotProductAttention.cp_stream = torch.cuda.Stream()
extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False) extra_kwargs["cp_group"] = model_comm_pgs.cp
extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks( extra_kwargs["cp_global_ranks"] = torch.distributed.get_process_group_ranks(
check_initialized=False model_comm_pgs.cp
) )
extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
if is_te_min_version("1.10.0"): if is_te_min_version("1.10.0"):
...@@ -277,7 +298,7 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention): ...@@ -277,7 +298,7 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
get_rng_state_tracker=( get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
), ),
tp_group=get_tensor_model_parallel_group(check_initialized=False), tp_group=model_comm_pgs.tp,
layer_number=layer_number, layer_number=layer_number,
**extra_kwargs, **extra_kwargs,
) )
...@@ -294,7 +315,6 @@ if is_te_min_version("1.9.0.dev0"): ...@@ -294,7 +315,6 @@ if is_te_min_version("1.9.0.dev0"):
yet, the tp_group passed to TE will be None and must be set later yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group(). via set_tensor_parallel_group().
""" """
def __init__( def __init__(
self, self,
num_gemms: int, num_gemms: int,
...@@ -308,6 +328,7 @@ if is_te_min_version("1.9.0.dev0"): ...@@ -308,6 +328,7 @@ if is_te_min_version("1.9.0.dev0"):
skip_bias_add: bool, skip_bias_add: bool,
is_expert: bool = False, is_expert: bool = False,
tp_comm_buffer_name: Optional[str] = None, tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
): ):
args = get_args() args = get_args()
self.split_bw = args.split_bw if hasattr(args, "split_bw") else False self.split_bw = args.split_bw if hasattr(args, "split_bw") else False
...@@ -329,6 +350,7 @@ if is_te_min_version("1.9.0.dev0"): ...@@ -329,6 +350,7 @@ if is_te_min_version("1.9.0.dev0"):
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
is_expert=is_expert, is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name, tp_comm_buffer_name=tp_comm_buffer_name,
tp_group=tp_group,
) )
def backward_dw(self): def backward_dw(self):
......
...@@ -288,7 +288,7 @@ class MoeAttnNode(TransformerLayerNode): ...@@ -288,7 +288,7 @@ class MoeAttnNode(TransformerLayerNode):
pre_mlp_layernorm_output, pre_mlp_layernorm_output,
tokens_per_expert, tokens_per_expert,
permutated_local_input_tokens, permutated_local_input_tokens,
probs, permuted_probs,
) = self.layer._submodule_attention_router_compound_forward( ) = self.layer._submodule_attention_router_compound_forward(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -304,11 +304,10 @@ class MoeAttnNode(TransformerLayerNode): ...@@ -304,11 +304,10 @@ class MoeAttnNode(TransformerLayerNode):
self.common_state.tokens_per_expert = tokens_per_expert self.common_state.tokens_per_expert = tokens_per_expert
# detached here # detached here
self.common_state.probs = self.detach(probs)
self.common_state.residual = self.detach(hidden_states) self.common_state.residual = self.detach(hidden_states)
self.common_state.pre_mlp_layernorm_output = self.detach(pre_mlp_layernorm_output) self.common_state.pre_mlp_layernorm_output = self.detach(pre_mlp_layernorm_output)
return permutated_local_input_tokens return permutated_local_input_tokens, permuted_probs
def dw(self): def dw(self):
with torch.cuda.nvtx.range(f"{self.name} wgrad"): with torch.cuda.nvtx.range(f"{self.name} wgrad"):
...@@ -317,26 +316,26 @@ class MoeAttnNode(TransformerLayerNode): ...@@ -317,26 +316,26 @@ class MoeAttnNode(TransformerLayerNode):
class MoeDispatchNode(TransformerLayerNode): class MoeDispatchNode(TransformerLayerNode):
def forward_impl(self, permutated_local_input_tokens): def forward_impl(self, permutated_local_input_tokens, permuted_probs):
token_dispatcher = self.layer.mlp.token_dispatcher token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state): with token_dispatcher.per_batch_state_context(self.common_state):
tokens_per_expert, global_input_tokens = token_dispatcher.dispatch_all_to_all( tokens_per_expert, global_input_tokens, global_probs = token_dispatcher.dispatch_all_to_all(
self.common_state.tokens_per_expert, permutated_local_input_tokens self.common_state.tokens_per_expert, permutated_local_input_tokens, permuted_probs
) )
# release tensor not used by backward # release tensor not used by backward
# inputs.untyped_storage().resize_(0) # inputs.untyped_storage().resize_(0)
self.common_state.tokens_per_expert = tokens_per_expert self.common_state.tokens_per_expert = tokens_per_expert
return global_input_tokens return global_input_tokens, global_probs
class MoeMlPNode(TransformerLayerNode): class MoeMlPNode(TransformerLayerNode):
def forward_impl(self, global_input_tokens): def forward_impl(self, global_input_tokens, global_probs):
pre_mlp_layernorm_output = self.common_state.pre_mlp_layernorm_output pre_mlp_layernorm_output = self.common_state.pre_mlp_layernorm_output
token_dispatcher = self.layer.mlp.token_dispatcher token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state): with token_dispatcher.per_batch_state_context(self.common_state):
expert_output, shared_expert_output, mlp_bias = self.layer._submodule_moe_forward( expert_output, shared_expert_output, mlp_bias = self.layer._submodule_moe_forward(
self.common_state.tokens_per_expert, global_input_tokens, pre_mlp_layernorm_output self.common_state.tokens_per_expert, global_input_tokens, global_probs, pre_mlp_layernorm_output
) )
assert mlp_bias is None assert mlp_bias is None
...@@ -363,9 +362,7 @@ class MoeCombineNode(TransformerLayerNode): ...@@ -363,9 +362,7 @@ class MoeCombineNode(TransformerLayerNode):
) )
cur_stream = torch.cuda.current_stream() cur_stream = torch.cuda.current_stream()
self.common_state.residual.record_stream(cur_stream) self.common_state.residual.record_stream(cur_stream)
self.common_state.probs.record_stream(cur_stream)
self.common_state.residual = None self.common_state.residual = None
self.common_state.probs = None
return output return output
......
...@@ -125,8 +125,9 @@ def gpt_model_forward( ...@@ -125,8 +125,9 @@ def gpt_model_forward(
and inference_context.is_static_batching() and inference_context.is_static_batching()
and not self.training and not self.training
): ):
current_batch_size = input_ids.shape[0]
sequence_len_offset = torch.tensor( sequence_len_offset = torch.tensor(
[inference_context.sequence_len_offset] * inference_context.current_batch_size, [inference_context.sequence_len_offset] * current_batch_size,
dtype=torch.int32, dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
) )
...@@ -156,12 +157,6 @@ def gpt_model_forward( ...@@ -156,12 +157,6 @@ def gpt_model_forward(
**(extra_block_kwargs or {}), **(extra_block_kwargs or {}),
) )
# Process inference output.
if inference_context and not inference_context.is_static_batching():
hidden_states = inference_context.last_token_logits(
hidden_states.squeeze(1).unsqueeze(0)
).unsqueeze(1)
# logits and loss # logits and loss
output_weight = None output_weight = None
if self.share_embeddings_and_output_weights: if self.share_embeddings_and_output_weights:
...@@ -202,10 +197,17 @@ def gpt_model_forward( ...@@ -202,10 +197,17 @@ def gpt_model_forward(
if ( if (
not self.training not self.training
and inference_context is not None and inference_context is not None
and inference_context.is_static_batching()
and inference_context.materialize_only_last_token_logits and inference_context.materialize_only_last_token_logits
): ):
hidden_states = hidden_states[-1:, :, :] if inference_context.is_static_batching():
hidden_states = hidden_states[-1:, :, :]
else:
# Reshape [B, 1, H] to [1, B, H] → extract each sample’s true last‐token hidden
# state ([B, H]) → unsqueeze back to [1, B, H]
# (so that the output layer, which expects S×B×H, receives only the final token)
hidden_states = inference_context.last_token_logits(
hidden_states.squeeze(1).unsqueeze(0)
).unsqueeze(1)
logits, _ = self.output_layer( logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
) )
......
...@@ -12,8 +12,11 @@ from megatron.core.transformer.moe.moe_utils import ( ...@@ -12,8 +12,11 @@ from megatron.core.transformer.moe.moe_utils import (
permute, permute,
sort_chunks_by_idxs, sort_chunks_by_idxs,
unpermute, unpermute,
pad_routing_map,
) )
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher as MegatronCoreMoEAlltoAllTokenDispatcher from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher as MegatronCoreMoEAlltoAllTokenDispatcher
from megatron.core.fp8_utils import get_fp8_align_size
from megatron.core.fusions.fused_pad_routing_map import fused_pad_routing_map
from dcu_megatron.core.tensor_parallel import all_to_all from dcu_megatron.core.tensor_parallel import all_to_all
...@@ -101,6 +104,12 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -101,6 +104,12 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
assert routing_map.dim() == 2, "Expected 2D tensor for token2expert mask" assert routing_map.dim() == 2, "Expected 2D tensor for token2expert mask"
assert routing_map.dtype == torch.bool, "Expected bool tensor for mask" assert routing_map.dtype == torch.bool, "Expected bool tensor for mask"
if self.config.moe_router_padding_for_fp8:
pad_multiple = get_fp8_align_size(self.config.fp8_recipe)
if experimental_config.ENABLE_EXPERIMENTAL and self.config.moe_permute_fusion:
self.routing_map = fused_pad_routing_map(self.routing_map, pad_multiple)
else:
self.routing_map = pad_routing_map(self.routing_map, pad_multiple)
tokens_per_expert = self.preprocess(self.routing_map) tokens_per_expert = self.preprocess(self.routing_map)
return tokens_per_expert return tokens_per_expert
...@@ -117,18 +126,20 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -117,18 +126,20 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
self.hidden_shape_before_permute = hidden_states.shape self.hidden_shape_before_permute = hidden_states.shape
( (
permutated_local_input_tokens, permutated_local_input_tokens,
permuted_probs,
self.reversed_local_input_permutation_mapping, self.reversed_local_input_permutation_mapping,
) = permute( ) = permute(
hidden_states, hidden_states,
routing_map, routing_map,
self.probs,
num_out_tokens=self.num_out_tokens, num_out_tokens=self.num_out_tokens,
fused=self.config.moe_permute_fusion, fused=self.config.moe_permute_fusion,
drop_and_pad=self.drop_and_pad, drop_and_pad=self.drop_and_pad,
) )
return tokens_per_expert, permutated_local_input_tokens return tokens_per_expert, permutated_local_input_tokens, permuted_probs
def dispatch_all_to_all(self, tokens_per_expert, permutated_local_input_tokens): def dispatch_all_to_all(self, tokens_per_expert, permutated_local_input_tokens, permuted_probs):
# Perform expert parallel AlltoAll communication # Perform expert parallel AlltoAll communication
tokens_per_expert = self._maybe_dtoh_and_synchronize( tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_ep_alltoall", tokens_per_expert "before_ep_alltoall", tokens_per_expert
...@@ -136,10 +147,13 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -136,10 +147,13 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
global_input_tokens = all_to_all( global_input_tokens = all_to_all(
self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits, use_qcomm=self.use_qcomm self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits, use_qcomm=self.use_qcomm
) )
global_probs = all_to_all(
self.ep_group, permuted_probs, self.output_splits, self.input_splits, use_qcomm=self.use_qcomm
)
return tokens_per_expert, global_input_tokens return tokens_per_expert, global_input_tokens, global_probs
def dispatch_postprocess(self, tokens_per_expert, global_input_tokens): def dispatch_postprocess(self, tokens_per_expert, global_input_tokens, global_probs):
if self.shared_experts is not None: if self.shared_experts is not None:
self.shared_experts.linear_fc1_forward_and_act(global_input_tokens) self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)
...@@ -152,6 +166,9 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -152,6 +166,9 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
global_input_tokens = gather_from_sequence_parallel_region( global_input_tokens = gather_from_sequence_parallel_region(
global_input_tokens, group=self.tp_group, output_split_sizes=output_split_sizes global_input_tokens, group=self.tp_group, output_split_sizes=output_split_sizes
) )
global_probs = gather_from_sequence_parallel_region(
global_probs, group=self.tp_group, output_split_sizes=output_split_sizes
)
# Permutation 2: Sort tokens by local expert. # Permutation 2: Sort tokens by local expert.
tokens_per_expert = self._maybe_dtoh_and_synchronize( tokens_per_expert = self._maybe_dtoh_and_synchronize(
...@@ -170,16 +187,28 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -170,16 +187,28 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
.contiguous() .contiguous()
.flatten(start_dim=0, end_dim=2) .flatten(start_dim=0, end_dim=2)
) )
global_probs = (
global_probs.view(
self.tp_size * self.ep_size,
self.num_local_experts,
self.capacity,
*global_probs.size()[1:],
)
.transpose(0, 1)
.contiguous()
.flatten(start_dim=0, end_dim=2)
)
else: else:
global_input_tokens = sort_chunks_by_idxs( global_input_tokens, global_probs = sort_chunks_by_idxs(
global_input_tokens, global_input_tokens,
self.num_global_tokens_per_local_expert.ravel(), self.num_global_tokens_per_local_expert.ravel(),
self.sort_input_by_local_experts, self.sort_input_by_local_experts,
probs=global_probs,
fused=self.config.moe_permute_fusion, fused=self.config.moe_permute_fusion,
) )
tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert) tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert)
return global_input_tokens, tokens_per_expert return global_input_tokens, tokens_per_expert, global_probs
def token_permutation( def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
...@@ -207,15 +236,15 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -207,15 +236,15 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
# Preprocess: Get the metadata for communication, permutation and computation operations. # Preprocess: Get the metadata for communication, permutation and computation operations.
# Permutation 1: input to AlltoAll input # Permutation 1: input to AlltoAll input
tokens_per_expert = self.meta_prepare(hidden_states, probs, routing_map) tokens_per_expert = self.meta_prepare(hidden_states, probs, routing_map)
tokens_per_expert, permutated_local_input_tokens = self.dispatch_preprocess(hidden_states, routing_map, tokens_per_expert) tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.dispatch_preprocess(hidden_states, routing_map, tokens_per_expert)
# Perform expert parallel AlltoAll communication # Perform expert parallel AlltoAll communication
tokens_per_expert, global_input_tokens = self.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens) tokens_per_expert, global_input_tokens, global_probs = self.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens, permuted_probs)
# Permutation 2: Sort tokens by local expert. # Permutation 2: Sort tokens by local expert.
global_input_tokens, tokens_per_expert = self.dispatch_postprocess(tokens_per_expert, global_input_tokens) global_input_tokens, tokens_per_expert, global_probs = self.dispatch_postprocess(tokens_per_expert, global_input_tokens, global_probs)
return global_input_tokens, tokens_per_expert return global_input_tokens, tokens_per_expert, global_probs
def combine_preprocess(self, hidden_states): def combine_preprocess(self, hidden_states):
# Unpermutation 2: Unsort tokens by local expert. # Unpermutation 2: Unsort tokens by local expert.
...@@ -272,7 +301,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -272,7 +301,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
permutated_local_input_tokens, permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping, self.reversed_local_input_permutation_mapping,
restore_shape=self.hidden_shape_before_permute, restore_shape=self.hidden_shape_before_permute,
probs=self.probs,
routing_map=self.routing_map, routing_map=self.routing_map,
fused=self.config.moe_permute_fusion, fused=self.config.moe_permute_fusion,
drop_and_pad=self.drop_and_pad, drop_and_pad=self.drop_and_pad,
......
...@@ -8,6 +8,8 @@ from megatron.core.packed_seq_params import PackedSeqParams ...@@ -8,6 +8,8 @@ from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import ( from megatron.core.utils import (
deprecate_inference_params, deprecate_inference_params,
make_viewless_tensor, make_viewless_tensor,
nvtx_range_pop,
nvtx_range_push,
) )
from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer
...@@ -15,7 +17,7 @@ from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispa ...@@ -15,7 +17,7 @@ from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispa
from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_config import TransformerConfig
def get_transformer_layer_offset(config: TransformerConfig): def get_transformer_layer_offset(config: TransformerConfig, vp_stage: Optional[int] = None):
"""Get the index offset of current pipeline stage, given the level of pipelining.""" """Get the index offset of current pipeline stage, given the level of pipelining."""
args = get_args() args = get_args()
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
...@@ -67,9 +69,10 @@ def get_transformer_layer_offset(config: TransformerConfig): ...@@ -67,9 +69,10 @@ def get_transformer_layer_offset(config: TransformerConfig):
- num_layers_in_last_pipeline_stage - num_layers_in_last_pipeline_stage
) )
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: if (vp_size := config.virtual_pipeline_model_parallel_size) is not None:
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() assert (
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() vp_stage is not None
), "vp_stage must be provided if virtual pipeline model parallel size is set"
# Calculate number of layers in each virtual model chunk # Calculate number of layers in each virtual model chunk
# If the num_layers_in_first_pipeline_stage and # If the num_layers_in_first_pipeline_stage and
...@@ -100,10 +103,10 @@ def get_transformer_layer_offset(config: TransformerConfig): ...@@ -100,10 +103,10 @@ def get_transformer_layer_offset(config: TransformerConfig):
# Calculate the layer offset with interleaved uneven pipeline parallelism # Calculate the layer offset with interleaved uneven pipeline parallelism
if pipeline_rank == 0: if pipeline_rank == 0:
offset = vp_rank * total_virtual_chunks offset = vp_stage * total_virtual_chunks
else: else:
offset = ( offset = (
vp_rank * total_virtual_chunks vp_stage * total_virtual_chunks
+ num_layers_per_virtual_model_chunk_in_first_pipeline_stage + num_layers_per_virtual_model_chunk_in_first_pipeline_stage
+ (pipeline_rank - 1) + (pipeline_rank - 1)
* ( * (
...@@ -151,20 +154,23 @@ def get_transformer_layer_offset(config: TransformerConfig): ...@@ -151,20 +154,23 @@ def get_transformer_layer_offset(config: TransformerConfig):
if args.schedule_method == 'dualpipev': if args.schedule_method == 'dualpipev':
num_layers_per_pipeline_rank = num_layers_per_pipeline_rank // 2 num_layers_per_pipeline_rank = num_layers_per_pipeline_rank // 2
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: if (vp_size := config.virtual_pipeline_model_parallel_size) is not None:
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() assert (
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() vp_stage is not None
), "vp_stage must be provided if virtual pipeline model parallel size is set"
num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
total_virtual_chunks = num_layers // vp_size total_virtual_chunks = num_layers // vp_size
offset = vp_rank * total_virtual_chunks + ( offset = vp_stage * total_virtual_chunks + (
pipeline_rank * num_layers_per_virtual_rank pipeline_rank * num_layers_per_virtual_rank
) )
# Reduce the offset of embedding layer from the total layer number # Reduce the offset of embedding layer from the total layer number
if ( if (
config.account_for_embedding_in_pipeline_split config.account_for_embedding_in_pipeline_split
and not parallel_state.is_pipeline_first_stage() and not parallel_state.is_pipeline_first_stage(
ignore_virtual=False, vp_stage=vp_stage
)
): ):
offset -= 1 offset -= 1
else: else:
...@@ -176,7 +182,9 @@ def get_transformer_layer_offset(config: TransformerConfig): ...@@ -176,7 +182,9 @@ def get_transformer_layer_offset(config: TransformerConfig):
# Reduce the offset of embedding layer from the total layer number # Reduce the offset of embedding layer from the total layer number
if ( if (
config.account_for_embedding_in_pipeline_split config.account_for_embedding_in_pipeline_split
and not parallel_state.is_pipeline_first_stage() and not parallel_state.is_pipeline_first_stage(
ignore_virtual=False, vp_stage=vp_stage
)
): ):
offset -= 1 offset -= 1
else: else:
...@@ -188,9 +196,9 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -188,9 +196,9 @@ class TransformerLayer(MegatronCoreTransformerLayer):
def forward( def forward(
self, self,
hidden_states: Tensor, hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
context: Optional[Tensor] = None, context: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None, context_mask: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None, rotary_pos_emb: Optional[Tensor] = None,
rotary_pos_cos: Optional[Tensor] = None, rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None, rotary_pos_sin: Optional[Tensor] = None,
...@@ -208,9 +216,9 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -208,9 +216,9 @@ class TransformerLayer(MegatronCoreTransformerLayer):
): ):
return super().forward( return super().forward(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask,
context=context, context=context,
context_mask=context_mask, context_mask=context_mask,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos, rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin, rotary_pos_sin=rotary_pos_sin,
...@@ -226,7 +234,7 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -226,7 +234,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output, pre_mlp_layernorm_output,
tokens_per_expert, tokens_per_expert,
permutated_local_input_tokens, permutated_local_input_tokens,
_, permuted_probs,
) = self._submodule_attention_router_compound_forward( ) = self._submodule_attention_router_compound_forward(
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -240,14 +248,16 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -240,14 +248,16 @@ class TransformerLayer(MegatronCoreTransformerLayer):
inference_params=inference_params, inference_params=inference_params,
) )
(tokens_per_expert, global_input_tokens) = self._submodule_dispatch_forward( (tokens_per_expert, global_input_tokens, global_probs) = self._submodule_dispatch_forward(
tokens_per_expert, tokens_per_expert,
permutated_local_input_tokens, permutated_local_input_tokens,
permuted_probs,
) )
(expert_output, shared_expert_output, mlp_bias) = self._submodule_moe_forward( (expert_output, shared_expert_output, mlp_bias) = self._submodule_moe_forward(
tokens_per_expert, tokens_per_expert,
global_input_tokens, global_input_tokens,
global_probs,
pre_mlp_layernorm_output pre_mlp_layernorm_output
) )
...@@ -292,6 +302,7 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -292,6 +302,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
input_layernorm_output = self.input_layernorm(hidden_states) input_layernorm_output = self.input_layernorm(hidden_states)
# Self attention. # Self attention.
nvtx_range_push(suffix="self_attention")
attention_output_with_bias = self.self_attention( attention_output_with_bias = self.self_attention(
input_layernorm_output, input_layernorm_output,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -303,6 +314,7 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -303,6 +314,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
packed_seq_params=packed_seq_params, packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset, sequence_len_offset=sequence_len_offset,
) )
nvtx_range_pop(suffix="self_attention")
if self.recompute_input_layernorm: if self.recompute_input_layernorm:
# discard the output of the input layernorm and register the recompute # discard the output of the input layernorm and register the recompute
...@@ -313,10 +325,12 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -313,10 +325,12 @@ class TransformerLayer(MegatronCoreTransformerLayer):
# TODO: could we move `bias_dropout_add_exec_handler` itself # TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module? # inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push(suffix="self_attn_bda")
with self.bias_dropout_add_exec_handler(): with self.bias_dropout_add_exec_handler():
hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)( hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout attention_output_with_bias, residual, self.hidden_dropout
) )
nvtx_range_pop(suffix="self_attn_bda")
return hidden_states return hidden_states
...@@ -363,7 +377,7 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -363,7 +377,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
tokens_per_expert = self.mlp.token_dispatcher.meta_prepare( tokens_per_expert = self.mlp.token_dispatcher.meta_prepare(
pre_mlp_layernorm_output, probs, routing_map pre_mlp_layernorm_output, probs, routing_map
) )
tokens_per_expert, permutated_local_input_tokens = self.mlp.token_dispatcher.dispatch_preprocess( tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.mlp.token_dispatcher.dispatch_preprocess(
pre_mlp_layernorm_output, routing_map, tokens_per_expert pre_mlp_layernorm_output, routing_map, tokens_per_expert
) )
...@@ -372,18 +386,18 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -372,18 +386,18 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output, pre_mlp_layernorm_output,
tokens_per_expert, tokens_per_expert,
permutated_local_input_tokens, permutated_local_input_tokens,
probs, permuted_probs,
] ]
return tuple(outputs) return tuple(outputs)
def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens): def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens, permuted_probs):
""" """
Dispatches tokens to the appropriate experts based on the router output. Dispatches tokens to the appropriate experts based on the router output.
""" """
tokens_per_expert, global_input_tokens = self.mlp.token_dispatcher.dispatch_all_to_all( tokens_per_expert, global_input_tokens, global_probs = self.mlp.token_dispatcher.dispatch_all_to_all(
tokens_per_expert, permutated_local_input_tokens tokens_per_expert, permutated_local_input_tokens, permuted_probs
) )
return [tokens_per_expert, global_input_tokens] return [tokens_per_expert, global_input_tokens, global_probs]
def _submodule_dense_forward(self, hidden_states): def _submodule_dense_forward(self, hidden_states):
residual = hidden_states residual = hidden_states
...@@ -399,18 +413,20 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -399,18 +413,20 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return output return output
def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, pre_mlp_layernorm_output): def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, global_probs, pre_mlp_layernorm_output):
""" """
Performs a forward pass for the MLP submodule, including both expert-based Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations. and optional shared-expert computations.
""" """
shared_expert_output = None shared_expert_output = None
(dispatched_input, tokens_per_expert) = ( (dispatched_input, tokens_per_expert, permuted_probs) = (
self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens) self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens, global_probs)
) )
expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert) expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert, permuted_probs)
expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output) expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output)
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap: if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
shared_expert_output = self.mlp.shared_experts(pre_mlp_layernorm_output) shared_expert_output = self.mlp.shared_experts(pre_mlp_layernorm_output)
return expert_output, shared_expert_output, mlp_bias return expert_output, shared_expert_output, mlp_bias
...@@ -438,10 +454,19 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -438,10 +454,19 @@ class TransformerLayer(MegatronCoreTransformerLayer):
# TODO: could we move `bias_dropout_add_exec_handler` itself # TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module? # inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push(suffix="mlp_bda")
with self.bias_dropout_add_exec_handler(): with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)( hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout mlp_output_with_bias, residual, self.hidden_dropout
) )
nvtx_range_pop(suffix="mlp_bda")
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor( output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
) )
......
...@@ -6,6 +6,7 @@ from functools import wraps ...@@ -6,6 +6,7 @@ from functools import wraps
from megatron.training import get_args from megatron.training import get_args
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
from megatron.legacy.model.enums import AttnType from megatron.legacy.model.enums import AttnType
from megatron.core.utils import deprecate_inference_params
from megatron.core.models.common.embeddings import apply_rotary_pos_emb from megatron.core.models.common.embeddings import apply_rotary_pos_emb
from megatron.legacy.model.module import MegatronModule from megatron.legacy.model.module import MegatronModule
...@@ -86,26 +87,21 @@ def parallel_attention_init_wrapper(fn): ...@@ -86,26 +87,21 @@ def parallel_attention_init_wrapper(fn):
return wrapper return wrapper
class ParallelAttentionPatch(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None, encoder_output=None, inference_context=None,
rotary_pos_emb=None): rotary_pos_emb=None, *, inference_params=None):
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
inference_context = deprecate_inference_params(inference_context, inference_params)
# ================================================= # =================================================
# Pre-allocate memory for key-values for inference. # Pre-allocate memory for key-values for inference.
# ================================================= # =================================================
is_first_step = False is_first_step = False
if inference_params: if inference_context:
if self.layer_number not in inference_params.key_value_memory_dict: if self.layer_number not in inference_context.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_length inf_max_seq_len = inference_context.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size inf_max_batch_size = inference_context.max_batch_size
inference_key_memory = self._allocate_memory( inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size, inf_max_seq_len, inf_max_batch_size,
self.num_query_groups_per_partition) self.num_query_groups_per_partition)
...@@ -113,12 +109,12 @@ class ParallelAttentionPatch(MegatronModule): ...@@ -113,12 +109,12 @@ class ParallelAttentionPatch(MegatronModule):
inf_max_seq_len, inf_max_batch_size, inf_max_seq_len, inf_max_batch_size,
self.num_query_groups_per_partition) self.num_query_groups_per_partition)
inference_params.key_value_memory_dict[self.layer_number] = ( inference_context.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory) inference_key_memory, inference_value_memory)
is_first_step = True is_first_step = True
else: else:
inference_key_memory, inference_value_memory = \ inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number] inference_context.key_value_memory_dict[self.layer_number]
# ===================== # =====================
# Query, Key, and Value # Query, Key, and Value
...@@ -188,13 +184,14 @@ class ParallelAttentionPatch(MegatronModule): ...@@ -188,13 +184,14 @@ class ParallelAttentionPatch(MegatronModule):
else: else:
rotary_pos_emb = ((rotary_pos_emb,) * 2) rotary_pos_emb = ((rotary_pos_emb,) * 2)
if inference_params: if inference_context:
batch_start = inference_params.batch_size_offset batch_start = inference_context.batch_size_offset
batch_end = batch_start + key_layer.size(1) batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1) assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset sequence_start = inference_context.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0) sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0) assert sequence_end <= inference_key_memory.size(0), ("Current sequence length is "
"longer than expected maximum sequence length! Increase inference_max_seq_length.")
# Copy key and values. # Copy key and values.
inference_key_memory[sequence_start:sequence_end, inference_key_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = key_layer batch_start:batch_end, ...] = key_layer
......
...@@ -3,6 +3,7 @@ import argparse ...@@ -3,6 +3,7 @@ import argparse
from typing import Union from typing import Union
from megatron.training.arguments import add_megatron_arguments from megatron.training.arguments import add_megatron_arguments
from megatron.core.msc_utils import MultiStorageClientFeature
from dcu_megatron.adaptor.features_manager import ADAPTOR_FEATURES from dcu_megatron.adaptor.features_manager import ADAPTOR_FEATURES
...@@ -65,6 +66,12 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): ...@@ -65,6 +66,12 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
# args.rank = int(os.getenv('RANK', '0')) # args.rank = int(os.getenv('RANK', '0'))
# args.world_size = int(os.getenv("WORLD_SIZE", '1')) # args.world_size = int(os.getenv("WORLD_SIZE", '1'))
# Args to disable MSC
if not args.enable_msc:
MultiStorageClientFeature.disable()
assert MultiStorageClientFeature.is_enabled() is False
print('WARNING: The MSC feature is disabled.')
return args return args
......
...@@ -8,6 +8,7 @@ from datetime import timedelta ...@@ -8,6 +8,7 @@ from datetime import timedelta
from megatron.training import get_args from megatron.training import get_args
from megatron.core import mpu, tensor_parallel from megatron.core import mpu, tensor_parallel
from megatron.training import inprocess_restart
def _compile_dependencies(): def _compile_dependencies():
...@@ -76,7 +77,7 @@ def _compile_dependencies(): ...@@ -76,7 +77,7 @@ def _compile_dependencies():
) )
def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, store):
"""Initialize torch.distributed and core model parallel.""" """Initialize torch.distributed and core model parallel."""
args = get_args() args = get_args()
...@@ -109,6 +110,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): ...@@ -109,6 +110,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
# Call the init process # Call the init process
init_process_group_kwargs = { init_process_group_kwargs = {
'backend': args.distributed_backend, 'backend': args.distributed_backend,
'store': store,
'world_size': args.world_size, 'world_size': args.world_size,
'rank': args.rank, 'rank': args.rank,
'init_method': args.dist_url, 'init_method': args.dist_url,
...@@ -116,6 +118,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): ...@@ -116,6 +118,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
} }
torch.distributed.init_process_group(**init_process_group_kwargs) torch.distributed.init_process_group(**init_process_group_kwargs)
inprocess_restart.maybe_force_nccl_backend_init(device_id)
# Set the tensor model-parallel, pipeline model-parallel, and # Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators. # data-parallel communicators.
...@@ -129,6 +132,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): ...@@ -129,6 +132,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
args.virtual_pipeline_model_parallel_size, args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank, args.pipeline_model_parallel_split_rank,
pipeline_model_parallel_comm_backend=args.pipeline_model_parallel_comm_backend, pipeline_model_parallel_comm_backend=args.pipeline_model_parallel_comm_backend,
use_sharp=args.use_sharp,
context_parallel_size=args.context_parallel_size, context_parallel_size=args.context_parallel_size,
hierarchical_context_parallel_sizes=args.hierarchical_context_parallel_sizes, hierarchical_context_parallel_sizes=args.hierarchical_context_parallel_sizes,
expert_model_parallel_size=args.expert_model_parallel_size, expert_model_parallel_size=args.expert_model_parallel_size,
...@@ -142,6 +146,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): ...@@ -142,6 +146,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
get_embedding_ranks=get_embedding_ranks, get_embedding_ranks=get_embedding_ranks,
get_position_embedding_ranks=get_position_embedding_ranks, get_position_embedding_ranks=get_position_embedding_ranks,
create_gloo_process_groups=args.enable_gloo_process_groups, create_gloo_process_groups=args.enable_gloo_process_groups,
high_priority_stream_groups=args.high_priority_stream_groups,
) )
if args.rank == 0: if args.rank == 0:
print( print(
......
...@@ -19,7 +19,11 @@ def get_batch_on_this_tp_rank(data_iterator): ...@@ -19,7 +19,11 @@ def get_batch_on_this_tp_rank(data_iterator):
def _broadcast(item): def _broadcast(item):
if item is not None: if item is not None:
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) torch.distributed.broadcast(
item,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group(),
)
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
...@@ -29,11 +33,15 @@ def get_batch_on_this_tp_rank(data_iterator): ...@@ -29,11 +33,15 @@ def get_batch_on_this_tp_rank(data_iterator):
data = None data = None
batch = { batch = {
'tokens': data["tokens"].cuda(non_blocking = True), 'tokens': data["tokens"].cuda(non_blocking=True),
'labels': data["labels"].cuda(non_blocking = True), 'labels': data["labels"].cuda(non_blocking=True),
'loss_mask': data["loss_mask"].cuda(non_blocking = True), 'loss_mask': data["loss_mask"].cuda(non_blocking=True),
'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True), 'attention_mask': (
'position_ids': data["position_ids"].cuda(non_blocking = True) None
if "attention_mask" not in data
else data["attention_mask"].cuda(non_blocking=True)
),
'position_ids': data["position_ids"].cuda(non_blocking=True),
} }
if args.pipeline_model_parallel_size == 1: if args.pipeline_model_parallel_size == 1:
...@@ -64,16 +72,34 @@ def get_batch_on_this_tp_rank(data_iterator): ...@@ -64,16 +72,34 @@ def get_batch_on_this_tp_rank(data_iterator):
else: else:
tokens=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device()) tokens = torch.empty(
labels=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device()) (args.micro_batch_size, args.seq_length),
loss_mask=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.float32 , device = torch.cuda.current_device()) dtype=torch.int64,
device=torch.cuda.current_device(),
)
labels = torch.empty(
(args.micro_batch_size, args.seq_length),
dtype=torch.int64,
device=torch.cuda.current_device(),
)
loss_mask = torch.empty(
(args.micro_batch_size, args.seq_length),
dtype=torch.float32,
device=torch.cuda.current_device(),
)
if args.create_attention_mask_in_dataloader: if args.create_attention_mask_in_dataloader:
attention_mask=torch.empty( attention_mask = torch.empty(
(args.micro_batch_size,1,args.seq_length,args.seq_length), dtype = torch.bool , device = torch.cuda.current_device() (args.micro_batch_size, 1, args.seq_length, args.seq_length),
dtype=torch.bool,
device=torch.cuda.current_device(),
) )
else: else:
attention_mask=None attention_mask = None
position_ids=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device()) position_ids = torch.empty(
(args.micro_batch_size, args.seq_length),
dtype=torch.int64,
device=torch.cuda.current_device(),
)
if args.pipeline_model_parallel_size == 1: if args.pipeline_model_parallel_size == 1:
_broadcast(tokens) _broadcast(tokens)
...@@ -117,4 +143,4 @@ def get_batch_on_this_tp_rank(data_iterator): ...@@ -117,4 +143,4 @@ def get_batch_on_this_tp_rank(data_iterator):
'position_ids': position_ids 'position_ids': position_ids
} }
return batch return batch
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment