Commit 1d497357 authored by dongcl's avatar dongcl
Browse files

split routed experts and shared experts

parent 2ceeaafd
......@@ -74,3 +74,9 @@ def a2a_overlap_adaptation(patches_manager):
patches_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.backward_dw',
MoELayer.backward_dw,
create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.backward_routed_expert_dw',
MoELayer.backward_routed_expert_dw,
create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.backward_shared_expert_dw',
MoELayer.backward_shared_expert_dw,
create_dummy=True)
......@@ -356,10 +356,54 @@ class MoeMlPNode(TransformerLayerNode):
self.layer._submodule_mlp_dw()
class MoeSharedExpertNode(TransformerLayerNode):
def forward_impl(self):
if self.layer.mlp.use_shared_expert:
pre_mlp_layernorm_output = self.common_state.pre_mlp_layernorm_output
else:
pre_mlp_layernorm_output = None
token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state):
shared_expert_output = self.layer._submodule_shared_expert_forward(
pre_mlp_layernorm_output
)
# pre_mlp_layernorm_output used
self.common_state.pre_mlp_layernorm_output = None
return shared_expert_output
# self.common_state.shared_expert_output = self.detach(shared_expert_output)
def dw(self):
with torch.cuda.nvtx.range(f"{self.name} wgrad"):
self.layer._submodule_shared_expert_dw()
class MoeRoutedExpertNode(TransformerLayerNode):
def forward_impl(self, global_input_tokens):
token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state):
expert_output, mlp_bias = self.layer._submodule_routed_expert_forward(
self.common_state.tokens_per_expert, global_input_tokens
)
assert mlp_bias is None
return expert_output
def dw(self):
with torch.cuda.nvtx.range(f"{self.name} wgrad"):
self.layer._submodule_routed_expert_dw()
class MoeCombineNode(TransformerLayerNode):
def forward_impl(self, expert_output, shared_expert_output=None):
# TODO(lhb): if dw use grad of residual and probs, necessary synchronization should be add
residual = self.common_state.residual
# shared_expert_output = None
# if self.layer.mlp.use_shared_expert:
# shared_expert_output = self.common_state.shared_expert_output
token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state):
permutated_local_input_tokens = token_dispatcher.combine_all_to_all(
......@@ -371,8 +415,10 @@ class MoeCombineNode(TransformerLayerNode):
cur_stream = torch.cuda.current_stream()
self.common_state.residual.record_stream(cur_stream)
self.common_state.probs.record_stream(cur_stream)
# self.common_state.shared_expert_output.record_stream(cur_stream)
self.common_state.residual = None
self.common_state.probs = None
# self.common_state.shared_expert_output = None
return output
......@@ -443,13 +489,22 @@ def build_layer_schedule_plan(layer, event, chunk_state, comp_stream, com_stream
common_state = TransformerLayerState()
attn = MoeAttnNode(chunk_state, common_state, layer, comp_stream, event)
attn.name = "attn"
dispatch = MoeDispatchNode(chunk_state, common_state, layer, com_stream, event, True)
dispatch.name = "dispatch"
mlp = MoeMlPNode(chunk_state, common_state, layer, comp_stream, event, True)
mlp.name = "mlp"
routed_expert = MoeRoutedExpertNode(chunk_state, common_state, layer, comp_stream, event, True)
routed_expert.name = "routed_expert"
shared_expert = MoeSharedExpertNode(chunk_state, common_state, layer, comp_stream, event, True)
shared_expert.name = "shared_expert"
combine = MoeCombineNode(chunk_state, common_state, layer, com_stream, event, True)
combine.name = "combine"
return TransformerLayerSchedulePlan(attn, dispatch, mlp, combine)
return TransformerLayerSchedulePlan(attn, dispatch, mlp, combine, shared_expert=shared_expert, routed_expert=routed_expert)
class TransformerLayerState(MoEAlltoAllPerBatchState):
......@@ -462,11 +517,13 @@ class ModelChunkSate:
class TransformerLayerSchedulePlan:
def __init__(self, attn, dispatch, mlp, combine):
def __init__(self, attn, dispatch, mlp, combine, shared_expert=None, routed_expert=None):
self.attn = attn
self.dispatch = dispatch
self.mlp = mlp
self.combine = combine
self.shared_expert = shared_expert
self.routed_expert = routed_expert
class ModelChunkSchedulePlan(AbstractSchedulePlan):
......@@ -577,7 +634,7 @@ def schedule_layer_1f1b(
if b_layer is not None:
with b_context:
b_grad = b_layer.combine.backward(b_grad)
routed_expert_output_grad, shared_expert_output_grad = b_layer.combine.backward(b_grad)
if pre_backward_dw is not None:
pre_backward_dw()
......@@ -593,22 +650,36 @@ def schedule_layer_1f1b(
if f_layer is not None:
with f_context:
shared_expert_output = f_layer.shared_expert.forward()
f_input = f_layer.dispatch.forward(f_input, stream_record_event=f_dispatch_b_mlp_sync_event)
# if f_layer is not None:
# with f_context:
# f_input = f_layer.dispatch.forward(f_input, stream_record_event=f_dispatch_b_mlp_sync_event)
if b_layer is not None:
with b_context:
b_grad = b_layer.mlp.backward(b_grad, stream_wait_event=f_dispatch_b_mlp_sync_event)
# routed_expert_output_grad, shared_expert_output_grad = b_grad
b_grad = b_layer.routed_expert.backward(routed_expert_output_grad, stream_wait_event=f_dispatch_b_mlp_sync_event)
b_layer.shared_expert.backward(shared_expert_output_grad)
b_grad = b_layer.dispatch.backward(b_grad)
b_layer.mlp.dw()
b_layer.routed_expert.dw()
if f_layer is not None:
with f_context:
f_input = f_layer.mlp.forward(f_input)
f_input = f_layer.routed_expert.forward(f_input)
# if b_layer is not None:
# with b_context:
# # b_grad = b_layer.dispatch.backward(b_grad)
# b_layer.shared_expert.backward(shared_expert_output_grad)
# b_layer.routed_expert.dw()
def next_iter_pre_forward():
if f_layer is not None:
with f_context:
output = f_layer.combine.forward(f_input)
output = f_layer.combine.forward((f_input, shared_expert_output))
return output
def next_iter_pre_backward():
......
......@@ -3,3 +3,10 @@ class MoELayer():
self.experts.backward_dw()
if self.use_shared_expert and not self.shared_expert_overlap:
self.shared_experts.backward_dw()
def backward_routed_expert_dw(self):
self.experts.backward_dw()
def backward_shared_expert_dw(self):
if self.use_shared_expert and not self.shared_expert_overlap:
self.shared_experts.backward_dw()
......@@ -91,7 +91,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
self.collect_per_batch_state(state)
self.apply_per_batch_state(origin_state)
def meta_prepare(
def dispatch_preprocess(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
):
self.hidden_shape = hidden_states.shape
......@@ -103,9 +103,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
tokens_per_expert = self.preprocess(self.routing_map)
return tokens_per_expert
def dispatch_preprocess(self, hidden_states: torch.Tensor, routing_map: torch.Tensor, tokens_per_expert: torch.Tensor):
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
if self.shared_experts is not None:
self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape))
......@@ -206,8 +203,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
# Permutation 1: input to AlltoAll input
tokens_per_expert = self.meta_prepare(hidden_states, probs, routing_map)
tokens_per_expert, permutated_local_input_tokens = self.dispatch_preprocess(hidden_states, routing_map, tokens_per_expert)
tokens_per_expert, permutated_local_input_tokens = self.dispatch_preprocess(hidden_states, probs, routing_map)
# Perform expert parallel AlltoAll communication
tokens_per_expert, global_input_tokens = self.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens)
......
......@@ -189,11 +189,8 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
probs, routing_map = self.mlp.router(pre_mlp_layernorm_output)
tokens_per_expert = self.mlp.token_dispatcher.meta_prepare(
pre_mlp_layernorm_output, probs, routing_map
)
tokens_per_expert, permutated_local_input_tokens = self.mlp.token_dispatcher.dispatch_preprocess(
pre_mlp_layernorm_output, routing_map, tokens_per_expert
pre_mlp_layernorm_output, probs, routing_map
)
outputs = [
......@@ -205,15 +202,6 @@ class TransformerLayer(MegatronCoreTransformerLayer):
]
return tuple(outputs)
def _submodule_shared_expert_forward(self, pre_mlp_layernorm_output):
"""
Performs a forward pass for shared experts.
"""
shared_expert_output = None
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
shared_expert_output = self.mlp.shared_experts(pre_mlp_layernorm_output)
return shared_expert_output
def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens):
"""
Dispatches tokens to the appropriate experts based on the router output.
......@@ -253,6 +241,27 @@ class TransformerLayer(MegatronCoreTransformerLayer):
expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output)
return expert_output, shared_expert_output, mlp_bias
def _submodule_shared_expert_forward(self, pre_mlp_layernorm_output):
"""
Performs a forward pass for shared experts.
"""
shared_expert_output = None
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
shared_expert_output = self.mlp.shared_experts(pre_mlp_layernorm_output)
return shared_expert_output
def _submodule_routed_expert_forward(self, tokens_per_expert, global_input_tokens):
"""
Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
"""
(dispatched_input, tokens_per_expert) = (
self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens)
)
expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert)
expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output)
return expert_output, mlp_bias
def _submodule_combine_forward(self, hidden_states):
return [self.mlp.token_dispatcher.combine_all_to_all(hidden_states)]
......@@ -295,3 +304,9 @@ class TransformerLayer(MegatronCoreTransformerLayer):
def _submodule_mlp_dw(self):
self.mlp.backward_dw()
def _submodule_routed_expert_dw(self):
self.mlp.backward_routed_expert_dw()
def _submodule_shared_expert_dw(self):
self.mlp.backward_shared_expert_dw()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment