Unverified Commit a40aecc5 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix num_qps_per_rank computation when providing custom DeepEP configuration (#6468)

parent d6e1d28c
......@@ -67,9 +67,9 @@ class DeepEPBuffer:
if deepep_mode.enable_normal():
hidden_bytes = hidden_size * param_bytes
for config in (
_DeepEPConfig.get_instance().normal_dispatch_config
DeepEPConfig.get_instance().normal_dispatch_config
or Buffer.get_dispatch_config(group.size()),
_DeepEPConfig.get_instance().normal_combine_config
DeepEPConfig.get_instance().normal_combine_config
or Buffer.get_combine_config(group.size()),
):
num_nvl_bytes = max(
......@@ -97,7 +97,12 @@ class DeepEPBuffer:
num_nvl_bytes,
num_rdma_bytes,
low_latency_mode=deepep_mode.enable_low_latency(),
num_qps_per_rank=(max(num_experts // group.size(), Buffer.num_sms // 2)),
num_qps_per_rank=(
max(
num_experts // group.size(),
DeepEPConfig.get_instance().num_sms // 2,
)
),
)
return cls._buffer
......@@ -122,7 +127,7 @@ class DeepEPBuffer:
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
class _DeepEPConfig:
class DeepEPConfig:
_instance = None
def __init__(self):
......@@ -131,16 +136,23 @@ class _DeepEPConfig:
config_parsed = load_json_config(config_str)
if torch.distributed.get_rank() == 0:
logger.info(f"Use DeepEP Config: {config_parsed}")
self.normal_dispatch_config = Config(**config_parsed["normal_dispatch"])
self.normal_combine_config = Config(**config_parsed["normal_combine"])
config_dispatch = config_parsed["normal_dispatch"]
config_combine = config_parsed["normal_combine"]
self.normal_dispatch_config = Config(**config_dispatch)
self.normal_combine_config = Config(**config_combine)
assert config_dispatch["num_sms"] == config_combine["num_sms"]
self.num_sms = config_dispatch["num_sms"]
else:
self.normal_dispatch_config = None
self.normal_combine_config = None
self.num_sms = Buffer.num_sms
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = _DeepEPConfig()
cls._instance = DeepEPConfig()
return cls._instance
......@@ -326,7 +338,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
async_finish=self.async_finish,
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
config=_DeepEPConfig.get_instance().normal_dispatch_config,
config=DeepEPConfig.get_instance().normal_dispatch_config,
)
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
......@@ -433,7 +445,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
async_finish=self.async_finish,
previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None,
config=_DeepEPConfig.get_instance().normal_combine_config,
config=DeepEPConfig.get_instance().normal_combine_config,
)
return combined_x, event
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment