Unverified Commit 9558f439 authored by Dao007forever's avatar Dao007forever Committed by GitHub
Browse files

[Bugfix] Size FlashInfer NVLink MNNVL workspace to EP group (#40893)


Signed-off-by: default avatarDao Le <Dao007forever@gmail.com>
parent 8cd174fa
...@@ -492,15 +492,18 @@ class FlashInferNVLinkTwoSidedManager(All2AllManagerBase): ...@@ -492,15 +492,18 @@ class FlashInferNVLinkTwoSidedManager(All2AllManagerBase):
CustomCommunicator, CustomCommunicator,
) )
dp_config = MnnvlConfig( # MNNVL workspace is allocated per rank in the comm_backend's group; the
comm_backend=CustomCommunicator(get_dp_group().cpu_group), # flashinfer kernel asserts workspace.size(0) == moe_ep_size, so the backend
# must span the EP group (= DP*PCP*TP), not the DP group.
ep_config = MnnvlConfig(
comm_backend=CustomCommunicator(self.cpu_group),
fabric_page_size=1 << 29, # 512MB fabric_page_size=1 << 29, # 512MB
allocation_granularity=0, # Auto-detect allocation_granularity=0, # Auto-detect
) )
self.workspace_tensor = MnnvlMoe.get_moe_workspaces(self.mapping, dp_config) self.workspace_tensor = MnnvlMoe.get_moe_workspaces(self.mapping, ep_config)
self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace( self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace(
self.mapping, dp_config self.mapping, ep_config
) )
self.world_size = world_size self.world_size = world_size
...@@ -605,8 +608,11 @@ class FlashInferNVLinkOneSidedManager(All2AllManagerBase): ...@@ -605,8 +608,11 @@ class FlashInferNVLinkOneSidedManager(All2AllManagerBase):
CustomCommunicator, CustomCommunicator,
) )
dp_config = MnnvlConfig( # MNNVL workspace is allocated per rank in the comm_backend's group; the
comm_backend=CustomCommunicator(get_dp_group().cpu_group), # flashinfer kernel asserts workspace.size(0) == moe_ep_size, so the backend
# must span the EP group (= DP*PCP*TP), not the DP group.
ep_config = MnnvlConfig(
comm_backend=CustomCommunicator(self.cpu_group),
) )
total_dispatch_payload_size_per_token = ( total_dispatch_payload_size_per_token = (
hidden_size // 2 # nvfp4 hidden states hidden_size // 2 # nvfp4 hidden states
...@@ -628,7 +634,7 @@ class FlashInferNVLinkOneSidedManager(All2AllManagerBase): ...@@ -628,7 +634,7 @@ class FlashInferNVLinkOneSidedManager(All2AllManagerBase):
top_k=top_k, top_k=top_k,
num_experts=num_experts, num_experts=num_experts,
workspace_size_per_rank=self.workspace_size, workspace_size_per_rank=self.workspace_size,
mnnvl_config=dp_config, mnnvl_config=ep_config,
) )
self.gpus_per_node = gpus_per_node self.gpus_per_node = gpus_per_node
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment