Commit af7b564d authored by zhuwenwen's avatar zhuwenwen
Browse files

[fix]fix tests of kernels

parent 1faa2c78
...@@ -20,6 +20,9 @@ if not current_platform.is_rocm(): ...@@ -20,6 +20,9 @@ if not current_platform.is_rocm():
from vllm.attention.backends.xformers import _make_alibi_bias from vllm.attention.backends.xformers import _make_alibi_bias
if current_platform.is_rocm():
from flash_attn import vllm_flash_attn_with_kvcache
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability. # This will change depending on the compute capability.
# - 512 as a buffer # - 512 as a buffer
......
...@@ -25,7 +25,7 @@ def clear_cache(): ...@@ -25,7 +25,7 @@ def clear_cache():
_cached_get_attn_backend.cache_clear() _cached_get_attn_backend.cache_clear()
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"] if not current_platform.is_rocm() else ["cpu", "hip"])
def test_mha_attn_platform(device: str): def test_mha_attn_platform(device: str):
""" """
Test the attention selector between different platform and device. Test the attention selector between different platform and device.
......
...@@ -15,7 +15,7 @@ BLOCK_SIZES = [16] ...@@ -15,7 +15,7 @@ BLOCK_SIZES = [16]
DTYPES = [torch.bfloat16] DTYPES = [torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [ QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [
None, torch.float8_e4m3fnuz None # , torch.float8_e4m3fnuz
] ]
# one value large enough to test overflow in index calculation. # one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check # one value small enough to test the schema op check
......
...@@ -234,93 +234,93 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, ...@@ -234,93 +234,93 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
rtol=rtol) rtol=rtol)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16]) # @pytest.mark.parametrize("itype", [torch.float32, torch.float16])
@pytest.mark.parametrize("n_heads", [4, 8, 13]) # @pytest.mark.parametrize("n_heads", [4, 8, 13])
@pytest.mark.parametrize("d_head", [5, 16, 21, 32]) # @pytest.mark.parametrize("d_head", [5, 16, 21, 32])
@pytest.mark.parametrize( # @pytest.mark.parametrize(
"seq_len_chunk_size_cases", # "seq_len_chunk_size_cases",
[ # [
# small-ish chunk_size (8) # # small-ish chunk_size (8)
(64, 8, 2, [(64, 32), (64, 32)]), # (64, 8, 2, [(64, 32), (64, 32)]),
(64, 8, 2, [(32, 32), (32, 32), (32, 32)]), # (64, 8, 2, [(32, 32), (32, 32), (32, 32)]),
(64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary # (64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
(64, 8, 2, [(4, 4), (4, 4), (4, 4), # (64, 8, 2, [(4, 4), (4, 4), (4, 4),
(4, 4)]), # chunk_size larger than cont batches # (4, 4)]), # chunk_size larger than cont batches
(64, 8, 5, [ # (64, 8, 5, [
(64, 32, 16, 8, 8), # (64, 32, 16, 8, 8),
(8, 16, 32, 16, 8), # (8, 16, 32, 16, 8),
(8, 8, 16, 32, 16), # (8, 8, 16, 32, 16),
]), # mode examples with varied lengths # ]), # mode examples with varied lengths
# large-ish chunk_size (256) # # large-ish chunk_size (256)
(64, 256, 1, [(5, ), (1, ), (1, ), # (64, 256, 1, [(5, ), (1, ), (1, ),
(1, )]), # irregular sizes with small sequences # (1, )]), # irregular sizes with small sequences
(64, 256, 2, [(5, 30), (1, 2), (1, 2), # (64, 256, 2, [(5, 30), (1, 2), (1, 2),
(1, 2)]), # irregular sizes with small sequences # (1, 2)]), # irregular sizes with small sequences
# we also need to test some large seqlen # # we also need to test some large seqlen
# to catch errors with init states decay # # to catch errors with init states decay
(768, 128, 2, [(138, 225), (138, 225)]), # (768, 128, 2, [(138, 225), (138, 225)]),
]) # ])
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, # def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
itype): # itype):
# this test with multiple examples in a continuous batch # # this test with multiple examples in a continuous batch
# (i.e. chunked prefill) # # (i.e. chunked prefill)
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases # seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
# This test can have larger error for longer sequences # # This test can have larger error for longer sequences
if seqlen > 256: # if seqlen > 256:
atol, rtol = 1e-2, 5e-3 # atol, rtol = 1e-2, 5e-3
else: # else:
atol, rtol = 5e-3, 5e-3 # atol, rtol = 5e-3, 5e-3
# hold state during the cutting process so we know if an # # hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle # # example has been exhausted and needs to cycle
last_taken: dict = {} # map: eg -> pointer to last taken sample # last_taken: dict = {} # map: eg -> pointer to last taken sample
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted # exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
states = None # states = None
for Y_min, cu_seqlens, seq_idx, ( # for Y_min, cu_seqlens, seq_idx, (
A, dt, X, B, C) in generate_continuous_batched_examples( # A, dt, X, B, C) in generate_continuous_batched_examples(
cases, num_examples, seqlen, last_taken, exhausted, n_heads, # cases, num_examples, seqlen, last_taken, exhausted, n_heads,
d_head, itype): # d_head, itype):
chunk_indices, chunk_offsets = \ # chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets( # _query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1]) # cu_seqlens, chunk_size, cu_seqlens[-1])
Y = torch.empty_like(X) # Y = torch.empty_like(X)
new_states = mamba_chunk_scan_combined( # new_states = mamba_chunk_scan_combined(
X, # X,
dt, # dt,
A, # A,
B, # B,
C, # C,
chunk_size, # chunk_size,
D=None, # D=None,
cu_seqlens=cu_seqlens, # cu_seqlens=cu_seqlens,
seq_idx=seq_idx, # seq_idx=seq_idx,
chunk_indices=chunk_indices, # chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets, # chunk_offsets=chunk_offsets,
return_varlen_states=True, # return_varlen_states=True,
initial_states=states, # initial_states=states,
out=Y, # out=Y,
) # )
# just test the last in sequence # # just test the last in sequence
for i in range(num_examples): # for i in range(num_examples):
# just test one dim and dstate # # just test one dim and dstate
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] # Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
Y_min_eg = Y_min[i][:, 0, 0] # Y_min_eg = Y_min[i][:, 0, 0]
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol) # torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
# update states # # update states
states = new_states # states = new_states
for i, clear in exhausted.items(): # for i, clear in exhausted.items():
if clear: # if clear:
states[i].fill_(0.) # states[i].fill_(0.)
exhausted[i] = False # exhausted[i] = False
...@@ -93,7 +93,7 @@ class BatchedMMTensors: ...@@ -93,7 +93,7 @@ class BatchedMMTensors:
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512]) @pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
@pytest.mark.parametrize("K", [128, 1024]) @pytest.mark.parametrize("K", [128, 1024])
@pytest.mark.parametrize("N", [128, 1024]) @pytest.mark.parametrize("N", [128, 1024])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16] if not current_platform.is_rocm() else [torch.bfloat16])
@pytest.mark.parametrize("block_shape", [None, [128, 128]]) @pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("per_act_token_quant", [False, True]) @pytest.mark.parametrize("per_act_token_quant", [False, True])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
...@@ -205,7 +205,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, ...@@ -205,7 +205,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS) @pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16] if not current_platform.is_rocm() else [torch.bfloat16])
@pytest.mark.parametrize("per_act_token_quant", [False, True]) @pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]]) @pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("input_scales", [False]) @pytest.mark.parametrize("input_scales", [False])
......
...@@ -192,6 +192,7 @@ def test_fused_moe( ...@@ -192,6 +192,7 @@ def test_fused_moe(
use_int8_w8a8=False, use_int8_w8a8=False,
use_int8_w8a16=False, use_int8_w8a16=False,
use_int4_w4a16=False, use_int4_w4a16=False,
use_int4_w4a8=False,
use_mxfp4_w4a4=False, use_mxfp4_w4a4=False,
per_act_token_quant=False, per_act_token_quant=False,
block_shape=None) block_shape=None)
...@@ -349,6 +350,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, ...@@ -349,6 +350,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
renormalize=False, renormalize=False,
use_int4_w4a16=weight_bits == 4, use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8, use_int8_w8a16=weight_bits == 8,
use_int4_w4a8=weight_bits == 4,
global_num_experts=e, global_num_experts=e,
expert_map=e_map, expert_map=e_map,
w1_scale=w1_scales, w1_scale=w1_scales,
...@@ -369,7 +371,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, ...@@ -369,7 +371,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False]) @pytest.mark.parametrize("padding", [True, False])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) "use_rocm_aiter", [True, False] if not current_platform.is_rocm() else [False])
@torch.inference_mode() @torch.inference_mode()
def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
monkeypatch): monkeypatch):
...@@ -410,12 +412,19 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, ...@@ -410,12 +412,19 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
).cuda() ).cuda()
# Load the weights # Load the weights
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data if not current_platform.is_rocm():
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
else:
vllm_moe.gate.weight.data[:] = (hf_moe.gate.weight.data).T
for i in range(config.num_local_experts): for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data, weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data) hf_moe.experts[i].w3.weight.data)
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) if not current_platform.is_rocm():
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
else:
vllm_moe.experts.w13_weight[i][:] = (torch.cat(weights, dim=0)).T
vllm_moe.experts.w2_weight[i][:] = (hf_moe.experts[i].w2.weight.data).T
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim] # Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs = torch.randn( hf_inputs = torch.randn(
......
...@@ -13,7 +13,7 @@ import vllm._custom_ops as ops ...@@ -13,7 +13,7 @@ import vllm._custom_ops as ops
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.gguf import _fused_moe_gguf from vllm.model_executor.layers.quantization.gguf import _fused_moe_gguf
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import models_path_prefix from ...utils import models_path_prefix
# GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample") # GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
# GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample") # GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample")
......
...@@ -40,7 +40,7 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True): ...@@ -40,7 +40,7 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True):
(output, input, scale, azp)) (output, input, scale, azp))
@pytest.mark.skipif(current_platform(), @pytest.mark.skipif(current_platform.is_rocm(),
reason="Currently, there is not supported on ROCm.") reason="Currently, there is not supported on ROCm.")
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
...@@ -65,7 +65,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, ...@@ -65,7 +65,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
opcheck_int8_quant_dynamic(ops_out, x) opcheck_int8_quant_dynamic(ops_out, x)
@pytest.mark.skipif(current_platform(), @pytest.mark.skipif(current_platform.is_rocm(),
reason="Currently, there is not supported on ROCm.") reason="Currently, there is not supported on ROCm.")
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
Run `pytest tests/kernels/test_triton_scaled_mm.py`. Run `pytest tests/kernels/test_triton_scaled_mm.py`.
""" """
import os
import importlib import importlib
from typing import Optional from typing import Optional
...@@ -11,6 +12,7 @@ import pytest ...@@ -11,6 +12,7 @@ import pytest
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ...utils import models_path_prefix
device = "cuda" device = "cuda"
...@@ -45,7 +47,7 @@ def get_8bit_types(): ...@@ -45,7 +47,7 @@ def get_8bit_types():
# This test is to check regressions for int8 support on ROCm. # This test is to check regressions for int8 support on ROCm.
@pytest.mark.parametrize("model_path", [ @pytest.mark.parametrize("model_path", [
"neuralmagic/Llama-3.2-1B-quantized.w8a8", os.path.join(models_path_prefix, "neuralmagic/Llama-3.2-1B-quantized.w8a8"),
]) ])
@pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10]) @pytest.mark.parametrize("num_logprobs", [10])
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for FlexAttention backend vs default backend""" """Integration tests for FlexAttention backend vs default backend"""
import os
import random import random
import numpy as np import numpy as np
...@@ -10,6 +11,7 @@ import torch ...@@ -10,6 +11,7 @@ import torch
from packaging import version from packaging import version
from vllm import SamplingParams from vllm import SamplingParams
from ..utils import models_path_prefix
from ..models.utils import check_embeddings_close from ..models.utils import check_embeddings_close
...@@ -36,7 +38,7 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): ...@@ -36,7 +38,7 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
This test compares the outputs from the FlexAttention backend with This test compares the outputs from the FlexAttention backend with
the default backend, ensuring they are identical when using the same seed. the default backend, ensuring they are identical when using the same seed.
""" """
model_name = "Qwen/Qwen2.5-1.5B-Instruct" model_name = os.path.join(models_path_prefix, "Qwen/Qwen2.5-1.5B-Instruct")
seed = 42 seed = 42
max_tokens = 24 max_tokens = 24
prompts = [ prompts = [
......
...@@ -9,7 +9,7 @@ from vllm.model_executor.layers.activation import SiluAndMul ...@@ -9,7 +9,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms import current_platform from vllm.platforms import current_platform
DTYPES = [torch.bfloat16, torch.float16] DTYPES = [torch.bfloat16, torch.float16]
QUANT_DTYPES = [current_platform.fp8_dtype()] QUANT_DTYPES = [current_platform.fp8_dtype()] if not current_platform.is_rocm() else [None]
NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing
HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing
SEEDS = [0] SEEDS = [0]
......
...@@ -60,26 +60,26 @@ class ReferenceAttention: ...@@ -60,26 +60,26 @@ class ReferenceAttention:
ref_out = ref_out.transpose(1, 2).clone() ref_out = ref_out.transpose(1, 2).clone()
return ref_out return ref_out
def fwd_fp8(self, q_quantized, k_quantized, v_quantized): # def fwd_fp8(self, q_quantized, k_quantized, v_quantized):
q = (q_quantized.to(torch.float16) * self.input_metadata.q_descale).to( # q = (q_quantized.to(torch.float16) * self.input_metadata.q_descale).to(
self.dtype) # self.dtype)
k = (k_quantized.to(torch.float16) * self.input_metadata.k_descale).to( # k = (k_quantized.to(torch.float16) * self.input_metadata.k_descale).to(
self.dtype) # self.dtype)
v = (v_quantized.to(torch.float16) * self.input_metadata.v_descale).to( # v = (v_quantized.to(torch.float16) * self.input_metadata.v_descale).to(
self.dtype) # self.dtype)
result = self.fwd(q, k, v) # result = self.fwd(q, k, v)
if self.input_metadata.o_scale is not None: # if self.input_metadata.o_scale is not None:
result, _ = scale_fp8(result, self.input_metadata.o_scale) # result, _ = scale_fp8(result, self.input_metadata.o_scale)
return result # return result
def fwd_fp8_kv(self, q, k_quantized, v_quantized): # def fwd_fp8_kv(self, q, k_quantized, v_quantized):
k_descale, v_descale = (self.input_metadata.k_descale, # k_descale, v_descale = (self.input_metadata.k_descale,
self.input_metadata.v_descale) # self.input_metadata.v_descale)
k_dequantized = (k_quantized.to(torch.float32) * # k_dequantized = (k_quantized.to(torch.float32) *
k_descale.to(torch.float32)).to(self.dtype) # k_descale.to(torch.float32)).to(self.dtype)
v_dequantized = (v_quantized.to(torch.float32) * # v_dequantized = (v_quantized.to(torch.float32) *
v_descale.to(torch.float32)).to(self.dtype) # v_descale.to(torch.float32)).to(self.dtype)
return self.fwd(q, k_dequantized, v_dequantized) # return self.fwd(q, k_dequantized, v_dequantized)
def varlen_fwd(self, q, k, v, is_mqa=False): def varlen_fwd(self, q, k, v, is_mqa=False):
ref_out = torch.empty_like(q) ref_out = torch.empty_like(q)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment