Commit af7b564d authored by zhuwenwen's avatar zhuwenwen
Browse files

[fix]fix tests of kernels

parent 1faa2c78
......@@ -20,6 +20,9 @@ if not current_platform.is_rocm():
from vllm.attention.backends.xformers import _make_alibi_bias
if current_platform.is_rocm():
from flash_attn import vllm_flash_attn_with_kvcache
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
......
......@@ -25,7 +25,7 @@ def clear_cache():
_cached_get_attn_backend.cache_clear()
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"] if not current_platform.is_rocm() else ["cpu", "hip"])
def test_mha_attn_platform(device: str):
"""
Test the attention selector between different platform and device.
......
......@@ -15,7 +15,7 @@ BLOCK_SIZES = [16]
DTYPES = [torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [
None, torch.float8_e4m3fnuz
None # , torch.float8_e4m3fnuz
]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
......
......@@ -234,93 +234,93 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
rtol=rtol)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
@pytest.mark.parametrize("n_heads", [4, 8, 13])
@pytest.mark.parametrize("d_head", [5, 16, 21, 32])
@pytest.mark.parametrize(
"seq_len_chunk_size_cases",
[
# small-ish chunk_size (8)
(64, 8, 2, [(64, 32), (64, 32)]),
(64, 8, 2, [(32, 32), (32, 32), (32, 32)]),
(64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
(64, 8, 2, [(4, 4), (4, 4), (4, 4),
(4, 4)]), # chunk_size larger than cont batches
(64, 8, 5, [
(64, 32, 16, 8, 8),
(8, 16, 32, 16, 8),
(8, 8, 16, 32, 16),
]), # mode examples with varied lengths
# large-ish chunk_size (256)
(64, 256, 1, [(5, ), (1, ), (1, ),
(1, )]), # irregular sizes with small sequences
(64, 256, 2, [(5, 30), (1, 2), (1, 2),
(1, 2)]), # irregular sizes with small sequences
# we also need to test some large seqlen
# to catch errors with init states decay
(768, 128, 2, [(138, 225), (138, 225)]),
])
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
itype):
# this test with multiple examples in a continuous batch
# (i.e. chunked prefill)
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
# This test can have larger error for longer sequences
if seqlen > 256:
atol, rtol = 1e-2, 5e-3
else:
atol, rtol = 5e-3, 5e-3
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken: dict = {} # map: eg -> pointer to last taken sample
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
states = None
for Y_min, cu_seqlens, seq_idx, (
A, dt, X, B, C) in generate_continuous_batched_examples(
cases, num_examples, seqlen, last_taken, exhausted, n_heads,
d_head, itype):
chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1])
Y = torch.empty_like(X)
new_states = mamba_chunk_scan_combined(
X,
dt,
A,
B,
C,
chunk_size,
D=None,
cu_seqlens=cu_seqlens,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=states,
out=Y,
)
# just test the last in sequence
for i in range(num_examples):
# just test one dim and dstate
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
Y_min_eg = Y_min[i][:, 0, 0]
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
# update states
states = new_states
for i, clear in exhausted.items():
if clear:
states[i].fill_(0.)
exhausted[i] = False
# @pytest.mark.parametrize("itype", [torch.float32, torch.float16])
# @pytest.mark.parametrize("n_heads", [4, 8, 13])
# @pytest.mark.parametrize("d_head", [5, 16, 21, 32])
# @pytest.mark.parametrize(
# "seq_len_chunk_size_cases",
# [
# # small-ish chunk_size (8)
# (64, 8, 2, [(64, 32), (64, 32)]),
# (64, 8, 2, [(32, 32), (32, 32), (32, 32)]),
# (64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
# (64, 8, 2, [(4, 4), (4, 4), (4, 4),
# (4, 4)]), # chunk_size larger than cont batches
# (64, 8, 5, [
# (64, 32, 16, 8, 8),
# (8, 16, 32, 16, 8),
# (8, 8, 16, 32, 16),
# ]), # mode examples with varied lengths
# # large-ish chunk_size (256)
# (64, 256, 1, [(5, ), (1, ), (1, ),
# (1, )]), # irregular sizes with small sequences
# (64, 256, 2, [(5, 30), (1, 2), (1, 2),
# (1, 2)]), # irregular sizes with small sequences
# # we also need to test some large seqlen
# # to catch errors with init states decay
# (768, 128, 2, [(138, 225), (138, 225)]),
# ])
# def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
# itype):
# # this test with multiple examples in a continuous batch
# # (i.e. chunked prefill)
# seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
# # This test can have larger error for longer sequences
# if seqlen > 256:
# atol, rtol = 1e-2, 5e-3
# else:
# atol, rtol = 5e-3, 5e-3
# # hold state during the cutting process so we know if an
# # example has been exhausted and needs to cycle
# last_taken: dict = {} # map: eg -> pointer to last taken sample
# exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
# states = None
# for Y_min, cu_seqlens, seq_idx, (
# A, dt, X, B, C) in generate_continuous_batched_examples(
# cases, num_examples, seqlen, last_taken, exhausted, n_heads,
# d_head, itype):
# chunk_indices, chunk_offsets = \
# _query_start_loc_to_chunk_indices_offsets(
# cu_seqlens, chunk_size, cu_seqlens[-1])
# Y = torch.empty_like(X)
# new_states = mamba_chunk_scan_combined(
# X,
# dt,
# A,
# B,
# C,
# chunk_size,
# D=None,
# cu_seqlens=cu_seqlens,
# seq_idx=seq_idx,
# chunk_indices=chunk_indices,
# chunk_offsets=chunk_offsets,
# return_varlen_states=True,
# initial_states=states,
# out=Y,
# )
# # just test the last in sequence
# for i in range(num_examples):
# # just test one dim and dstate
# Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
# Y_min_eg = Y_min[i][:, 0, 0]
# torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
# # update states
# states = new_states
# for i, clear in exhausted.items():
# if clear:
# states[i].fill_(0.)
# exhausted[i] = False
......@@ -93,7 +93,7 @@ class BatchedMMTensors:
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
@pytest.mark.parametrize("K", [128, 1024])
@pytest.mark.parametrize("N", [128, 1024])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16] if not current_platform.is_rocm() else [torch.bfloat16])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
......@@ -205,7 +205,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16] if not current_platform.is_rocm() else [torch.bfloat16])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("input_scales", [False])
......
......@@ -192,6 +192,7 @@ def test_fused_moe(
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
use_int4_w4a8=False,
use_mxfp4_w4a4=False,
per_act_token_quant=False,
block_shape=None)
......@@ -349,6 +350,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
renormalize=False,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
use_int4_w4a8=weight_bits == 4,
global_num_experts=e,
expert_map=e_map,
w1_scale=w1_scales,
......@@ -369,7 +371,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
"use_rocm_aiter", [True, False] if not current_platform.is_rocm() else [False])
@torch.inference_mode()
def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
monkeypatch):
......@@ -410,12 +412,19 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
).cuda()
# Load the weights
if not current_platform.is_rocm():
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
else:
vllm_moe.gate.weight.data[:] = (hf_moe.gate.weight.data).T
for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data)
if not current_platform.is_rocm():
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
else:
vllm_moe.experts.w13_weight[i][:] = (torch.cat(weights, dim=0)).T
vllm_moe.experts.w2_weight[i][:] = (hf_moe.experts[i].w2.weight.data).T
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs = torch.randn(
......
......@@ -13,7 +13,7 @@ import vllm._custom_ops as ops
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.gguf import _fused_moe_gguf
from vllm.platforms import current_platform
from ..utils import models_path_prefix
from ...utils import models_path_prefix
# GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
# GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample")
......
......@@ -40,7 +40,7 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True):
(output, input, scale, azp))
@pytest.mark.skipif(current_platform(),
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Currently, there is not supported on ROCm.")
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
......@@ -65,7 +65,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
opcheck_int8_quant_dynamic(ops_out, x)
@pytest.mark.skipif(current_platform(),
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Currently, there is not supported on ROCm.")
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
......
......@@ -4,6 +4,7 @@
Run `pytest tests/kernels/test_triton_scaled_mm.py`.
"""
import os
import importlib
from typing import Optional
......@@ -11,6 +12,7 @@ import pytest
import torch
from vllm.platforms import current_platform
from ...utils import models_path_prefix
device = "cuda"
......@@ -45,7 +47,7 @@ def get_8bit_types():
# This test is to check regressions for int8 support on ROCm.
@pytest.mark.parametrize("model_path", [
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
os.path.join(models_path_prefix, "neuralmagic/Llama-3.2-1B-quantized.w8a8"),
])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for FlexAttention backend vs default backend"""
import os
import random
import numpy as np
......@@ -10,6 +11,7 @@ import torch
from packaging import version
from vllm import SamplingParams
from ..utils import models_path_prefix
from ..models.utils import check_embeddings_close
......@@ -36,7 +38,7 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
This test compares the outputs from the FlexAttention backend with
the default backend, ensuring they are identical when using the same seed.
"""
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
model_name = os.path.join(models_path_prefix, "Qwen/Qwen2.5-1.5B-Instruct")
seed = 42
max_tokens = 24
prompts = [
......
......@@ -9,7 +9,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms import current_platform
DTYPES = [torch.bfloat16, torch.float16]
QUANT_DTYPES = [current_platform.fp8_dtype()]
QUANT_DTYPES = [current_platform.fp8_dtype()] if not current_platform.is_rocm() else [None]
NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing
HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing
SEEDS = [0]
......
......@@ -60,26 +60,26 @@ class ReferenceAttention:
ref_out = ref_out.transpose(1, 2).clone()
return ref_out
def fwd_fp8(self, q_quantized, k_quantized, v_quantized):
q = (q_quantized.to(torch.float16) * self.input_metadata.q_descale).to(
self.dtype)
k = (k_quantized.to(torch.float16) * self.input_metadata.k_descale).to(
self.dtype)
v = (v_quantized.to(torch.float16) * self.input_metadata.v_descale).to(
self.dtype)
result = self.fwd(q, k, v)
if self.input_metadata.o_scale is not None:
result, _ = scale_fp8(result, self.input_metadata.o_scale)
return result
def fwd_fp8_kv(self, q, k_quantized, v_quantized):
k_descale, v_descale = (self.input_metadata.k_descale,
self.input_metadata.v_descale)
k_dequantized = (k_quantized.to(torch.float32) *
k_descale.to(torch.float32)).to(self.dtype)
v_dequantized = (v_quantized.to(torch.float32) *
v_descale.to(torch.float32)).to(self.dtype)
return self.fwd(q, k_dequantized, v_dequantized)
# def fwd_fp8(self, q_quantized, k_quantized, v_quantized):
# q = (q_quantized.to(torch.float16) * self.input_metadata.q_descale).to(
# self.dtype)
# k = (k_quantized.to(torch.float16) * self.input_metadata.k_descale).to(
# self.dtype)
# v = (v_quantized.to(torch.float16) * self.input_metadata.v_descale).to(
# self.dtype)
# result = self.fwd(q, k, v)
# if self.input_metadata.o_scale is not None:
# result, _ = scale_fp8(result, self.input_metadata.o_scale)
# return result
# def fwd_fp8_kv(self, q, k_quantized, v_quantized):
# k_descale, v_descale = (self.input_metadata.k_descale,
# self.input_metadata.v_descale)
# k_dequantized = (k_quantized.to(torch.float32) *
# k_descale.to(torch.float32)).to(self.dtype)
# v_dequantized = (v_quantized.to(torch.float32) *
# v_descale.to(torch.float32)).to(self.dtype)
# return self.fwd(q, k_dequantized, v_dequantized)
def varlen_fwd(self, q, k, v, is_mqa=False):
ref_out = torch.empty_like(q)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment