Unverified Commit a40aecc5 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix num_qps_per_rank computation when providing custom DeepEP configuration (#6468)

parent d6e1d28c
...@@ -67,9 +67,9 @@ class DeepEPBuffer: ...@@ -67,9 +67,9 @@ class DeepEPBuffer:
if deepep_mode.enable_normal(): if deepep_mode.enable_normal():
hidden_bytes = hidden_size * param_bytes hidden_bytes = hidden_size * param_bytes
for config in ( for config in (
_DeepEPConfig.get_instance().normal_dispatch_config DeepEPConfig.get_instance().normal_dispatch_config
or Buffer.get_dispatch_config(group.size()), or Buffer.get_dispatch_config(group.size()),
_DeepEPConfig.get_instance().normal_combine_config DeepEPConfig.get_instance().normal_combine_config
or Buffer.get_combine_config(group.size()), or Buffer.get_combine_config(group.size()),
): ):
num_nvl_bytes = max( num_nvl_bytes = max(
...@@ -97,7 +97,12 @@ class DeepEPBuffer: ...@@ -97,7 +97,12 @@ class DeepEPBuffer:
num_nvl_bytes, num_nvl_bytes,
num_rdma_bytes, num_rdma_bytes,
low_latency_mode=deepep_mode.enable_low_latency(), low_latency_mode=deepep_mode.enable_low_latency(),
num_qps_per_rank=(max(num_experts // group.size(), Buffer.num_sms // 2)), num_qps_per_rank=(
max(
num_experts // group.size(),
DeepEPConfig.get_instance().num_sms // 2,
)
),
) )
return cls._buffer return cls._buffer
...@@ -122,7 +127,7 @@ class DeepEPBuffer: ...@@ -122,7 +127,7 @@ class DeepEPBuffer:
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
class _DeepEPConfig: class DeepEPConfig:
_instance = None _instance = None
def __init__(self): def __init__(self):
...@@ -131,16 +136,23 @@ class _DeepEPConfig: ...@@ -131,16 +136,23 @@ class _DeepEPConfig:
config_parsed = load_json_config(config_str) config_parsed = load_json_config(config_str)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
logger.info(f"Use DeepEP Config: {config_parsed}") logger.info(f"Use DeepEP Config: {config_parsed}")
self.normal_dispatch_config = Config(**config_parsed["normal_dispatch"]) config_dispatch = config_parsed["normal_dispatch"]
self.normal_combine_config = Config(**config_parsed["normal_combine"]) config_combine = config_parsed["normal_combine"]
self.normal_dispatch_config = Config(**config_dispatch)
self.normal_combine_config = Config(**config_combine)
assert config_dispatch["num_sms"] == config_combine["num_sms"]
self.num_sms = config_dispatch["num_sms"]
else: else:
self.normal_dispatch_config = None self.normal_dispatch_config = None
self.normal_combine_config = None self.normal_combine_config = None
self.num_sms = Buffer.num_sms
@classmethod @classmethod
def get_instance(cls): def get_instance(cls):
if cls._instance is None: if cls._instance is None:
cls._instance = _DeepEPConfig() cls._instance = DeepEPConfig()
return cls._instance return cls._instance
...@@ -326,7 +338,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -326,7 +338,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
async_finish=self.async_finish, async_finish=self.async_finish,
allocate_on_comm_stream=(previous_event is not None) and self.async_finish, allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1, expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
config=_DeepEPConfig.get_instance().normal_dispatch_config, config=DeepEPConfig.get_instance().normal_dispatch_config,
) )
get_global_expert_distribution_recorder().on_deepep_dispatch_normal( get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
...@@ -433,7 +445,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -433,7 +445,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
async_finish=self.async_finish, async_finish=self.async_finish,
previous_event=previous_event, previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None, allocate_on_comm_stream=previous_event is not None,
config=_DeepEPConfig.get_instance().normal_combine_config, config=DeepEPConfig.get_instance().normal_combine_config,
) )
return combined_x, event return combined_x, event
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment