Commit d2b52805 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-ori

parents 9a521c23 5438967f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path
def pytest_addoption(parser):
"""Add custom command line options."""
parser.addoption("--config-list-file",
default="configs/models-small.txt",
help="File containing list of config files to test")
parser.addoption("--tp-size",
default=1,
type=int,
help="Tensor parallel size")
def pytest_generate_tests(metafunc):
"""Generate test parameters from config files."""
if "config_filename" in metafunc.fixturenames:
config_list_file = metafunc.config.getoption("--config-list-file")
tp_size = metafunc.config.getoption("--tp-size")
# Handle both relative and absolute paths
config_list_path = Path(config_list_file)
if not config_list_path.is_absolute():
# If relative, try relative to test directory first
test_dir_path = Path(__file__).parent / config_list_file
if test_dir_path.exists():
config_list_path = test_dir_path
else:
# Try relative to current working directory
config_list_path = Path.cwd() / config_list_file
print(f"Looking for config list at: {config_list_path}")
config_files = []
if config_list_path.exists():
# Determine config directory (same directory as the list file)
config_dir = config_list_path.parent
with open(config_list_path) as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
config_path = config_dir / line
print(f"Checking config file: {config_path}")
if config_path.exists():
config_files.append(config_path)
print(f" ✓ Found: {config_path}")
else:
print(f" ✗ Missing: {config_path}")
else:
print(f"Config list file not found: {config_list_path}")
# Generate test parameters
if config_files:
metafunc.parametrize(["config_filename", "tp_size"],
[(config_file, int(tp_size))
for config_file in config_files],
ids=[
f"{config_file.stem}-tp{tp_size}"
for config_file in config_files
])
else:
print("No config files found, test will be skipped")
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Isolated GSM8K evaluation script for vLLM serve endpoint.
"""
import argparse
import ast
import asyncio
import json
import os
import time
from collections.abc import Generator
from typing import Optional, Union
import aiohttp
import numpy as np
import regex as re
import requests
from tqdm.asyncio import tqdm
INVALID = -9999999
def download_and_cache_file(url: str, filename: Optional[str] = None) -> str:
"""Download and cache a file from a URL."""
if filename is None:
filename = os.path.join("/tmp", url.split("/")[-1])
if os.path.exists(filename):
return filename
print(f"Downloading from {url} to {filename}")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(filename, "wb") as f:
for chunk in response.iter_content(chunk_size=1024):
f.write(chunk)
return filename
def load_gsm8k_data() -> tuple[list[dict], list[dict]]:
"""Load GSM8K train and test data"""
train_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl"
test_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
train_file = download_and_cache_file(train_url)
test_file = download_and_cache_file(test_url)
train_data = list(read_jsonl(train_file))
test_data = list(read_jsonl(test_file))
return train_data, test_data
def read_jsonl(filename: str) -> Generator[dict, None, None]:
"""Read a JSONL file."""
with open(filename) as fin:
for line in fin:
if not line.startswith("#"):
yield json.loads(line)
def get_answer_value(answer_str: str) -> int:
"""Extract the numerical answer from the response."""
answer_str = answer_str.replace(",", "")
numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
async def call_vllm_api(session: aiohttp.ClientSession,
prompt: str,
temperature: float,
max_tokens: int,
stop: Optional[list[str]] = None,
url: Optional[str] = None,
seed: Optional[int] = None) -> str:
"""Call vLLM's OpenAI-compatible completions endpoint."""
data = {
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_tokens,
"stop": stop,
}
if seed is not None:
data["seed"] = seed
try:
async with session.post(f"{url}/v1/completions",
json=data) as response:
response.raise_for_status()
result = await response.json()
return result["choices"][0]["text"]
except Exception as e:
print(f"Error calling vLLM API: {e}")
return ""
def evaluate_gsm8k(num_questions: int = 1319,
num_shots: int = 5,
max_tokens: int = 256,
host: str = "http://127.0.0.1",
port: int = 8000,
temperature: float = 0.0,
seed: Optional[int] = 42) -> dict[str, Union[float, int]]:
"""
Evaluate GSM8K accuracy using vLLM serve endpoint.
Returns dict with accuracy, invalid_rate, latency, etc.
"""
base_url = f"{host}:{port}"
# Load GSM8K train and test data
train_data, test_data = load_gsm8k_data()
# Limit to available test questions
num_questions = min(num_questions, len(test_data))
# Build few-shot examples from train split (like lm-eval does)
few_shot_examples = ""
for i in range(num_shots):
few_shot_examples += (f"Question: {train_data[i]['question']}\n"
f"Answer: {train_data[i]['answer']}\n\n")
# Prepare test questions and labels from test split
questions = []
labels = []
for i in range(num_questions):
questions.append(f"Question: {test_data[i]['question']}\nAnswer:")
labels.append(get_answer_value(test_data[i]["answer"]))
assert all(label != INVALID for label in labels), "Some labels are invalid"
# Run evaluation
async def run_async_evaluation():
states: list[str] = [""] * num_questions
async def get_answer(session: aiohttp.ClientSession, i: int) -> str:
prompt = few_shot_examples + questions[i]
answer = await call_vllm_api(
session=session,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
stop=["Question", "Assistant:", "<|separator|>"],
url=base_url,
seed=seed,
)
states[i] = answer
return answer
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(
total=600)) as session:
tasks = [get_answer(session, i) for i in range(num_questions)]
await tqdm.gather(*tasks, desc="Evaluating")
return states
print(f"Running GSM8K evaluation: {num_questions} questions, "
f"{num_shots}-shot")
tic = time.perf_counter()
states = asyncio.run(run_async_evaluation())
latency = time.perf_counter() - tic
# Compute metrics
preds = [get_answer_value(state) for state in states]
accuracy = np.mean(np.array(preds) == np.array(labels))
invalid_rate = np.mean(np.array(preds) == INVALID)
result = {
"accuracy": accuracy,
"invalid_rate": invalid_rate,
"latency": latency,
"questions_per_second": num_questions / latency,
"num_questions": num_questions,
"num_shots": num_shots,
"max_tokens": max_tokens,
"timestamp": time.time(),
}
return result
def main() -> None:
parser = argparse.ArgumentParser(
description="GSM8K evaluation for vLLM serve")
parser.add_argument("--num-shots",
type=int,
default=5,
help="Number of few-shot examples")
parser.add_argument("--num-questions",
type=int,
default=1319,
help="Number of questions to evaluate")
parser.add_argument("--max-tokens",
type=int,
default=256,
help="Max tokens for generation")
parser.add_argument("--host",
type=str,
default="http://127.0.0.1",
help="Host URL")
parser.add_argument("--port", type=int, default=8000, help="Port number")
parser.add_argument("--temperature",
type=float,
default=0.0,
help="Temperature for generation")
parser.add_argument("--seed",
type=int,
default=42,
help="Random seed for reproducibility")
parser.add_argument("--save-results",
type=str,
help="Save results to JSON file")
args = parser.parse_args()
result = evaluate_gsm8k(
num_questions=args.num_questions,
num_shots=args.num_shots,
max_tokens=args.max_tokens,
host=args.host,
port=args.port,
temperature=args.temperature,
seed=args.seed,
)
# Print results to terminal
print("\nResults:")
print(f"Accuracy: {result['accuracy']:.3f}")
print(f"Invalid responses: {result['invalid_rate']:.3f}")
print(f"Total latency: {result['latency']:.3f} s")
print(f"Questions per second: {result['questions_per_second']:.3f}")
# Optional file saving
if args.save_results:
with open(args.save_results, "w") as f:
json.dump(result, f, indent=2)
print(f"Results saved to {args.save_results}")
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
GSM8K evaluation using vLLM server and isolated GSM8K script.
Replacement for lm-eval-harness with better performance and control.
Usage:
pytest -s -v test_gsm8k_correctness.py \
--config-list-file=configs/models-small.txt \
--tp-size=1
"""
import yaml
from tests.utils import RemoteOpenAIServer
from .gsm8k_eval import evaluate_gsm8k
RTOL = 0.08 # Relative tolerance for accuracy comparison
def launch_gsm8k_eval(eval_config, server_url, tp_size):
"""Launch GSM8K evaluation using our isolated script."""
# Extract host and port from server URL
if "://" in server_url:
server_url = server_url.split("://")[1]
host_port = server_url.split("/")[0] # Remove path if present
if ":" in host_port:
host, port = host_port.split(":")
port = int(port)
else:
host = host_port
port = 8000
# Add http:// prefix if not present
if not host.startswith("http"):
host = f"http://{host}"
# Run GSM8K evaluation
results = evaluate_gsm8k(
num_questions=eval_config["num_questions"],
num_shots=eval_config["num_fewshot"],
host=host,
port=port,
)
return results
def test_gsm8k_correctness_param(config_filename, tp_size):
"""Test GSM8K correctness for a given model configuration."""
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
# Server arguments
server_args = [
"--max-model-len",
str(eval_config.get("max_model_len", 4096)),
"--enforce-eager",
"--trust-remote-code",
"--tensor-parallel-size",
str(tp_size),
]
# Launch server and run evaluation
with RemoteOpenAIServer(eval_config["model_name"],
server_args,
max_wait_seconds=480) as remote_server:
server_url = remote_server.url_for("v1")
results = launch_gsm8k_eval(eval_config, server_url, tp_size)
# Check accuracy against threshold
measured_accuracy = results["accuracy"]
expected_accuracy = eval_config["accuracy_threshold"]
print(f"GSM8K Results for {eval_config['model_name']}:")
print(f" Accuracy: {measured_accuracy:.3f}")
print(f" Expected: {expected_accuracy:.3f}")
print(f" Questions: {results['num_questions']}")
print(f" Invalid rate: {results['invalid_rate']:.3f}")
print(f" Latency: {results['latency']:.1f}s")
print(f" QPS: {results['questions_per_second']:.1f}")
# Verify accuracy is within tolerance
assert measured_accuracy >= expected_accuracy - RTOL, (
f"Accuracy too low: {measured_accuracy:.3f} < "
f"{expected_accuracy:.3f} - {RTOL:.3f}")
print(f"✅ GSM8K test passed for {eval_config['model_name']}")
......@@ -80,6 +80,9 @@ def test_env(
m.setenv(STR_BACKEND_ENV_VAR, name)
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
if name == "FLASHINFER" and not use_v1:
pytest.skip("FlashInfer backend is only available on V1 engine")
if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
......
......@@ -702,6 +702,94 @@ def test_swap_blocks_mla(
f"{dst} in dst_cache.")
@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
block_size, num_blocks,
max_seq_len, batch_size, dtype,
kv_cache_dtype, device):
entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
seq_len_tensor = torch.randint(0,
max_seq_len + 1, (batch_size, ),
device=device)
total_tokens = seq_len_tensor.sum()
cu_seq_lens = torch.empty((batch_size + 1),
dtype=torch.int32,
device=device)
cu_seq_lens[0] = 0
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
print("seq_len_tensor", seq_len_tensor)
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
block_table = torch.empty((batch_size, num_blocks),
dtype=torch.int32,
device=device)
for b in range(batch_size):
perm = torch.randperm(num_blocks, device=device)
block_table[b, :] = perm
dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device)
expected_batches = []
for b in range(batch_size):
s = seq_len_tensor[b]
if s == 0:
continue
tot = tot_blocks_tensor[b]
blocks = block_table[b, :tot].tolist()
gathered_rows = []
for i in range(tot - 1):
block_data = src_cache[blocks[i]]
if kv_cache_dtype == "fp8":
dequantized_block = torch.empty_like(block_data, dtype=dtype)
ops.convert_fp8(dequantized_block, block_data, scale.item())
gathered_rows.append(dequantized_block)
else:
gathered_rows.append(block_data)
remaining = s - (tot - 1) * block_size
last_block_data = src_cache[blocks[-1], :remaining, :]
if kv_cache_dtype == "fp8":
dequantized_last_block = torch.empty_like(last_block_data,
dtype=dtype)
ops.convert_fp8(dequantized_last_block, last_block_data,
scale.item())
gathered_rows.append(dequantized_last_block)
else:
gathered_rows.append(last_block_data)
batch_expected = torch.cat(gathered_rows, dim=0)
expected_batches.append(batch_expected)
expected = torch.cat(expected_batches, dim=0)
opcheck(
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
(src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
scale, None),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, kv_cache_dtype,
scale, None)
torch.testing.assert_close(dst, expected)
@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
......@@ -713,9 +801,9 @@ def test_swap_blocks_mla(
["auto"]) # You can also test "fp8" if needed.
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
num_blocks, max_seq_len, batch_size, dtype,
kv_cache_dtype, device):
def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
num_blocks, max_seq_len, batch_size, dtype,
kv_cache_dtype, device):
entry_size = kv_lora_rank + qk_rope_head_dim
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
......@@ -765,12 +853,12 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
expected = torch.cat(expected_batches, dim=0)
opcheck(
torch.ops._C_cache_ops.gather_cache,
torch.ops._C_cache_ops.cp_gather_cache,
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
ops.cp_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
torch.testing.assert_close(dst, expected)
......
......@@ -137,9 +137,7 @@ def test_flashinfer_decode_with_paged_kv(
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
use_tensor_cores=(
(num_query_heads//num_kv_heads) > 4)
)
use_tensor_cores=True)
wrapper.plan(
kv_indptr,
kv_indices,
......@@ -411,7 +409,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
use_tensor_cores = (num_query_heads // num_kv_heads) > 4
use_tensor_cores = True
kv_cache_dtype = torch.float8_e4m3fn
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
......
......@@ -6,28 +6,19 @@ import flashinfer
import pytest
import torch
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from vllm.platforms import current_platform
from vllm.utils import round_up
if not current_platform.is_device_capability(100):
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
allow_module_level=True)
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# KV Cache Layout for TRT-LLM
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
MAX_Q_LEN = 1024
MAX_KV_LEN = 4096
BATCH_SIZES = [4, 12]
NUM_HEADS = [(16, 16), (40, 8)]
HEAD_SIZES = [128]
BLOCK_SIZES = [16]
KV_LAYOUTS = ["HND"]
DTYPES = [torch.bfloat16]
KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
SOFT_CAPS = [None, 50.0]
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def to_float8(x, dtype=torch.float8_e4m3fn):
......@@ -39,42 +30,61 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
return x_scl_sat.to(dtype), scale.float().reciprocal()
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
DTYPE = [torch.bfloat16]
QUANT_DTYPES = [
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(None, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
]
BATCH_SIZE = [4, 12]
MAX_SEQ_LENS = [(1024, 4096)]
NUM_HEADS = [(64, 8), (40, 8)]
HEAD_SIZE = [128]
KV_LAYOUT = ["HND"] # currently only HND is supported
BLOCK_SIZE = [16]
SOFT_CAP = [None, 50.0]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES)
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
@pytest.mark.parametrize("head_size", HEAD_SIZE)
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("soft_cap", SOFT_CAP)
@torch.inference_mode
def test_flashinfer_trtllm_decode_with_baseline(
dtype: torch.dtype,
quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype],
Optional[torch.dtype]],
batch_size: int,
max_seq_lens: tuple[int, int],
num_heads: tuple[int, int],
head_size: int,
block_size: int,
kv_layout: str,
dtype: torch.dtype,
kv_cache_dtype: Optional[torch.dtype],
block_size: int,
soft_cap: Optional[float],
) -> None:
kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
torch.set_default_device("cuda")
current_platform.seed_everything(0)
kv_lens = torch.randint(1, MAX_KV_LEN, (batch_size, ), dtype=torch.int32)
kv_lens[-1] = MAX_KV_LEN
max_kv_len = torch.max(kv_lens).item()
num_seqs = len(kv_lens)
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
q_quant_dtype = q_quant_dtype or dtype
kv_quant_dtype = kv_quant_dtype or dtype
o_quant_dtype = o_quant_dtype or dtype
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
_, max_kv_len = max_seq_lens
scale = head_size**-0.5
num_qo_heads, num_kv_heads = num_heads
assert num_qo_heads % num_kv_heads == 0
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
sm_scale = float(1.0 / (head_size**0.5))
kv_cache_shape = None
if kv_layout == "NHD":
......@@ -83,23 +93,40 @@ def test_flashinfer_trtllm_decode_with_baseline(
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
kv_scale = 1.0
if kv_cache_dtype is current_platform.fp8_dtype():
key_value_cache, kv_scale = to_float8(key_value_cache,
current_platform.fp8_dtype())
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
else:
q_scale = 1.0
ref_query = query
kv_lens = torch.randint(1, max_kv_len, (batch_size, ), dtype=torch.int32)
kv_lens[-1] = max_kv_len
seq_lens = kv_lens
max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
else:
kv_scale = 1.0
ref_kv_cache = kv_cache
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq),
(batch_size, max_num_blocks_per_seq),
dtype=torch.int32)
k_scale = v_scale = kv_scale
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
for i in range(batch_size):
seq_len = seq_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
......@@ -112,103 +139,120 @@ def test_flashinfer_trtllm_decode_with_baseline(
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
# Baseline Decode
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
use_tensor_cores=((num_query_heads // num_kv_heads) > 4))
workspace_buffer, kv_layout, use_tensor_cores=True)
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_qo_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
sm_scale=scale,
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=kv_cache_dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap)
output = torch.empty(query.shape, dtype=dtype)
wrapper.run(query,
key_value_cache,
k_scale=k_scale,
v_scale=v_scale,
out=output)
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0
o_sf_scale = None
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
# TRTLLM Decode
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
output_trtllm = torch.empty(query.shape, dtype=dtype)
if o_quant_dtype == FP4_DTYPE:
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
dtype=torch.uint8),
torch.empty((round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4)),
dtype=torch.float8_e4m3fn),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=query.contiguous(),
kv_cache=key_value_cache,
query=query,
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=kv_lens_tensor,
max_seq_len=max_kv_len,
bmm1_scale=k_scale * scale,
bmm2_scale=v_scale,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
o_sf_scale=o_sf_scale,
out=output_trtllm,
)
if o_quant_dtype == FP8_DTYPE:
output_trtllm = output_trtllm.to(dtype) * o_scale
elif o_quant_dtype == FP4_DTYPE:
output_trtllm.data = output_trtllm.data.reshape(
-1, query.shape[1] * query.shape[2] // 2)
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
output_trtllm.scale,
o_sf_scale, dtype,
query.device)
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 3e-1, 1e0
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
else:
rtol, atol = 1e-2, 2e-2
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - output_trtllm))}"
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES)
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("head_size", HEAD_SIZE)
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("soft_cap", [None])
@torch.inference_mode
def test_flashinfer_trtllm_prefill_with_baseline(
dtype: torch.dtype,
quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype],
Optional[torch.dtype]],
batch_size: int,
max_seq_lens: tuple[int, int],
num_heads: tuple[int, int],
head_size: int,
block_size: int,
kv_layout: str,
dtype: torch.dtype,
kv_cache_dtype: Optional[torch.dtype],
block_size: int,
soft_cap: Optional[float],
) -> None:
kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
if dtype != kv_cache_dtype:
pytest.skip(f"Not supported dtype({dtype}) with "
"kv_cache_dtype({kv_cache_dtype})")
torch.set_default_device("cuda")
current_platform.seed_everything(0)
q_lens = torch.randint(1, MAX_Q_LEN, (batch_size, ), dtype=torch.int32)
q_lens[-1] = MAX_Q_LEN
max_q_len = torch.max(q_lens).item()
q_indptr = torch.cat([
torch.tensor([0], dtype=torch.int32),
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
])
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
q_quant_dtype = q_quant_dtype or dtype
kv_quant_dtype = kv_quant_dtype or dtype
o_quant_dtype = o_quant_dtype or dtype
kv_lens = torch.randint(0, MAX_KV_LEN, (batch_size, ), dtype=torch.int32)
kv_lens[-1] = MAX_KV_LEN
if q_quant_dtype != kv_quant_dtype:
pytest.skip("Skipped mixed QKV dtypes for prefill")
seq_lens = kv_lens + q_lens
max_seq_len = torch.max(seq_lens).item()
num_seqs = len(seq_lens)
max_q_len, max_kv_len = max_seq_lens
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
num_qo_heads, num_kv_heads = num_heads
assert num_qo_heads % num_kv_heads == 0
scale = head_size**-0.5
query = torch.randn(torch.sum(q_lens).item(),
num_query_heads,
head_size,
dtype=dtype)
sm_scale = float(1.0 / (head_size**0.5))
kv_cache_shape = None
if kv_layout == "NHD":
......@@ -217,22 +261,49 @@ def test_flashinfer_trtllm_prefill_with_baseline(
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
kv_scale = 1.0
if kv_cache_dtype is current_platform.fp8_dtype():
key_value_cache, kv_scale = to_float8(key_value_cache,
current_platform.fp8_dtype())
q_lens = torch.randint(1, max_q_len, (batch_size, ), dtype=torch.int32)
q_lens[-1] = max_q_len
q_indptr = torch.cat([
torch.tensor([0], dtype=torch.int32),
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
])
query = torch.randn(torch.sum(q_lens).item(),
num_qo_heads,
head_size,
dtype=dtype)
if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
else:
q_scale = 1.0
ref_query = query
kv_lens = torch.randint(0, max_kv_len, (batch_size, ), dtype=torch.int32)
kv_lens[-1] = max_kv_len
seq_lens = kv_lens + q_lens
max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
else:
kv_scale = 1.0
ref_kv_cache = kv_cache
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq),
(batch_size, max_num_blocks_per_seq),
dtype=torch.int32)
k_scale = v_scale = kv_scale
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
for i in range(batch_size):
seq_len = seq_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
......@@ -246,48 +317,81 @@ def test_flashinfer_trtllm_prefill_with_baseline(
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
# Baseline Prefill
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout)
wrapper.plan(q_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_qo_heads,
num_kv_heads,
head_size,
block_size,
causal=True,
sm_scale=scale,
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=kv_cache_dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap)
output = torch.empty(query.shape, dtype=dtype)
wrapper.run(query,
key_value_cache,
k_scale=k_scale,
v_scale=v_scale,
out=output)
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0
o_sf_scale = None
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
# TRTLLM Prefill
if o_quant_dtype == FP4_DTYPE:
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
dtype=torch.uint8),
torch.empty((round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4)),
dtype=torch.float8_e4m3fn),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
# TRTLLM Decode
output_trtllm = torch.empty(query.shape, dtype=dtype)
flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=query.contiguous(),
kv_cache=key_value_cache,
query=query,
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=seq_lens,
max_q_len=max_q_len,
max_kv_len=max_seq_len,
bmm1_scale=k_scale * scale,
bmm2_scale=v_scale,
batch_size=num_seqs,
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
batch_size=batch_size,
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
o_sf_scale=o_sf_scale,
out=output_trtllm,
)
if o_quant_dtype == FP8_DTYPE:
output_trtllm = output_trtllm.to(dtype) * o_scale
elif o_quant_dtype == FP4_DTYPE:
output_trtllm.data = output_trtllm.data.reshape(
-1, query.shape[1] * query.shape[2] // 2)
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
output_trtllm.scale,
o_sf_scale, dtype,
query.device)
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 4e-1, 1e0
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
else:
rtol, atol = 1e-2, 1e-2
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - output_trtllm))}"
......@@ -13,11 +13,17 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
from vllm.triton_utils import triton
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
def cal_diff(x: torch.Tensor,
y: torch.Tensor,
name: str,
use_fp8: bool = False) -> None:
x, y = x.double(), y.double()
cos_diff = 1 - 2 * (x * y).sum().item() / max(
(x * x + y * y).sum().item(), 1e-12)
assert cos_diff < 1e-5
if (use_fp8):
assert cos_diff < 1e-4
else:
assert cos_diff < 1e-5
FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
if not is_flashmla_supported()[0] else "FlashMLA is supported"
......@@ -27,7 +33,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
reason=FLASH_MLA_UNSUPPORTED_REASON)
@pytest.mark.parametrize("b", [128])
@pytest.mark.parametrize("s_q", [1, 2])
@pytest.mark.parametrize("mean_sk", [4096, 8192])
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
@pytest.mark.parametrize("h_q", [16, 32, 64, 128])
@pytest.mark.parametrize("h_kv", [1])
@pytest.mark.parametrize("d", [576])
......@@ -35,20 +41,26 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
@pytest.mark.parametrize("block_size", [64])
@pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("torch_dtype",
[torch.bfloat16, torch.float16, torch.float8_e4m3fn])
@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
varlen, dtype):
varlen, torch_dtype):
device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
if torch_dtype == torch.float8_e4m3fn:
init_dtype = torch.bfloat16
else:
init_dtype = torch_dtype
torch.set_default_dtype(init_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
f"{d=}, {dv=}, {causal=}, {varlen=}, {dtype=}")
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}")
use_fp8 = torch_dtype == torch.float8_e4m3fn
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
......@@ -71,6 +83,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, s_q * h_q // h_kv, h_kv)
init_dtype = q.dtype
if use_fp8:
fp8_dtype = torch.float8_e4m3fn
descale_q = torch.ones((1), dtype=torch.float32)
descale_k = torch.ones((1), dtype=torch.float32)
q = q.to(fp8_dtype)
blocked_k = blocked_k.to(fp8_dtype)
blocked_v = blocked_v.to(fp8_dtype)
else:
descale_q = None
descale_k = None
def flash_mla():
return flash_mla_with_kvcache(
q,
......@@ -81,6 +106,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
tile_scheduler_metadata,
num_splits,
causal=causal,
descale_q=descale_q,
descale_k=descale_k,
)
def scaled_dot_product_attention(query, key, value, is_causal=False):
......@@ -104,29 +131,35 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
return attn_weight @ value, lse
def ref_mla():
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
blocked_k_ = (blocked_k.to(torch.float) *
descale_k).to(init_dtype) if use_fp8 else blocked_k
blocked_v_ = (blocked_v.to(torch.float) *
descale_k).to(init_dtype) if use_fp8 else blocked_v
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
ref_O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
out_i, lse_i = scaled_dot_product_attention(
q_[i].transpose(0, 1),
blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
is_causal=causal,
)
out[i] = ref_O.transpose(0, 1)
lse[i] = LSE
out[i] = out_i.transpose(0, 1)
lse[i] = lse_i
return out, lse
out_flash, lse_flash = flash_mla()
out_torch, lse_torch = ref_mla()
cal_diff(out_flash, out_torch, "out")
cal_diff(out_flash, out_torch, "out", use_fp8)
cal_diff(lse_flash, lse_torch, "lse")
t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d +
b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} "
f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s")
bytes = (total_seqlens * h_kv * d +
b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (
b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,",
f"{bytes / 10 ** 6 / t:.0f} GB/s")
......@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
dg_available = has_deep_gemm()
......@@ -226,8 +226,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
reason="Not E8M0 scale MOE")
@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
monkeypatch):
......
......@@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale,
'ab_strides1': moe_tensors.ab_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides1': moe_tensors.c_strides1,
'c_strides2': moe_tensors.c_strides2,
'per_act_token': per_act_token,
'a1_scale': None #moe_tensors.a_scale
}
......@@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8(
topk_ids[0][1] = 1
workspace13_shape = (m * topk, max(2 * n, k))
workspace2_shape = (m * topk, n)
output_shape = (m * topk, k)
workspace2_shape = (m * topk, max(n, k))
output_shape = (m, k)
workspace13 = torch.empty(prod(workspace13_shape),
device="cuda",
......@@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8(
expert_map[start:end] = list(range(num_local_experts))
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
torch.float8_e4m3fn,
......@@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8(
func = lambda output: run_cutlass_moe_fp8(
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
per_act_token, per_out_channel, False)
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
workspace13, workspace2, None, mt.a.dtype, per_act_token,
per_out_channel, False, topk_weights)
workspace13.random_()
output_random_workspace = torch.empty(output_shape,
......
......@@ -20,9 +20,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
is_deep_gemm_supported)
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch
from .utils import make_test_weights
......@@ -370,9 +370,10 @@ NUM_EXPERTS = [32]
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@multi_gpu_test(num_gpus=2)
@requires_deep_ep
@requires_deep_gemm
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
@pytest.mark.skipif(is_deep_gemm_e8m0_used(),
reason="Skipping test for Blackwell DeepGEMM")
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
topk: int, world_dp_size: tuple[int, int]):
......@@ -427,9 +428,10 @@ USE_FP8_DISPATCH = [False]
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@pytest.mark.parametrize("block_size", [[128, 128]])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@multi_gpu_test(num_gpus=2)
@requires_deep_ep
@requires_deep_gemm
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
@pytest.mark.skipif(is_deep_gemm_e8m0_used(),
reason="Skipping test for Blackwell DeepGEMM")
def test_ll_deepep_deepgemm_moe(
mnk: tuple[int, int, int],
......
......@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep
from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch
if has_deep_ep():
......@@ -411,6 +412,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@multi_gpu_test(num_gpus=2)
@requires_deep_ep
def test_deep_ep_moe(
dtype: torch.dtype,
......@@ -459,6 +461,7 @@ USE_FP8_DISPATCH = [True, False]
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@multi_gpu_test(num_gpus=2)
@requires_deep_ep
def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
num_experts: int, topk: int,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import pytest
import torch
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8,
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
swap_w13_to_w31)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
input_to_float8)
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if not has_flashinfer_cutlass_fused_moe(
) or not current_platform.has_device_capability(100):
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
allow_module_level=True)
NUM_EXPERTS = [16]
TOP_KS = [1]
MNK_FACTORS = [
(256, 8192, 5120),
(256, 4096, 5120),
(127, 8192, 5120),
(127, 4096, 5120),
(10, 8192, 5120),
(10, 4096, 5120),
(1, 8192, 5120),
(1, 4096, 5120),
]
vllm_config = VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def quant_fp8_per_tensor_batches(a):
num_batches = a.size(0)
a_quant = []
a_scales = []
for i in range(num_batches):
a_fp8, a_global_sf = input_to_float8(a[i])
a_global_sf = 1.0 / a_global_sf
a_quant.append(a_fp8)
a_scales.append(a_global_sf)
result_a_quant = torch.stack(a_quant)
result_a_scales = torch.stack(a_scales)
return result_a_quant, result_a_scales
@dataclass
class TestData:
hidden_states: torch.Tensor
w13_quantized: torch.Tensor
w2_quantized: torch.Tensor
a1_scale: torch.Tensor
a2_scale: torch.Tensor
w13_weight_scale: torch.Tensor
w2_weight_scale: torch.Tensor
layer: torch.nn.Module
@staticmethod
def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
reorder: bool) -> "TestData":
hidden_states = torch.randn(
(m, k), device="cuda", dtype=torch.bfloat16) / 10
w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16)
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
# Scale to fp8
_, a1_scale = input_to_float8(hidden_states)
a1_scale = 1.0 / a1_scale
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(
dtype=torch.float32)
w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
layer = torch.nn.Module()
layer.w13_weight = w13_quantized.clone()
layer.w2_weight = w2_quantized.clone()
layer.w13_input_scale = a1_scale
layer.w2_input_scale = a2_scale
layer.w13_weight_scale = w13_weight_scale
layer.w2_weight_scale = w2_weight_scale
register_moe_scaling_factors(layer)
# flashinfer expects swapped rows for w13
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if reorder:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
layer.w2_weight)
layer.custom_routing_function = Llama4MoE.custom_routing_function
layer.intermediate_size_per_partition = n
layer.ep_rank = 0
layer.local_num_experts = e
return TestData(
hidden_states=hidden_states,
w13_quantized=w13_quantized,
w2_quantized=w2_quantized,
a1_scale=a1_scale,
a2_scale=a2_scale,
w13_weight_scale=w13_weight_scale,
w2_weight_scale=w2_weight_scale,
layer=layer,
)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
def test_flashinfer_per_tensor_moe_fp8_no_graph(
m: int,
n: int,
k: int,
e: int,
topk: int,
monkeypatch,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=td.hidden_states,
router_logits=score,
use_grouped_topk=False,
top_k=topk,
renormalize=False,
custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax")
output = fused_experts(
td.hidden_states,
td.w13_quantized,
td.w2_quantized,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
activation="silu",
use_fp8_w8a8=True,
per_channel_quant=False,
global_num_experts=e,
expert_map=None,
w1_scale=td.w13_weight_scale,
w2_scale=td.w2_weight_scale,
a1_scale=td.a1_scale,
a2_scale=td.a2_scale,
apply_router_weight_on_input=True,
)
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
layer=td.layer,
hidden_states=td.hidden_states,
router_logits=score,
routing_bias=None,
global_num_experts=e,
top_k=topk,
num_expert_group=None,
topk_group=None,
apply_router_weight_on_input=True)
torch.testing.assert_close(output,
flashinfer_output,
atol=5.5e-2,
rtol=1e-2)
@pytest.mark.skip(
"Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472"
)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
def test_flashinfer_cutlass_moe_fp8_no_graph(
m: int,
n: int,
k: int,
e: int,
topk: int,
monkeypatch,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=td.hidden_states,
router_logits=score,
use_grouped_topk=False,
top_k=topk,
renormalize=False,
custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax")
output = fused_experts(
td.hidden_states,
td.w13_quantized,
td.w2_quantized,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
activation="silu",
use_fp8_w8a8=True,
per_channel_quant=False,
global_num_experts=e,
expert_map=None,
w1_scale=td.w13_weight_scale,
w2_scale=td.w2_weight_scale,
a1_scale=td.a1_scale,
a2_scale=td.a2_scale,
apply_router_weight_on_input=True,
)
td.layer.dp_size = 1
flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
td.hidden_states,
td.layer,
topk_weights,
topk_ids,
activation="silu",
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=True,
)
torch.testing.assert_close(output,
flashinfer_cutlass_output,
atol=5.5e-2,
rtol=1e-2)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the MoE grouped topk kernel
Run `pytest tests/kernels/moe/test_grouped_topk.py`.
"""
import pytest
import torch
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_grouped_topk,
grouped_topk)
from vllm.platforms import current_platform
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="This test is skipped on non-CUDA platform.")
@pytest.mark.parametrize("n_token", [1, 33, 64])
@pytest.mark.parametrize("n_hidden", [1024, 2048])
@pytest.mark.parametrize("n_expert", [16])
@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("num_expert_group", [8])
@pytest.mark.parametrize("topk_group", [2])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
@pytest.mark.parametrize("dtype",
[torch.float16, torch.bfloat16, torch.float32])
def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int,
n_hidden: int, n_expert: int, topk: int,
renormalize: bool, num_expert_group: int,
topk_group: int, scoring_func: str,
routed_scaling_factor: float, dtype: torch.dtype):
current_platform.seed_everything(0)
hidden_states = torch.randn((n_token, n_hidden),
dtype=dtype,
device="cuda")
gating_output = torch.randn((n_token, n_expert),
dtype=dtype,
device="cuda")
e_score_correction_bias = torch.randn((n_expert, ),
dtype=torch.float32,
device="cuda")
with monkeypatch.context() as m:
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
baseline_topk_weights, baseline_topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias)
test_topk_weights, test_topk_ids = fused_grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias)
if renormalize:
torch.testing.assert_close(baseline_topk_weights,
test_topk_weights,
atol=2e-2,
rtol=0)
torch.testing.assert_close(baseline_topk_ids,
test_topk_ids,
atol=0,
rtol=0)
......@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from ...utils import multi_gpu_test
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
reference_moe_impl,
run_modular_kernel)
......@@ -162,6 +163,7 @@ def is_nyi_config(config: Config) -> bool:
product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [2])
@multi_gpu_test(num_gpus=2)
@meets_multi_gpu_requirements
def test_modular_kernel_combinations_multigpu(
k: int, n: int, e: int, dtype: torch.dtype,
......
......@@ -429,11 +429,11 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[...,
0:-128],
requires_grad=False)
torch.cuda.empty_cache()
vllm_moe.experts.w2_weight = Parameter(F.pad(
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[...,
0:-128],
requires_grad=False)
torch.cuda.synchronize()
torch.cuda.empty_cache()
# Run forward passes for both MoE blocks
......
......@@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
atol=0,
rtol=0)
# check mindice
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
# current kernel usage assumes deepgemm requires align_block_size
# when it's not provided then we don't compute m_indices (for cutlass)
if align_block_size is not None:
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
# check permuted_hidden_states, only valid token
torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx],
permuted_hidden_states[valid_row_idx],
......
......@@ -4,15 +4,27 @@
import importlib
import importlib.metadata
from dataclasses import dataclass
from typing import Optional
import pytest
import torch
from packaging import version
from vllm.platforms import current_platform
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
"quark") is not None and version.parse(
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
) and current_platform.is_device_capability(100)
if TRTLLM_GEN_MXFP4_AVAILABLE:
from flashinfer import (fp4_quantize, mxfp8_quantize,
next_positive_power_of_2,
reorder_rows_for_gated_act_gemm, shuffle_matrix_a,
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
@dataclass
class ModelCase:
......@@ -54,4 +66,410 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
output = llm.generate_greedy("Today I am in the French Alps and",
max_tokens=20)
assert output
\ No newline at end of file
assert output
def swiglu(x,
alpha: float = 1.702,
beta: float = 1.0,
limit: Optional[float] = None):
# Note we add an extra bias of 1 to the linear layer
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
if limit is not None:
x_glu = x_glu.clamp(max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
return out_glu * (x_linear + beta)
fp4_lookup_table = [
0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6
]
def mxfp4_dequantize(x, scale):
assert x.dtype == torch.uint8
x = x.view(torch.uint8).to(torch.int32)
x_unpacked = torch.zeros(*x.shape[:-1],
x.shape[-1] * 2,
dtype=torch.int32,
device=x.device)
x_unpacked[..., 0::2].copy_(x & 0xF)
x_unpacked[..., 1::2].copy_((x >> 4) & 0xF)
x_float = torch.zeros(x_unpacked.shape,
dtype=torch.float32,
device=x.device)
for i, val in enumerate(fp4_lookup_table):
x_float[x_unpacked == i] = val
scale = scale.view(torch.uint8).to(torch.int32)
scale = (scale << 23).view(torch.float32)
scale = scale.reshape(*x.shape[:-1], -1)
scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape)
return x_float * scale
def mxfp8_dequantize(x, scale):
assert x.dtype == torch.float8_e4m3fn
x_float = x.to(torch.float32)
scale = scale.view(torch.uint8).to(torch.int32)
scale = (scale << 23).view(torch.float32)
scale = scale.reshape(*x.shape[:-1], -1)
scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape)
return x_float * scale
def reference_moe(
roouting_logits,
topk,
num_experts,
hidden_states,
w13,
bias13,
w2,
bias2,
alpha,
beta,
limit,
act_type,
):
# renormalize routing
experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True)
expert_weights = torch.nn.functional.softmax(experts.values, dim=1)
expert_indices = experts.indices
t = hidden_states.clone()
# MLP #1
mlp1_weight = w13[expert_indices, ...]
mlp1_bias = bias13[expert_indices, ...]
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
if act_type == 'mxfp8':
t_quantized, t_scale = mxfp8_quantize(t.to(torch.bfloat16),
is_sf_swizzled_layout=False)
t = mxfp8_dequantize(t_quantized, t_scale)
# MLP #2
mlp2_weight = w2[expert_indices, ...]
mlp2_bias = bias2[expert_indices, ...]
t = torch.einsum("beck,bek->bec", mlp2_weight, t) + mlp2_bias
# Weighted sum of experts
t = torch.einsum("bec,be->bc", t, expert_weights)
assert t.shape == hidden_states.shape
return t.to(torch.bfloat16)
def get_tile_tokens_dim(x: torch.Tensor, top_k: int, num_experts: int):
# Number of tokens in the input tensor.
num_tokens = x.shape[0]
# Factor to account for the imbalance of the experts.
# factor equals to the
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
# - 1.0 means perfect expert distribution.
# - > 1.0 means some experts have more
# tokens than the perfect distribution.
# - < 1.0 does not make sense.
imbalance_factor = 1.3
# Calculate the number of tokens per expert
# assuming perfect distribution.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# Apply the imbalance factor.
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile
# as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
def tg_mxfp4_moe(
router_logits,
topk,
num_experts,
intermediate_size,
hidden_size,
hidden_states,
hidden_states_scale,
w13_weight,
w13_weight_scale,
w13_bias,
w2_weight,
w2_weight_scale,
w2_bias,
act_type,
alpha,
beta,
limit,
) -> torch.Tensor:
sf_block_size = 32
assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts
and w13_weight.shape[1] == intermediate_size * 2
and w13_weight.shape[2] == hidden_size // 2)
assert (w13_weight_scale.dim() == 3
and w13_weight_scale.shape[0] == num_experts
and w13_weight_scale.shape[1] == intermediate_size * 2
and w13_weight_scale.shape[2] == hidden_size // sf_block_size)
assert (w2_weight.dim() == 3 and w2_weight.shape[0] == num_experts
and w2_weight.shape[1] == hidden_size
and w2_weight.shape[2] == intermediate_size // 2)
assert (w2_weight_scale.dim() == 3
and w2_weight_scale.shape[1] == hidden_size
and w2_weight_scale.shape[2] == intermediate_size // sf_block_size)
assert (w13_bias.dim() == 2 and w13_bias.shape[0] == num_experts
and w13_bias.shape[1] == intermediate_size * 2)
assert (w2_bias.dim() == 2 and w2_bias.shape[0] == num_experts
and w2_bias.shape[1] == hidden_size)
# Swap w1 and w3 as the defenition of
# swiglu is different in the trtllm-gen
w13_weight_scale_ = w13_weight_scale.clone()
w13_weight_ = w13_weight.clone()
w13_bias_ = w13_bias.clone()
w13_weight[:, :intermediate_size, :].copy_(
w13_weight_[:, intermediate_size:, :])
w13_weight[:, intermediate_size:, :].copy_(
w13_weight_[:, :intermediate_size, :])
w13_weight_scale[:, :intermediate_size, :].copy_(
w13_weight_scale_[:, intermediate_size:, :])
w13_weight_scale[:, intermediate_size:, :].copy_(
w13_weight_scale_[:, :intermediate_size, :])
w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:])
w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size])
# Interleave the weights and scaling factors for activation
w13_weight_interleaved = []
w13_weight_scale_interleaved = []
w13_bias_interleaved = []
for i in range(num_experts):
w13_weight_interleaved.append(
reorder_rows_for_gated_act_gemm(w13_weight[i].clone()))
w13_weight_scale_interleaved.append(
reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone()))
w13_bias_interleaved.append(
reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1,
1)))
w13_weight = torch.stack(w13_weight_interleaved).reshape(
num_experts, 2 * intermediate_size, hidden_size // 2)
w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape(
num_experts, 2 * intermediate_size, hidden_size // 32)
w13_bias = torch.stack(w13_bias_interleaved).reshape(
num_experts, 2 * intermediate_size)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_shuffled = []
gemm1_scales_shuffled = []
gemm2_weights_shuffled = []
gemm2_scales_shuffled = []
gemm1_bias_shuffled = []
gemm2_bias_shuffled = []
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
for i in range(num_experts):
gemm1_weights_shuffled.append(
shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m))
gemm1_scales_shuffled.append(
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
gemm2_weights_shuffled.append(
shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m))
gemm2_scales_shuffled.append(
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
gemm1_bias_shuffled.append(
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m))
gemm2_bias_shuffled.append(
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m))
w13_weight = torch.stack(gemm1_weights_shuffled)
w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape(
num_experts, 2 * intermediate_size,
hidden_size // sf_block_size).view(torch.float8_e4m3fn)
w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1)
w2_weight = torch.stack(gemm2_weights_shuffled)
w2_weight_scale = torch.stack(gemm2_scales_shuffled).reshape(
num_experts, hidden_size,
intermediate_size // sf_block_size).view(torch.float8_e4m3fn)
w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1)
tg_result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits.to(torch.bfloat16),
routing_bias=None,
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
gemm1_weights=w13_weight,
gemm1_weights_scale=w13_weight_scale,
gemm1_bias=w13_bias,
gemm1_alpha=alpha,
gemm1_beta=beta,
gemm1_clamp_limit=limit,
gemm2_weights=w2_weight,
gemm2_weights_scale=w2_weight_scale,
gemm2_bias=w2_bias,
output1_scale_scalar=None,
output1_scale_gate_scalar=None,
output2_scale_scalar=None,
num_experts=num_experts,
top_k=topk,
n_group=None,
topk_group=None,
intermediate_size=intermediate_size,
local_expert_offset=0,
local_num_experts=num_experts,
routed_scaling_factor=None,
tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts),
routing_method_type=1, # renormalize
do_finalize=True)[0]
return tg_result
def check_accuracy(a, b, atol, rtol, percent):
"""Allow a mismatch percentage of 1 - percent."""
if torch.any(torch.isnan(a)):
raise Exception("NaN in reference output")
if torch.any(torch.isnan(b)):
raise Exception("NaN in actual output")
if torch.any(torch.isinf(a)):
raise Exception("Inf in reference output")
if torch.any(torch.isinf(b)):
raise Exception("Inf in actual output")
assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}"
left = torch.abs(a - b)
right = atol + rtol * torch.abs(b)
count = torch.sum(left > right)
mismatch_percent = count / a.numel()
if mismatch_percent > 1 - percent:
raise Exception(
f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
f"(threshold: {1-percent:.4f})")
@pytest.mark.parametrize("topk", [1, 4])
@pytest.mark.parametrize("num_experts", [32, 128])
@pytest.mark.parametrize("num_tokens", [1, 128, 1024])
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
(1.702, 1.0, 7.0)])
@pytest.mark.parametrize("act_type", ['mxfp8', 'bf16'])
@pytest.mark.skipif(
not TRTLLM_GEN_MXFP4_AVAILABLE,
reason="nvidia gpu and compute capability sm100 is required for this test")
def test_trtllm_gen_mxfp4_fused_moe(
topk: int,
num_experts: int,
num_tokens: int,
intermediate_size: int,
hidden_size: int,
alpha: float,
beta: float,
limit: Optional[float],
act_type: str,
):
seed = 42
torch.manual_seed(seed)
hidden_states = torch.randn(num_tokens,
hidden_size,
device="cuda:0",
dtype=torch.bfloat16)
w13 = (torch.randn(num_experts,
intermediate_size * 2,
hidden_size,
device="cuda:0",
dtype=torch.bfloat16))
w2 = (torch.randn(num_experts,
hidden_size,
intermediate_size,
device="cuda:0",
dtype=torch.bfloat16))
bias13 = torch.randn(num_experts, intermediate_size * 2,
device="cuda:0") * 10
bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10
router_logits = torch.rand(num_tokens, num_experts,
dtype=torch.float32).cuda()
w13, w13_scale = fp4_quantize(w13,
torch.tensor(1.0, device="cuda:0"),
32,
sf_use_ue8m0=True,
is_sf_swizzled_layout=False)
w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape(
num_experts, intermediate_size * 2, hidden_size // 32)
w2, w2_scale = fp4_quantize(w2,
torch.tensor(1.0, device="cuda:0"),
32,
sf_use_ue8m0=True,
is_sf_swizzled_layout=False)
w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape(
num_experts, hidden_size, intermediate_size // 32)
if act_type == 'mxfp8':
hidden_states, hidden_states_scale = mxfp8_quantize(
hidden_states, is_sf_swizzled_layout=False)
hidden_states_scale = hidden_states_scale.view(
torch.float8_e4m3fn).reshape(-1)
else:
hidden_states_scale = None
# reference result
ref_result = torch.empty_like(hidden_states, dtype=torch.bfloat16)
w13_ref = mxfp4_dequantize(w13.clone(), w13_scale.clone())
w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone())
bias13_ref = bias13
bias2_ref = bias2
if act_type == 'mxfp8':
hidden_states_ref = mxfp8_dequantize(
hidden_states, hidden_states_scale).to(torch.float32)
else:
hidden_states_ref = hidden_states.to(torch.float32)
# Process tokens in chunks of 32 to reduce memory usage
chunk_size = 32
num_chunks = (num_tokens + chunk_size - 1) // chunk_size
for i in range(num_chunks):
start_idx = i * chunk_size
end_idx = min(start_idx + chunk_size, num_tokens)
chunk_result = reference_moe(
router_logits[start_idx:end_idx].to(torch.float32),
topk,
num_experts,
hidden_states_ref[start_idx:end_idx],
w13_ref,
bias13_ref,
w2_ref,
bias2_ref,
alpha,
beta,
limit,
act_type,
)
ref_result[start_idx:end_idx].copy_(chunk_result)
# trtllm-gen result
if alpha is not None:
alpha = torch.full((num_experts, ), alpha, device=hidden_states.device)
if limit is not None:
limit = torch.full((num_experts, ), limit, device=hidden_states.device)
if beta is not None:
beta = torch.full((num_experts, ), beta, device=hidden_states.device)
tg_result = tg_mxfp4_moe(router_logits,
topk,
num_experts,
intermediate_size,
hidden_size,
hidden_states,
hidden_states_scale,
w13,
w13_scale,
bias13,
w2,
w2_scale,
bias2,
act_type,
alpha=alpha,
beta=beta,
limit=limit)
# relatively loose check since the mxfp4 quantization is less accurate
check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)
......@@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
from vllm.platforms import current_platform
from vllm.utils import cdiv
from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch
try:
......@@ -76,6 +77,7 @@ def pplx_cutlass_moe(
assert torch.cuda.current_device() == pgi.local_rank
num_tokens, hidden_dim = a.shape
intermediate_dim = w2.shape[2]
num_experts = w1.shape[0]
block_size = hidden_dim # TODO support more cases
device = pgi.device
......@@ -124,8 +126,27 @@ def pplx_cutlass_moe(
num_local_experts=num_local_experts,
num_dispatchers=num_dispatchers)
ab_strides1 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
ab_strides2 = torch.full((num_local_experts, ),
intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides1 = torch.full((num_local_experts, ),
2 * intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides2 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
out_dtype, per_act_token, per_out_ch)
out_dtype, per_act_token, per_out_ch,
ab_strides1, ab_strides2, c_strides1,
c_strides2)
fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize,
......@@ -227,6 +248,7 @@ def _pplx_moe(
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
@pytest.mark.parametrize("use_internode", [False])
@multi_gpu_test(num_gpus=2)
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
......
......@@ -37,6 +37,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.platforms import current_platform
from vllm.utils import round_up
from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch
requires_pplx = pytest.mark.skipif(
......@@ -452,6 +453,7 @@ def _pplx_prepare_finalize(
@pytest.mark.parametrize("use_internode", [False])
@pytest.mark.optional
@requires_pplx
@multi_gpu_test(num_gpus=2)
def test_pplx_prepare_finalize_slow(
mnk: tuple[int, int, int],
e: int,
......@@ -740,6 +742,7 @@ def _pplx_moe(
@pytest.mark.parametrize("use_internode", [False])
@pytest.mark.optional
@requires_pplx
@multi_gpu_test(num_gpus=2)
def test_pplx_moe_slow(
mnk: tuple[int, int, int],
e: int,
......@@ -880,6 +883,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
@requires_pplx
@multi_gpu_test(num_gpus=2)
def test_pplx_prepare_finalize(
world_dp_size: tuple[int, int],
use_internode: bool,
......@@ -893,6 +897,7 @@ def test_pplx_prepare_finalize(
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
@requires_pplx
@multi_gpu_test(num_gpus=2)
def test_pplx_moe(
world_dp_size: tuple[int, int],
use_internode: bool,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment