Unverified Commit 63c08894 authored by Roy Wang's avatar Roy Wang Committed by GitHub
Browse files

[Misc] Fix flashinfer related tests (#33462)


Signed-off-by: default avataresmeetu <jasonailu87@gmail.com>
parent 1e86c802
...@@ -412,7 +412,7 @@ def test_naive_block_assignment_moe( ...@@ -412,7 +412,7 @@ def test_naive_block_assignment_moe(
monkeypatch, monkeypatch,
workspace_init, workspace_init,
): ):
current_platform.seed_everything(7) set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
......
...@@ -74,7 +74,7 @@ def get_ref_results( ...@@ -74,7 +74,7 @@ def get_ref_results(
@pytest.mark.parametrize("shape", SHAPES) @pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("backend", ["cutlass", "trtllm"]) @pytest.mark.parametrize("backend", ["cutlass", "cudnn", "trtllm"])
@pytest.mark.parametrize("autotune", [False, True]) @pytest.mark.parametrize("autotune", [False, True])
@torch.inference_mode() @torch.inference_mode()
def test_flashinfer_nvfp4_gemm( def test_flashinfer_nvfp4_gemm(
......
...@@ -174,7 +174,7 @@ def test_static_fp8_quant_group_2d( ...@@ -174,7 +174,7 @@ def test_static_fp8_quant_group_2d(
f"group_shape ({group_shape[0]}, {group_shape[1]})" f"group_shape ({group_shape[0]}, {group_shape[1]})"
) )
current_platform.seed_everything(seed) set_random_seed(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
ref_out, scale = scaled_quantize( ref_out, scale = scaled_quantize(
...@@ -202,7 +202,7 @@ def test_static_fp8_quant_1d_scale( ...@@ -202,7 +202,7 @@ def test_static_fp8_quant_1d_scale(
group_shape: tuple[int, int], group_shape: tuple[int, int],
) -> None: ) -> None:
"""Test static FP8 quantization with 1D scale (per-token or per-channel).""" """Test static FP8 quantization with 1D scale (per-token or per-channel)."""
current_platform.seed_everything(seed) set_random_seed(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
ref_out, scale_2d = scaled_quantize( ref_out, scale_2d = scaled_quantize(
......
...@@ -154,9 +154,10 @@ def convert_to_nvfp4_linear_kernel_format( ...@@ -154,9 +154,10 @@ def convert_to_nvfp4_linear_kernel_format(
) )
layer.weight = torch.nn.Parameter(weight, requires_grad=False) layer.weight = torch.nn.Parameter(weight, requires_grad=False)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
elif ( elif backend in (
backend == NvFp4LinearBackend.VLLM_CUTLASS NvFp4LinearBackend.VLLM_CUTLASS,
or backend == NvFp4LinearBackend.FLASHINFER_CUTLASS NvFp4LinearBackend.FLASHINFER_CUTLASS,
NvFp4LinearBackend.FLASHINFER_CUDNN,
): ):
weight, weight_scale, weights_padding_cols = prepare_weights_for_nvfp4_cutlass( weight, weight_scale, weights_padding_cols = prepare_weights_for_nvfp4_cutlass(
layer.weight.data, layer.weight_scale.data layer.weight.data, layer.weight_scale.data
......
...@@ -521,7 +521,7 @@ def flashinfer_scaled_fp4_mm( ...@@ -521,7 +521,7 @@ def flashinfer_scaled_fp4_mm(
assert a.stride(-1) == 1 and b.stride(-1) == 1 assert a.stride(-1) == 1 and b.stride(-1) == 1
assert a.shape[1] == b.shape[1] assert a.shape[1] == b.shape[1]
if backend == "cutlass": if backend in ("cutlass", "cudnn"):
block_scale_a = block_scale_a.view(torch.uint8) block_scale_a = block_scale_a.view(torch.uint8)
block_scale_b = block_scale_b.view(torch.uint8) block_scale_b = block_scale_b.view(torch.uint8)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment