Unverified Commit fa98d777 authored by Varun Sundar Rabindranath's avatar Varun Sundar Rabindranath Committed by GitHub
Browse files

[Kernel] DeepEP dispatch-combine kernel integration (#18434)


Signed-off-by: default avatarVarun <vsundarr@redhat.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
parent 01eee405
......@@ -18,8 +18,8 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
Shrink the given tensor and apply the given view to it. This is
used to resize the intermediate fused_moe caches.
"""
assert prod(
v) <= x.numel(), f"{prod(v)} <= {x.numel()}" # CUDAGRAPH unfriendly?
assert prod(v) <= x.numel(
), f"{v} ({prod(v)}) <= {x.shape} ({x.numel()})" # CUDAGRAPH unfriendly?
return x.flatten()[:prod(v)].view(*v)
......
......@@ -3,7 +3,7 @@
import functools
import importlib.util
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union
import torch
import torch.nn.functional as F
......@@ -452,6 +452,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if envs.VLLM_USE_DEEP_GEMM:
if not has_deep_gemm:
logger.warning_once("Failed to import DeepGemm kernels.")
elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "
" DeepGemm kernels")
elif (current_platform.is_cuda()
and current_platform.has_device_capability(90)):
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
......@@ -460,8 +463,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger.warning_once(
"DeepGemm not supported on the current platform.")
self.topk_indices_dtype = None
self.fused_experts = functools.partial( # type: ignore
fused_experts,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm)
......@@ -765,18 +770,39 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w2_input_scale
def select_gemm_impl(self, prepare_finalize):
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet.")
experts = TritonOrDeepGemmExperts(
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
)
experts: Optional[Union[BatchedTritonExperts,
TritonOrDeepGemmExperts]] = None
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
use_batched_experts = max_num_tokens_per_rank is not None
if use_batched_experts:
experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens_per_rank,
world_size=prepare_finalize.world_size,
dp_size=prepare_finalize.dp_size,
use_fp8_w8a8=True,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
)
else:
experts = TritonOrDeepGemmExperts(
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
)
assert experts is not None
return experts
def apply(
......@@ -797,6 +823,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
......@@ -808,6 +835,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
if self.rocm_aiter_moe_enabled:
......@@ -855,7 +883,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
......
......@@ -154,6 +154,21 @@ class CudaPlatformBase(Platform):
logger.info(
"Forcing kv cache block size to 64 for FlashMLA backend.")
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
and parallel_config.data_parallel_size > 1
and vllm_config.compilation_config.use_cudagraph):
logger.info(
"Data Parallel: Forcing enforce eager to be True since DP "
"with DeepEP high-throughput kernels are not CUDA Graph "
"compatible. The DeepEP low-latency kernels are CUDA Graph "
"compatible. Set the all_to_all backend to deepep_low_latency "
"to use those kernels instead.")
vllm_config.compilation_config.use_cudagraph = False
vllm_config.model_config.enforce_eager = True
# TODO (varun): Turning this ON gives incorrect results for the
# Deepseek-V2-lite model.
vllm_config.compilation_config.use_inductor = False
@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment