Commit e150cf11 authored by zhuwenwen's avatar zhuwenwen
Browse files

added support for kernels tests with torch 2.3

parent a3d96521
......@@ -3,13 +3,13 @@ from typing import Type
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul,
NewGELU, QuickGELU,
SiluAndMul)
from vllm.utils import seed_everything
from .allclose_default import get_default_atol, get_default_rtol
from .utils import torch_version
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
......@@ -49,6 +49,11 @@ def test_act_and_mul(
fn = torch.ops._C.gelu_tanh_and_mul
out = layer(x)
ref_out = layer.forward_native(x)
if torch_version.startswith("2.3"):
assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0)
elif torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
......@@ -57,6 +62,8 @@ def test_act_and_mul(
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
opcheck(fn, (out, x))
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
......@@ -83,6 +90,14 @@ def test_activation(
fn = activation[1]
out = layer(x)
ref_out = layer.forward_native(x)
if torch_version.startswith("2.3"):
assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
elif torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
torch.testing.assert_close(out,
ref_out,
atol=get_default_atol(out),
......@@ -90,3 +105,5 @@ def test_activation(
out = torch.empty_like(x)
opcheck(fn, (out, x))
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
......@@ -4,11 +4,11 @@ from typing import List, Optional, Tuple
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
from .allclose_default import get_default_atol, get_default_rtol
from .utils import torch_version
if not is_hip():
from xformers import ops as xops
......@@ -186,6 +186,25 @@ def test_paged_attention(
# Call the paged attention kernel.
output = torch.empty_like(query)
if version == "v1":
if torch_version.startswith("2.3"):
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
elif torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
ops.paged_attention_v1(
output,
query,
......@@ -209,6 +228,8 @@ def test_paged_attention(
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, None, 0),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
elif version in ("v2", "rocm"):
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
......@@ -224,6 +245,28 @@ def test_paged_attention(
)
max_logits = torch.empty_like(exp_sums)
if version == "v2":
if torch_version.startswith("2.3"):
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
elif torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
ops.paged_attention_v2(
output,
exp_sums,
......@@ -251,8 +294,32 @@ def test_paged_attention(
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, None, 0),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
else:
if torch_version.startswith("2.3"):
ops.paged_attention_rocm(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
elif torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
ops.paged_attention_rocm(
output,
exp_sums,
......@@ -280,6 +347,8 @@ def test_paged_attention(
kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
else:
raise AssertionError(f"Unknown version: {version}")
......
......@@ -6,14 +6,12 @@ import torch
from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import which_attn_to_use
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
from vllm.utils import is_hip
# @pytest.mark.parametrize(
# "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
# @pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
@pytest.mark.parametrize(
"name", ["ROCM_FLASH"])
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"] if not is_hip() else ["ROCM_FLASH"])
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
def test_env(name: str, device: str, monkeypatch):
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
......
......@@ -8,6 +8,7 @@ import torch
from vllm.model_executor.layers.quantization.awq_triton import (
AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton)
from vllm.utils import seed_everything
from .utils import torch_version
device = "cuda"
......@@ -64,6 +65,8 @@ def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
# qweights - [R , C // 8], int32
# scales - [R // G, C ], float16
# zeros - [R // G, C // 8], int32
@pytest.mark.skipif(torch_version.startswith("2.3"),
reason="Need triton3.0.")
@pytest.mark.parametrize("qweight_rows", [3584, 18944, 128, 256, 512, 1024])
@pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128])
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
......@@ -111,6 +114,8 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
# qweight - [K, M // 8]
# qzeros - [K // G, M // 8]
# scales - [K // G, M]
@pytest.mark.skipif(torch_version.startswith("2.3"),
reason="Need triton3.0.")
@pytest.mark.parametrize("N", [1, 2, 4, 8, 14, 17, 23, 32])
@pytest.mark.parametrize("K", [128])
@pytest.mark.parametrize("M", [16, 24, 32])
......
......@@ -4,9 +4,11 @@ from typing import List, Tuple
import pytest
import torch
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS
from vllm import _custom_ops as ops
from vllm.utils import seed_everything
from vllm.utils import is_hip
from .utils import torch_version
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
DTYPES = [torch.half, torch.bfloat16, torch.float]
......@@ -88,6 +90,23 @@ def test_copy_blocks(
dtype=torch.int64,
device=device).view(-1, 2)
if torch_version.startswith("2.3"):
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
for src, dst in block_mapping:
for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches:
cloned_value_cache[dst].copy_(cloned_value_cache[src])
# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
torch.allclose(key_cache, cloned_key_cache)
for value_cache, cloned_value_cache in zip(value_caches,
cloned_value_caches):
assert torch.allclose(value_cache, cloned_value_cache)
elif torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
opcheck(torch.ops._C_cache_ops.copy_blocks,
(key_caches, value_caches, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
......@@ -107,6 +126,8 @@ def test_copy_blocks(
for value_cache, cloned_value_cache in zip(value_caches,
cloned_value_caches):
torch.testing.assert_close(value_cache, cloned_value_cache)
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
......@@ -163,6 +184,45 @@ def test_reshape_and_cache(
# Using default kv_scale
k_scale = v_scale = 1.0
if torch_version.startswith("2.3"):
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, k_scale,v_scale)
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(result_key_cache, key_cache)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(result_value_cache, value_cache)
# Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_indicies = block_indicies.cpu().tolist()
block_offsets = slot_mapping % block_size
block_offsets = block_offsets.cpu().tolist()
for i in range(num_tokens):
block_idx = block_indicies[i]
block_offset = block_offsets[i]
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
if kv_cache_dtype == "fp8":
assert torch.allclose(result_key_cache,
cloned_key_cache,
atol=0.001,
rtol=0.1)
assert torch.allclose(result_value_cache,
cloned_value_cache,
atol=0.001,
rtol=0.1)
else:
assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache)
elif torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
......@@ -201,6 +261,8 @@ def test_reshape_and_cache(
else:
torch.testing.assert_close(key_cache, cloned_key_cache)
torch.testing.assert_close(value_cache, cloned_value_cache)
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
......@@ -272,6 +334,30 @@ def test_reshape_and_cache_flash(
# Using default kv_scale
k_scale = v_scale = 1.0
if torch_version.startswith("2.3"):
# Clone the KV caches.
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
# Call the reshape_and_cache kernel.
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale, v_scale)
# Run the reference implementation.
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
block_indicies = block_indicies.cpu().tolist()
block_offsets = slot_mapping % block_size
block_offsets = block_offsets.cpu().tolist()
for i in range(num_tokens):
block_idx = block_indicies[i]
block_offset = block_offsets[i]
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache)
elif torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
......@@ -309,6 +395,8 @@ def test_reshape_and_cache_flash(
else:
torch.testing.assert_close(key_cache, cloned_key_cache)
torch.testing.assert_close(value_cache, cloned_value_cache)
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
......@@ -371,6 +459,20 @@ def test_swap_blocks(
src_key_caches_clone = src_key_caches[0].clone()
src_value_caches_clone = src_value_caches[0].clone()
if torch_version.startswith("2.3"):
# Call the swap_blocks kernel.
ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
block_mapping_tensor)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
block_mapping_tensor)
for src, dst in block_mapping:
assert torch.allclose(src_key_caches_clone[src].cpu(),
dist_key_caches[0][dst].cpu())
assert torch.allclose(src_value_caches_clone[src].cpu(),
dist_value_caches[0][dst].cpu())
elif torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
# Call the swap_blocks kernel.
do_opcheck = (head_size == HEAD_SIZES[0])
opcheck(torch.ops._C_cache_ops.swap_blocks,
......@@ -390,37 +492,41 @@ def test_swap_blocks(
dist_key_caches[0][dst].cpu())
torch.testing.assert_close(src_value_caches_clone[src].cpu(),
dist_value_caches[0][dst].cpu())
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
@pytest.mark.skipif(is_hip(),
reason="FP8 is not supported on ROCm.")
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_fp8_e4m3_conversion(
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
seed_everything(seed)
low = -224.0
high = 224.0
shape = (num_blocks, num_heads, head_size, block_size)
cache = torch.empty(shape, dtype=dtype, device=device)
cache.uniform_(low, high)
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
ops.convert_fp8(cache_fp8, cache)
converted_cache = torch.empty_like(cache)
ops.convert_fp8(converted_cache, cache_fp8)
# @pytest.mark.parametrize("num_heads", NUM_HEADS)
# @pytest.mark.parametrize("head_size", HEAD_SIZES)
# @pytest.mark.parametrize("block_size", BLOCK_SIZES)
# @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
# @pytest.mark.parametrize("dtype", DTYPES)
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @torch.inference_mode()
# def test_fp8_e4m3_conversion(
# num_heads: int,
# head_size: int,
# block_size: int,
# num_blocks: int,
# dtype: torch.dtype,
# seed: int,
# device: str,
# ) -> None:
# seed_everything(seed)
# low = -224.0
# high = 224.0
# shape = (num_blocks, num_heads, head_size, block_size)
# cache = torch.empty(shape, dtype=dtype, device=device)
# cache.uniform_(low, high)
# cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
# ops.convert_fp8(cache_fp8, cache)
# converted_cache = torch.empty_like(cache)
# ops.convert_fp8(converted_cache, cache_fp8)
# torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
......@@ -7,16 +7,16 @@ from typing import Optional, Type
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
#from vllm.platforms import current_platform
from vllm.platforms import current_platform
from .utils import torch_version
CUDA_DEVICES = [
f"cuda:{0}" #for i in range(1 if torch.cuda.device_count() == 1 else 2)
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
#capability = current_platform.get_device_capability()
capability = 90#capability[0] * 10 + capability[1]
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
def to_fp8(tensor: torch.Tensor):
......@@ -75,10 +75,16 @@ def cutlass_fp8_gemm_helper(m: int,
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
if torch_version.startswith("2.3"):
assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1)
elif torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)
opcheck(torch.ops._C.cutlass_scaled_mm,
(out, a, b, scale_a, scale_b, bias))
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
def cutlass_int8_gemm_helper(m: int,
......@@ -116,10 +122,15 @@ def cutlass_int8_gemm_helper(m: int,
# print("out:",out[0:5][0:5])
# print("baseline:",baseline[0:5][0:5])
if torch_version.startswith("2.3"):
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
elif torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
# opcheck(torch.ops._C.cutlass_scaled_mm,
# (out, a, b, scale_a, scale_b, bias))
opcheck(torch.ops._C.cutlass_scaled_mm,
(out, a, b, scale_a, scale_b, bias))
# @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
......@@ -350,18 +361,25 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
# # bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
# # float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
# rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
# atol = 1e-3
# torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
# torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
# if azp_per_token:
# opcheck(torch.ops._C.cutlass_scaled_mm_azp,
# (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
# func_bias))
# else:
# opcheck(torch.ops._C.cutlass_scaled_mm_azp,
# (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
# func_bias))
# atol = 1e-3
# if torch_version.startswith("2.3"):
# assert torch.allclose(out, baseline_dq, rtol=rtol, atol=atol)
# assert torch.allclose(out, baseline_q, rtol=rtol, atol=atol)
# elif torch_version.startswith("2.4"):
# from tests.kernels.utils import opcheck
# torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
# torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
# if azp_per_token:
# opcheck(torch.ops._C.cutlass_scaled_mm_azp,
# (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
# func_bias))
# else:
# opcheck(torch.ops._C.cutlass_scaled_mm_azp,
# (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
# func_bias))
# else:
# print(f"PyTorch version {torch_version} is not specifically handled.")
# Test working with a subset of A and B
......
......@@ -8,8 +8,8 @@ if is_hip():
import flash_attn
else:
import vllm.attention.backends.flash_attn # noqa: F401
from tests.kernels.utils import opcheck
from vllm.utils import seed_everything
from .utils import torch_version
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256]
......@@ -132,6 +132,8 @@ if not is_hip():
else:
test_utils = ["test_faketensor"]
if torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
opcheck(torch.ops.vllm.flash_attn_with_kvcache,
args=tuple(),
kwargs=dict(
......@@ -253,6 +255,8 @@ def test_varlen_with_paged_kv(
test_utils = ["test_faketensor"]
if not is_hip():
if torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
opcheck(torch.ops.vllm.flash_attn_varlen_func,
args=tuple(),
kwargs=dict(
......
......@@ -2,10 +2,10 @@ import pytest
import torch
from tests.kernels.quant_utils import ref_dynamic_per_token_quant
from tests.kernels.utils import opcheck
from vllm._custom_ops import scaled_int8_quant
from vllm.utils import seed_everything
from vllm.utils import is_hip
from .utils import torch_version
DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
......@@ -15,7 +15,11 @@ SEEDS = [0]
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]
def opcheck_int8_quant_static(output, input, scale, azp=None):
if torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
def opcheck_int8_quant_static(output, input, scale, azp=None):
if azp is None:
opcheck(torch.ops._C.static_scaled_int8_quant,
(output, input, scale, None))
......@@ -24,7 +28,7 @@ def opcheck_int8_quant_static(output, input, scale, azp=None):
(output, input, scale, azp))
def opcheck_int8_quant_dynamic(output, input, symmetric=True):
def opcheck_int8_quant_dynamic(output, input, symmetric=True):
scale = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
......@@ -56,11 +60,18 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
# kernel
ops_out, ops_scales, _ = scaled_int8_quant(x)
if torch_version.startswith("2.3"):
torch.allclose(ops_scales, ref_scales)
torch.allclose(ops_out, ref_out, atol=1, rtol=0.0)
elif torch_version.startswith("2.4"):
torch.testing.assert_close(ops_scales, ref_scales)
# big atol to account for rounding errors
torch.testing.assert_close(ops_out, ref_out, atol=1, rtol=0.0)
opcheck_int8_quant_dynamic(ops_out, x)
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
@pytest.mark.skipif(is_hip(),
reason="Currently, there is not supported on ROCm.")
......@@ -97,6 +108,11 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
if (not torch.allclose(scales_out, scales)):
print(torch.argmax(torch.abs(scales_out - scales)))
if torch_version.startswith("2.3"):
torch.allclose(scales_out, scales)
torch.allclose(azp_out, azps, atol=1, rtol=0.0)
torch.allclose(ops_out, torch_out, atol=2, rtol=0.0)
elif torch_version.startswith("2.4"):
torch.testing.assert_close(scales_out, scales)
# big atol to account for rounding errors
torch.testing.assert_close(azp_out, azps, atol=1, rtol=0.0)
......@@ -104,6 +120,8 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
torch.testing.assert_close(ops_out, torch_out, atol=2, rtol=0.0)
opcheck_int8_quant_dynamic(ops_out, x, False)
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
......@@ -125,10 +143,15 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
int8_traits.max).to(torch.int8)
out2, _, _ = scaled_int8_quant(x, scale_arg)
if torch_version.startswith("2.3"):
torch.allclose(out1, out2, atol=1, rtol=0.0)
elif torch_version.startswith("2.4"):
# big atol to account for rounding errors
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
opcheck_int8_quant_static(out2, x, scale_arg)
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
......@@ -155,10 +178,15 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
torch.ops._C.static_scaled_int8_quant(out2, x, scale_arg, azp_arg)
if torch_version.startswith("2.3"):
torch.allclose(out1, out2, atol=1, rtol=0.0)
elif torch_version.startswith("2.4"):
# big atol to account for rounding errors
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
opcheck_int8_quant_static(out2, x, scale_arg, azp_arg)
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
@pytest.mark.parametrize("is_max", [True, False])
......@@ -190,4 +218,9 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
out = torch.empty_like(expected)
torch.ops._C.static_scaled_int8_quant(out, x, scale, azp)
if torch_version.startswith("2.3"):
torch.allclose(expected, out, atol=0, rtol=0)
elif torch_version.startswith("2.4"):
torch.testing.assert_close(expected, out, atol=0, rtol=0)
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.utils import seed_everything
from .utils import torch_version
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
......@@ -47,6 +47,14 @@ def test_rms_norm(
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# numerical errors than other operators because they involve reductions.
# Therefore, we use a larger tolerance.
if torch_version.startswith("2.3"):
if add_residual:
torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2)
torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
else:
torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2)
elif torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
if add_residual:
torch.testing.assert_close(out[0], ref_out[0], atol=1e-2, rtol=1e-2)
torch.testing.assert_close(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
......@@ -59,3 +67,5 @@ def test_rms_norm(
else:
opcheck(torch.ops._C.rms_norm,
(out, x, layer.weight.data, layer.variance_epsilon))
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
......@@ -9,7 +9,6 @@ import torch
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
......@@ -22,6 +21,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.scalar_type import scalar_types
from vllm.utils import seed_everything
from .utils import torch_version
from vllm.utils import is_hip
def torch_moe(a, w1, w2, score, topk):
......@@ -76,7 +77,12 @@ def test_fused_moe(
score = torch.randn((m, e), device="cuda", dtype=dtype)
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
if torch_version.startswith("2.3"):
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)
elif torch_version.startswith("2.4"):
torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
@pytest.mark.parametrize("dtype",
......@@ -120,11 +126,18 @@ def test_mixtral_moe(dtype: torch.dtype):
torch.float16: 1e-3,
torch.bfloat16: 1e-2,
}
if torch_version.startswith("2.3"):
assert torch.allclose(hf_states.flatten(0, 1),
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])
elif torch_version.startswith("2.4"):
torch.testing.assert_close(hf_states.flatten(0, 1),
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
def stack_and_dev(tensors: List[torch.Tensor]):
......@@ -137,6 +150,8 @@ def compute_max_diff(output, output_ref):
torch.abs(output_ref))
@pytest.mark.skipif(is_hip(),
reason="Currently, there is not supported on ROCm.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
......@@ -256,6 +271,8 @@ def test_fused_marlin_moe(
dtype=torch.int32,
device=a.device)
if torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
opcheck(torch.ops._moe_C.topk_softmax, (
topk_weights,
topk_ids,
......@@ -274,12 +291,16 @@ def test_fused_marlin_moe(
device="cuda",
requires_grad=False)
if torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
opcheck(torch.ops._moe_C.marlin_gemm_moe,
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
scales1, g_idx1, sort_indices1, workspace, quant_type, m,
2 * n, k, True, e, topk, block_size_m, True, False))
@pytest.mark.skipif(is_hip(),
reason="Currently, there is not supported on ROCm.")
@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
......@@ -373,7 +394,8 @@ def test_moe_align_block_size_opcheck():
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
if torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
opcheck(torch.ops._C.moe_align_block_size,
(topk_ids, num_experts, block_size, sorted_ids, expert_ids,
num_tokens_post_pad))
......@@ -8,6 +8,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.utils import seed_everything
from .allclose_default import get_default_atol, get_default_rtol
from .utils import torch_version
IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float]
......@@ -18,7 +19,7 @@ BATCH_SIZES = [1, 5] # Arbitrary values for testing
SEQ_LENS = [11, 8192] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 1)
]
......@@ -67,6 +68,16 @@ def test_rotary_embedding(
ref_query, ref_key = rope.forward_native(positions, query, key)
out_query, out_key = rope.forward(positions, query, key)
# Compare the results.
if torch_version.startswith("2.3"):
torch.allclose(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
torch.allclose(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
elif torch_version.startswith("2.4"):
torch.testing.assert_close(out_query,
ref_query,
atol=get_default_atol(out_query),
......@@ -75,6 +86,8 @@ def test_rotary_embedding(
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
......@@ -126,6 +139,16 @@ def test_batched_rotary_embedding(
offsets=torch.zeros(batch_size * seq_len,
dtype=torch.long,
device=device))
if torch_version.startswith("2.3"):
torch.allclose(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
torch.allclose(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
elif torch_version.startswith("2.4"):
# Compare the results.
torch.testing.assert_close(out_query,
ref_query,
......@@ -135,6 +158,8 @@ def test_batched_rotary_embedding(
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
......@@ -195,6 +220,16 @@ def test_batched_rotary_embedding_multi_lora(
query_offsets)
out_query, out_key = rope.forward(positions, query, key,
query_offsets.flatten())
if torch_version.startswith("2.3"):
torch.allclose(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
torch.allclose(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
elif torch_version.startswith("2.4"):
# Compare the results.
torch.testing.assert_close(out_query,
ref_query,
......@@ -204,7 +239,8 @@ def test_batched_rotary_embedding_multi_lora(
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
@torch.inference_mode()
def test_rope_module_cache():
......
......@@ -7,8 +7,11 @@ from typing import Optional
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from .utils import torch_version
if torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
def rotary_embedding_opcheck(rot,
......@@ -30,6 +33,8 @@ def rotary_embedding_opcheck(rot,
rot.is_neox_style))
@pytest.mark.skipif(torch_version.startswith("2.3"),
reason="Need torch2.4.")
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("max_position", [11, 4096, 32768])
@pytest.mark.parametrize("is_neox_style", [True, False])
......
......@@ -5,14 +5,17 @@ Tests for miscellaneous utilities
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.platforms import current_platform
from .utils import torch_version
if torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck
def test_convert_fp8_opcheck():
data = torch.randn((256, 256), dtype=torch.float32, device="cuda")
result = torch.empty_like(data, dtype=torch.float8_e4m3fn)
opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8"))
# def test_convert_fp8_opcheck():
# data = torch.randn((256, 256), dtype=torch.float32, device="cuda")
# result = torch.empty_like(data, dtype=torch.float8_e4m3fn)
# opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8"))
@pytest.mark.skipif(not current_platform.is_cuda(),
......
......@@ -3,13 +3,22 @@ import torch
from tests.kernels.utils import opcheck
from vllm._custom_ops import permute_cols
from .utils import torch_version
@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)])
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16])
def test_permute_cols(shape, dtype):
if torch_version.startswith("2.3"):
x = torch.randn(shape, dtype=dtype).cuda()
perm = torch.randperm(x.shape[1]).to(torch.int).cuda()
y = permute_cols(x, perm)
torch.allclose(y, x[:, perm])
elif torch_version.startswith("2.4"):
x = torch.randn(shape, dtype=dtype).cuda()
perm = torch.randperm(x.shape[1]).to(torch.int).cuda()
opcheck(torch.ops._C.permute_cols, (x, perm))
y = permute_cols(x, perm)
torch.testing.assert_close(y, x[:, perm])
else:
print(f"PyTorch version {torch_version} is not specifically handled.")
\ No newline at end of file
......@@ -30,6 +30,8 @@ ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
"test_aot_dispatch_dynamic",
)
torch_version = torch.__version__
class QKVInputs(NamedTuple):
'''
......@@ -974,9 +976,10 @@ def fp8_allclose(
equal_nan=equal_nan)).item())
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
if torch_version.startswith("2.4"):
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
torch._library.custom_ops.CustomOpDef],
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment