Unverified Commit 6317a517 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Categorize `tests/kernels/` based on kernel type (#16799)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent aa72d9a4
...@@ -6,6 +6,7 @@ import itertools ...@@ -6,6 +6,7 @@ import itertools
import pytest import pytest
import torch import torch
from tests.kernels.utils_block import native_w8a8_block_matmul
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
...@@ -13,8 +14,6 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import ( ...@@ -13,8 +14,6 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
w8a8_block_int8_matmul) w8a8_block_int8_matmul)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .utils_block import native_w8a8_block_matmul
if current_platform.get_device_capability() < (7, 0): if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
allow_module_level=True) allow_module_level=True)
......
...@@ -7,13 +7,12 @@ Run `pytest tests/kernels/test_semi_structured.py`. ...@@ -7,13 +7,12 @@ Run `pytest tests/kernels/test_semi_structured.py`.
import pytest import pytest
import torch import torch
from tests.kernels.utils import baseline_scaled_mm, to_fp8, to_int8
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported) sparse_cutlass_supported)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .utils import baseline_scaled_mm, to_fp8, to_int8
CUDA_DEVICES = [ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] ]
......
...@@ -8,13 +8,11 @@ import random ...@@ -8,13 +8,11 @@ import random
import pytest import pytest
import torch import torch
from tests.kernels.utils import opcheck from tests.kernels.utils import baseline_scaled_mm, opcheck, to_fp8, to_int8
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv from vllm.utils import cdiv
from .utils import baseline_scaled_mm, to_fp8, to_int8
MNK_FACTORS = [ MNK_FACTORS = [
(1, 256, 128), (1, 256, 128),
(1, 16384, 1024), (1, 16384, 1024),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment