Commit ab2a8334 authored by dongcl's avatar dongcl
Browse files

modify MoEAlltoAllPerBatchState, add tokens_per_expert attr

parent 5890bb4c
...@@ -6,6 +6,7 @@ from typing import Optional ...@@ -6,6 +6,7 @@ from typing import Optional
import torch import torch
from torch import Tensor from torch import Tensor
from megatron.core import parallel_state
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.inference.contexts import BaseInferenceContext
...@@ -19,6 +20,7 @@ from dcu_megatron.core.pipeline_parallel.combined_1f1b import ( ...@@ -19,6 +20,7 @@ from dcu_megatron.core.pipeline_parallel.combined_1f1b import (
AbstractSchedulePlan, AbstractSchedulePlan,
ScheduleNode, ScheduleNode,
get_com_stream, get_com_stream,
get_comp_stream,
make_viewless, make_viewless,
) )
...@@ -620,7 +622,6 @@ def schedule_chunk_1f1b( ...@@ -620,7 +622,6 @@ def schedule_chunk_1f1b(
f_context = f_context if f_context is not None else contextlib.nullcontext() f_context = f_context if f_context is not None else contextlib.nullcontext()
b_context = b_context if b_context is not None else contextlib.nullcontext() b_context = b_context if b_context is not None else contextlib.nullcontext()
if f_schedule_plan: if f_schedule_plan:
# pp output send/receive sync # pp output send/receive sync
if pre_forward is not None: if pre_forward is not None:
...@@ -709,7 +710,7 @@ def schedule_chunk_1f1b( ...@@ -709,7 +710,7 @@ def schedule_chunk_1f1b(
if f_schedule_plan is not None and post_forward is not None: if f_schedule_plan is not None and post_forward is not None:
with f_context: with f_context:
f_schedule_plan.wait_current_stream() f_schedule_plan.wait_current_stream()
post_forward(f_input) post_forward(None if parallel_state.is_pipeline_last_stage(ignore_virtual=False) else f_input)
# pp grad send / receive, overlapped with attn dw of cur micro-batch and forward attn of next micro-batch # pp grad send / receive, overlapped with attn dw of cur micro-batch and forward attn of next micro-batch
if b_schedule_plan is not None and post_backward is not None: if b_schedule_plan is not None and post_backward is not None:
...@@ -744,7 +745,7 @@ def build_model_chunk_schedule_plan( ...@@ -744,7 +745,7 @@ def build_model_chunk_schedule_plan(
loss_mask: Optional[Tensor] = None loss_mask: Optional[Tensor] = None
): ):
comp_stream = torch.cuda.current_stream() comp_stream = get_comp_stream()
com_stream = get_com_stream() com_stream = get_com_stream()
model_chunk_schedule_plan = ModelChunkSchedulePlan() model_chunk_schedule_plan = ModelChunkSchedulePlan()
event = model_chunk_schedule_plan.event event = model_chunk_schedule_plan.event
......
...@@ -472,7 +472,10 @@ def forward_backward_step( ...@@ -472,7 +472,10 @@ def forward_backward_step(
else torch.tensor(1.0) else torch.tensor(1.0)
) )
# Set the loss scale # Set the loss scale
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches) if config.calculate_per_token_loss:
MoEAuxLossAutoScaler.set_loss_scale(loss_scale)
else:
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
if not unwrap_output_tensor: if not unwrap_output_tensor:
output_tensor, num_tokens = [output_tensor], num_tokens output_tensor, num_tokens = [output_tensor], num_tokens
......
...@@ -25,13 +25,13 @@ class MoEAlltoAllPerBatchState: ...@@ -25,13 +25,13 @@ class MoEAlltoAllPerBatchState:
self.input_splits = None self.input_splits = None
self.num_out_tokens = None self.num_out_tokens = None
self.capacity = None self.capacity = None
self.preprocess_event = None
self.hidden_shape = None self.hidden_shape = None
self.probs = None self.probs = None
self.routing_map = None self.routing_map = None
self.reversed_local_input_permutation_mapping = None self.reversed_local_input_permutation_mapping = None
self.cuda_sync_point = None self.cuda_sync_point = None
self.hidden_shape_before_permute = None self.hidden_shape_before_permute = None
self.tokens_per_expert = None
class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...@@ -44,7 +44,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -44,7 +44,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
state.input_splits = getattr(self, "input_splits", None) state.input_splits = getattr(self, "input_splits", None)
state.num_out_tokens = getattr(self, "num_out_tokens", None) state.num_out_tokens = getattr(self, "num_out_tokens", None)
state.capacity = getattr(self, "capacity", None) state.capacity = getattr(self, "capacity", None)
state.preprocess_event = getattr(self, "preprocess_event", None)
state.hidden_shape = getattr(self, "hidden_shape", None) state.hidden_shape = getattr(self, "hidden_shape", None)
state.probs = getattr(self, "probs", None) state.probs = getattr(self, "probs", None)
state.routing_map = getattr(self, "routing_map", None) state.routing_map = getattr(self, "routing_map", None)
...@@ -53,6 +52,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -53,6 +52,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
) )
state.hidden_shape_before_permute = getattr(self, "hidden_shape_before_permute", None) state.hidden_shape_before_permute = getattr(self, "hidden_shape_before_permute", None)
state.cuda_sync_point = getattr(self, "cuda_sync_point", None) state.cuda_sync_point = getattr(self, "cuda_sync_point", None)
state.tokens_per_expert = getattr(self, "tokens_per_expert", None)
def apply_per_batch_state(self, state: MoEAlltoAllPerBatchState): def apply_per_batch_state(self, state: MoEAlltoAllPerBatchState):
self.num_global_tokens_per_local_expert = state.num_global_tokens_per_local_expert self.num_global_tokens_per_local_expert = state.num_global_tokens_per_local_expert
...@@ -61,7 +61,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -61,7 +61,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
self.input_splits = state.input_splits self.input_splits = state.input_splits
self.num_out_tokens = state.num_out_tokens self.num_out_tokens = state.num_out_tokens
self.capacity = state.capacity self.capacity = state.capacity
self.preprocess_event = state.preprocess_event
self.hidden_shape = state.hidden_shape self.hidden_shape = state.hidden_shape
self.probs = state.probs self.probs = state.probs
self.routing_map = state.routing_map self.routing_map = state.routing_map
...@@ -70,6 +69,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -70,6 +69,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
) )
self.hidden_shape_before_permute = state.hidden_shape_before_permute self.hidden_shape_before_permute = state.hidden_shape_before_permute
self.cuda_sync_point = state.cuda_sync_point self.cuda_sync_point = state.cuda_sync_point
self.tokens_per_expert = state.tokens_per_expert
@contextmanager @contextmanager
def per_batch_state_context(self, state: MoEAlltoAllPerBatchState): def per_batch_state_context(self, state: MoEAlltoAllPerBatchState):
...@@ -144,6 +144,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -144,6 +144,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
output_split_sizes = None output_split_sizes = None
else: else:
output_split_sizes = self.output_splits_tp.tolist() output_split_sizes = self.output_splits_tp.tolist()
global_input_tokens = gather_from_sequence_parallel_region( global_input_tokens = gather_from_sequence_parallel_region(
global_input_tokens, group=self.tp_group, output_split_sizes=output_split_sizes global_input_tokens, group=self.tp_group, output_split_sizes=output_split_sizes
) )
......
...@@ -182,7 +182,7 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -182,7 +182,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return output return output
def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, global_prob, hidden_states): def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, global_prob, pre_mlp_layernorm_output):
""" """
Performs a forward pass for the MLP submodule, including both expert-based Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations. and optional shared-expert computations.
...@@ -194,7 +194,7 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -194,7 +194,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert, permuted_probs) expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert, permuted_probs)
expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output) expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output)
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap: if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
shared_expert_output = self.mlp.shared_experts(hidden_states) shared_expert_output = self.mlp.shared_experts(pre_mlp_layernorm_output)
return expert_output, shared_expert_output, mlp_bias return expert_output, shared_expert_output, mlp_bias
def _submodule_combine_forward(self, hidden_states): def _submodule_combine_forward(self, hidden_states):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment