Commit 7e63ef82 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0' into v0.14.0-dev

parents 8cbcac5d b17039bc
model_name: "RedHatAI/Qwen3-30B-A3B-FP8-dynamic"
accuracy_threshold: 0.85
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel"
env:
VLLM_USE_FLASHINFER_MOE_FP4: "1"
VLLM_FLASHINFER_MOE_BACKEND: "throughput"
model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
env:
VLLM_USE_FLASHINFER_MOE_FP4: "1"
VLLM_FLASHINFER_MOE_BACKEND: "throughput"
model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
env:
VLLM_USE_FLASHINFER_MOE_FP4: "1"
VLLM_FLASHINFER_MOE_BACKEND: "latency"
model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
env:
VLLM_TEST_FORCE_FP8_MARLIN: "1"
model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
model_name: "nvidia/Qwen3-30B-A3B-NVFP4"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel"
env:
VLLM_USE_FLASHINFER_MOE_FP4: "1"
VLLM_FLASHINFER_MOE_BACKEND: "throughput"
model_name: "nvidia/Qwen3-30B-A3B-NVFP4"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
env:
VLLM_USE_FLASHINFER_MOE_FP4: "1"
VLLM_FLASHINFER_MOE_BACKEND: "throughput"
model_name: "nvidia/Qwen3-30B-A3B-NVFP4"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
env:
VLLM_USE_FLASHINFER_MOE_FP4: "1"
VLLM_FLASHINFER_MOE_BACKEND: "latency"
model_name: "nvidia/Qwen3-30B-A3B-NVFP4"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
env:
VLLM_TEST_FORCE_FP8_MARLIN: "1"
model_name: "nvidia/Qwen3-30B-A3B-NVFP4"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
Llama-4-Scout-Fp8-CT-vllm-cutlass.yaml
Llama-4-Scout-Fp8-ModelOpt-fi-trtllm.yaml
Qwen3-30B-A3B-Fp8-AutoFp8-fi-trtllm.yaml
Qwen3-30B-A3B-NvFp4-CT-vllm-cutlass.yaml
Qwen3-30B-A3B-NvFp4-CT-marlin.yaml
Qwen3-30B-A3B-NvFp4-CT-fi-trtllm.yaml
Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml
Qwen3-30B-A3B-NvFp4-CT-fi-cutlass-dp-ep.yaml
Qwen3-30B-A3B-NvFp4-ModelOpt-vllm-cutlass.yaml
Qwen3-30B-A3B-NvFp4-ModelOpt-marlin.yaml
Qwen3-30B-A3B-NvFp4-ModelOpt-fi-trtllm.yaml
Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutlass.yaml
Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutlass-dp-ep.yaml
Mixtral-8x7B-Fp8-AutoFp8-triton.yaml
Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm.yaml
Qwen3-30B-A3B-Fp8-AutoFp8-fi-cutlass.yaml
Qwen3-30B-A3B-Fp8-AutoFp8-marlin.yaml
Qwen3-30B-A3B-Fp8-AutoFp8-triton.yaml
Qwen3-30B-A3B-Fp8-CT-Block-deepgemm.yaml
Qwen3-30B-A3B-Fp8-CT-Block-marlin.yaml
Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml
Qwen3-30B-A3B-Fp8-CT-Channel-marlin.yaml
Qwen3-30B-A3B-Fp8-CT-Channel-vllm-cutlass.yaml
Llama-4-Scout-Fp8-ModelOpt-fi-cutlass.yaml
Llama-4-Scout-Fp8-ModelOpt-marlin.yaml
Llama-4-Scout-Fp8-ModelOpt-triton.yaml
Qwen3-30B-A3B-NvFp4-CT-marlin.yaml
\ No newline at end of file
...@@ -11,14 +11,12 @@ def pytest_addoption(parser): ...@@ -11,14 +11,12 @@ def pytest_addoption(parser):
default="configs/models-small.txt", default="configs/models-small.txt",
help="File containing list of config files to test", help="File containing list of config files to test",
) )
parser.addoption("--tp-size", default=1, type=int, help="Tensor parallel size")
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
"""Generate test parameters from config files.""" """Generate test parameters from config files."""
if "config_filename" in metafunc.fixturenames: if "config_filename" in metafunc.fixturenames:
config_list_file = metafunc.config.getoption("--config-list-file") config_list_file = metafunc.config.getoption("--config-list-file")
tp_size = metafunc.config.getoption("--tp-size")
# Handle both relative and absolute paths # Handle both relative and absolute paths
config_list_path = Path(config_list_file) config_list_path = Path(config_list_file)
...@@ -55,9 +53,9 @@ def pytest_generate_tests(metafunc): ...@@ -55,9 +53,9 @@ def pytest_generate_tests(metafunc):
# Generate test parameters # Generate test parameters
if config_files: if config_files:
metafunc.parametrize( metafunc.parametrize(
["config_filename", "tp_size"], "config_filename",
[(config_file, int(tp_size)) for config_file in config_files], config_files,
ids=[f"{config_file.stem}-tp{tp_size}" for config_file in config_files], ids=[config_file.stem for config_file in config_files],
) )
else: else:
print("No config files found, test will be skipped") print("No config files found, test will be skipped")
...@@ -5,30 +5,31 @@ GSM8K evaluation using vLLM server and isolated GSM8K script. ...@@ -5,30 +5,31 @@ GSM8K evaluation using vLLM server and isolated GSM8K script.
Replacement for lm-eval-harness with better performance and control. Replacement for lm-eval-harness with better performance and control.
Usage: Usage:
pytest -s -v test_gsm8k_correctness.py \ pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py \
--config-list-file=configs/models-small.txt \ --config-list-file=configs/models-small.txt
--tp-size=1
""" """
import shlex
import yaml import yaml
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from .gsm8k_eval import evaluate_gsm8k from .gsm8k_eval import evaluate_gsm8k
RTOL = 0.08 # Relative tolerance for accuracy comparison TOL = 0.08 # Absolute tolerance for accuracy comparison
def launch_gsm8k_eval(eval_config, server_url, tp_size): def run_gsm8k_eval(eval_config: dict, server_url: str) -> dict:
"""Launch GSM8K evaluation using our isolated script.""" """Run GSM8K evaluation using our isolated script."""
# Extract host and port from server URL # Extract host and port from server URL
if "://" in server_url: if "://" in server_url:
server_url = server_url.split("://")[1] server_url = server_url.split("://")[1]
host_port = server_url.split("/")[0] # Remove path if present host_port = server_url.split("/")[0] # Remove path if present
if ":" in host_port: if ":" in host_port:
host, port = host_port.split(":") host, p = host_port.split(":")
port = int(port) port = int(p)
else: else:
host = host_port host = host_port
port = 8000 port = 8000
...@@ -48,46 +49,59 @@ def launch_gsm8k_eval(eval_config, server_url, tp_size): ...@@ -48,46 +49,59 @@ def launch_gsm8k_eval(eval_config, server_url, tp_size):
return results return results
def test_gsm8k_correctness_param(config_filename, tp_size): def test_gsm8k_correctness(config_filename):
"""Test GSM8K correctness for a given model configuration.""" """Test GSM8K correctness for a given model configuration."""
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8")) eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
# Server arguments # Parse server arguments from config (use shlex to handle quoted strings)
server_args = [ server_args_str = eval_config.get("server_args", "")
"--max-model-len", server_args = shlex.split(server_args_str) if server_args_str else []
str(eval_config.get("max_model_len", 4096)),
"--enforce-eager", # Add standard server arguments
"--trust-remote-code", server_args.extend(
"--tensor-parallel-size", [
str(tp_size), "--trust-remote-code",
] "--disable-uvicorn-access-log",
]
)
env_dict = eval_config.get("env", None) env_dict = eval_config.get("env", None)
print(f"Starting GSM8K evaluation for model: {eval_config['model_name']}")
print(f"Expected metric threshold: {eval_config['accuracy_threshold']}")
print(f"Number of questions: {eval_config['num_questions']}")
print(f"Number of few-shot examples: {eval_config['num_fewshot']}")
print(f"Server args: {' '.join(server_args)}")
print(f"Environment variables: {env_dict}")
# Launch server and run evaluation # Launch server and run evaluation
with RemoteOpenAIServer( with RemoteOpenAIServer(
eval_config["model_name"], server_args, env_dict=env_dict, max_wait_seconds=480 eval_config["model_name"],
server_args,
env_dict=env_dict,
max_wait_seconds=eval_config.get("startup_max_wait_seconds", 600),
) as remote_server: ) as remote_server:
server_url = remote_server.url_for("v1") server_url = remote_server.url_for("v1")
print(f"Server started at: {server_url}")
results = launch_gsm8k_eval(eval_config, server_url, tp_size) results = run_gsm8k_eval(eval_config, server_url)
# Check accuracy against threshold measured_metric = results["accuracy"]
measured_accuracy = results["accuracy"] expected_metric = eval_config["accuracy_threshold"]
expected_accuracy = eval_config["accuracy_threshold"]
print(f"GSM8K Results for {eval_config['model_name']}:") print(f"GSM8K Results for {eval_config['model_name']}:")
print(f" Accuracy: {measured_accuracy:.3f}") print(f" Measured metric: {measured_metric:.4f}")
print(f" Expected: {expected_accuracy:.3f}") print(f" Expected metric: {expected_metric:.4f}")
print(f" Tolerance: {TOL:.4f}")
print(f" Questions: {results['num_questions']}") print(f" Questions: {results['num_questions']}")
print(f" Invalid rate: {results['invalid_rate']:.3f}") print(f" Invalid rate: {results['invalid_rate']:.3f}")
print(f" Latency: {results['latency']:.1f}s") print(f" Latency: {results['latency']:.1f}s")
print(f" QPS: {results['questions_per_second']:.1f}") print(f" QPS: {results['questions_per_second']:.1f}")
# Verify accuracy is within tolerance # Verify metric is within tolerance
assert measured_accuracy >= expected_accuracy - RTOL, ( assert measured_metric >= expected_metric - TOL, (
f"Accuracy too low: {measured_accuracy:.3f} < " f"GSM8K metric too low: {measured_metric:.4f} < "
f"{expected_accuracy:.3f} - {RTOL:.3f}" f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}"
) )
print(f"✅ GSM8K test passed for {eval_config['model_name']}") print(f"✅ GSM8K test passed for {eval_config['model_name']}")
...@@ -6,8 +6,9 @@ import pytest ...@@ -6,8 +6,9 @@ import pytest
import torch import torch
import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401 import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401
from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
NUM_HEADS = [(4, 4), (8, 2)] NUM_HEADS = [(4, 4), (8, 2)]
HEAD_SIZES = [128, 256] HEAD_SIZES = [128, 256]
...@@ -104,7 +105,7 @@ def test_varlen_with_paged_kv( ...@@ -104,7 +105,7 @@ def test_varlen_with_paged_kv(
if not is_flash_attn_varlen_func_available(): if not is_flash_attn_varlen_func_available():
pytest.skip("flash_attn_varlen_func required to run this test.") pytest.skip("flash_attn_varlen_func required to run this test.")
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(0) set_random_seed(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens] query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens] kv_lens = [x[1] for x in seq_lens]
......
...@@ -9,9 +9,11 @@ import torch ...@@ -9,9 +9,11 @@ import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.layer import Attention, MultiHeadAttention from vllm.attention.layer import Attention
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.mem_utils import get_max_shared_memory_bytes from vllm.utils.mem_utils import get_max_shared_memory_bytes
from vllm.utils.torch_utils import set_random_seed
if current_platform.is_rocm(): if current_platform.is_rocm():
from flash_attn import vllm_flash_attn_with_kvcache from flash_attn import vllm_flash_attn_with_kvcache
...@@ -31,7 +33,7 @@ NUM_PREFILL_SEQS = [3] # Arbitrary values for testing ...@@ -31,7 +33,7 @@ NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
# This should be sync with get_supported_head_sizes() in # This should be sync with get_supported_head_sizes() in
# vllm.attention.ops.paged_attn.PagedAttention # vllm.v1.attention.ops.paged_attn.PagedAttention
HEAD_SIZES = [32, 80, 128, 256] HEAD_SIZES = [32, 80, 128, 256]
BLOCK_SIZES = [16, 32] BLOCK_SIZES = [16, 32]
...@@ -152,7 +154,7 @@ def test_paged_attention( ...@@ -152,7 +154,7 @@ def test_paged_attention(
global PARTITION_SIZE global PARTITION_SIZE
current_platform.seed_everything(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads num_query_heads, num_kv_heads = num_heads
...@@ -445,7 +447,7 @@ def ref_multi_query_kv_attention( ...@@ -445,7 +447,7 @@ def ref_multi_query_kv_attention(
return torch.cat(ref_outputs, dim=0) return torch.cat(ref_outputs, dim=0)
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention]) @pytest.mark.parametrize("attention_cls", [Attention, MMEncoderAttention])
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None: def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
head_size = 64 head_size = 64
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
......
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")] COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
DTYPES = [torch.bfloat16, torch.float] DTYPES = [torch.bfloat16, torch.float]
...@@ -41,93 +42,6 @@ KV_CACHE_DTYPE = ["auto"] ...@@ -41,93 +42,6 @@ KV_CACHE_DTYPE = ["auto"]
RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"] RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_copy_blocks(
kv_cache_factory,
num_mappings: int,
num_layers: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
kv_cache_dtype: str,
device: str,
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
current_platform.seed_everything(seed)
torch.set_default_device(device)
torch.cuda.set_device(device)
# Generate random block mappings where each source block is mapped to two
# destination blocks.
assert 2 * num_mappings <= num_blocks
src_blocks = random.sample(range(num_blocks), num_mappings)
remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remaining_blocks, 2 * num_mappings)
block_mapping: list[tuple[int, int]] = []
for i in range(num_mappings):
src = src_blocks[i]
dst1 = dst_blocks[2 * i]
dst2 = dst_blocks[2 * i + 1]
block_mapping.append((src, dst1))
block_mapping.append((src, dst2))
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(
num_blocks,
block_size,
num_layers,
num_heads,
head_size,
kv_cache_dtype,
dtype,
seed,
device,
)
# Clone the KV caches.
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
# Call the copy blocks kernel.
block_mapping_tensor = torch.tensor(
block_mapping, dtype=torch.int64, device=device
).view(-1, 2)
opcheck(
torch.ops._C_cache_ops.copy_blocks,
(key_caches, value_caches, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
cond=(head_size == HEAD_SIZES[0]),
)
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
# Run the reference implementation.
for src, dst in block_mapping:
for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches:
cloned_value_cache[dst].copy_(cloned_value_cache[src])
# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
torch.testing.assert_close(key_cache, cloned_key_cache)
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
torch.testing.assert_close(value_cache, cloned_value_cache)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
...@@ -152,7 +66,7 @@ def test_reshape_and_cache( ...@@ -152,7 +66,7 @@ def test_reshape_and_cache(
) -> None: ) -> None:
if kv_cache_dtype == "fp8" and head_size % 16: if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip() pytest.skip()
current_platform.seed_everything(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.cuda.set_device(device) torch.cuda.set_device(device)
# Create a random slot mapping. # Create a random slot mapping.
...@@ -273,7 +187,7 @@ def test_reshape_and_cache_flash( ...@@ -273,7 +187,7 @@ def test_reshape_and_cache_flash(
kv_cache_layout: str, kv_cache_layout: str,
implementation: str, implementation: str,
) -> None: ) -> None:
current_platform.seed_everything(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.cuda.set_device(device) torch.cuda.set_device(device)
assert implementation in ["cuda", "triton"] assert implementation in ["cuda", "triton"]
...@@ -357,7 +271,7 @@ def test_reshape_and_cache_flash( ...@@ -357,7 +271,7 @@ def test_reshape_and_cache_flash(
v_scale, v_scale,
) )
elif implementation == "triton": elif implementation == "triton":
from vllm.attention.ops.triton_reshape_and_cache_flash import ( from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash, triton_reshape_and_cache_flash,
) )
...@@ -443,7 +357,7 @@ def test_swap_blocks( ...@@ -443,7 +357,7 @@ def test_swap_blocks(
if kv_cache_dtype == "fp8" and head_size % 16: if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip() pytest.skip()
current_platform.seed_everything(seed) set_random_seed(seed)
src_device = device if direction[0] == "cuda" else "cpu" src_device = device if direction[0] == "cuda" else "cpu"
dst_device = device if direction[1] == "cuda" else "cpu" dst_device = device if direction[1] == "cuda" else "cpu"
...@@ -534,7 +448,7 @@ def test_fp8_e4m3_conversion( ...@@ -534,7 +448,7 @@ def test_fp8_e4m3_conversion(
seed: int, seed: int,
device: str, device: str,
) -> None: ) -> None:
current_platform.seed_everything(seed) set_random_seed(seed)
low = -224.0 low = -224.0
high = 224.0 high = 224.0
...@@ -597,7 +511,7 @@ def test_concat_and_cache_mla( ...@@ -597,7 +511,7 @@ def test_concat_and_cache_mla(
device: str, device: str,
kv_cache_dtype: str, kv_cache_dtype: str,
) -> None: ) -> None:
current_platform.seed_everything(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.cuda.set_device(device) torch.cuda.set_device(device)
...@@ -674,7 +588,7 @@ def test_concat_and_cache_ds_mla( ...@@ -674,7 +588,7 @@ def test_concat_and_cache_ds_mla(
if dtype.itemsize != 2: if dtype.itemsize != 2:
pytest.skip("ds_mla only supports 16-bit input") pytest.skip("ds_mla only supports 16-bit input")
kv_cache_dtype = "fp8_ds_mla" kv_cache_dtype = "fp8_ds_mla"
current_platform.seed_everything(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.cuda.set_device(device) torch.cuda.set_device(device)
...@@ -766,73 +680,6 @@ def test_concat_and_cache_ds_mla( ...@@ -766,73 +680,6 @@ def test_concat_and_cache_ds_mla(
torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1) torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_copy_blocks_mla(
kv_lora_rank: int,
qk_rope_head_dim: int,
block_size: int,
num_blocks: int,
num_layers: int,
dtype: torch.dtype,
seed: int,
device: str,
kv_cache_dtype: str,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
torch.cuda.set_device(device)
entry_size = kv_lora_rank + qk_rope_head_dim
kv_caches = []
for _ in range(num_layers):
kv_cache = _create_mla_cache(
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
)
_fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
kv_caches.append(kv_cache)
ref_caches = [kv_cache.clone() for kv_cache in kv_caches]
num_mappings = min(2, num_blocks // 2)
src_blocks = random.sample(range(num_blocks), num_mappings)
remaining = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remaining, 2 * num_mappings)
block_mapping = []
for i in range(num_mappings):
src = src_blocks[i]
dst1 = dst_blocks[2 * i]
dst2 = dst_blocks[2 * i + 1]
block_mapping.append((src, dst1))
block_mapping.append((src, dst2))
block_mapping_tensor = torch.tensor(
block_mapping, dtype=torch.int64, device=device
).view(-1, 2)
for src, dst in block_mapping:
for ref_cache in ref_caches:
ref_cache[dst].copy_(ref_cache[src])
opcheck(
torch.ops._C_cache_ops.copy_blocks_mla,
(kv_caches, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.copy_blocks_mla(kv_caches, block_mapping_tensor)
for kv_cache, ref_cache in zip(kv_caches, ref_caches):
torch.testing.assert_close(kv_cache, ref_cache)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) @pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) @pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) @pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
...@@ -852,7 +699,7 @@ def test_swap_blocks_mla( ...@@ -852,7 +699,7 @@ def test_swap_blocks_mla(
device: str, device: str,
kv_cache_dtype: str, kv_cache_dtype: str,
) -> None: ) -> None:
current_platform.seed_everything(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.cuda.set_device(device) torch.cuda.set_device(device)
...@@ -1104,7 +951,7 @@ def test_concat_and_cache_mla_cpu( ...@@ -1104,7 +951,7 @@ def test_concat_and_cache_mla_cpu(
) -> None: ) -> None:
device = "cpu" device = "cpu"
kv_cache_dtype = "auto" kv_cache_dtype = "auto"
current_platform.seed_everything(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
total_slots = num_blocks * block_size total_slots = num_blocks * block_size
......
...@@ -6,6 +6,7 @@ import pytest ...@@ -6,6 +6,7 @@ import pytest
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.flash_attn import cascade_attention, merge_attn_states from vllm.v1.attention.backends.flash_attn import cascade_attention, merge_attn_states
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -46,7 +47,7 @@ def test_merge_kernel( ...@@ -46,7 +47,7 @@ def test_merge_kernel(
dtype: torch.dtype, dtype: torch.dtype,
): ):
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(0) set_random_seed(0)
num_query_heads = num_heads[0] num_query_heads = num_heads[0]
num_kv_heads = num_heads[1] num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0 assert num_query_heads % num_kv_heads == 0
...@@ -110,7 +111,7 @@ CASES = [ ...@@ -110,7 +111,7 @@ CASES = [
# f'to: "{fa_version_unsupported_reason(fa_version)}"' # f'to: "{fa_version_unsupported_reason(fa_version)}"'
# ) # )
# current_platform.seed_everything(0) # set_random_seed(0)
# window_size = (-1, -1) # window_size = (-1, -1)
# scale = head_size**-0.5 # scale = head_size**-0.5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment