Unverified Commit 2e6bc468 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Startup] Make DeepGEMM warmup scale with max-num-batched-tokens (#24693)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent fcba05c4
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
from tqdm import tqdm from tqdm import tqdm
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.parallel_state import get_dp_group
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
compute_aligned_M, deep_gemm_block_shape) compute_aligned_M, deep_gemm_block_shape)
...@@ -131,11 +132,9 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, ...@@ -131,11 +132,9 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor,
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set() GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set()
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor, def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
w2: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, w1_scale: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor, num_topk: int, max_tokens: int):
w2_scale: torch.Tensor,
num_topk: int):
if (w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE if (w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE): and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE):
return return
...@@ -147,9 +146,13 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor, ...@@ -147,9 +146,13 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor,
num_experts = w1.size(0) num_experts = w1.size(0)
device = w1.device device = w1.device
# Assumes all ranks have the same max_num_batched_tokens
max_tokens_across_dp = get_dp_group().world_size * max_tokens
max_tokens = min(max_tokens_across_dp, envs.VLLM_FUSED_MOE_CHUNK_SIZE)
# This is the maximum GroupedGemm M size that we expect to run # This is the maximum GroupedGemm M size that we expect to run
# the grouped_gemm with. # the grouped_gemm with.
MAX_M = compute_aligned_M(envs.VLLM_FUSED_MOE_CHUNK_SIZE, MAX_M = compute_aligned_M(max_tokens,
num_topk, num_topk,
num_experts, num_experts,
block_m, block_m,
...@@ -201,7 +204,8 @@ def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int): ...@@ -201,7 +204,8 @@ def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int):
_deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens) _deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens)
def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module): def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module,
max_tokens: int):
dg_modules = [ dg_modules = [
m for m in model.modules() m for m in model.modules()
if _fused_moe_grouped_gemm_may_use_deep_gemm(m) if _fused_moe_grouped_gemm_may_use_deep_gemm(m)
...@@ -211,9 +215,9 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module): ...@@ -211,9 +215,9 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module):
w13, w13_scale, w2, w2_scale, num_topk = ( w13, w13_scale, w2, w2_scale, num_topk = (
_extract_data_from_fused_moe_module(dgm)) _extract_data_from_fused_moe_module(dgm))
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
w13, w2, w13_scale, w2_scale, num_topk) w13, w2, w13_scale, w2_scale, num_topk, max_tokens)
def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int): def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
deepgemm_fp8_gemm_nt_warmup(model, max_tokens) deepgemm_fp8_gemm_nt_warmup(model, max_tokens)
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model) deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment