Unverified Commit a65f46be authored by Varun Sundar Rabindranath's avatar Varun Sundar Rabindranath Committed by GitHub
Browse files

[Misc] DeepGemmExperts : Avoid JIT generation in the hot-path (#21955)


Signed-off-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
parent 57393715
...@@ -126,6 +126,7 @@ if TYPE_CHECKING: ...@@ -126,6 +126,7 @@ if TYPE_CHECKING:
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
VLLM_TPU_USING_PATHWAYS: bool = False VLLM_TPU_USING_PATHWAYS: bool = False
VLLM_USE_DEEP_GEMM: bool = False VLLM_USE_DEEP_GEMM: bool = False
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_XGRAMMAR_CACHE_MB: int = 0
...@@ -910,6 +911,14 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -910,6 +911,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_DEEP_GEMM": "VLLM_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
# JIT all the required kernels before model execution so there is no
# JIT'ing in the hot-path. However, this warmup increases the engine
# startup time by a couple of minutes.
# Set `VLLM_SKIP_DEEP_GEMM_WARMUP` to disable the warmup.
"VLLM_SKIP_DEEP_GEMM_WARMUP":
lambda: bool(int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))),
# Allow use of FlashInfer MoE kernels for fused moe ops. # Allow use of FlashInfer MoE kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE_FP8": "VLLM_USE_FLASHINFER_MOE_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))), lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))),
......
...@@ -4,7 +4,9 @@ import functools ...@@ -4,7 +4,9 @@ import functools
from typing import Any, Optional from typing import Any, Optional
import torch import torch
from tqdm import tqdm
import vllm.envs as env
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
...@@ -17,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ...@@ -17,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8) per_token_group_quant_fp8)
from vllm.utils import has_deep_gemm from vllm.utils import has_deep_gemm, run_once
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -82,6 +84,65 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, ...@@ -82,6 +84,65 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
return True return True
@run_once
def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
num_topk: int):
"""
DeepGemm JITs the grouped-gemm kernels. The JIT'ing happens based on the
input tensor shapes. In this function, we construct all possible input
tensor shapes so all the kernels are JIT'ed and cached.
Note that this warmup is expected to happen during the model profile
call and not during actual model inference.
"""
assert w1.size(0) == w2.size(0), (
"w1 and w2 must have the same number of experts")
block_m = deep_gemm_block_shape()[0]
num_experts = w1.size(0)
device = w1.device
# This is the maximum GroupedGemm M size that we expect to run
# the grouped_gemm with.
MAX_M = compute_aligned_M(env.VLLM_FUSED_MOE_CHUNK_SIZE,
num_topk,
num_experts,
block_m,
expert_tokens_meta=None)
# Distribute expert-ids evenly.
MAX_BLOCKS = MAX_M // block_m
expert_ids_block = torch.randint(low=0,
high=num_experts,
size=(MAX_BLOCKS, ),
device=device,
dtype=torch.int32)
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
_, n, k = w.size()
a1q = torch.empty((MAX_M, k), device=device).to(torch.float8_e4m3fn)
a1q_scales = torch.empty((MAX_M, k // block_m),
device=device,
dtype=torch.float32)
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
pbar = tqdm(total=MAX_BLOCKS,
desc=f"DeepGemmExperts GEMM warmup (MAX_M={MAX_M})")
num_tokens = MAX_M
while num_tokens > 0:
m_grouped_fp8_gemm_nt_contiguous(
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale),
out[:num_tokens], expert_ids[:num_tokens])
pbar.update(1)
num_tokens = num_tokens - block_m
_warmup(w1, w1_scale)
_warmup(w2, w2_scale)
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self): def __init__(self):
...@@ -156,6 +217,20 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -156,6 +217,20 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
): ):
assert self.block_shape is not None assert self.block_shape is not None
assert a1q_scale is not None assert a1q_scale is not None
assert w1_scale is not None
assert w2_scale is not None
if not env.VLLM_SKIP_DEEP_GEMM_WARMUP:
# DeepGemm JITs the grouped-gemm kernels. We don't want the JIT'ing
# to happen during actual model-inference. The
# `warmup_deepgemm_kernels` function is a `run_once` decorated
# function that executes during the model profile run. This warmup
# should create all the required JITs for the current model.
warmup_deepgemm_gg_contiguous_kernels(w1,
w2,
w1_scale,
w2_scale,
num_topk=topk_ids.size(1))
a1q = hidden_states a1q = hidden_states
_, N, K = w1.size() _, N, K = w1.size()
......
...@@ -8,6 +8,7 @@ from __future__ import annotations ...@@ -8,6 +8,7 @@ from __future__ import annotations
import functools import functools
import importlib import importlib
import os
from typing import Any, Callable, NoReturn from typing import Any, Callable, NoReturn
import torch import torch
...@@ -77,6 +78,12 @@ def _lazy_init() -> None: ...@@ -77,6 +78,12 @@ def _lazy_init() -> None:
if not has_deep_gemm(): if not has_deep_gemm():
return return
# Set up deep_gemm cache path
DEEP_GEMM_JIT_CACHE_ENV_NAME = 'DG_JIT_CACHE_DIR'
if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
envs.VLLM_CACHE_ROOT, "deep_gemm")
_dg = importlib.import_module("deep_gemm") _dg = importlib.import_module("deep_gemm")
_fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt", _fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment