Unverified Commit 6317a517 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Categorize `tests/kernels/` based on kernel type (#16799)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent aa72d9a4
......@@ -3,11 +3,9 @@
Tests for miscellaneous utilities
"""
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.platforms import current_platform
def test_convert_fp8_opcheck():
......@@ -16,10 +14,12 @@ def test_convert_fp8_opcheck():
opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8"))
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="Only supported for CUDA")
def test_cuda_utils_opcheck():
opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0))
opcheck(
torch.ops._C_cuda_utils.
get_max_shared_memory_per_block_device_attribute, (0, ))
# TODO: Add this back, currently fails with
# csrc/cuda_utils_kernels.cu:15 'invalid argument'
# @pytest.mark.skipif(not current_platform.is_cuda(),
# reason="Only supported for CUDA")
# def test_cuda_utils_opcheck():
# opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0))
# opcheck(
# torch.ops._C_cuda_utils.
# get_max_shared_memory_per_block_device_attribute, (0, ))
......@@ -6,11 +6,10 @@ from typing import Callable, Optional
import pytest
import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
from .allclose_default import get_default_atol, get_default_rtol
IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 112, 120, 256]
......
......@@ -6,6 +6,7 @@ import itertools
import pytest
import torch
from tests.kernels.utils_block import native_w8a8_block_matmul
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
......@@ -18,8 +19,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform
from .utils_block import native_w8a8_block_matmul
dg_available = False
try:
import deep_gemm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment