"tests/python/pytorch/nn/test_nn.py" did not exist on "e4ddafe9512a3befbe91370300f85efd9c4acee8"
Unverified Commit a7d825fc authored by hlu1's avatar hlu1 Committed by GitHub
Browse files

Skip some tests on Blackwell (#9777)


Signed-off-by: default avatarHao Lu <14827759+hlu1@users.noreply.github.com>
parent 38cd5fb1
import pytest import pytest
import torch import torch
from sgl_kernel import cutlass_w4a8_moe_mm from sgl_kernel import cutlass_w4a8_moe_mm
from utils import is_hopper
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor: def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
...@@ -38,6 +39,10 @@ def pack_interleave(num_experts, ref_weight, ref_scale): ...@@ -38,6 +39,10 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
return w_q, w_scale return w_q, w_scale
@pytest.mark.skipif(
not is_hopper(),
reason="cutlass_w4a8_moe_mm is only supported on sm90",
)
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
def test_int4_fp8_grouped_gemm_single_expert(batch_size): def test_int4_fp8_grouped_gemm_single_expert(batch_size):
# Test parameters # Test parameters
...@@ -127,6 +132,10 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size): ...@@ -127,6 +132,10 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
raise raise
@pytest.mark.skipif(
not is_hopper(),
reason="cutlass_w4a8_moe_mm is only supported on sm90",
)
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16]) @pytest.mark.parametrize("batch_size", [2, 4, 8, 16])
@pytest.mark.parametrize("k", [512, 1024]) @pytest.mark.parametrize("k", [512, 1024])
@pytest.mark.parametrize("n", [1024, 2048]) @pytest.mark.parametrize("n", [1024, 2048])
......
import pytest import pytest
import torch import torch
from sgl_kernel import int8_scaled_mm from sgl_kernel import int8_scaled_mm
from utils import is_sm10x
def to_int8(tensor: torch.Tensor) -> torch.Tensor: def to_int8(tensor: torch.Tensor) -> torch.Tensor:
...@@ -30,6 +31,10 @@ def _test_accuracy_once(M, N, K, with_bias, out_dtype, device): ...@@ -30,6 +31,10 @@ def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
torch.testing.assert_close(o, o1) torch.testing.assert_close(o, o1)
@pytest.mark.skipif(
is_sm10x(),
reason="int8_scaled_mm is only supported on sm90 and lower",
)
@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192]) @pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192])
@pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096, 8192, 16384]) @pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384]) @pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
......
import torch
def is_sm10x():
return torch.cuda.get_device_capability() >= (10, 0)
def is_hopper():
return torch.cuda.get_device_capability() == (9, 0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment