Commit 0e35e124 authored by 王敏's avatar 王敏
Browse files

[fix]修复mori报错

parent d698d6f2
......@@ -4755,7 +4755,7 @@ class VllmConfig:
batch_size_capture_list = []
if self.model_config is not None and \
not self.model_config.enforce_eager:
if self.model_config.use_mla and self.compilation_config.full_cuda_graph and self.scheduler_config.max_num_seqs<=512:
if self.model_config.use_mla and self.scheduler_config.max_num_seqs<=512:
cuda_graph_sizes = [self.scheduler_config.max_num_seqs]
else:
cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes
......
......@@ -87,6 +87,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
logger.info("Using DeepEP Low-Latency all2all manager.")
elif all2all_backend == "mori":
pass
else:
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
......
......@@ -369,7 +369,7 @@ class MoriMoE(FusedMoE):
apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe,
num_local_tokens=dispatch_recv_num_token,
config_select_bs=hidden_states.shape[0]*self.ep_size/self.dp_size,
expect_m=hidden_states.shape[0],
scales=dispatch_scales if self.use_int8_dispatch else None
)
......
......@@ -293,7 +293,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None,
config_select_bs: Optional[int] = None,
expect_m: Optional[int] = None,
routed_scaling_factor: Optional[float] = None,
scales: Optional[torch.Tensor] = None,
**_
......@@ -320,7 +320,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe,
num_local_tokens=num_local_tokens,
config_select_bs=config_select_bs,
expect_m=expect_m,
)
def apply(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment