Unverified Commit 8ef50d9a authored by Roberto L. Castro's avatar Roberto L. Castro Committed by GitHub
Browse files

[Kernel][Performance] Enable smaller Scaling Factor tiling for NVFP4 small-batch decoding (#30885)


Signed-off-by: default avatarLopezCastroRoberto <roberto.lopez.castro@udc.es>
Signed-off-by: default avatarRoberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Signed-off-by: default avatarLopezCastroRoberto <rocastro@redhat.com>
parent 2a60ac91
...@@ -951,7 +951,7 @@ steps: ...@@ -951,7 +951,7 @@ steps:
# Whisper needs spawn method to avoid deadlock # Whisper needs spawn method to avoid deadlock
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper
- label: Blackwell Test # 21 min - label: Blackwell Test # 23 min
timeout_in_minutes: 30 timeout_in_minutes: 30
working_dir: "/vllm-workspace/" working_dir: "/vllm-workspace/"
gpu: b200 gpu: b200
...@@ -991,6 +991,8 @@ steps: ...@@ -991,6 +991,8 @@ steps:
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
- pytest -v -s tests/kernels/moe/test_flashinfer.py - pytest -v -s tests/kernels/moe/test_flashinfer.py
- pytest -v -s tests/kernels/moe/test_cutedsl_moe.py - pytest -v -s tests/kernels/moe/test_cutedsl_moe.py
# e2e
- pytest -v -s tests/models/quantization/test_nvfp4.py
- label: Blackwell Fusion and Compile Tests # 30 min - label: Blackwell Fusion and Compile Tests # 30 min
timeout_in_minutes: 40 timeout_in_minutes: 40
......
...@@ -23,8 +23,26 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): ...@@ -23,8 +23,26 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
return out[0:m, 0:k] return out[0:m, 0:k]
def convert_swizzled_8x4_layout_to_linear(
a_sf_swizzled: torch.Tensor, m, k, block_size
):
m_tiles = (m + 8 - 1) // 8
f = block_size * 4
k_tiles = (k + f - 1) // f
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 8, 4))
tmp = torch.permute(tmp, (0, 1, 3, 2, 4))
out = tmp.reshape(m_tiles * 8, k_tiles * f // block_size)
return out[0:m, 0:k]
def dequantize_nvfp4_to_dtype( def dequantize_nvfp4_to_dtype(
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16 tensor_fp4,
tensor_sf,
global_scale,
dtype,
device,
block_size=16,
is_sf_128x4_layout=True,
): ):
"""Dequantize the fp4 tensor back to high precision.""" """Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8. # Two fp4 values are packed into one uint8.
...@@ -34,7 +52,11 @@ def dequantize_nvfp4_to_dtype( ...@@ -34,7 +52,11 @@ def dequantize_nvfp4_to_dtype(
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn) tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) if is_sf_128x4_layout:
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
else:
tensor_sf = convert_swizzled_8x4_layout_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
# scale the tensor # scale the tensor
......
...@@ -11,7 +11,9 @@ from nvfp4_utils import ( ...@@ -11,7 +11,9 @@ from nvfp4_utils import (
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm from vllm.utils.flashinfer import (
flashinfer_scaled_fp4_mm,
)
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
if not current_platform.has_device_capability(100): if not current_platform.has_device_capability(100):
...@@ -22,8 +24,14 @@ if not current_platform.has_device_capability(100): ...@@ -22,8 +24,14 @@ if not current_platform.has_device_capability(100):
DTYPES = [torch.float16, torch.bfloat16] DTYPES = [torch.float16, torch.bfloat16]
# m, n, k # m, n, k
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] SHAPES = [
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)] (128, 128, 64),
(128, 128, 128),
(256, 128, 64),
(128, 256, 128),
(1, 128, 128),
]
PAD_SHAPES = [(150, 128, 64), (128, 128, 96), (2, 128, 64), (3, 128, 96)]
SHAPES.extend(PAD_SHAPES) SHAPES.extend(PAD_SHAPES)
SEEDS = [42] SEEDS = [42]
...@@ -42,12 +50,19 @@ def get_ref_results( ...@@ -42,12 +50,19 @@ def get_ref_results(
dtype, dtype,
block_size, block_size,
device, device,
is_sf_128x4_layout,
): ):
_, m_k = a_fp4.shape _, m_k = a_fp4.shape
_, n_k = b_fp4.shape _, n_k = b_fp4.shape
assert m_k == n_k assert m_k == n_k
a_in_dtype = dequantize_nvfp4_to_dtype( a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size a_fp4,
a_sf,
a_global_scale,
dtype=dtype,
device=device,
block_size=block_size,
is_sf_128x4_layout=is_sf_128x4_layout,
) )
b_in_dtype = dequantize_nvfp4_to_dtype( b_in_dtype = dequantize_nvfp4_to_dtype(
b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size
...@@ -70,7 +85,7 @@ def test_flashinfer_nvfp4_gemm( ...@@ -70,7 +85,7 @@ def test_flashinfer_nvfp4_gemm(
backend: str, backend: str,
autotune: bool, autotune: bool,
) -> None: ) -> None:
if backend == "trtllm" and dtype == torch.float16: if "trtllm" in backend and dtype == torch.float16:
pytest.skip("Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations") pytest.skip("Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations")
set_random_seed(seed) set_random_seed(seed)
...@@ -87,11 +102,14 @@ def test_flashinfer_nvfp4_gemm( ...@@ -87,11 +102,14 @@ def test_flashinfer_nvfp4_gemm(
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1) (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
).to(torch.float32) ).to(torch.float32)
alpha = 1.0 / (a_global_scale * b_global_scale) alpha = 1.0 / (a_global_scale * b_global_scale)
# ops.scaled_fp4_quant returns swizzled scales, while weights # ops.scaled_fp4_quant returns swizzled scales, while weights
# from checkpoints are in linear scales. # from checkpoints are in linear scales.
# So instead of needing to swizzle for cutlass as in modelopt.py, # So instead of needing to swizzle for cutlass as in modelopt.py,
# we need to unswizzle for trtllm here. # we need to unswizzle for trtllm here.
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale) a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale, backend)
is_sf_128x4_layout = not (backend == "trtllm" and m <= 32)
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale) b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
# get_ref_results unswizzles the scales internally. # get_ref_results unswizzles the scales internally.
...@@ -107,14 +125,14 @@ def test_flashinfer_nvfp4_gemm( ...@@ -107,14 +125,14 @@ def test_flashinfer_nvfp4_gemm(
dtype, dtype,
block_size, block_size,
device, device,
is_sf_128x4_layout,
) )
import flashinfer import flashinfer
if backend == "trtllm": if "trtllm" in backend:
epilogue_tile_m = 128 epilogue_tile_m = 128
b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8), epilogue_tile_m) b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8), epilogue_tile_m)
b_scale_interleaved = convert_swizzled_to_linear( b_scale_interleaved = convert_swizzled_to_linear(
b_scale_interleaved, n, k, block_size b_scale_interleaved, n, k, block_size
) )
......
...@@ -14,6 +14,8 @@ from transformers import AutoTokenizer ...@@ -14,6 +14,8 @@ from transformers import AutoTokenizer
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
os.environ["TOKENIZERS_PARALLELISM"] = "true" os.environ["TOKENIZERS_PARALLELISM"] = "true"
MAX_MODEL_LEN = 1024 MAX_MODEL_LEN = 1024
...@@ -83,3 +85,27 @@ def test_models(example_prompts, model_name) -> None: ...@@ -83,3 +85,27 @@ def test_models(example_prompts, model_name) -> None:
assert expected_str == generated_str, ( assert expected_str == generated_str, (
f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}" f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}"
) )
EAGER = [True, False]
@pytest.mark.skipif(
not current_platform.has_device_capability(100),
reason="modelopt_fp4 is not supported on this GPU type.",
)
@pytest.mark.parametrize("model", ["nvidia/Llama-3.1-8B-Instruct-NVFP4"])
@pytest.mark.parametrize("eager", EAGER)
@pytest.mark.parametrize(
"backend",
[
"flashinfer-cudnn",
"flashinfer-trtllm", # the small seq_len ensures trtllm_8x4_layout backend is used
"flashinfer-cutlass",
],
)
def test_nvfp4(vllm_runner, model, eager, backend, monkeypatch):
monkeypatch.setenv("VLLM_NVFP4_GEMM_BACKEND", backend)
with vllm_runner(model, enforce_eager=eager) as llm:
output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2)
assert output[0][1] == "1 2 3 4 5 6"
...@@ -9,6 +9,9 @@ import vllm.envs as envs ...@@ -9,6 +9,9 @@ import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType from vllm.scalar_type import ScalarType
from vllm.utils.flashinfer import (
flashinfer_quant_nvfp4_8x4_sf_layout,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1563,7 +1566,9 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: ...@@ -1563,7 +1566,9 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
# fp4 # fp4
def scaled_fp4_quant( def scaled_fp4_quant(
input: torch.Tensor, input_global_scale: torch.Tensor input: torch.Tensor,
input_global_scale: torch.Tensor,
backend: str = "none",
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Quantize input tensor to FP4 and return quantized tensor and scale. Quantize input tensor to FP4 and return quantized tensor and scale.
...@@ -1577,6 +1582,7 @@ def scaled_fp4_quant( ...@@ -1577,6 +1582,7 @@ def scaled_fp4_quant(
Args: Args:
input: The input tensor to be quantized to FP4 input: The input tensor to be quantized to FP4
input_global_scale: A scalar scaling factor for the entire tensor. input_global_scale: A scalar scaling factor for the entire tensor.
use_8x4_sf_layout: Whether to use the 8x4 or 128x4 layout for the scaling
Returns: Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
...@@ -1596,23 +1602,31 @@ def scaled_fp4_quant( ...@@ -1596,23 +1602,31 @@ def scaled_fp4_quant(
f"input.dtype needs to be fp16 or bf16 but got {input.dtype}." f"input.dtype needs to be fp16 or bf16 but got {input.dtype}."
) )
# Two fp4 values will be packed into an uint8. use_8x4_sf_layout = True if "trtllm" in backend and m <= 32 else False # noqa: SIM210
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
# We use the rounded values to store the swizzled values. Due to the if use_8x4_sf_layout:
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales. output, output_scale = flashinfer_quant_nvfp4_8x4_sf_layout(
# So, we first pad the scales to multiples of 128 and 4. Then, the scales input, input_global_scale
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More: )
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x else:
round_up = lambda x, y: (x + y - 1) // y * y # Two fp4 values will be packed into an uint8.
rounded_m = round_up(m, 128) output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
scale_n = n // block_size
rounded_n = round_up(scale_n, 4) # We use the rounded values to store the swizzled values. Due to the
output_scale = torch.empty( # requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32 # So, we first pad the scales to multiples of 128 and 4. Then, the scales
) # (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(m, 128)
scale_n = n // block_size
rounded_n = round_up(scale_n, 4)
output_scale = torch.empty(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)
torch.ops._C.scaled_fp4_quant(output, input, output_scale, input_global_scale)
torch.ops._C.scaled_fp4_quant(output, input, output_scale, input_global_scale)
output_scale = output_scale.view(torch.float8_e4m3fn) output_scale = output_scale.view(torch.float8_e4m3fn)
return output, output_scale return output, output_scale
......
...@@ -1444,7 +1444,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1444,7 +1444,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_NVFP4_GEMM_BACKEND": env_with_choices( "VLLM_NVFP4_GEMM_BACKEND": env_with_choices(
"VLLM_NVFP4_GEMM_BACKEND", "VLLM_NVFP4_GEMM_BACKEND",
None, None,
["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass", "cutlass"], [
"flashinfer-cudnn",
"flashinfer-trtllm",
"flashinfer-cutlass",
"cutlass",
],
), ),
# Controls garbage collection during CUDA graph capture. # Controls garbage collection during CUDA graph capture.
# If set to 0 (default), enables GC freezing to speed up capture time. # If set to 0 (default), enables GC freezing to speed up capture time.
......
...@@ -23,7 +23,10 @@ from vllm.model_executor.parameter import ( ...@@ -23,7 +23,10 @@ from vllm.model_executor.parameter import (
ModelWeightParameter, ModelWeightParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
) )
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer from vllm.utils.flashinfer import (
flashinfer_scaled_fp4_mm,
has_flashinfer,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -187,7 +190,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -187,7 +190,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]] output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale) # quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) x_fp4, x_blockscale = scaled_fp4_quant(
x, layer.input_global_scale, self.backend
)
mm_args = ( mm_args = (
x_fp4, x_fp4,
......
...@@ -1291,7 +1291,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1291,7 +1291,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
output_shape = [x.shape[0], layer.weight.shape[0]] output_shape = [x.shape[0], layer.weight.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale) # quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv) x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv, self.backend)
# validate dtypes of quantized input, input block scale, # validate dtypes of quantized input, input block scale,
# weight and weight_blockscale # weight and weight_blockscale
......
...@@ -406,12 +406,21 @@ if has_flashinfer(): ...@@ -406,12 +406,21 @@ if has_flashinfer():
B_scale: torch.Tensor, B_scale: torch.Tensor,
g_scale: torch.Tensor, g_scale: torch.Tensor,
dtype: torch.dtype, dtype: torch.dtype,
use_8x4_sf_layout: bool,
backend: str, backend: str,
) -> torch.Tensor: ) -> torch.Tensor:
from flashinfer import mm_fp4 as flashinfer_mm_fp4_ from flashinfer import mm_fp4 as flashinfer_mm_fp4_
return flashinfer_mm_fp4_( return flashinfer_mm_fp4_(
A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend A,
B,
A_scale,
B_scale,
g_scale,
dtype,
block_size=16,
use_8x4_sf_layout=use_8x4_sf_layout,
backend=backend,
) )
@torch.library.register_fake( @torch.library.register_fake(
...@@ -424,6 +433,7 @@ if has_flashinfer(): ...@@ -424,6 +433,7 @@ if has_flashinfer():
B_scale: torch.Tensor, B_scale: torch.Tensor,
g_scale: torch.Tensor, g_scale: torch.Tensor,
dtype: torch.dtype, dtype: torch.dtype,
use_8x4_sf_layout: bool,
backend: str, backend: str,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device) return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device)
...@@ -460,6 +470,39 @@ if has_flashinfer(): ...@@ -460,6 +470,39 @@ if has_flashinfer():
A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device
) )
@torch.library.custom_op(
"vllm::flashinfer_nvfp4_quantize",
mutates_args=[],
device_types="cuda",
)
def flashinfer_nvfp4_quantize(
a: torch.Tensor, a_global_sf: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
from flashinfer import SfLayout
from flashinfer import nvfp4_quantize as nvfp4_quantize_
return nvfp4_quantize_(
a, a_global_sf, sfLayout=SfLayout.layout_8x4, do_shuffle=False
)
@torch.library.register_fake(
"vllm::flashinfer_nvfp4_quantize",
)
def flashinfer_nvfp4_quantize_fake(
a: torch.Tensor, a_global_sf: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
m, n = a.shape
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(m, 8)
scale_n = n // 16
rounded_n = round_up(scale_n, 4)
return torch.empty(m, n // 2, dtype=torch.uint8, device=a.device), torch.empty(
rounded_m, rounded_n, dtype=torch.uint8, device=a.device
)
def flashinfer_scaled_fp4_mm( def flashinfer_scaled_fp4_mm(
a: torch.Tensor, a: torch.Tensor,
...@@ -479,6 +522,8 @@ def flashinfer_scaled_fp4_mm( ...@@ -479,6 +522,8 @@ def flashinfer_scaled_fp4_mm(
block_scale_a = block_scale_a.view(torch.uint8) block_scale_a = block_scale_a.view(torch.uint8)
block_scale_b = block_scale_b.view(torch.uint8) block_scale_b = block_scale_b.view(torch.uint8)
use_8x4_sf_layout = True if backend == "trtllm" and a.shape[0] <= 32 else False # noqa: SIM210
return flashinfer_mm_fp4( return flashinfer_mm_fp4(
a, a,
b.t(), b.t(),
...@@ -486,6 +531,7 @@ def flashinfer_scaled_fp4_mm( ...@@ -486,6 +531,7 @@ def flashinfer_scaled_fp4_mm(
block_scale_b.t(), block_scale_b.t(),
alpha, alpha,
out_dtype, out_dtype,
use_8x4_sf_layout=use_8x4_sf_layout,
backend=backend, backend=backend,
) )
...@@ -520,6 +566,12 @@ def flashinfer_scaled_fp8_mm( ...@@ -520,6 +566,12 @@ def flashinfer_scaled_fp8_mm(
return output return output
def flashinfer_quant_nvfp4_8x4_sf_layout(
a: torch.Tensor, a_global_sf: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
return flashinfer_nvfp4_quantize(a, a_global_sf)
flashinfer_fp8_blockscale_gemm = _lazy_import_wrapper( flashinfer_fp8_blockscale_gemm = _lazy_import_wrapper(
"flashinfer.gemm", "fp8_blockscale_gemm_sm90" "flashinfer.gemm", "fp8_blockscale_gemm_sm90"
) )
...@@ -596,6 +648,7 @@ __all__ = [ ...@@ -596,6 +648,7 @@ __all__ = [
"use_trtllm_attention", "use_trtllm_attention",
"flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp4_mm",
"flashinfer_scaled_fp8_mm", "flashinfer_scaled_fp8_mm",
"flashinfer_quant_nvfp4_8x4_sf_layout",
"flashinfer_fp8_blockscale_gemm", "flashinfer_fp8_blockscale_gemm",
"should_use_flashinfer_for_blockscale_fp8_gemm", "should_use_flashinfer_for_blockscale_fp8_gemm",
"is_flashinfer_fp8_blockscale_gemm_supported", "is_flashinfer_fp8_blockscale_gemm_supported",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment