Unverified Commit 31d5c179 authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[Perf][fp8] Use CustomOp abstraction for fp8 quant for better perf (#19830)


Signed-off-by: default avatarLuka Govedic <lgovedic@redhat.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
parent 35514b68
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from typing import Callable
import torch
from vllm import _custom_ops as ops
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.triton_utils import triton
# TODO(luka): use standalone_compile utility
def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
def inner(*args):
torch._dynamo.mark_dynamic(args[arg_index], dim_index)
return fn(*args)
return inner
torch._dynamo.config.recompile_limit = 8888
compilation_config = CompilationConfig(custom_ops=["none"])
with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)):
torch_per_token_quant_fp8 = torch.compile(
QuantFP8(False, GroupShape.PER_TOKEN),
fullgraph=True,
dynamic=False, # recompile for different shapes
)
# First dim is explicitly dynamic to simulate vLLM usage
torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0)
def cuda_per_token_quant_fp8(
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return ops.scaled_fp8_quant(input)
def calculate_diff(batch_size: int, seq_len: int):
"""Calculate difference between Triton and CUDA implementations."""
device = torch.device("cuda")
x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device)
torch_out, torch_scale = torch_per_token_quant_fp8(x)
cuda_out, cuda_scale = cuda_per_token_quant_fp8(x)
if torch.allclose(
cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5
) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [1, 16, 32, 64, 128]
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
configs = list(itertools.product(batch_size_range, seq_len_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=configs,
line_arg="provider",
line_vals=["torch", "cuda"],
line_names=["Torch", "CUDA"],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="per-token-dynamic-quant-fp8-performance",
args={},
)
)
def benchmark_quantization(batch_size, seq_len, provider):
dtype = torch.float16
device = torch.device("cuda")
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
fn = lambda: torch_per_token_quant_fp8(x.clone())
elif provider == "cuda":
fn = lambda: cuda_per_token_quant_fp8(x.clone())
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
calculate_diff(batch_size=4, seq_len=4096)
benchmark_quantization.run(print_data=True)
...@@ -44,7 +44,9 @@ class TestModel(torch.nn.Module): ...@@ -44,7 +44,9 @@ class TestModel(torch.nn.Module):
] ]
self.fp8_linear = Fp8LinearOp( self.fp8_linear = Fp8LinearOp(
cutlass_fp8_supported=cutlass_fp8_enabled, cutlass_fp8_supported=cutlass_fp8_enabled,
use_per_token_if_dynamic=True) act_quant_static=static,
act_quant_group_shape=group_shape,
)
def forward(self, x): def forward(self, x):
resid = torch.sqrt(x) resid = torch.sqrt(x)
...@@ -91,9 +93,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, ...@@ -91,9 +93,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"])) level=CompilationLevel.PIECEWISE,
vllm_config.compilation_config.pass_config = \ custom_ops=["+rms_norm", "+quant_fp8"],
PassConfig(enable_fusion=True, enable_noop=True) pass_config=PassConfig(enable_fusion=True, enable_noop=True),
))
with vllm.config.set_current_vllm_config(vllm_config): with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work # Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
......
...@@ -50,6 +50,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, ...@@ -50,6 +50,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
# DYNAMO_ONCE does not properly propagate shapes. # DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS, level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend_unfused", backend="tests.compile.test_fusion_attn.backend_unfused",
custom_ops=["+quant_fp8"],
) )
vllm_config = VllmConfig(compilation_config=compile_config) vllm_config = VllmConfig(compilation_config=compile_config)
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config)) backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
...@@ -73,6 +74,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, ...@@ -73,6 +74,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
# DYNAMO_ONCE does not properly propagate shapes. # DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS, level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend", backend="tests.compile.test_fusion_attn.backend",
custom_ops=["+quant_fp8"],
) )
vllm_config = VllmConfig(compilation_config=compile_config) vllm_config = VllmConfig(compilation_config=compile_config)
......
...@@ -4,33 +4,56 @@ import pytest ...@@ -4,33 +4,56 @@ import pytest
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm._custom_ops import scaled_fp8_quant
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, Fp8LinearOp)
from vllm.platforms import current_platform
from .backend import TestBackend from .backend import TestBackend
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
def __init__(self, *args, **kwargs): def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args,
**kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.silu_and_mul = SiluAndMul() self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32) self.scale = torch.rand(1, dtype=torch.float32)
self.w = (torch.rand(
hidden_size,
hidden_size).to(dtype=current_platform.fp8_dtype()).t())
self.fp8_linear = Fp8LinearOp(
cutlass_fp8_supported=cutlass_fp8_enabled,
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
def forward(self, x): def forward(self, x):
y = self.silu_and_mul(x) y = self.silu_and_mul(x)
x2 = scaled_fp8_quant(y, self.scale) x2 = self.fp8_linear.apply(y,
self.w,
self.wscale,
input_scale=self.wscale)
return x2 return x2
@pytest.mark.parametrize("num_tokens", [256]) @pytest.mark.parametrize("num_tokens", [256])
@pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("cutlass_fp8_enabled",
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm") reason="Only test on CUDA and ROCm")
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
cutlass_fp8_enabled):
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16) torch.set_default_dtype(torch.float16)
...@@ -40,11 +63,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): ...@@ -40,11 +63,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
pass_config=PassConfig(enable_fusion=True, enable_noop=True)) pass_config=PassConfig(enable_fusion=True, enable_noop=True))
fusion_pass = ActivationQuantFusionPass(config) fusion_pass = ActivationQuantFusionPass(config)
backend = TestBackend(fusion_pass) backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
model = TestModel() model = TestModel(hidden_size, cutlass_fp8_enabled)
# First dimension dynamic # First dimension dynamic
x = torch.rand(num_tokens, hidden_size) x = torch.rand(num_tokens, hidden_size * 2)
torch._dynamo.mark_dynamic(x, 0) torch._dynamo.mark_dynamic(x, 0)
result = model(x) result = model(x)
......
...@@ -9,6 +9,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, ...@@ -9,6 +9,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
import torch import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -289,7 +291,7 @@ class AttentionImpl(ABC, Generic[T]): ...@@ -289,7 +291,7 @@ class AttentionImpl(ABC, Generic[T]):
raise NotImplementedError raise NotImplementedError
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]): group_shape: GroupShape):
""" """
Does this attention implementation support fused output quantization. Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization This is used by the AttnFusionPass to only fuse output quantization
...@@ -298,7 +300,7 @@ class AttentionImpl(ABC, Generic[T]): ...@@ -298,7 +300,7 @@ class AttentionImpl(ABC, Generic[T]):
TODO(luka) merge parameters into QuantDescriptor TODO(luka) merge parameters into QuantDescriptor
:param dtype: quantized dtype :param dtype: quantized dtype
:param static: static or dynamic quantization :param static: static or dynamic quantization
:param group_shape: quant group shape. (-1, -1) for per-tensor. :param group_shape: quant group shape.
:return: is fusion supported for this type of quantization :return: is fusion supported for this type of quantization
""" """
return False return False
......
...@@ -19,6 +19,8 @@ from vllm.attention.ops.paged_attn import (PagedAttention, ...@@ -19,6 +19,8 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention from vllm.platforms.rocm import use_rocm_custom_paged_attention
...@@ -598,10 +600,10 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -598,10 +600,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
head_dim)) head_dim))
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]): group_shape: GroupShape):
if self.use_triton_flash_attn: if self.use_triton_flash_attn:
return dtype == current_platform.fp8_dtype( return dtype == current_platform.fp8_dtype(
) and static and group_shape == (-1, -1) # per-tensor ) and static and group_shape == GroupShape.PER_TENSOR
# Only supported in the Triton backend # Only supported in the Triton backend
return False return False
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, ClassVar, NamedTuple, Optional from typing import Callable, NamedTuple, Optional
import torch import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
...@@ -11,6 +11,8 @@ from torch._ops import OpOverload ...@@ -11,6 +11,8 @@ from torch._ops import OpOverload
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .fx_utils import find_getitem_maybe from .fx_utils import find_getitem_maybe
...@@ -33,27 +35,6 @@ RMS_OP = torch.ops._C.rms_norm.default ...@@ -33,27 +35,6 @@ RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int
col: int
class GroupShape(_GroupShape):
"""
This class describes the quantization group shape.
It includes static members for common shapes (per-tensor, per-token).
"""
# Aliases for common quantization group shapes
PER_TENSOR: ClassVar['GroupShape']
PER_TOKEN: ClassVar['GroupShape']
GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
class QuantKey(NamedTuple): class QuantKey(NamedTuple):
""" """
Named tuple for identifying the type of quantization. Named tuple for identifying the type of quantization.
......
...@@ -111,6 +111,8 @@ def _fp8_quantize( ...@@ -111,6 +111,8 @@ def _fp8_quantize(
is provided, the output will be blocked. is provided, the output will be blocked.
""" """
if block_shape is None: if block_shape is None:
# TODO(luka): use QuantFP8 custom op
# https://github.com/vllm-project/vllm/issues/20711
A, A_scale = ops.scaled_fp8_quant( A, A_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_act_token) A, A_scale, use_per_token_if_dynamic=per_act_token)
else: else:
......
...@@ -15,6 +15,9 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -15,6 +15,9 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear) QKVParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, sparse_cutlass_supported) convert_to_channelwise, sparse_cutlass_supported)
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
...@@ -24,6 +27,8 @@ from vllm.model_executor.parameter import (BasevLLMParameter, ...@@ -24,6 +27,8 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
__all__ = ["CompressedTensors24"] __all__ = ["CompressedTensors24"]
from vllm.platforms import current_platform
class CompressedTensors24(CompressedTensorsScheme): class CompressedTensors24(CompressedTensorsScheme):
...@@ -45,6 +50,12 @@ class CompressedTensors24(CompressedTensorsScheme): ...@@ -45,6 +50,12 @@ class CompressedTensors24(CompressedTensorsScheme):
and self.model_compressor.sparsity_config.format and self.model_compressor.sparsity_config.format
== CompressionFormat.sparse_24_bitmask.value) == CompressionFormat.sparse_24_bitmask.value)
if quantized and input_quant is not None and \
self._get_quant_dtype() == current_platform.fp8_dtype():
static = not input_quant.dynamic
g_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
self.quant_fp8 = QuantFP8(static, g_shape)
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
# Only cutlass 3.x kernels are implemented so far # Only cutlass 3.x kernels are implemented so far
...@@ -232,9 +243,7 @@ class CompressedTensors24(CompressedTensorsScheme): ...@@ -232,9 +243,7 @@ class CompressedTensors24(CompressedTensorsScheme):
:return: The output tensor of the layer :return: The output tensor of the layer
""" """
if self.quantized: if self.quantized:
scale = None scale = getattr(layer, 'input_scale', None)
if hasattr(layer, "input_scale"):
scale = layer.input_scale
if self.weights_dtype == torch.int8: if self.weights_dtype == torch.int8:
ops_output = ops.scaled_int8_quant(x, scale=scale) ops_output = ops.scaled_int8_quant(x, scale=scale)
...@@ -242,11 +251,7 @@ class CompressedTensors24(CompressedTensorsScheme): ...@@ -242,11 +251,7 @@ class CompressedTensors24(CompressedTensorsScheme):
input_scale = ops_output[1] input_scale = ops_output[1]
else: else:
assert self.weights_dtype == torch.float8_e4m3fn assert self.weights_dtype == torch.float8_e4m3fn
if scale is not None: q_input, input_scale = self.quant_fp8(x, scale=scale)
q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale)
else:
q_input, input_scale = ops.scaled_fp8_quant(
x, use_per_token_if_dynamic=True)
else: else:
# Not quantized, nothing to do with the input_scales, use as is # Not quantized, nothing to do with the input_scales, use as is
...@@ -269,7 +274,10 @@ class CompressedTensors24(CompressedTensorsScheme): ...@@ -269,7 +274,10 @@ class CompressedTensors24(CompressedTensorsScheme):
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype: def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
if not self.quantized: if not self.quantized:
return params_dtype return params_dtype
return self._get_quant_dtype()
def _get_quant_dtype(self) -> torch.dtype:
assert self.quantized
assert self.weight_quant is not None assert self.weight_quant is not None
assert self.input_quant is not None assert self.input_quant is not None
......
...@@ -9,6 +9,8 @@ from torch.nn import Parameter ...@@ -9,6 +9,8 @@ from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale) requantize_with_max_scale)
...@@ -26,7 +28,11 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -26,7 +28,11 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.strategy = strategy self.strategy = strategy
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) self.act_q_group_shape = GroupShape.PER_TENSOR \
if is_static_input_scheme else GroupShape.PER_TOKEN
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_q_group_shape)
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
......
...@@ -16,7 +16,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -16,7 +16,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped) GroupShape, is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz) Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
...@@ -37,7 +37,6 @@ class FBGEMMFp8Config(QuantizationConfig): ...@@ -37,7 +37,6 @@ class FBGEMMFp8Config(QuantizationConfig):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin # For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization # kernel for fast weight-only FP8 quantization
self.use_marlin = not current_platform.has_device_capability(89) self.use_marlin = not current_platform.has_device_capability(89)
self.fp8_linear = Fp8LinearOp()
@classmethod @classmethod
def get_name(cls) -> QuantizationMethods: def get_name(cls) -> QuantizationMethods:
...@@ -76,7 +75,8 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): ...@@ -76,7 +75,8 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config): def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) self.fp8_linear = Fp8LinearOp(
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN)
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
def create_weights( def create_weights(
......
...@@ -29,7 +29,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( ...@@ -29,7 +29,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
prepare_moe_fp8_layer_for_marlin) prepare_moe_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped) GroupShape, is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported, Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
cutlass_fp8_supported, maybe_create_device_identity, cutlass_fp8_supported, maybe_create_device_identity,
...@@ -202,9 +202,17 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -202,9 +202,17 @@ class Fp8LinearMethod(LinearMethodBase):
and current_platform.is_fp8_fnuz()) and current_platform.is_fp8_fnuz())
self.block_quant = self.quant_config.weight_block_size is not None self.block_quant = self.quant_config.weight_block_size is not None
self.act_q_static = self.quant_config.activation_scheme == "static"
# Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
else:
self.act_q_group_shape = GroupShape.PER_TENSOR
self.fp8_linear = Fp8LinearOp( self.fp8_linear = Fp8LinearOp(
# Default to using per_token quantization if cutlass is supported act_quant_static=self.act_q_static,
use_per_token_if_dynamic=cutlass_fp8_supported()) act_quant_group_shape=self.act_q_group_shape,
cutlass_fp8_supported=cutlass_fp8_supported())
def create_weights( def create_weights(
self, self,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm.
_FP8_DTYPE = current_platform.fp8_dtype()
_FP8_FINFO = torch.finfo(_FP8_DTYPE)
_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max
_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min
_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
@CustomOp.register("quant_fp8")
class QuantFP8(CustomOp):
"""
Quantize input tensor to per-tensor or per-token FP8.
This CustomOp supports both static and dynamic quantization.
"""
def __init__(self,
static: bool,
group_shape: GroupShape,
num_token_padding: Optional[int] = None):
"""
:param static: static or dynamic quantization
:param group_shape: quantization group shape (PER_TOKEN or PER_TENSOR)
:param num_token_padding: Pad the token dimension of output to this size
"""
super().__init__()
self.num_token_padding = num_token_padding
assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR}
assert not static or group_shape == GroupShape.PER_TENSOR, \
"Only per-tensor scales supported for static quantization."
self.static = static
self.group_shape = group_shape
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
def forward_cuda(
self,
x: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_ub: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert (scale is not None) == self.static
assert scale_ub is None or (not self.static and self.group_shape
== GroupShape.PER_TOKEN
and scale_ub.numel() == 1)
return ops.scaled_fp8_quant(
x,
scale,
num_token_padding=self.num_token_padding,
scale_ub=scale_ub,
use_per_token_if_dynamic=self.use_per_token_if_dynamic)
def forward_native(
self,
x: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_ub: Optional[torch.Tensor] = None,
):
assert (scale is not None) == self.static
assert scale_ub is None or (not self.static and self.group_shape
== GroupShape.PER_TOKEN
and scale_ub.numel() == 1)
if scale is None:
if self.group_shape == GroupShape.PER_TOKEN:
x_max, _ = x.abs().max(dim=-1)
x_max = x_max.unsqueeze(-1).to(torch.float32)
if scale_ub is not None:
x_max = x_max.clamp(max=scale_ub)
else:
x_max = x.abs().max().unsqueeze(-1).to(torch.float32)
scale = x_max / _FP8_MAX
scale = scale.clamp(min=_FP8_MIN_SCALING_FACTOR)
# Even for dynamic per-token scales,
# reciprocal performs slightly better than division
out = x.to(torch.float32) * scale.reciprocal()
out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
# This currently generates an extra Triton kernel in compilation.
# Fortunately, we don't use padding if compiling.
# TODO(luka): benchmark torch._scaled_mm to hopefully remove padding
# in general.
if self.num_token_padding is not None:
padding = max(self.num_token_padding - out.size(0), 0)
out = F.pad(out, (0, 0, 0, padding), "constant", 0.0)
return out, scale
...@@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( ...@@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear, is_fp4_marlin_supported, apply_fp4_marlin_linear, is_fp4_marlin_supported,
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped) GroupShape, is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, requantize_with_max_scale) Fp8LinearOp, requantize_with_max_scale)
from vllm.model_executor.parameter import (ModelWeightParameter, from vllm.model_executor.parameter import (ModelWeightParameter,
...@@ -102,7 +102,8 @@ class ModelOptFp8LinearMethod(LinearMethodBase): ...@@ -102,7 +102,8 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptFp8Config): def __init__(self, quant_config: ModelOptFp8Config):
self.quant_config = quant_config self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp() self.fp8_linear = Fp8LinearOp(
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
def create_weights( def create_weights(
self, self,
......
...@@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, ...@@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
Fp8KVCacheMethod, Fp8KVCacheMethod,
Fp8LinearMethod) Fp8LinearMethod)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped) GroupShape, is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp) Fp8LinearOp)
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -95,8 +95,10 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): ...@@ -95,8 +95,10 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
super().__init__(quant_config=quant_config) super().__init__(quant_config=quant_config)
# Force weight quantization # Force weight quantization
self.quant_config.is_checkpoint_fp8_serialized = False self.quant_config.is_checkpoint_fp8_serialized = False
self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=False, self.fp8_linear = Fp8LinearOp(
use_per_token_if_dynamic=True) act_quant_static=False,
cutlass_fp8_supported=False,
act_quant_group_shape=GroupShape.PER_TOKEN)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data, layer.weight = torch.nn.Parameter(layer.weight.data,
......
...@@ -7,6 +7,8 @@ import torch ...@@ -7,6 +7,8 @@ import torch
from torch.nn import Parameter from torch.nn import Parameter
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
...@@ -28,10 +30,14 @@ class QuarkW8A8Fp8(QuarkScheme): ...@@ -28,10 +30,14 @@ class QuarkW8A8Fp8(QuarkScheme):
self.is_static_input_scheme = not cast( self.is_static_input_scheme = not cast(
bool, input_config.get("is_dynamic")) bool, input_config.get("is_dynamic"))
self.input_qscheme = cast(str, input_config.get("qscheme")) self.input_qscheme = cast(str, input_config.get("qscheme"))
self.use_per_token_if_dynamic = (not self.is_static_input_scheme \
per_token = (not self.is_static_input_scheme
and self.input_qscheme == "per_channel") and self.input_qscheme == "per_channel")
self.act_quant_group_shape = GroupShape.PER_TOKEN \
if per_token else GroupShape.PER_TENSOR
self.fp8_linear = Fp8LinearOp( self.fp8_linear = Fp8LinearOp(
use_per_token_if_dynamic=self.use_per_token_if_dynamic) act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_quant_group_shape)
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
@classmethod @classmethod
...@@ -44,7 +50,7 @@ class QuarkW8A8Fp8(QuarkScheme): ...@@ -44,7 +50,7 @@ class QuarkW8A8Fp8(QuarkScheme):
# tensor scales (thus N scales being passed to the kernel), # tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor # requantize so we can always run per tensor
if self.weight_qscheme == "per_tensor": if self.weight_qscheme == "per_tensor":
if current_platform.is_rocm(): if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None) input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight, weight=layer.weight,
...@@ -82,7 +88,7 @@ class QuarkW8A8Fp8(QuarkScheme): ...@@ -82,7 +88,7 @@ class QuarkW8A8Fp8(QuarkScheme):
requires_grad=False) requires_grad=False)
else: else:
weight_scale = layer.weight_scale.data weight_scale = layer.weight_scale.data
if self.use_per_token_if_dynamic: if self.act_quant_group_shape == GroupShape.PER_TOKEN:
weight_scale = weight_scale.view(-1, 1) weight_scale = weight_scale.view(-1, 1)
layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter # required by torch.compile to be torch.nn.Parameter
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"""This file is used for /tests and /benchmarks""" """This file is used for /tests and /benchmarks"""
from collections.abc import Mapping from collections.abc import Mapping
from types import MappingProxyType from types import MappingProxyType
from typing import Optional from typing import ClassVar, NamedTuple, Optional
import numpy import numpy
import torch import torch
...@@ -12,13 +12,30 @@ from vllm.model_executor.layers.quantization.qqq import ( ...@@ -12,13 +12,30 @@ from vllm.model_executor.layers.quantization.qqq import (
MARLIN_QQQ_SUPPORTED_NUM_BITS) MARLIN_QQQ_SUPPORTED_NUM_BITS)
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] # Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int
col: int
class GroupShape(_GroupShape):
"""
This class describes the quantization group shape.
It includes static members for common shapes (per-tensor, per-token).
"""
# Aliases for common quantization group shapes
PER_TENSOR: ClassVar['GroupShape']
PER_TOKEN: ClassVar['GroupShape']
GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
# Normalize the group_shape to the full extent for any dims that are -1 # Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int, def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
int]):
# -1 means full extent # -1 means full extent
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2], return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
group_shape[1] if group_shape[1] > 0 else x.shape[-1]) group_shape[1] if group_shape[1] > 0 else x.shape[-1])
...@@ -58,7 +75,7 @@ def group_broadcast(t, shape): ...@@ -58,7 +75,7 @@ def group_broadcast(t, shape):
# (i.e. per-token-per-group) # (i.e. per-token-per-group)
def scaled_quantize( def scaled_quantize(
x: torch.Tensor, x: torch.Tensor,
group_shape: tuple[int, int], group_shape: GroupShape,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
group_shape = _normalize_quant_group_shape(x, group_shape) group_shape = _normalize_quant_group_shape(x, group_shape)
...@@ -99,7 +116,7 @@ def scaled_quantize( ...@@ -99,7 +116,7 @@ def scaled_quantize(
def scaled_dequantize( def scaled_dequantize(
x_q: torch.Tensor, x_q: torch.Tensor,
x_s: torch.Tensor, x_s: torch.Tensor,
group_shape: Optional[tuple[int, int]] = None, group_shape: Optional[GroupShape] = None,
out_dtype: torch.dtype = torch.float32, out_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if group_shape is not None: if group_shape is not None:
...@@ -332,6 +349,10 @@ def quantize_weights(w: torch.Tensor, ...@@ -332,6 +349,10 @@ def quantize_weights(w: torch.Tensor,
) )
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
def gptq_quantize_weights(w: torch.Tensor, def gptq_quantize_weights(w: torch.Tensor,
quant_type: ScalarType, quant_type: ScalarType,
group_size: int, group_size: int,
......
...@@ -8,6 +8,9 @@ import torch ...@@ -8,6 +8,9 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs from vllm import envs
from vllm.config import CompilationLevel, get_current_vllm_config from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform from vllm.platforms import current_platform
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
...@@ -271,20 +274,21 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, ...@@ -271,20 +274,21 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
def dispatch_w8a8_scaled_mm( def dispatch_w8a8_scaled_mm(
cutlass_fp8_supported: bool, per_tensor_weights: bool, cutlass_fp8_supported: bool, per_tensor_weights: bool,
per_tensor_activations: bool, use_per_token_if_dynamic: Optional[bool] per_tensor_activations: bool) -> Callable[..., torch.Tensor]:
) -> Callable[..., torch.Tensor]:
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if cutlass_fp8_supported: if cutlass_fp8_supported:
return cutlass_w8a8_scaled_mm return cutlass_w8a8_scaled_mm
if per_tensor_weights and per_tensor_activations: if per_tensor_weights and per_tensor_activations:
if current_platform.is_rocm(): if current_platform.is_rocm():
return rocm_per_tensor_w8a8_scaled_mm return rocm_per_tensor_w8a8_scaled_mm
return torch_per_tensor_w8a8_scaled_mm return torch_per_tensor_w8a8_scaled_mm
# torch.scaled_mm supports per tensor weights + activations only # If torch.scaled_mm supports per-channel (weights) per-token (inputs)
# so fallback to naive if per channel or per token if not per_tensor_weights and not per_tensor_activations \
if (use_per_token_if_dynamic and not per_tensor_weights and USE_ROWWISE_TORCH_SCALED_MM:
and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM):
return torch_per_token_w8a8_scaled_mm return torch_per_token_w8a8_scaled_mm
# Normally, torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
return torch_channelwise_w8a8_scaled_mm return torch_channelwise_w8a8_scaled_mm
...@@ -299,11 +303,11 @@ class Fp8LinearOp: ...@@ -299,11 +303,11 @@ class Fp8LinearOp:
""" """
def __init__(self, def __init__(self,
act_quant_static: bool,
cutlass_fp8_supported: bool = cutlass_fp8_supported(), cutlass_fp8_supported: bool = cutlass_fp8_supported(),
use_per_token_if_dynamic: bool = False, act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
pad_output: Optional[bool] = None): pad_output: Optional[bool] = None):
self.cutlass_fp8_supported = cutlass_fp8_supported self.cutlass_fp8_supported = cutlass_fp8_supported
self.use_per_token_if_dynamic = use_per_token_if_dynamic
# Note: we pad the input because torch._scaled_mm is more performant # Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16. # for matrices with batch dimension > 16.
...@@ -312,9 +316,16 @@ class Fp8LinearOp: ...@@ -312,9 +316,16 @@ class Fp8LinearOp:
# as it breaks with dynamic shapes. # as it breaks with dynamic shapes.
if pad_output is None: if pad_output is None:
config = get_current_vllm_config().compilation_config config = get_current_vllm_config().compilation_config
pad_output = config.level < CompilationLevel.PIECEWISE pad_output = config.level < CompilationLevel.PIECEWISE and \
self.output_padding = 17 if ( not cutlass_fp8_supported and \
pad_output and not current_platform.is_rocm()) else None not current_platform.is_rocm()
self.output_padding = 17 if pad_output else None
self.act_quant_static = act_quant_static
self.act_quant_group_shape = act_quant_group_shape
self.quant_fp8 = QuantFP8(static=act_quant_static,
group_shape=act_quant_group_shape,
num_token_padding=self.output_padding)
def apply( def apply(
self, self,
...@@ -325,8 +336,6 @@ class Fp8LinearOp: ...@@ -325,8 +336,6 @@ class Fp8LinearOp:
input_scale: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None,
input_scale_ub: Optional[torch.Tensor] = None, input_scale_ub: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
# TODO(luka) remove this parameter in favor of __init__
use_per_token_if_dynamic: Optional[bool] = None
) -> torch.Tensor: ) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant. # ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x. # If dynamic, layer.input_scale is None and x_scale computed from x.
...@@ -336,40 +345,27 @@ class Fp8LinearOp: ...@@ -336,40 +345,27 @@ class Fp8LinearOp:
input_2d = input.view(-1, input.shape[-1]) input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[1]] output_shape = [*input.shape[:-1], weight.shape[1]]
# TODO(luka) this is here because currently MLA only decides this
# during the forward method instead of in __init__.
if use_per_token_if_dynamic is None:
use_per_token_if_dynamic = self.use_per_token_if_dynamic
if out_dtype is None: if out_dtype is None:
out_dtype = input.dtype out_dtype = input.dtype
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A # If input not quantized
if self.cutlass_fp8_supported: # TODO(luka) remove this path if not used anymore
assert input.dtype != current_platform.fp8_dtype(
), "FP8 input to cutlass is not currently implemented"
qinput, x_scale = ops.scaled_fp8_quant(
input_2d,
input_scale,
scale_ub=input_scale_ub,
use_per_token_if_dynamic=use_per_token_if_dynamic)
else:
if input.dtype != current_platform.fp8_dtype(): if input.dtype != current_platform.fp8_dtype():
# Maybe apply padding to output, see comment in __init__ qinput, x_scale = self.quant_fp8(
qinput, x_scale = ops.scaled_fp8_quant(
input_2d, input_2d,
input_scale, input_scale,
num_token_padding=self.output_padding, input_scale_ub,
use_per_token_if_dynamic=use_per_token_if_dynamic) )
else: else:
qinput, x_scale = input_2d, input_scale qinput, x_scale = input_2d, input_scale
per_tensor_weights = (weight_scale.numel() == 1) per_tensor_weights = (weight_scale.numel() == 1)
per_tensor_activations = (x_scale.numel() == 1) per_tensor_activations = (x_scale.numel() == 1)
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
self.cutlass_fp8_supported, per_tensor_weights, self.cutlass_fp8_supported, per_tensor_weights,
per_tensor_activations, use_per_token_if_dynamic) per_tensor_activations)
return w8a8_scaled_mm_func(qinput=qinput, return w8a8_scaled_mm_func(qinput=qinput,
weight=weight, weight=weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment