Commit 0636f239 authored by lixh6's avatar lixh6
Browse files

feat:适配Blaslt Channelwise gemm

parent 440222e9
...@@ -19,6 +19,7 @@ from vllm.utils.torch_utils import direct_register_custom_op ...@@ -19,6 +19,7 @@ from vllm.utils.torch_utils import direct_register_custom_op
try: try:
from lmslim import quant_ops from lmslim import quant_ops
from lmslim import quant_tools from lmslim import quant_tools
from lmslim.layers.gemm.fp8_utils import per_token_quant_fp8
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n") print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
try: try:
...@@ -1878,6 +1879,67 @@ def scaled_fp4_experts_quant( ...@@ -1878,6 +1879,67 @@ def scaled_fp4_experts_quant(
output_scales = output_scales.view(torch.float8_e4m3fn) output_scales = output_scales.view(torch.float8_e4m3fn)
return output, output_scales return output, output_scales
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
group_shape: Optional[tuple[int, int]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape: Union[tuple[int, int], torch.Size] = input.shape
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
if scale is None:
if use_per_token_if_dynamic:
scale = torch.empty((shape[0], 1),
device=input.device,
dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
output, scale = per_token_quant_fp8(input.contiguous())
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
assert scale.numel() == 1, f"{scale.shape}"
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale
def silu_and_mul_scaled_fp4_experts_quant( def silu_and_mul_scaled_fp4_experts_quant(
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable from collections.abc import Callable
from typing import Optional
from vllm import envs
import torch import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from torch.nn import Parameter from torch.nn import Parameter
...@@ -40,7 +41,6 @@ from vllm.model_executor.parameter import ( ...@@ -40,7 +41,6 @@ from vllm.model_executor.parameter import (
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
) )
__all__ = ["CompressedTensorsW8A8Fp8"] __all__ = ["CompressedTensorsW8A8Fp8"]
strategy_to_parameter_type = { strategy_to_parameter_type = {
...@@ -159,8 +159,10 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -159,8 +159,10 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight, weight_scale, input_scale = process_fp8_weight_channel_strategy( weight, weight_scale, input_scale = process_fp8_weight_channel_strategy(
layer.weight, layer.weight_scale, getattr(layer, "input_scale", None) layer.weight, layer.weight_scale, getattr(layer, "input_scale", None)
) )
if envs.VLLM_W8A8_BACKEND == 3:
weight = weight.t().contiguous()
else:
weight = weight.t() weight = weight.t()
elif self.strategy == QuantizationStrategy.BLOCK: elif self.strategy == QuantizationStrategy.BLOCK:
assert self.is_static_input_scheme is False assert self.is_static_input_scheme is False
weight, weight_scale = process_fp8_weight_block_strategy( weight, weight_scale = process_fp8_weight_block_strategy(
...@@ -193,6 +195,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -193,6 +195,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None, **_,
) -> torch.Tensor: ) -> torch.Tensor:
if self.weight_block_size is not None: if self.weight_block_size is not None:
return self.w8a8_block_fp8_linear.apply( return self.w8a8_block_fp8_linear.apply(
......
...@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Optional
import torch import torch
from torch.nn import Module from torch.nn import Module
from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._python_dispatch import TorchDispatchMode
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -1027,6 +1026,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1027,6 +1026,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor, x: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None assert self.kernel is not None
assert not self.is_monolithic assert not self.is_monolithic
......
...@@ -12,7 +12,11 @@ from .ScaledMMLinearKernel import ( ...@@ -12,7 +12,11 @@ from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel, FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig,
) )
try:
from lmslim.quantize.quant_ops import hipblaslt_w8a8_channelwise_gemm
except ImportError:
print("INFO: Please updata lmslim if you want to use fp8_utils.\n")
from vllm import envs
class TorchFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): class TorchFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
""" """
...@@ -176,46 +180,31 @@ class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel): ...@@ -176,46 +180,31 @@ class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
bias: torch.Tensor | None, bias: torch.Tensor | None,
output_shape: list, output_shape: list,
) -> torch.Tensor: ) -> torch.Tensor:
# Use unfused DQ due to limitations with scaled_mm m = A.shape[0]
k = A.shape[1]
# Symmetric quantized GEMM by definition computes the following: n = B.shape[0]
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations if envs.VLLM_W8A8_BACKEND == 3:
# before applying a GEMM. _, output = hipblaslt_w8a8_channelwise_gemm(
# a=A,
# In order to compute quantized operands, a quantized kernel b=B,
# will rewrite the above like so: scale_a=As,
# C = s_w * s_x * (X * W) + bias scale_b=Bs,
# m=m,
# For the scaled_mm fallback case, we break this down, since it n=n,
# does not support s_w being a vector. k=k,
transpose_flag="NT",
# Input scaling factors are no longer optional in _scaled_mm starting out_dtype=out_dtype,
# from pytorch 2.5. Allocating a dummy tensor to pass as scales bias=bias,
dummy_tensor = torch.ones(1, dtype=torch.float32, device=A.device) )
return output.view(m, n)
# GEMM else:
# This computes C = (X * W). output = triton_scaled_mm_fp8(
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(
A, A,
B, B,
scale_a=dummy_tensor, scale_a=As,
scale_b=dummy_tensor, scale_b=Bs,
out_dtype=torch.float32, out_dtype=out_dtype,
bias=bias,
) )
# A fix for discrepancy in scaled_mm which returns tuple return output.view(*output_shape)
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, output_shape[0])
x_scale = torch.narrow(As, 0, 0, output_shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * Bs.t()
if bias is not None:
output = output + bias
return output.to(out_dtype).view(*output_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment