Unverified Commit 54411f6a authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

fix: disable dsv3_router_gemm in dsv3_nextn (#7793)

parent 625018d2
...@@ -210,8 +210,10 @@ class MoEGate(nn.Module): ...@@ -210,8 +210,10 @@ class MoEGate(nn.Module):
self, self,
config, config,
prefix: str = "", prefix: str = "",
is_nextn: bool = False,
): ):
super().__init__() super().__init__()
self.is_nextn = is_nextn
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.empty((config.n_routed_experts, config.hidden_size)) torch.empty((config.n_routed_experts, config.hidden_size))
) )
...@@ -233,8 +235,10 @@ class MoEGate(nn.Module): ...@@ -233,8 +235,10 @@ class MoEGate(nn.Module):
True, # is_vnni True, # is_vnni
) )
# NOTE: For some unknown reason, router_gemm seems degrade accept length.
if ( if (
_is_cuda _is_cuda
and not self.is_nextn
and hidden_states.shape[0] < 4 and hidden_states.shape[0] < 4
and hidden_states.shape[1] == 7168 and hidden_states.shape[1] == 7168
and self.weight.shape[0] == 256 and self.weight.shape[0] == 256
...@@ -258,6 +262,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -258,6 +262,7 @@ class DeepseekV2MoE(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None, alt_stream: Optional[torch.cuda.Stream] = None,
is_nextn: bool = False,
): ):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
...@@ -284,7 +289,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -284,7 +289,9 @@ class DeepseekV2MoE(nn.Module):
"Only silu is supported for now." "Only silu is supported for now."
) )
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix)) self.gate = MoEGate(
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
)
self.experts = get_moe_impl_class()( self.experts = get_moe_impl_class()(
num_experts=config.n_routed_experts num_experts=config.n_routed_experts
...@@ -1776,6 +1783,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1776,6 +1783,7 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix=add_prefix("mlp", prefix), prefix=add_prefix("mlp", prefix),
layer_id=self.layer_id, layer_id=self.layer_id,
alt_stream=alt_stream, alt_stream=alt_stream,
is_nextn=is_nextn,
) )
else: else:
if enable_moe_dense_fully_dp(): if enable_moe_dense_fully_dp():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment