Unverified Commit 91be5f9b authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Rename "naive" all2all backend (#36294)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent 4ee847e4
...@@ -103,7 +103,7 @@ To be used with a particular `FusedMoEPrepareAndFinalizeModular` subclass, MoE k ...@@ -103,7 +103,7 @@ To be used with a particular `FusedMoEPrepareAndFinalizeModular` subclass, MoE k
## Modular Kernel "families" ## Modular Kernel "families"
The following table shows "families" of modular kernels that are intended to work together. There are some combinations which may work but have not yet been tested, e.g. flashinfer with other fp8 experts. Note that the "naive" backend will work with any non-modular experts. The following table shows "families" of modular kernels that are intended to work together. There are some combinations which may work but have not yet been tested, e.g. flashinfer with other fp8 experts.
| backend | `FusedMoEPrepareAndFinalizeModular` subclasses | `FusedMoEExpertsModular` subclasses | | backend | `FusedMoEPrepareAndFinalizeModular` subclasses | `FusedMoEExpertsModular` subclasses |
| ------- | ---------------------------------------------- | ----------------------------------- | | ------- | ---------------------------------------------- | ----------------------------------- |
......
...@@ -23,7 +23,6 @@ vLLM provides multiple communication backends for EP. Use `--all2all-backend` to ...@@ -23,7 +23,6 @@ vLLM provides multiple communication backends for EP. Use `--all2all-backend` to
| `deepep_low_latency` | Multi-node decode | CUDA graph support, masked layout, optimized for decode | Decode-dominated workloads, low-latency scenarios | | `deepep_low_latency` | Multi-node decode | CUDA graph support, masked layout, optimized for decode | Decode-dominated workloads, low-latency scenarios |
| `flashinfer_nvlink_one_sided` | MNNVL systems | FlashInfer's one-sided A2A strategy for multi-node NVLink | High-throughput workloads | | `flashinfer_nvlink_one_sided` | MNNVL systems | FlashInfer's one-sided A2A strategy for multi-node NVLink | High-throughput workloads |
| `flashinfer_nvlink_two_sided` | MNNVL systems | FlashInfer's two-sided A2A strategy for multi-node NVLink | Systems with NVLink across nodes | | `flashinfer_nvlink_two_sided` | MNNVL systems | FlashInfer's two-sided A2A strategy for multi-node NVLink | Systems with NVLink across nodes |
| `naive` | Testing/debugging | Simple broadcast-based implementation | Debugging, not recommended for production |
## Single Node Deployment ## Single Node Deployment
......
...@@ -162,7 +162,6 @@ class ParallelConfig: ...@@ -162,7 +162,6 @@ class ParallelConfig:
all2all_backend: All2AllBackend = "allgather_reducescatter" all2all_backend: All2AllBackend = "allgather_reducescatter"
"""All2All backend for MoE expert parallel communication. Available options: """All2All backend for MoE expert parallel communication. Available options:
- "naive": Naive all2all implementation using broadcasts\n
- "allgather_reducescatter": All2all based on allgather and reducescatter\n - "allgather_reducescatter": All2all based on allgather and reducescatter\n
- "deepep_high_throughput": Use deepep high-throughput kernels\n - "deepep_high_throughput": Use deepep high-throughput kernels\n
- "deepep_low_latency": Use deepep low-latency kernels\n - "deepep_low_latency": Use deepep low-latency kernels\n
...@@ -344,10 +343,11 @@ class ParallelConfig: ...@@ -344,10 +343,11 @@ class ParallelConfig:
f"but found: {self._api_process_rank}" f"but found: {self._api_process_rank}"
) )
if self.all2all_backend == "pplx": if self.all2all_backend in ["pplx", "naive"]:
logger.warning( logger.warning(
"The 'pplx' all2all backend has been removed. " "The '%s' all2all backend has been removed. "
"Falling back to 'allgather_reducescatter'." "Falling back to 'allgather_reducescatter'.",
self.all2all_backend,
) )
self.all2all_backend = "allgather_reducescatter" self.all2all_backend = "allgather_reducescatter"
...@@ -534,7 +534,6 @@ class ParallelConfig: ...@@ -534,7 +534,6 @@ class ParallelConfig:
self.all2all_backend self.all2all_backend
in ( in (
"allgather_reducescatter", "allgather_reducescatter",
"naive",
"deepep_high_throughput", "deepep_high_throughput",
"deepep_low_latency", "deepep_low_latency",
"mori", "mori",
...@@ -764,7 +763,7 @@ class ParallelConfig: ...@@ -764,7 +763,7 @@ class ParallelConfig:
) )
if ( if (
self.all2all_backend in ("allgather_reducescatter", "naive") self.all2all_backend in ("allgather_reducescatter")
and self.eplb_config.use_async and self.eplb_config.use_async
): ):
logger.warning( logger.warning(
......
...@@ -229,7 +229,7 @@ def maybe_make_prepare_finalize( ...@@ -229,7 +229,7 @@ def maybe_make_prepare_finalize(
num_dispatchers=all2all_manager.world_size, num_dispatchers=all2all_manager.world_size,
) )
elif moe.use_naive_all2all_kernels and allow_new_interface: elif moe.use_ag_rs_all2all_kernels and allow_new_interface:
prepare_finalize = make_moe_prepare_and_finalize_naive_dp_ep( prepare_finalize = make_moe_prepare_and_finalize_naive_dp_ep(
use_monolithic=use_monolithic, use_monolithic=use_monolithic,
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel, is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
......
...@@ -975,9 +975,10 @@ class FusedMoEParallelConfig: ...@@ -975,9 +975,10 @@ class FusedMoEParallelConfig:
return self.use_deepep_ll_kernels return self.use_deepep_ll_kernels
@property @property
def use_naive_all2all_kernels(self): def use_ag_rs_all2all_kernels(self):
return self.use_all2all_kernels and ( return (
self.all2all_backend in ["naive", "allgather_reducescatter"] self.use_all2all_kernels
and self.all2all_backend == "allgather_reducescatter"
) )
@property @property
...@@ -1143,7 +1144,7 @@ class FusedMoEParallelConfig: ...@@ -1143,7 +1144,7 @@ class FusedMoEParallelConfig:
ep_rank=0, ep_rank=0,
sp_size=1, sp_size=1,
use_ep=False, use_ep=False,
all2all_backend="naive", all2all_backend="allgather_reducescatter",
enable_eplb=False, enable_eplb=False,
) )
...@@ -1256,8 +1257,8 @@ class FusedMoEConfig: ...@@ -1256,8 +1257,8 @@ class FusedMoEConfig:
return self.moe_parallel_config.use_fi_nvl_one_sided_kernels return self.moe_parallel_config.use_fi_nvl_one_sided_kernels
@property @property
def use_naive_all2all_kernels(self): def use_ag_rs_all2all_kernels(self):
return self.moe_parallel_config.use_naive_all2all_kernels return self.moe_parallel_config.use_ag_rs_all2all_kernels
@property @property
def use_nixl_ep_kernels(self): def use_nixl_ep_kernels(self):
......
...@@ -79,7 +79,7 @@ class TrtLlmFp8ExpertsBase: ...@@ -79,7 +79,7 @@ class TrtLlmFp8ExpertsBase:
"""Monolithic kernel so only use with naive DP/EP and TP.""" """Monolithic kernel so only use with naive DP/EP and TP."""
return ( return (
not moe_parallel_config.use_all2all_kernels not moe_parallel_config.use_all2all_kernels
or moe_parallel_config.use_naive_all2all_kernels or moe_parallel_config.use_ag_rs_all2all_kernels
) and not moe_parallel_config.enable_eplb ) and not moe_parallel_config.enable_eplb
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment