Unverified Commit fa98d777 authored by Varun Sundar Rabindranath's avatar Varun Sundar Rabindranath Committed by GitHub
Browse files

[Kernel] DeepEP dispatch-combine kernel integration (#18434)


Signed-off-by: default avatarVarun <vsundarr@redhat.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
parent 01eee405
...@@ -18,8 +18,8 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: ...@@ -18,8 +18,8 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
Shrink the given tensor and apply the given view to it. This is Shrink the given tensor and apply the given view to it. This is
used to resize the intermediate fused_moe caches. used to resize the intermediate fused_moe caches.
""" """
assert prod( assert prod(v) <= x.numel(
v) <= x.numel(), f"{prod(v)} <= {x.numel()}" # CUDAGRAPH unfriendly? ), f"{v} ({prod(v)}) <= {x.shape} ({x.numel()})" # CUDAGRAPH unfriendly?
return x.flatten()[:prod(v)].view(*v) return x.flatten()[:prod(v)].view(*v)
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import functools import functools
import importlib.util import importlib.util
from typing import Any, Callable, Optional from typing import Any, Callable, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -452,6 +452,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -452,6 +452,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if envs.VLLM_USE_DEEP_GEMM: if envs.VLLM_USE_DEEP_GEMM:
if not has_deep_gemm: if not has_deep_gemm:
logger.warning_once("Failed to import DeepGemm kernels.") logger.warning_once("Failed to import DeepGemm kernels.")
elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "
" DeepGemm kernels")
elif (current_platform.is_cuda() elif (current_platform.is_cuda()
and current_platform.has_device_capability(90)): and current_platform.has_device_capability(90)):
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.") logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
...@@ -460,8 +463,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -460,8 +463,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger.warning_once( logger.warning_once(
"DeepGemm not supported on the current platform.") "DeepGemm not supported on the current platform.")
self.topk_indices_dtype = None
self.fused_experts = functools.partial( # type: ignore self.fused_experts = functools.partial( # type: ignore
fused_experts, fused_experts,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size, block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm) allow_deep_gemm=self.allow_deep_gemm)
...@@ -765,18 +770,39 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -765,18 +770,39 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w2_input_scale del layer.w2_input_scale
def select_gemm_impl(self, prepare_finalize): def select_gemm_impl(self, prepare_finalize):
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts) TritonOrDeepGemmExperts)
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet.") "Marlin and ROCm AITER are not supported with all2all yet.")
experts: Optional[Union[BatchedTritonExperts,
TritonOrDeepGemmExperts]] = None
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
use_batched_experts = max_num_tokens_per_rank is not None
if use_batched_experts:
experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens_per_rank,
world_size=prepare_finalize.world_size,
dp_size=prepare_finalize.dp_size,
use_fp8_w8a8=True,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
)
else:
experts = TritonOrDeepGemmExperts( experts = TritonOrDeepGemmExperts(
use_fp8_w8a8=True, use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size, block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm, allow_deep_gemm=self.allow_deep_gemm,
) )
assert experts is not None
return experts return experts
def apply( def apply(
...@@ -797,6 +823,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -797,6 +823,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
...@@ -808,6 +835,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -808,6 +835,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
...@@ -855,7 +883,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -855,7 +883,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=activation,
use_fp8_w8a8=True,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map, expert_map=expert_map,
......
...@@ -154,6 +154,21 @@ class CudaPlatformBase(Platform): ...@@ -154,6 +154,21 @@ class CudaPlatformBase(Platform):
logger.info( logger.info(
"Forcing kv cache block size to 64 for FlashMLA backend.") "Forcing kv cache block size to 64 for FlashMLA backend.")
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
and parallel_config.data_parallel_size > 1
and vllm_config.compilation_config.use_cudagraph):
logger.info(
"Data Parallel: Forcing enforce eager to be True since DP "
"with DeepEP high-throughput kernels are not CUDA Graph "
"compatible. The DeepEP low-latency kernels are CUDA Graph "
"compatible. Set the all_to_all backend to deepep_low_latency "
"to use those kernels instead.")
vllm_config.compilation_config.use_cudagraph = False
vllm_config.model_config.enforce_eager = True
# TODO (varun): Turning this ON gives incorrect results for the
# Deepseek-V2-lite model.
vllm_config.compilation_config.use_inductor = False
@classmethod @classmethod
def get_current_memory_usage(cls, def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None device: Optional[torch.types.Device] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment