Commit 1d497357 authored by dongcl's avatar dongcl
Browse files

split routed experts and shared experts

parent 2ceeaafd
...@@ -74,3 +74,9 @@ def a2a_overlap_adaptation(patches_manager): ...@@ -74,3 +74,9 @@ def a2a_overlap_adaptation(patches_manager):
patches_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.backward_dw', patches_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.backward_dw',
MoELayer.backward_dw, MoELayer.backward_dw,
create_dummy=True) create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.backward_routed_expert_dw',
MoELayer.backward_routed_expert_dw,
create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.backward_shared_expert_dw',
MoELayer.backward_shared_expert_dw,
create_dummy=True)
...@@ -356,10 +356,54 @@ class MoeMlPNode(TransformerLayerNode): ...@@ -356,10 +356,54 @@ class MoeMlPNode(TransformerLayerNode):
self.layer._submodule_mlp_dw() self.layer._submodule_mlp_dw()
class MoeSharedExpertNode(TransformerLayerNode):
def forward_impl(self):
if self.layer.mlp.use_shared_expert:
pre_mlp_layernorm_output = self.common_state.pre_mlp_layernorm_output
else:
pre_mlp_layernorm_output = None
token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state):
shared_expert_output = self.layer._submodule_shared_expert_forward(
pre_mlp_layernorm_output
)
# pre_mlp_layernorm_output used
self.common_state.pre_mlp_layernorm_output = None
return shared_expert_output
# self.common_state.shared_expert_output = self.detach(shared_expert_output)
def dw(self):
with torch.cuda.nvtx.range(f"{self.name} wgrad"):
self.layer._submodule_shared_expert_dw()
class MoeRoutedExpertNode(TransformerLayerNode):
def forward_impl(self, global_input_tokens):
token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state):
expert_output, mlp_bias = self.layer._submodule_routed_expert_forward(
self.common_state.tokens_per_expert, global_input_tokens
)
assert mlp_bias is None
return expert_output
def dw(self):
with torch.cuda.nvtx.range(f"{self.name} wgrad"):
self.layer._submodule_routed_expert_dw()
class MoeCombineNode(TransformerLayerNode): class MoeCombineNode(TransformerLayerNode):
def forward_impl(self, expert_output, shared_expert_output=None): def forward_impl(self, expert_output, shared_expert_output=None):
# TODO(lhb): if dw use grad of residual and probs, necessary synchronization should be add # TODO(lhb): if dw use grad of residual and probs, necessary synchronization should be add
residual = self.common_state.residual residual = self.common_state.residual
# shared_expert_output = None
# if self.layer.mlp.use_shared_expert:
# shared_expert_output = self.common_state.shared_expert_output
token_dispatcher = self.layer.mlp.token_dispatcher token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state): with token_dispatcher.per_batch_state_context(self.common_state):
permutated_local_input_tokens = token_dispatcher.combine_all_to_all( permutated_local_input_tokens = token_dispatcher.combine_all_to_all(
...@@ -371,8 +415,10 @@ class MoeCombineNode(TransformerLayerNode): ...@@ -371,8 +415,10 @@ class MoeCombineNode(TransformerLayerNode):
cur_stream = torch.cuda.current_stream() cur_stream = torch.cuda.current_stream()
self.common_state.residual.record_stream(cur_stream) self.common_state.residual.record_stream(cur_stream)
self.common_state.probs.record_stream(cur_stream) self.common_state.probs.record_stream(cur_stream)
# self.common_state.shared_expert_output.record_stream(cur_stream)
self.common_state.residual = None self.common_state.residual = None
self.common_state.probs = None self.common_state.probs = None
# self.common_state.shared_expert_output = None
return output return output
...@@ -443,13 +489,22 @@ def build_layer_schedule_plan(layer, event, chunk_state, comp_stream, com_stream ...@@ -443,13 +489,22 @@ def build_layer_schedule_plan(layer, event, chunk_state, comp_stream, com_stream
common_state = TransformerLayerState() common_state = TransformerLayerState()
attn = MoeAttnNode(chunk_state, common_state, layer, comp_stream, event) attn = MoeAttnNode(chunk_state, common_state, layer, comp_stream, event)
attn.name = "attn" attn.name = "attn"
dispatch = MoeDispatchNode(chunk_state, common_state, layer, com_stream, event, True) dispatch = MoeDispatchNode(chunk_state, common_state, layer, com_stream, event, True)
dispatch.name = "dispatch" dispatch.name = "dispatch"
mlp = MoeMlPNode(chunk_state, common_state, layer, comp_stream, event, True) mlp = MoeMlPNode(chunk_state, common_state, layer, comp_stream, event, True)
mlp.name = "mlp" mlp.name = "mlp"
routed_expert = MoeRoutedExpertNode(chunk_state, common_state, layer, comp_stream, event, True)
routed_expert.name = "routed_expert"
shared_expert = MoeSharedExpertNode(chunk_state, common_state, layer, comp_stream, event, True)
shared_expert.name = "shared_expert"
combine = MoeCombineNode(chunk_state, common_state, layer, com_stream, event, True) combine = MoeCombineNode(chunk_state, common_state, layer, com_stream, event, True)
combine.name = "combine" combine.name = "combine"
return TransformerLayerSchedulePlan(attn, dispatch, mlp, combine) return TransformerLayerSchedulePlan(attn, dispatch, mlp, combine, shared_expert=shared_expert, routed_expert=routed_expert)
class TransformerLayerState(MoEAlltoAllPerBatchState): class TransformerLayerState(MoEAlltoAllPerBatchState):
...@@ -462,11 +517,13 @@ class ModelChunkSate: ...@@ -462,11 +517,13 @@ class ModelChunkSate:
class TransformerLayerSchedulePlan: class TransformerLayerSchedulePlan:
def __init__(self, attn, dispatch, mlp, combine): def __init__(self, attn, dispatch, mlp, combine, shared_expert=None, routed_expert=None):
self.attn = attn self.attn = attn
self.dispatch = dispatch self.dispatch = dispatch
self.mlp = mlp self.mlp = mlp
self.combine = combine self.combine = combine
self.shared_expert = shared_expert
self.routed_expert = routed_expert
class ModelChunkSchedulePlan(AbstractSchedulePlan): class ModelChunkSchedulePlan(AbstractSchedulePlan):
...@@ -577,7 +634,7 @@ def schedule_layer_1f1b( ...@@ -577,7 +634,7 @@ def schedule_layer_1f1b(
if b_layer is not None: if b_layer is not None:
with b_context: with b_context:
b_grad = b_layer.combine.backward(b_grad) routed_expert_output_grad, shared_expert_output_grad = b_layer.combine.backward(b_grad)
if pre_backward_dw is not None: if pre_backward_dw is not None:
pre_backward_dw() pre_backward_dw()
...@@ -593,22 +650,36 @@ def schedule_layer_1f1b( ...@@ -593,22 +650,36 @@ def schedule_layer_1f1b(
if f_layer is not None: if f_layer is not None:
with f_context: with f_context:
shared_expert_output = f_layer.shared_expert.forward()
f_input = f_layer.dispatch.forward(f_input, stream_record_event=f_dispatch_b_mlp_sync_event) f_input = f_layer.dispatch.forward(f_input, stream_record_event=f_dispatch_b_mlp_sync_event)
# if f_layer is not None:
# with f_context:
# f_input = f_layer.dispatch.forward(f_input, stream_record_event=f_dispatch_b_mlp_sync_event)
if b_layer is not None: if b_layer is not None:
with b_context: with b_context:
b_grad = b_layer.mlp.backward(b_grad, stream_wait_event=f_dispatch_b_mlp_sync_event) # routed_expert_output_grad, shared_expert_output_grad = b_grad
b_grad = b_layer.routed_expert.backward(routed_expert_output_grad, stream_wait_event=f_dispatch_b_mlp_sync_event)
b_layer.shared_expert.backward(shared_expert_output_grad)
b_grad = b_layer.dispatch.backward(b_grad) b_grad = b_layer.dispatch.backward(b_grad)
b_layer.mlp.dw() b_layer.routed_expert.dw()
if f_layer is not None: if f_layer is not None:
with f_context: with f_context:
f_input = f_layer.mlp.forward(f_input) f_input = f_layer.routed_expert.forward(f_input)
# if b_layer is not None:
# with b_context:
# # b_grad = b_layer.dispatch.backward(b_grad)
# b_layer.shared_expert.backward(shared_expert_output_grad)
# b_layer.routed_expert.dw()
def next_iter_pre_forward(): def next_iter_pre_forward():
if f_layer is not None: if f_layer is not None:
with f_context: with f_context:
output = f_layer.combine.forward(f_input) output = f_layer.combine.forward((f_input, shared_expert_output))
return output return output
def next_iter_pre_backward(): def next_iter_pre_backward():
......
...@@ -3,3 +3,10 @@ class MoELayer(): ...@@ -3,3 +3,10 @@ class MoELayer():
self.experts.backward_dw() self.experts.backward_dw()
if self.use_shared_expert and not self.shared_expert_overlap: if self.use_shared_expert and not self.shared_expert_overlap:
self.shared_experts.backward_dw() self.shared_experts.backward_dw()
def backward_routed_expert_dw(self):
self.experts.backward_dw()
def backward_shared_expert_dw(self):
if self.use_shared_expert and not self.shared_expert_overlap:
self.shared_experts.backward_dw()
...@@ -91,7 +91,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -91,7 +91,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
self.collect_per_batch_state(state) self.collect_per_batch_state(state)
self.apply_per_batch_state(origin_state) self.apply_per_batch_state(origin_state)
def meta_prepare( def dispatch_preprocess(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
): ):
self.hidden_shape = hidden_states.shape self.hidden_shape = hidden_states.shape
...@@ -103,9 +103,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -103,9 +103,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
tokens_per_expert = self.preprocess(self.routing_map) tokens_per_expert = self.preprocess(self.routing_map)
return tokens_per_expert
def dispatch_preprocess(self, hidden_states: torch.Tensor, routing_map: torch.Tensor, tokens_per_expert: torch.Tensor):
hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
if self.shared_experts is not None: if self.shared_experts is not None:
self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape)) self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape))
...@@ -206,8 +203,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -206,8 +203,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
""" """
# Preprocess: Get the metadata for communication, permutation and computation operations. # Preprocess: Get the metadata for communication, permutation and computation operations.
# Permutation 1: input to AlltoAll input # Permutation 1: input to AlltoAll input
tokens_per_expert = self.meta_prepare(hidden_states, probs, routing_map) tokens_per_expert, permutated_local_input_tokens = self.dispatch_preprocess(hidden_states, probs, routing_map)
tokens_per_expert, permutated_local_input_tokens = self.dispatch_preprocess(hidden_states, routing_map, tokens_per_expert)
# Perform expert parallel AlltoAll communication # Perform expert parallel AlltoAll communication
tokens_per_expert, global_input_tokens = self.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens) tokens_per_expert, global_input_tokens = self.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens)
......
...@@ -189,11 +189,8 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -189,11 +189,8 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
probs, routing_map = self.mlp.router(pre_mlp_layernorm_output) probs, routing_map = self.mlp.router(pre_mlp_layernorm_output)
tokens_per_expert = self.mlp.token_dispatcher.meta_prepare(
pre_mlp_layernorm_output, probs, routing_map
)
tokens_per_expert, permutated_local_input_tokens = self.mlp.token_dispatcher.dispatch_preprocess( tokens_per_expert, permutated_local_input_tokens = self.mlp.token_dispatcher.dispatch_preprocess(
pre_mlp_layernorm_output, routing_map, tokens_per_expert pre_mlp_layernorm_output, probs, routing_map
) )
outputs = [ outputs = [
...@@ -205,15 +202,6 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -205,15 +202,6 @@ class TransformerLayer(MegatronCoreTransformerLayer):
] ]
return tuple(outputs) return tuple(outputs)
def _submodule_shared_expert_forward(self, pre_mlp_layernorm_output):
"""
Performs a forward pass for shared experts.
"""
shared_expert_output = None
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
shared_expert_output = self.mlp.shared_experts(pre_mlp_layernorm_output)
return shared_expert_output
def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens): def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens):
""" """
Dispatches tokens to the appropriate experts based on the router output. Dispatches tokens to the appropriate experts based on the router output.
...@@ -253,6 +241,27 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -253,6 +241,27 @@ class TransformerLayer(MegatronCoreTransformerLayer):
expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output) expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output)
return expert_output, shared_expert_output, mlp_bias return expert_output, shared_expert_output, mlp_bias
def _submodule_shared_expert_forward(self, pre_mlp_layernorm_output):
"""
Performs a forward pass for shared experts.
"""
shared_expert_output = None
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
shared_expert_output = self.mlp.shared_experts(pre_mlp_layernorm_output)
return shared_expert_output
def _submodule_routed_expert_forward(self, tokens_per_expert, global_input_tokens):
"""
Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
"""
(dispatched_input, tokens_per_expert) = (
self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens)
)
expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert)
expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output)
return expert_output, mlp_bias
def _submodule_combine_forward(self, hidden_states): def _submodule_combine_forward(self, hidden_states):
return [self.mlp.token_dispatcher.combine_all_to_all(hidden_states)] return [self.mlp.token_dispatcher.combine_all_to_all(hidden_states)]
...@@ -295,3 +304,9 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -295,3 +304,9 @@ class TransformerLayer(MegatronCoreTransformerLayer):
def _submodule_mlp_dw(self): def _submodule_mlp_dw(self):
self.mlp.backward_dw() self.mlp.backward_dw()
def _submodule_routed_expert_dw(self):
self.mlp.backward_routed_expert_dw()
def _submodule_shared_expert_dw(self):
self.mlp.backward_shared_expert_dw()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment