Commit 0e35e124 authored by 王敏's avatar 王敏
Browse files

[fix]修复mori报错

parent d698d6f2
...@@ -4755,7 +4755,7 @@ class VllmConfig: ...@@ -4755,7 +4755,7 @@ class VllmConfig:
batch_size_capture_list = [] batch_size_capture_list = []
if self.model_config is not None and \ if self.model_config is not None and \
not self.model_config.enforce_eager: not self.model_config.enforce_eager:
if self.model_config.use_mla and self.compilation_config.full_cuda_graph and self.scheduler_config.max_num_seqs<=512: if self.model_config.use_mla and self.scheduler_config.max_num_seqs<=512:
cuda_graph_sizes = [self.scheduler_config.max_num_seqs] cuda_graph_sizes = [self.scheduler_config.max_num_seqs]
else: else:
cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes
......
...@@ -87,6 +87,8 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -87,6 +87,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
from .all2all import DeepEPLLAll2AllManager from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group) self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
logger.info("Using DeepEP Low-Latency all2all manager.") logger.info("Using DeepEP Low-Latency all2all manager.")
elif all2all_backend == "mori":
pass
else: else:
raise ValueError(f"Unknown all2all backend: {all2all_backend}") raise ValueError(f"Unknown all2all backend: {all2all_backend}")
......
...@@ -369,7 +369,7 @@ class MoriMoE(FusedMoE): ...@@ -369,7 +369,7 @@ class MoriMoE(FusedMoE):
apply_router_weight_on_input=self.apply_router_weight_on_input, apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe,
num_local_tokens=dispatch_recv_num_token, num_local_tokens=dispatch_recv_num_token,
config_select_bs=hidden_states.shape[0]*self.ep_size/self.dp_size, expect_m=hidden_states.shape[0],
scales=dispatch_scales if self.use_int8_dispatch else None scales=dispatch_scales if self.use_int8_dispatch else None
) )
......
...@@ -293,7 +293,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -293,7 +293,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None, num_local_tokens: Optional[torch.Tensor] = None,
config_select_bs: Optional[int] = None, expect_m: Optional[int] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
scales: Optional[torch.Tensor] = None, scales: Optional[torch.Tensor] = None,
**_ **_
...@@ -320,7 +320,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -320,7 +320,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
num_local_tokens=num_local_tokens, num_local_tokens=num_local_tokens,
config_select_bs=config_select_bs, expect_m=expect_m,
) )
def apply( def apply(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment