Commit a99300bd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-dev

parents cc3e01c7 5438967f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path
def pytest_addoption(parser):
"""Add custom command line options."""
parser.addoption("--config-list-file",
default="configs/models-small.txt",
help="File containing list of config files to test")
parser.addoption("--tp-size",
default=1,
type=int,
help="Tensor parallel size")
def pytest_generate_tests(metafunc):
"""Generate test parameters from config files."""
if "config_filename" in metafunc.fixturenames:
config_list_file = metafunc.config.getoption("--config-list-file")
tp_size = metafunc.config.getoption("--tp-size")
# Handle both relative and absolute paths
config_list_path = Path(config_list_file)
if not config_list_path.is_absolute():
# If relative, try relative to test directory first
test_dir_path = Path(__file__).parent / config_list_file
if test_dir_path.exists():
config_list_path = test_dir_path
else:
# Try relative to current working directory
config_list_path = Path.cwd() / config_list_file
print(f"Looking for config list at: {config_list_path}")
config_files = []
if config_list_path.exists():
# Determine config directory (same directory as the list file)
config_dir = config_list_path.parent
with open(config_list_path) as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
config_path = config_dir / line
print(f"Checking config file: {config_path}")
if config_path.exists():
config_files.append(config_path)
print(f" ✓ Found: {config_path}")
else:
print(f" ✗ Missing: {config_path}")
else:
print(f"Config list file not found: {config_list_path}")
# Generate test parameters
if config_files:
metafunc.parametrize(["config_filename", "tp_size"],
[(config_file, int(tp_size))
for config_file in config_files],
ids=[
f"{config_file.stem}-tp{tp_size}"
for config_file in config_files
])
else:
print("No config files found, test will be skipped")
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Isolated GSM8K evaluation script for vLLM serve endpoint.
"""
import argparse
import ast
import asyncio
import json
import os
import time
from collections.abc import Generator
from typing import Optional, Union
import aiohttp
import numpy as np
import regex as re
import requests
from tqdm.asyncio import tqdm
INVALID = -9999999
def download_and_cache_file(url: str, filename: Optional[str] = None) -> str:
"""Download and cache a file from a URL."""
if filename is None:
filename = os.path.join("/tmp", url.split("/")[-1])
if os.path.exists(filename):
return filename
print(f"Downloading from {url} to {filename}")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(filename, "wb") as f:
for chunk in response.iter_content(chunk_size=1024):
f.write(chunk)
return filename
def load_gsm8k_data() -> tuple[list[dict], list[dict]]:
"""Load GSM8K train and test data"""
train_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl"
test_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
train_file = download_and_cache_file(train_url)
test_file = download_and_cache_file(test_url)
train_data = list(read_jsonl(train_file))
test_data = list(read_jsonl(test_file))
return train_data, test_data
def read_jsonl(filename: str) -> Generator[dict, None, None]:
"""Read a JSONL file."""
with open(filename) as fin:
for line in fin:
if not line.startswith("#"):
yield json.loads(line)
def get_answer_value(answer_str: str) -> int:
"""Extract the numerical answer from the response."""
answer_str = answer_str.replace(",", "")
numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
async def call_vllm_api(session: aiohttp.ClientSession,
prompt: str,
temperature: float,
max_tokens: int,
stop: Optional[list[str]] = None,
url: Optional[str] = None,
seed: Optional[int] = None) -> str:
"""Call vLLM's OpenAI-compatible completions endpoint."""
data = {
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_tokens,
"stop": stop,
}
if seed is not None:
data["seed"] = seed
try:
async with session.post(f"{url}/v1/completions",
json=data) as response:
response.raise_for_status()
result = await response.json()
return result["choices"][0]["text"]
except Exception as e:
print(f"Error calling vLLM API: {e}")
return ""
def evaluate_gsm8k(num_questions: int = 1319,
num_shots: int = 5,
max_tokens: int = 256,
host: str = "http://127.0.0.1",
port: int = 8000,
temperature: float = 0.0,
seed: Optional[int] = 42) -> dict[str, Union[float, int]]:
"""
Evaluate GSM8K accuracy using vLLM serve endpoint.
Returns dict with accuracy, invalid_rate, latency, etc.
"""
base_url = f"{host}:{port}"
# Load GSM8K train and test data
train_data, test_data = load_gsm8k_data()
# Limit to available test questions
num_questions = min(num_questions, len(test_data))
# Build few-shot examples from train split (like lm-eval does)
few_shot_examples = ""
for i in range(num_shots):
few_shot_examples += (f"Question: {train_data[i]['question']}\n"
f"Answer: {train_data[i]['answer']}\n\n")
# Prepare test questions and labels from test split
questions = []
labels = []
for i in range(num_questions):
questions.append(f"Question: {test_data[i]['question']}\nAnswer:")
labels.append(get_answer_value(test_data[i]["answer"]))
assert all(label != INVALID for label in labels), "Some labels are invalid"
# Run evaluation
async def run_async_evaluation():
states: list[str] = [""] * num_questions
async def get_answer(session: aiohttp.ClientSession, i: int) -> str:
prompt = few_shot_examples + questions[i]
answer = await call_vllm_api(
session=session,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
stop=["Question", "Assistant:", "<|separator|>"],
url=base_url,
seed=seed,
)
states[i] = answer
return answer
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(
total=600)) as session:
tasks = [get_answer(session, i) for i in range(num_questions)]
await tqdm.gather(*tasks, desc="Evaluating")
return states
print(f"Running GSM8K evaluation: {num_questions} questions, "
f"{num_shots}-shot")
tic = time.perf_counter()
states = asyncio.run(run_async_evaluation())
latency = time.perf_counter() - tic
# Compute metrics
preds = [get_answer_value(state) for state in states]
accuracy = np.mean(np.array(preds) == np.array(labels))
invalid_rate = np.mean(np.array(preds) == INVALID)
result = {
"accuracy": accuracy,
"invalid_rate": invalid_rate,
"latency": latency,
"questions_per_second": num_questions / latency,
"num_questions": num_questions,
"num_shots": num_shots,
"max_tokens": max_tokens,
"timestamp": time.time(),
}
return result
def main() -> None:
parser = argparse.ArgumentParser(
description="GSM8K evaluation for vLLM serve")
parser.add_argument("--num-shots",
type=int,
default=5,
help="Number of few-shot examples")
parser.add_argument("--num-questions",
type=int,
default=1319,
help="Number of questions to evaluate")
parser.add_argument("--max-tokens",
type=int,
default=256,
help="Max tokens for generation")
parser.add_argument("--host",
type=str,
default="http://127.0.0.1",
help="Host URL")
parser.add_argument("--port", type=int, default=8000, help="Port number")
parser.add_argument("--temperature",
type=float,
default=0.0,
help="Temperature for generation")
parser.add_argument("--seed",
type=int,
default=42,
help="Random seed for reproducibility")
parser.add_argument("--save-results",
type=str,
help="Save results to JSON file")
args = parser.parse_args()
result = evaluate_gsm8k(
num_questions=args.num_questions,
num_shots=args.num_shots,
max_tokens=args.max_tokens,
host=args.host,
port=args.port,
temperature=args.temperature,
seed=args.seed,
)
# Print results to terminal
print("\nResults:")
print(f"Accuracy: {result['accuracy']:.3f}")
print(f"Invalid responses: {result['invalid_rate']:.3f}")
print(f"Total latency: {result['latency']:.3f} s")
print(f"Questions per second: {result['questions_per_second']:.3f}")
# Optional file saving
if args.save_results:
with open(args.save_results, "w") as f:
json.dump(result, f, indent=2)
print(f"Results saved to {args.save_results}")
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
GSM8K evaluation using vLLM server and isolated GSM8K script.
Replacement for lm-eval-harness with better performance and control.
Usage:
pytest -s -v test_gsm8k_correctness.py \
--config-list-file=configs/models-small.txt \
--tp-size=1
"""
import yaml
from tests.utils import RemoteOpenAIServer
from .gsm8k_eval import evaluate_gsm8k
RTOL = 0.08 # Relative tolerance for accuracy comparison
def launch_gsm8k_eval(eval_config, server_url, tp_size):
"""Launch GSM8K evaluation using our isolated script."""
# Extract host and port from server URL
if "://" in server_url:
server_url = server_url.split("://")[1]
host_port = server_url.split("/")[0] # Remove path if present
if ":" in host_port:
host, port = host_port.split(":")
port = int(port)
else:
host = host_port
port = 8000
# Add http:// prefix if not present
if not host.startswith("http"):
host = f"http://{host}"
# Run GSM8K evaluation
results = evaluate_gsm8k(
num_questions=eval_config["num_questions"],
num_shots=eval_config["num_fewshot"],
host=host,
port=port,
)
return results
def test_gsm8k_correctness_param(config_filename, tp_size):
"""Test GSM8K correctness for a given model configuration."""
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
# Server arguments
server_args = [
"--max-model-len",
str(eval_config.get("max_model_len", 4096)),
"--enforce-eager",
"--trust-remote-code",
"--tensor-parallel-size",
str(tp_size),
]
# Launch server and run evaluation
with RemoteOpenAIServer(eval_config["model_name"],
server_args,
max_wait_seconds=480) as remote_server:
server_url = remote_server.url_for("v1")
results = launch_gsm8k_eval(eval_config, server_url, tp_size)
# Check accuracy against threshold
measured_accuracy = results["accuracy"]
expected_accuracy = eval_config["accuracy_threshold"]
print(f"GSM8K Results for {eval_config['model_name']}:")
print(f" Accuracy: {measured_accuracy:.3f}")
print(f" Expected: {expected_accuracy:.3f}")
print(f" Questions: {results['num_questions']}")
print(f" Invalid rate: {results['invalid_rate']:.3f}")
print(f" Latency: {results['latency']:.1f}s")
print(f" QPS: {results['questions_per_second']:.1f}")
# Verify accuracy is within tolerance
assert measured_accuracy >= expected_accuracy - RTOL, (
f"Accuracy too low: {measured_accuracy:.3f} < "
f"{expected_accuracy:.3f} - {RTOL:.3f}")
print(f"✅ GSM8K test passed for {eval_config['model_name']}")
...@@ -705,6 +705,94 @@ def test_swap_blocks_mla( ...@@ -705,6 +705,94 @@ def test_swap_blocks_mla(
f"{dst} in dst_cache.") f"{dst} in dst_cache.")
@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
block_size, num_blocks,
max_seq_len, batch_size, dtype,
kv_cache_dtype, device):
entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
seq_len_tensor = torch.randint(0,
max_seq_len + 1, (batch_size, ),
device=device)
total_tokens = seq_len_tensor.sum()
cu_seq_lens = torch.empty((batch_size + 1),
dtype=torch.int32,
device=device)
cu_seq_lens[0] = 0
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
print("seq_len_tensor", seq_len_tensor)
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
block_table = torch.empty((batch_size, num_blocks),
dtype=torch.int32,
device=device)
for b in range(batch_size):
perm = torch.randperm(num_blocks, device=device)
block_table[b, :] = perm
dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device)
expected_batches = []
for b in range(batch_size):
s = seq_len_tensor[b]
if s == 0:
continue
tot = tot_blocks_tensor[b]
blocks = block_table[b, :tot].tolist()
gathered_rows = []
for i in range(tot - 1):
block_data = src_cache[blocks[i]]
if kv_cache_dtype == "fp8":
dequantized_block = torch.empty_like(block_data, dtype=dtype)
ops.convert_fp8(dequantized_block, block_data, scale.item())
gathered_rows.append(dequantized_block)
else:
gathered_rows.append(block_data)
remaining = s - (tot - 1) * block_size
last_block_data = src_cache[blocks[-1], :remaining, :]
if kv_cache_dtype == "fp8":
dequantized_last_block = torch.empty_like(last_block_data,
dtype=dtype)
ops.convert_fp8(dequantized_last_block, last_block_data,
scale.item())
gathered_rows.append(dequantized_last_block)
else:
gathered_rows.append(last_block_data)
batch_expected = torch.cat(gathered_rows, dim=0)
expected_batches.append(batch_expected)
expected = torch.cat(expected_batches, dim=0)
opcheck(
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
(src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
scale, None),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, kv_cache_dtype,
scale, None)
torch.testing.assert_close(dst, expected)
@pytest.mark.parametrize("kv_lora_rank", [512]) @pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64]) @pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("block_size", [16])
...@@ -716,9 +804,9 @@ def test_swap_blocks_mla( ...@@ -716,9 +804,9 @@ def test_swap_blocks_mla(
["auto"]) # You can also test "fp8" if needed. ["auto"]) # You can also test "fp8" if needed.
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
num_blocks, max_seq_len, batch_size, dtype, num_blocks, max_seq_len, batch_size, dtype,
kv_cache_dtype, device): kv_cache_dtype, device):
entry_size = kv_lora_rank + qk_rope_head_dim entry_size = kv_lora_rank + qk_rope_head_dim
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device) kv_cache_dtype, device)
...@@ -768,12 +856,12 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, ...@@ -768,12 +856,12 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
expected = torch.cat(expected_batches, dim=0) expected = torch.cat(expected_batches, dim=0)
opcheck( opcheck(
torch.ops._C_cache_ops.gather_cache, torch.ops._C_cache_ops.cp_gather_cache,
(src_cache, dst, block_table, cu_seq_lens, batch_size, None), (src_cache, dst, block_table, cu_seq_lens, batch_size, None),
test_utils=DEFAULT_OPCHECK_TEST_UTILS, test_utils=DEFAULT_OPCHECK_TEST_UTILS,
) )
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size) ops.cp_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
torch.testing.assert_close(dst, expected) torch.testing.assert_close(dst, expected)
......
...@@ -6,28 +6,19 @@ import flashinfer ...@@ -6,28 +6,19 @@ import flashinfer
import pytest import pytest
import torch import torch
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import round_up
if not current_platform.is_device_capability(100): if not current_platform.is_device_capability(100):
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.", pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
allow_module_level=True) allow_module_level=True)
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = current_platform.fp8_dtype()
# KV Cache Layout for TRT-LLM FP4_DTYPE = torch.uint8
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
MAX_Q_LEN = 1024
MAX_KV_LEN = 4096
BATCH_SIZES = [4, 12]
NUM_HEADS = [(16, 16), (40, 8)]
HEAD_SIZES = [128]
BLOCK_SIZES = [16]
KV_LAYOUTS = ["HND"]
DTYPES = [torch.bfloat16]
KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
SOFT_CAPS = [None, 50.0]
def to_float8(x, dtype=torch.float8_e4m3fn): def to_float8(x, dtype=torch.float8_e4m3fn):
...@@ -39,42 +30,61 @@ def to_float8(x, dtype=torch.float8_e4m3fn): ...@@ -39,42 +30,61 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
return x_scl_sat.to(dtype), scale.float().reciprocal() return x_scl_sat.to(dtype), scale.float().reciprocal()
@pytest.mark.parametrize("batch_size", BATCH_SIZES) DTYPE = [torch.bfloat16]
QUANT_DTYPES = [
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(None, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
]
BATCH_SIZE = [4, 12]
MAX_SEQ_LENS = [(1024, 4096)]
NUM_HEADS = [(64, 8), (40, 8)]
HEAD_SIZE = [128]
KV_LAYOUT = ["HND"] # currently only HND is supported
BLOCK_SIZE = [16]
SOFT_CAP = [None, 50.0]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES)
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZE)
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("kv_layout", KV_LAYOUT)
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS) @pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", SOFT_CAP)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
@torch.inference_mode @torch.inference_mode
def test_flashinfer_trtllm_decode_with_baseline( def test_flashinfer_trtllm_decode_with_baseline(
dtype: torch.dtype,
quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype],
Optional[torch.dtype]],
batch_size: int, batch_size: int,
max_seq_lens: tuple[int, int],
num_heads: tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
block_size: int,
kv_layout: str, kv_layout: str,
dtype: torch.dtype, block_size: int,
kv_cache_dtype: Optional[torch.dtype],
soft_cap: Optional[float], soft_cap: Optional[float],
) -> None: ) -> None:
kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(0) current_platform.seed_everything(0)
kv_lens = torch.randint(1, MAX_KV_LEN, (batch_size, ), dtype=torch.int32) q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
kv_lens[-1] = MAX_KV_LEN q_quant_dtype = q_quant_dtype or dtype
max_kv_len = torch.max(kv_lens).item() kv_quant_dtype = kv_quant_dtype or dtype
num_seqs = len(kv_lens) o_quant_dtype = o_quant_dtype or dtype
num_query_heads = num_heads[0] _, max_kv_len = max_seq_lens
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
scale = head_size**-0.5 num_qo_heads, num_kv_heads = num_heads
assert num_qo_heads % num_kv_heads == 0
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) sm_scale = float(1.0 / (head_size**0.5))
kv_cache_shape = None kv_cache_shape = None
if kv_layout == "NHD": if kv_layout == "NHD":
...@@ -83,23 +93,40 @@ def test_flashinfer_trtllm_decode_with_baseline( ...@@ -83,23 +93,40 @@ def test_flashinfer_trtllm_decode_with_baseline(
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
else: else:
raise ValueError(f"Invalid kv_layout: {kv_layout}") raise ValueError(f"Invalid kv_layout: {kv_layout}")
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
kv_scale = 1.0
if kv_cache_dtype is current_platform.fp8_dtype():
key_value_cache, kv_scale = to_float8(key_value_cache,
current_platform.fp8_dtype())
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
else:
q_scale = 1.0
ref_query = query
kv_lens = torch.randint(1, max_kv_len, (batch_size, ), dtype=torch.int32)
kv_lens[-1] = max_kv_len
seq_lens = kv_lens
max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
else:
kv_scale = 1.0
ref_kv_cache = kv_cache
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(0, block_tables = torch.randint(0,
NUM_BLOCKS, NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq), (batch_size, max_num_blocks_per_seq),
dtype=torch.int32) dtype=torch.int32)
k_scale = v_scale = kv_scale
kv_indptr = [0] kv_indptr = [0]
kv_indices = [] kv_indices = []
kv_last_page_lens = [] kv_last_page_lens = []
for i in range(num_seqs): for i in range(batch_size):
seq_len = kv_lens[i] seq_len = seq_lens[i]
assert seq_len > 0 assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks]) kv_indices.extend(block_tables[i, :num_blocks])
...@@ -112,103 +139,120 @@ def test_flashinfer_trtllm_decode_with_baseline( ...@@ -112,103 +139,120 @@ def test_flashinfer_trtllm_decode_with_baseline(
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
# Baseline Decode
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, workspace_buffer, kv_layout, use_tensor_cores=True)
kv_layout,
use_tensor_cores=((num_query_heads // num_kv_heads) > 4))
wrapper.plan(kv_indptr, wrapper.plan(kv_indptr,
kv_indices, kv_indices,
kv_last_page_lens, kv_last_page_lens,
num_query_heads, num_qo_heads,
num_kv_heads, num_kv_heads,
head_size, head_size,
block_size, block_size,
"NONE", "NONE",
sm_scale=scale, sm_scale=sm_scale,
q_data_type=dtype, q_data_type=dtype,
kv_data_type=kv_cache_dtype, kv_data_type=dtype,
logits_soft_cap=soft_cap) logits_soft_cap=soft_cap)
output = torch.empty(query.shape, dtype=dtype) output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(query, wrapper.run(ref_query, ref_kv_cache, out=output)
key_value_cache, o_scale = 1.0
k_scale=k_scale, o_sf_scale = None
v_scale=v_scale, if o_quant_dtype == FP8_DTYPE:
out=output) _, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
# TRTLLM Decode # TRTLLM Decode
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) if o_quant_dtype == FP4_DTYPE:
output_trtllm = torch.empty(query.shape, dtype=dtype) output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
dtype=torch.uint8),
torch.empty((round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4)),
dtype=torch.float8_e4m3fn),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
flashinfer.decode.trtllm_batch_decode_with_kv_cache( flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=query.contiguous(), query=query,
kv_cache=key_value_cache, kv_cache=kv_cache,
workspace_buffer=workspace_buffer, workspace_buffer=workspace_buffer,
block_tables=block_tables, block_tables=block_tables,
seq_lens=kv_lens_tensor, seq_lens=seq_lens,
max_seq_len=max_kv_len, max_seq_len=max_seq_len,
bmm1_scale=k_scale * scale, bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale, bmm2_scale=v_scale / o_scale,
o_sf_scale=o_sf_scale,
out=output_trtllm, out=output_trtllm,
) )
if o_quant_dtype == FP8_DTYPE:
output_trtllm = output_trtllm.to(dtype) * o_scale
elif o_quant_dtype == FP4_DTYPE:
output_trtllm.data = output_trtllm.data.reshape(
-1, query.shape[1] * query.shape[2] // 2)
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
output_trtllm.scale,
o_sf_scale, dtype,
query.device)
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 3e-1, 1e0
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
else:
rtol, atol = 1e-2, 2e-2
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - output_trtllm))}" f"{torch.max(torch.abs(output - output_trtllm))}"
@pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES)
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZE)
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("kv_layout", KV_LAYOUT)
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS) @pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("soft_cap", [None]) @pytest.mark.parametrize("soft_cap", [None])
@torch.inference_mode @torch.inference_mode
def test_flashinfer_trtllm_prefill_with_baseline( def test_flashinfer_trtllm_prefill_with_baseline(
dtype: torch.dtype,
quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype],
Optional[torch.dtype]],
batch_size: int, batch_size: int,
max_seq_lens: tuple[int, int],
num_heads: tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
block_size: int,
kv_layout: str, kv_layout: str,
dtype: torch.dtype, block_size: int,
kv_cache_dtype: Optional[torch.dtype],
soft_cap: Optional[float], soft_cap: Optional[float],
) -> None: ) -> None:
kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
if dtype != kv_cache_dtype:
pytest.skip(f"Not supported dtype({dtype}) with "
"kv_cache_dtype({kv_cache_dtype})")
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(0) current_platform.seed_everything(0)
q_lens = torch.randint(1, MAX_Q_LEN, (batch_size, ), dtype=torch.int32) q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
q_lens[-1] = MAX_Q_LEN q_quant_dtype = q_quant_dtype or dtype
max_q_len = torch.max(q_lens).item() kv_quant_dtype = kv_quant_dtype or dtype
q_indptr = torch.cat([ o_quant_dtype = o_quant_dtype or dtype
torch.tensor([0], dtype=torch.int32),
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
])
kv_lens = torch.randint(0, MAX_KV_LEN, (batch_size, ), dtype=torch.int32) if q_quant_dtype != kv_quant_dtype:
kv_lens[-1] = MAX_KV_LEN pytest.skip("Skipped mixed QKV dtypes for prefill")
seq_lens = kv_lens + q_lens max_q_len, max_kv_len = max_seq_lens
max_seq_len = torch.max(seq_lens).item()
num_seqs = len(seq_lens)
num_query_heads = num_heads[0] num_qo_heads, num_kv_heads = num_heads
num_kv_heads = num_heads[1] assert num_qo_heads % num_kv_heads == 0
assert num_query_heads % num_kv_heads == 0
scale = head_size**-0.5 sm_scale = float(1.0 / (head_size**0.5))
query = torch.randn(torch.sum(q_lens).item(),
num_query_heads,
head_size,
dtype=dtype)
kv_cache_shape = None kv_cache_shape = None
if kv_layout == "NHD": if kv_layout == "NHD":
...@@ -217,22 +261,49 @@ def test_flashinfer_trtllm_prefill_with_baseline( ...@@ -217,22 +261,49 @@ def test_flashinfer_trtllm_prefill_with_baseline(
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
else: else:
raise ValueError(f"Invalid kv_layout: {kv_layout}") raise ValueError(f"Invalid kv_layout: {kv_layout}")
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
kv_scale = 1.0 q_lens = torch.randint(1, max_q_len, (batch_size, ), dtype=torch.int32)
if kv_cache_dtype is current_platform.fp8_dtype(): q_lens[-1] = max_q_len
key_value_cache, kv_scale = to_float8(key_value_cache, q_indptr = torch.cat([
current_platform.fp8_dtype()) torch.tensor([0], dtype=torch.int32),
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
])
query = torch.randn(torch.sum(q_lens).item(),
num_qo_heads,
head_size,
dtype=dtype)
if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
else:
q_scale = 1.0
ref_query = query
kv_lens = torch.randint(0, max_kv_len, (batch_size, ), dtype=torch.int32)
kv_lens[-1] = max_kv_len
seq_lens = kv_lens + q_lens
max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
else:
kv_scale = 1.0
ref_kv_cache = kv_cache
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(0, block_tables = torch.randint(0,
NUM_BLOCKS, NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq), (batch_size, max_num_blocks_per_seq),
dtype=torch.int32) dtype=torch.int32)
k_scale = v_scale = kv_scale
kv_indptr = [0] kv_indptr = [0]
kv_indices = [] kv_indices = []
kv_last_page_lens = [] kv_last_page_lens = []
for i in range(num_seqs): for i in range(batch_size):
seq_len = seq_lens[i] seq_len = seq_lens[i]
assert seq_len > 0 assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size num_blocks = (seq_len + block_size - 1) // block_size
...@@ -246,48 +317,81 @@ def test_flashinfer_trtllm_prefill_with_baseline( ...@@ -246,48 +317,81 @@ def test_flashinfer_trtllm_prefill_with_baseline(
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
# Baseline Prefill
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout) workspace_buffer, kv_layout)
wrapper.plan(q_indptr, wrapper.plan(q_indptr,
kv_indptr, kv_indptr,
kv_indices, kv_indices,
kv_last_page_lens, kv_last_page_lens,
num_query_heads, num_qo_heads,
num_kv_heads, num_kv_heads,
head_size, head_size,
block_size, block_size,
causal=True, causal=True,
sm_scale=scale, sm_scale=sm_scale,
q_data_type=dtype, q_data_type=dtype,
kv_data_type=kv_cache_dtype, kv_data_type=dtype,
logits_soft_cap=soft_cap) logits_soft_cap=soft_cap)
output = torch.empty(query.shape, dtype=dtype) output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(query, wrapper.run(ref_query, ref_kv_cache, out=output)
key_value_cache, o_scale = 1.0
k_scale=k_scale, o_sf_scale = None
v_scale=v_scale, if o_quant_dtype == FP8_DTYPE:
out=output) _, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
# TRTLLM Prefill
if o_quant_dtype == FP4_DTYPE:
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
dtype=torch.uint8),
torch.empty((round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4)),
dtype=torch.float8_e4m3fn),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
# TRTLLM Decode
output_trtllm = torch.empty(query.shape, dtype=dtype)
flashinfer.prefill.trtllm_batch_context_with_kv_cache( flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=query.contiguous(), query=query,
kv_cache=key_value_cache, kv_cache=kv_cache,
workspace_buffer=workspace_buffer, workspace_buffer=workspace_buffer,
block_tables=block_tables, block_tables=block_tables,
seq_lens=seq_lens, seq_lens=seq_lens,
max_q_len=max_q_len, max_q_len=max_q_len,
max_kv_len=max_seq_len, max_kv_len=max_seq_len,
bmm1_scale=k_scale * scale, bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale, bmm2_scale=v_scale / o_scale,
batch_size=num_seqs, batch_size=batch_size,
cum_seq_lens_q=q_indptr, cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr, cum_seq_lens_kv=kv_indptr,
o_sf_scale=o_sf_scale,
out=output_trtllm, out=output_trtllm,
) )
if o_quant_dtype == FP8_DTYPE:
output_trtllm = output_trtllm.to(dtype) * o_scale
elif o_quant_dtype == FP4_DTYPE:
output_trtllm.data = output_trtllm.data.reshape(
-1, query.shape[1] * query.shape[2] // 2)
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
output_trtllm.scale,
o_sf_scale, dtype,
query.device)
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 4e-1, 1e0
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
else:
rtol, atol = 1e-2, 1e-2
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - output_trtllm))}" f"{torch.max(torch.abs(output - output_trtllm))}"
...@@ -13,11 +13,17 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, ...@@ -13,11 +13,17 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
from vllm.triton_utils import triton from vllm.triton_utils import triton
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: def cal_diff(x: torch.Tensor,
y: torch.Tensor,
name: str,
use_fp8: bool = False) -> None:
x, y = x.double(), y.double() x, y = x.double(), y.double()
cos_diff = 1 - 2 * (x * y).sum().item() / max( cos_diff = 1 - 2 * (x * y).sum().item() / max(
(x * x + y * y).sum().item(), 1e-12) (x * x + y * y).sum().item(), 1e-12)
assert cos_diff < 1e-4 if (use_fp8):
assert cos_diff < 1e-4
else:
assert cos_diff < 1e-4 #1e-5
FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
if not is_flashmla_supported()[0] else "FlashMLA is supported" if not is_flashmla_supported()[0] else "FlashMLA is supported"
...@@ -27,7 +33,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ ...@@ -27,7 +33,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
reason=FLASH_MLA_UNSUPPORTED_REASON) reason=FLASH_MLA_UNSUPPORTED_REASON)
@pytest.mark.parametrize("b", [128]) @pytest.mark.parametrize("b", [128])
@pytest.mark.parametrize("s_q", [1, 2]) @pytest.mark.parametrize("s_q", [1, 2])
@pytest.mark.parametrize("mean_sk", [4096, 8192]) @pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
@pytest.mark.parametrize("h_q", [16, 32, 64, 128]) @pytest.mark.parametrize("h_q", [16, 32, 64, 128])
@pytest.mark.parametrize("h_kv", [1]) @pytest.mark.parametrize("h_kv", [1])
@pytest.mark.parametrize("d", [576]) @pytest.mark.parametrize("d", [576])
...@@ -35,20 +41,26 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ ...@@ -35,20 +41,26 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
@pytest.mark.parametrize("block_size", [64]) @pytest.mark.parametrize("block_size", [64])
@pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("varlen", [False, True]) @pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("torch_dtype",
[torch.bfloat16, torch.float16, torch.float8_e4m3fn])
@torch.inference_mode() @torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
varlen, dtype): varlen, torch_dtype):
device = torch.device("cuda:0") device = torch.device("cuda:0")
torch.set_default_dtype(dtype) if torch_dtype == torch.float8_e4m3fn:
init_dtype = torch.bfloat16
else:
init_dtype = torch_dtype
torch.set_default_dtype(init_dtype)
torch.set_default_device(device) torch.set_default_device(device)
torch.cuda.set_device(device) torch.cuda.set_device(device)
torch.manual_seed(0) torch.manual_seed(0)
random.seed(0) random.seed(0)
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
f"{d=}, {dv=}, {causal=}, {varlen=}, {dtype=}") f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}")
use_fp8 = torch_dtype == torch.float8_e4m3fn
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
if varlen: if varlen:
for i in range(b): for i in range(b):
...@@ -71,6 +83,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, ...@@ -71,6 +83,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
tile_scheduler_metadata, num_splits = get_mla_metadata( tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, s_q * h_q // h_kv, h_kv) cache_seqlens, s_q * h_q // h_kv, h_kv)
init_dtype = q.dtype
if use_fp8:
fp8_dtype = torch.float8_e4m3fn
descale_q = torch.ones((1), dtype=torch.float32)
descale_k = torch.ones((1), dtype=torch.float32)
q = q.to(fp8_dtype)
blocked_k = blocked_k.to(fp8_dtype)
blocked_v = blocked_v.to(fp8_dtype)
else:
descale_q = None
descale_k = None
def flash_mla(): def flash_mla():
return flash_mla_with_kvcache( return flash_mla_with_kvcache(
q, q,
...@@ -81,6 +106,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, ...@@ -81,6 +106,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
tile_scheduler_metadata, tile_scheduler_metadata,
num_splits, num_splits,
causal=causal, causal=causal,
descale_q=descale_q,
descale_k=descale_k,
) )
def scaled_dot_product_attention(query, key, value, is_causal=False): def scaled_dot_product_attention(query, key, value, is_causal=False):
...@@ -104,29 +131,35 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, ...@@ -104,29 +131,35 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
return attn_weight @ value, lse return attn_weight @ value, lse
def ref_mla(): def ref_mla():
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
blocked_k_ = (blocked_k.to(torch.float) *
descale_k).to(init_dtype) if use_fp8 else blocked_k
blocked_v_ = (blocked_v.to(torch.float) *
descale_k).to(init_dtype) if use_fp8 else blocked_v
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32) lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b): for i in range(b):
begin = i * max_seqlen_pad begin = i * max_seqlen_pad
end = begin + cache_seqlens[i] end = begin + cache_seqlens[i]
ref_O, LSE = scaled_dot_product_attention( out_i, lse_i = scaled_dot_product_attention(
q[i].transpose(0, 1), q_[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
is_causal=causal, is_causal=causal,
) )
out[i] = ref_O.transpose(0, 1) out[i] = out_i.transpose(0, 1)
lse[i] = LSE lse[i] = lse_i
return out, lse return out, lse
out_flash, lse_flash = flash_mla() out_flash, lse_flash = flash_mla()
out_torch, lse_torch = ref_mla() out_torch, lse_torch = ref_mla()
cal_diff(out_flash, out_torch, "out") cal_diff(out_flash, out_torch, "out", use_fp8)
cal_diff(lse_flash, lse_torch, "lse") cal_diff(lse_flash, lse_torch, "lse")
t = triton.testing.do_bench(flash_mla) t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + bytes = (total_seqlens * h_kv * d +
b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} " b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,",
f"{bytes / 10 ** 6 / t:.0f} GB/s")
...@@ -80,6 +80,9 @@ def test_env( ...@@ -80,6 +80,9 @@ def test_env(
m.setenv(STR_BACKEND_ENV_VAR, name) m.setenv(STR_BACKEND_ENV_VAR, name)
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
if name == "FLASHINFER" and not use_v1:
pytest.skip("FlashInfer backend is only available on V1 engine")
if device == "cpu": if device == "cpu":
with patch("vllm.attention.selector.current_platform", with patch("vllm.attention.selector.current_platform",
CpuPlatform()): CpuPlatform()):
......
...@@ -137,9 +137,7 @@ def test_flashinfer_decode_with_paged_kv( ...@@ -137,9 +137,7 @@ def test_flashinfer_decode_with_paged_kv(
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.\ wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
use_tensor_cores=( use_tensor_cores=True)
(num_query_heads//num_kv_heads) > 4)
)
wrapper.plan( wrapper.plan(
kv_indptr, kv_indptr,
kv_indices, kv_indices,
...@@ -411,7 +409,7 @@ def test_flashinfer_decode_with_paged_fp8_kv( ...@@ -411,7 +409,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
assert num_query_heads % num_kv_heads == 0 assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens) max_kv_len = max(kv_lens)
scale = head_size**-0.5 scale = head_size**-0.5
use_tensor_cores = (num_query_heads // num_kv_heads) > 4 use_tensor_cores = True
kv_cache_dtype = torch.float8_e4m3fn kv_cache_dtype = torch.float8_e4m3fn
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
......
...@@ -20,9 +20,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( ...@@ -20,9 +20,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm from vllm.utils import has_deep_ep, has_deep_gemm
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used, from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
is_deep_gemm_supported)
from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch from .parallel_utils import ProcessGroupInfo, parallel_launch
from .utils import make_test_weights from .utils import make_test_weights
...@@ -370,9 +370,10 @@ NUM_EXPERTS = [32] ...@@ -370,9 +370,10 @@ NUM_EXPERTS = [32]
@pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS) @pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("world_dp_size", [(2, 1)])
@multi_gpu_test(num_gpus=2)
@requires_deep_ep @requires_deep_ep
@requires_deep_gemm @requires_deep_gemm
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), @pytest.mark.skipif(is_deep_gemm_e8m0_used(),
reason="Skipping test for Blackwell DeepGEMM") reason="Skipping test for Blackwell DeepGEMM")
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
topk: int, world_dp_size: tuple[int, int]): topk: int, world_dp_size: tuple[int, int]):
...@@ -427,9 +428,10 @@ USE_FP8_DISPATCH = [False] ...@@ -427,9 +428,10 @@ USE_FP8_DISPATCH = [False]
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH) @pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@pytest.mark.parametrize("block_size", [[128, 128]]) @pytest.mark.parametrize("block_size", [[128, 128]])
@pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("world_dp_size", [(2, 1)])
@multi_gpu_test(num_gpus=2)
@requires_deep_ep @requires_deep_ep
@requires_deep_gemm @requires_deep_gemm
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), @pytest.mark.skipif(is_deep_gemm_e8m0_used(),
reason="Skipping test for Blackwell DeepGEMM") reason="Skipping test for Blackwell DeepGEMM")
def test_ll_deepep_deepgemm_moe( def test_ll_deepep_deepgemm_moe(
mnk: tuple[int, int, int], mnk: tuple[int, int, int],
......
...@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ...@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_ep from vllm.utils import has_deep_ep
from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch from .parallel_utils import ProcessGroupInfo, parallel_launch
if has_deep_ep(): if has_deep_ep():
...@@ -411,6 +412,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn] ...@@ -411,6 +412,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("world_dp_size", [(2, 1)])
@pytest.mark.parametrize("per_act_token_quant", [False, True]) @pytest.mark.parametrize("per_act_token_quant", [False, True])
@multi_gpu_test(num_gpus=2)
@requires_deep_ep @requires_deep_ep
def test_deep_ep_moe( def test_deep_ep_moe(
dtype: torch.dtype, dtype: torch.dtype,
...@@ -459,6 +461,7 @@ USE_FP8_DISPATCH = [True, False] ...@@ -459,6 +461,7 @@ USE_FP8_DISPATCH = [True, False]
@pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("world_dp_size", [(2, 1)])
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH) @pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@multi_gpu_test(num_gpus=2)
@requires_deep_ep @requires_deep_ep
def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
num_experts: int, topk: int, num_experts: int, topk: int,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import pytest
import torch
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8,
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
swap_w13_to_w31)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
input_to_float8)
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if not has_flashinfer_cutlass_fused_moe(
) or not current_platform.has_device_capability(100):
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
allow_module_level=True)
NUM_EXPERTS = [16]
TOP_KS = [1]
MNK_FACTORS = [
(256, 8192, 5120),
(256, 4096, 5120),
(127, 8192, 5120),
(127, 4096, 5120),
(10, 8192, 5120),
(10, 4096, 5120),
(1, 8192, 5120),
(1, 4096, 5120),
]
vllm_config = VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def quant_fp8_per_tensor_batches(a):
num_batches = a.size(0)
a_quant = []
a_scales = []
for i in range(num_batches):
a_fp8, a_global_sf = input_to_float8(a[i])
a_global_sf = 1.0 / a_global_sf
a_quant.append(a_fp8)
a_scales.append(a_global_sf)
result_a_quant = torch.stack(a_quant)
result_a_scales = torch.stack(a_scales)
return result_a_quant, result_a_scales
@dataclass
class TestData:
hidden_states: torch.Tensor
w13_quantized: torch.Tensor
w2_quantized: torch.Tensor
a1_scale: torch.Tensor
a2_scale: torch.Tensor
w13_weight_scale: torch.Tensor
w2_weight_scale: torch.Tensor
layer: torch.nn.Module
@staticmethod
def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
reorder: bool) -> "TestData":
hidden_states = torch.randn(
(m, k), device="cuda", dtype=torch.bfloat16) / 10
w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16)
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
# Scale to fp8
_, a1_scale = input_to_float8(hidden_states)
a1_scale = 1.0 / a1_scale
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(
dtype=torch.float32)
w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
layer = torch.nn.Module()
layer.w13_weight = w13_quantized.clone()
layer.w2_weight = w2_quantized.clone()
layer.w13_input_scale = a1_scale
layer.w2_input_scale = a2_scale
layer.w13_weight_scale = w13_weight_scale
layer.w2_weight_scale = w2_weight_scale
register_moe_scaling_factors(layer)
# flashinfer expects swapped rows for w13
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if reorder:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
layer.w2_weight)
layer.custom_routing_function = Llama4MoE.custom_routing_function
layer.intermediate_size_per_partition = n
layer.ep_rank = 0
layer.local_num_experts = e
return TestData(
hidden_states=hidden_states,
w13_quantized=w13_quantized,
w2_quantized=w2_quantized,
a1_scale=a1_scale,
a2_scale=a2_scale,
w13_weight_scale=w13_weight_scale,
w2_weight_scale=w2_weight_scale,
layer=layer,
)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
def test_flashinfer_per_tensor_moe_fp8_no_graph(
m: int,
n: int,
k: int,
e: int,
topk: int,
monkeypatch,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=td.hidden_states,
router_logits=score,
use_grouped_topk=False,
top_k=topk,
renormalize=False,
custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax")
output = fused_experts(
td.hidden_states,
td.w13_quantized,
td.w2_quantized,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
activation="silu",
use_fp8_w8a8=True,
per_channel_quant=False,
global_num_experts=e,
expert_map=None,
w1_scale=td.w13_weight_scale,
w2_scale=td.w2_weight_scale,
a1_scale=td.a1_scale,
a2_scale=td.a2_scale,
apply_router_weight_on_input=True,
)
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
layer=td.layer,
hidden_states=td.hidden_states,
router_logits=score,
routing_bias=None,
global_num_experts=e,
top_k=topk,
num_expert_group=None,
topk_group=None,
apply_router_weight_on_input=True)
torch.testing.assert_close(output,
flashinfer_output,
atol=5.5e-2,
rtol=1e-2)
@pytest.mark.skip(
"Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472"
)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
def test_flashinfer_cutlass_moe_fp8_no_graph(
m: int,
n: int,
k: int,
e: int,
topk: int,
monkeypatch,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=td.hidden_states,
router_logits=score,
use_grouped_topk=False,
top_k=topk,
renormalize=False,
custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax")
output = fused_experts(
td.hidden_states,
td.w13_quantized,
td.w2_quantized,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
activation="silu",
use_fp8_w8a8=True,
per_channel_quant=False,
global_num_experts=e,
expert_map=None,
w1_scale=td.w13_weight_scale,
w2_scale=td.w2_weight_scale,
a1_scale=td.a1_scale,
a2_scale=td.a2_scale,
apply_router_weight_on_input=True,
)
td.layer.dp_size = 1
flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
td.hidden_states,
td.layer,
topk_weights,
topk_ids,
activation="silu",
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=True,
)
torch.testing.assert_close(output,
flashinfer_cutlass_output,
atol=5.5e-2,
rtol=1e-2)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the MoE grouped topk kernel
Run `pytest tests/kernels/moe/test_grouped_topk.py`.
"""
import pytest
import torch
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_grouped_topk,
grouped_topk)
from vllm.platforms import current_platform
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="This test is skipped on non-CUDA platform.")
@pytest.mark.parametrize("n_token", [1, 33, 64])
@pytest.mark.parametrize("n_hidden", [1024, 2048])
@pytest.mark.parametrize("n_expert", [16])
@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("num_expert_group", [8])
@pytest.mark.parametrize("topk_group", [2])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
@pytest.mark.parametrize("dtype",
[torch.float16, torch.bfloat16, torch.float32])
def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int,
n_hidden: int, n_expert: int, topk: int,
renormalize: bool, num_expert_group: int,
topk_group: int, scoring_func: str,
routed_scaling_factor: float, dtype: torch.dtype):
current_platform.seed_everything(0)
hidden_states = torch.randn((n_token, n_hidden),
dtype=dtype,
device="cuda")
gating_output = torch.randn((n_token, n_expert),
dtype=dtype,
device="cuda")
e_score_correction_bias = torch.randn((n_expert, ),
dtype=torch.float32,
device="cuda")
with monkeypatch.context() as m:
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
baseline_topk_weights, baseline_topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias)
test_topk_weights, test_topk_ids = fused_grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias)
if renormalize:
torch.testing.assert_close(baseline_topk_weights,
test_topk_weights,
atol=2e-2,
rtol=0)
torch.testing.assert_close(baseline_topk_ids,
test_topk_ids,
atol=0,
rtol=0)
...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig ...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from ...utils import multi_gpu_test
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors, from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
reference_moe_impl, reference_moe_impl,
run_modular_kernel) run_modular_kernel)
...@@ -162,6 +163,7 @@ def is_nyi_config(config: Config) -> bool: ...@@ -162,6 +163,7 @@ def is_nyi_config(config: Config) -> bool:
product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) @pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@multi_gpu_test(num_gpus=2)
@meets_multi_gpu_requirements @meets_multi_gpu_requirements
def test_modular_kernel_combinations_multigpu( def test_modular_kernel_combinations_multigpu(
k: int, n: int, e: int, dtype: torch.dtype, k: int, n: int, e: int, dtype: torch.dtype,
......
...@@ -438,11 +438,11 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, ...@@ -438,11 +438,11 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[...,
0:-128], 0:-128],
requires_grad=False) requires_grad=False)
torch.cuda.empty_cache()
vllm_moe.experts.w2_weight = Parameter(F.pad( vllm_moe.experts.w2_weight = Parameter(F.pad(
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[...,
0:-128], 0:-128],
requires_grad=False) requires_grad=False)
torch.cuda.synchronize()
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Run forward passes for both MoE blocks # Run forward passes for both MoE blocks
......
...@@ -4,15 +4,27 @@ ...@@ -4,15 +4,27 @@
import importlib import importlib
import importlib.metadata import importlib.metadata
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import pytest import pytest
import torch import torch
from packaging import version from packaging import version
from vllm.platforms import current_platform
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
"quark") is not None and version.parse( "quark") is not None and version.parse(
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
) and current_platform.is_device_capability(100)
if TRTLLM_GEN_MXFP4_AVAILABLE:
from flashinfer import (fp4_quantize, mxfp8_quantize,
next_positive_power_of_2,
reorder_rows_for_gated_act_gemm, shuffle_matrix_a,
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
@dataclass @dataclass
class ModelCase: class ModelCase:
...@@ -54,4 +66,410 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): ...@@ -54,4 +66,410 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
output = llm.generate_greedy("Today I am in the French Alps and", output = llm.generate_greedy("Today I am in the French Alps and",
max_tokens=20) max_tokens=20)
assert output assert output
\ No newline at end of file
def swiglu(x,
alpha: float = 1.702,
beta: float = 1.0,
limit: Optional[float] = None):
# Note we add an extra bias of 1 to the linear layer
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
if limit is not None:
x_glu = x_glu.clamp(max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
return out_glu * (x_linear + beta)
fp4_lookup_table = [
0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6
]
def mxfp4_dequantize(x, scale):
assert x.dtype == torch.uint8
x = x.view(torch.uint8).to(torch.int32)
x_unpacked = torch.zeros(*x.shape[:-1],
x.shape[-1] * 2,
dtype=torch.int32,
device=x.device)
x_unpacked[..., 0::2].copy_(x & 0xF)
x_unpacked[..., 1::2].copy_((x >> 4) & 0xF)
x_float = torch.zeros(x_unpacked.shape,
dtype=torch.float32,
device=x.device)
for i, val in enumerate(fp4_lookup_table):
x_float[x_unpacked == i] = val
scale = scale.view(torch.uint8).to(torch.int32)
scale = (scale << 23).view(torch.float32)
scale = scale.reshape(*x.shape[:-1], -1)
scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape)
return x_float * scale
def mxfp8_dequantize(x, scale):
assert x.dtype == torch.float8_e4m3fn
x_float = x.to(torch.float32)
scale = scale.view(torch.uint8).to(torch.int32)
scale = (scale << 23).view(torch.float32)
scale = scale.reshape(*x.shape[:-1], -1)
scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape)
return x_float * scale
def reference_moe(
roouting_logits,
topk,
num_experts,
hidden_states,
w13,
bias13,
w2,
bias2,
alpha,
beta,
limit,
act_type,
):
# renormalize routing
experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True)
expert_weights = torch.nn.functional.softmax(experts.values, dim=1)
expert_indices = experts.indices
t = hidden_states.clone()
# MLP #1
mlp1_weight = w13[expert_indices, ...]
mlp1_bias = bias13[expert_indices, ...]
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
if act_type == 'mxfp8':
t_quantized, t_scale = mxfp8_quantize(t.to(torch.bfloat16),
is_sf_swizzled_layout=False)
t = mxfp8_dequantize(t_quantized, t_scale)
# MLP #2
mlp2_weight = w2[expert_indices, ...]
mlp2_bias = bias2[expert_indices, ...]
t = torch.einsum("beck,bek->bec", mlp2_weight, t) + mlp2_bias
# Weighted sum of experts
t = torch.einsum("bec,be->bc", t, expert_weights)
assert t.shape == hidden_states.shape
return t.to(torch.bfloat16)
def get_tile_tokens_dim(x: torch.Tensor, top_k: int, num_experts: int):
# Number of tokens in the input tensor.
num_tokens = x.shape[0]
# Factor to account for the imbalance of the experts.
# factor equals to the
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
# - 1.0 means perfect expert distribution.
# - > 1.0 means some experts have more
# tokens than the perfect distribution.
# - < 1.0 does not make sense.
imbalance_factor = 1.3
# Calculate the number of tokens per expert
# assuming perfect distribution.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# Apply the imbalance factor.
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile
# as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
def tg_mxfp4_moe(
router_logits,
topk,
num_experts,
intermediate_size,
hidden_size,
hidden_states,
hidden_states_scale,
w13_weight,
w13_weight_scale,
w13_bias,
w2_weight,
w2_weight_scale,
w2_bias,
act_type,
alpha,
beta,
limit,
) -> torch.Tensor:
sf_block_size = 32
assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts
and w13_weight.shape[1] == intermediate_size * 2
and w13_weight.shape[2] == hidden_size // 2)
assert (w13_weight_scale.dim() == 3
and w13_weight_scale.shape[0] == num_experts
and w13_weight_scale.shape[1] == intermediate_size * 2
and w13_weight_scale.shape[2] == hidden_size // sf_block_size)
assert (w2_weight.dim() == 3 and w2_weight.shape[0] == num_experts
and w2_weight.shape[1] == hidden_size
and w2_weight.shape[2] == intermediate_size // 2)
assert (w2_weight_scale.dim() == 3
and w2_weight_scale.shape[1] == hidden_size
and w2_weight_scale.shape[2] == intermediate_size // sf_block_size)
assert (w13_bias.dim() == 2 and w13_bias.shape[0] == num_experts
and w13_bias.shape[1] == intermediate_size * 2)
assert (w2_bias.dim() == 2 and w2_bias.shape[0] == num_experts
and w2_bias.shape[1] == hidden_size)
# Swap w1 and w3 as the defenition of
# swiglu is different in the trtllm-gen
w13_weight_scale_ = w13_weight_scale.clone()
w13_weight_ = w13_weight.clone()
w13_bias_ = w13_bias.clone()
w13_weight[:, :intermediate_size, :].copy_(
w13_weight_[:, intermediate_size:, :])
w13_weight[:, intermediate_size:, :].copy_(
w13_weight_[:, :intermediate_size, :])
w13_weight_scale[:, :intermediate_size, :].copy_(
w13_weight_scale_[:, intermediate_size:, :])
w13_weight_scale[:, intermediate_size:, :].copy_(
w13_weight_scale_[:, :intermediate_size, :])
w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:])
w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size])
# Interleave the weights and scaling factors for activation
w13_weight_interleaved = []
w13_weight_scale_interleaved = []
w13_bias_interleaved = []
for i in range(num_experts):
w13_weight_interleaved.append(
reorder_rows_for_gated_act_gemm(w13_weight[i].clone()))
w13_weight_scale_interleaved.append(
reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone()))
w13_bias_interleaved.append(
reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1,
1)))
w13_weight = torch.stack(w13_weight_interleaved).reshape(
num_experts, 2 * intermediate_size, hidden_size // 2)
w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape(
num_experts, 2 * intermediate_size, hidden_size // 32)
w13_bias = torch.stack(w13_bias_interleaved).reshape(
num_experts, 2 * intermediate_size)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_shuffled = []
gemm1_scales_shuffled = []
gemm2_weights_shuffled = []
gemm2_scales_shuffled = []
gemm1_bias_shuffled = []
gemm2_bias_shuffled = []
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
for i in range(num_experts):
gemm1_weights_shuffled.append(
shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m))
gemm1_scales_shuffled.append(
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
gemm2_weights_shuffled.append(
shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m))
gemm2_scales_shuffled.append(
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
gemm1_bias_shuffled.append(
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m))
gemm2_bias_shuffled.append(
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m))
w13_weight = torch.stack(gemm1_weights_shuffled)
w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape(
num_experts, 2 * intermediate_size,
hidden_size // sf_block_size).view(torch.float8_e4m3fn)
w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1)
w2_weight = torch.stack(gemm2_weights_shuffled)
w2_weight_scale = torch.stack(gemm2_scales_shuffled).reshape(
num_experts, hidden_size,
intermediate_size // sf_block_size).view(torch.float8_e4m3fn)
w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1)
tg_result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits.to(torch.bfloat16),
routing_bias=None,
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
gemm1_weights=w13_weight,
gemm1_weights_scale=w13_weight_scale,
gemm1_bias=w13_bias,
gemm1_alpha=alpha,
gemm1_beta=beta,
gemm1_clamp_limit=limit,
gemm2_weights=w2_weight,
gemm2_weights_scale=w2_weight_scale,
gemm2_bias=w2_bias,
output1_scale_scalar=None,
output1_scale_gate_scalar=None,
output2_scale_scalar=None,
num_experts=num_experts,
top_k=topk,
n_group=None,
topk_group=None,
intermediate_size=intermediate_size,
local_expert_offset=0,
local_num_experts=num_experts,
routed_scaling_factor=None,
tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts),
routing_method_type=1, # renormalize
do_finalize=True)[0]
return tg_result
def check_accuracy(a, b, atol, rtol, percent):
"""Allow a mismatch percentage of 1 - percent."""
if torch.any(torch.isnan(a)):
raise Exception("NaN in reference output")
if torch.any(torch.isnan(b)):
raise Exception("NaN in actual output")
if torch.any(torch.isinf(a)):
raise Exception("Inf in reference output")
if torch.any(torch.isinf(b)):
raise Exception("Inf in actual output")
assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}"
left = torch.abs(a - b)
right = atol + rtol * torch.abs(b)
count = torch.sum(left > right)
mismatch_percent = count / a.numel()
if mismatch_percent > 1 - percent:
raise Exception(
f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
f"(threshold: {1-percent:.4f})")
@pytest.mark.parametrize("topk", [1, 4])
@pytest.mark.parametrize("num_experts", [32, 128])
@pytest.mark.parametrize("num_tokens", [1, 128, 1024])
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
(1.702, 1.0, 7.0)])
@pytest.mark.parametrize("act_type", ['mxfp8', 'bf16'])
@pytest.mark.skipif(
not TRTLLM_GEN_MXFP4_AVAILABLE,
reason="nvidia gpu and compute capability sm100 is required for this test")
def test_trtllm_gen_mxfp4_fused_moe(
topk: int,
num_experts: int,
num_tokens: int,
intermediate_size: int,
hidden_size: int,
alpha: float,
beta: float,
limit: Optional[float],
act_type: str,
):
seed = 42
torch.manual_seed(seed)
hidden_states = torch.randn(num_tokens,
hidden_size,
device="cuda:0",
dtype=torch.bfloat16)
w13 = (torch.randn(num_experts,
intermediate_size * 2,
hidden_size,
device="cuda:0",
dtype=torch.bfloat16))
w2 = (torch.randn(num_experts,
hidden_size,
intermediate_size,
device="cuda:0",
dtype=torch.bfloat16))
bias13 = torch.randn(num_experts, intermediate_size * 2,
device="cuda:0") * 10
bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10
router_logits = torch.rand(num_tokens, num_experts,
dtype=torch.float32).cuda()
w13, w13_scale = fp4_quantize(w13,
torch.tensor(1.0, device="cuda:0"),
32,
sf_use_ue8m0=True,
is_sf_swizzled_layout=False)
w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape(
num_experts, intermediate_size * 2, hidden_size // 32)
w2, w2_scale = fp4_quantize(w2,
torch.tensor(1.0, device="cuda:0"),
32,
sf_use_ue8m0=True,
is_sf_swizzled_layout=False)
w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape(
num_experts, hidden_size, intermediate_size // 32)
if act_type == 'mxfp8':
hidden_states, hidden_states_scale = mxfp8_quantize(
hidden_states, is_sf_swizzled_layout=False)
hidden_states_scale = hidden_states_scale.view(
torch.float8_e4m3fn).reshape(-1)
else:
hidden_states_scale = None
# reference result
ref_result = torch.empty_like(hidden_states, dtype=torch.bfloat16)
w13_ref = mxfp4_dequantize(w13.clone(), w13_scale.clone())
w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone())
bias13_ref = bias13
bias2_ref = bias2
if act_type == 'mxfp8':
hidden_states_ref = mxfp8_dequantize(
hidden_states, hidden_states_scale).to(torch.float32)
else:
hidden_states_ref = hidden_states.to(torch.float32)
# Process tokens in chunks of 32 to reduce memory usage
chunk_size = 32
num_chunks = (num_tokens + chunk_size - 1) // chunk_size
for i in range(num_chunks):
start_idx = i * chunk_size
end_idx = min(start_idx + chunk_size, num_tokens)
chunk_result = reference_moe(
router_logits[start_idx:end_idx].to(torch.float32),
topk,
num_experts,
hidden_states_ref[start_idx:end_idx],
w13_ref,
bias13_ref,
w2_ref,
bias2_ref,
alpha,
beta,
limit,
act_type,
)
ref_result[start_idx:end_idx].copy_(chunk_result)
# trtllm-gen result
if alpha is not None:
alpha = torch.full((num_experts, ), alpha, device=hidden_states.device)
if limit is not None:
limit = torch.full((num_experts, ), limit, device=hidden_states.device)
if beta is not None:
beta = torch.full((num_experts, ), beta, device=hidden_states.device)
tg_result = tg_mxfp4_moe(router_logits,
topk,
num_experts,
intermediate_size,
hidden_size,
hidden_states,
hidden_states_scale,
w13,
w13_scale,
bias13,
w2,
w2_scale,
bias2,
act_type,
alpha=alpha,
beta=beta,
limit=limit)
# relatively loose check since the mxfp4 quantization is less accurate
check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)
...@@ -37,6 +37,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ...@@ -37,6 +37,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import round_up from vllm.utils import round_up
from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch from .parallel_utils import ProcessGroupInfo, parallel_launch
requires_pplx = pytest.mark.skipif( requires_pplx = pytest.mark.skipif(
...@@ -452,6 +453,7 @@ def _pplx_prepare_finalize( ...@@ -452,6 +453,7 @@ def _pplx_prepare_finalize(
@pytest.mark.parametrize("use_internode", [False]) @pytest.mark.parametrize("use_internode", [False])
@pytest.mark.optional @pytest.mark.optional
@requires_pplx @requires_pplx
@multi_gpu_test(num_gpus=2)
def test_pplx_prepare_finalize_slow( def test_pplx_prepare_finalize_slow(
mnk: tuple[int, int, int], mnk: tuple[int, int, int],
e: int, e: int,
...@@ -740,6 +742,7 @@ def _pplx_moe( ...@@ -740,6 +742,7 @@ def _pplx_moe(
@pytest.mark.parametrize("use_internode", [False]) @pytest.mark.parametrize("use_internode", [False])
@pytest.mark.optional @pytest.mark.optional
@requires_pplx @requires_pplx
@multi_gpu_test(num_gpus=2)
def test_pplx_moe_slow( def test_pplx_moe_slow(
mnk: tuple[int, int, int], mnk: tuple[int, int, int],
e: int, e: int,
...@@ -880,6 +883,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, ...@@ -880,6 +883,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False]) @pytest.mark.parametrize("use_internode", [False])
@requires_pplx @requires_pplx
@multi_gpu_test(num_gpus=2)
def test_pplx_prepare_finalize( def test_pplx_prepare_finalize(
world_dp_size: tuple[int, int], world_dp_size: tuple[int, int],
use_internode: bool, use_internode: bool,
...@@ -893,6 +897,7 @@ def test_pplx_prepare_finalize( ...@@ -893,6 +897,7 @@ def test_pplx_prepare_finalize(
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False]) @pytest.mark.parametrize("use_internode", [False])
@requires_pplx @requires_pplx
@multi_gpu_test(num_gpus=2)
def test_pplx_moe( def test_pplx_moe(
world_dp_size: tuple[int, int], world_dp_size: tuple[int, int],
use_internode: bool, use_internode: bool,
......
...@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( ...@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe) fused_topk, modular_triton_fused_moe)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
dg_available = has_deep_gemm() dg_available = has_deep_gemm()
...@@ -226,8 +226,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, ...@@ -226,8 +226,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), @pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE")
reason="Not E8M0 scale MOE")
@torch.inference_mode() @torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
monkeypatch): monkeypatch):
......
...@@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit, ...@@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids': topk_ids, 'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale, 'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale, 'w2_scale': moe_tensors.w2_scale,
'ab_strides1': moe_tensors.ab_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides1': moe_tensors.c_strides1,
'c_strides2': moe_tensors.c_strides2,
'per_act_token': per_act_token, 'per_act_token': per_act_token,
'a1_scale': None #moe_tensors.a_scale 'a1_scale': None #moe_tensors.a_scale
} }
...@@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8( ...@@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8(
topk_ids[0][1] = 1 topk_ids[0][1] = 1
workspace13_shape = (m * topk, max(2 * n, k)) workspace13_shape = (m * topk, max(2 * n, k))
workspace2_shape = (m * topk, n) workspace2_shape = (m * topk, max(n, k))
output_shape = (m * topk, k) output_shape = (m, k)
workspace13 = torch.empty(prod(workspace13_shape), workspace13 = torch.empty(prod(workspace13_shape),
device="cuda", device="cuda",
...@@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8( ...@@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8(
expert_map[start:end] = list(range(num_local_experts)) expert_map[start:end] = list(range(num_local_experts))
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i) activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale, a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
torch.float8_e4m3fn, torch.float8_e4m3fn,
...@@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8( ...@@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8(
func = lambda output: run_cutlass_moe_fp8( func = lambda output: run_cutlass_moe_fp8(
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation, output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale, global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype, a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
per_act_token, per_out_channel, False) workspace13, workspace2, None, mt.a.dtype, per_act_token,
per_out_channel, False, topk_weights)
workspace13.random_() workspace13.random_()
output_random_workspace = torch.empty(output_shape, output_random_workspace = torch.empty(output_shape,
......
...@@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, ...@@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
atol=0, atol=0,
rtol=0) rtol=0)
# check mindice # check mindice
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0) # current kernel usage assumes deepgemm requires align_block_size
# when it's not provided then we don't compute m_indices (for cutlass)
if align_block_size is not None:
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
# check permuted_hidden_states, only valid token # check permuted_hidden_states, only valid token
torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx], torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx],
permuted_hidden_states[valid_row_idx], permuted_hidden_states[valid_row_idx],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment