Commit 94c8a620 authored by wanglong3's avatar wanglong3 Committed by zhuwenwen
Browse files

feat: Supprot fp8 channle-wise matmul.

parent 4dcfd0ae
...@@ -16,6 +16,7 @@ from vllm.utils import direct_register_custom_op ...@@ -16,6 +16,7 @@ from vllm.utils import direct_register_custom_op
try: try:
from lmslim import quant_ops from lmslim import quant_ops
from lmslim import quant_tools from lmslim import quant_tools
from lmslim.layers.gemm.fp8_utils import per_token_quant_fp8
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n") print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
try: try:
...@@ -1350,70 +1351,67 @@ def scaled_fp4_experts_quant( ...@@ -1350,70 +1351,67 @@ def scaled_fp4_experts_quant(
output_scales = output_scales.view(torch.float8_e4m3fn) output_scales = output_scales.view(torch.float8_e4m3fn)
return output, output_scales return output, output_scales
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
# fp8 This function supports both static and dynamic quantization: If you
# def scaled_fp8_quant( provide the scale, it will use static scaling and if you omit it,
# input: torch.Tensor, the scale will be determined dynamically. The function also allows
# scale: Optional[torch.Tensor] = None, optional padding of the output tensors for downstream kernels that
# num_token_padding: Optional[int] = None, will benefit from padding.
# scale_ub: Optional[torch.Tensor] = None,
# use_per_token_if_dynamic: bool = False,
# output: Optional[torch.Tensor] = None,
# ) -> tuple[torch.Tensor, torch.Tensor]:
# """
# Quantize input tensor to FP8 and return quantized tensor and scale.
# This function supports both static and dynamic quantization: If you
# provide the scale, it will use static scaling and if you omit it,
# the scale will be determined dynamically. The function also allows
# optional padding of the output tensors for downstream kernels that
# will benefit from padding.
# Args:
# input: The input tensor to be quantized to FP8
# scale: Optional scaling factor for the FP8 quantization
# scale_ub: Optional upper bound for scaling factor in dynamic
# per token case
# num_token_padding: If specified, pad the first dimension
# of the output to at least this value.
# use_per_token_if_dynamic: Whether to do per_tensor or per_token
# in the dynamic quantization case.
# Returns:
# tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
# scaling factor.
# """
# # This code assumes batch_dim and num_tokens are flattened
# assert (input.ndim == 2)
# shape: Union[tuple[int, int], torch.Size] = input.shape
# # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
# out_dtype: torch.dtype = current_platform.fp8_dtype()
# if num_token_padding:
# shape = (max(num_token_padding, input.shape[0]), shape[1])
# if output is None:
# output = torch.empty(shape, device=input.device, dtype=out_dtype)
# else:
# assert num_token_padding is None, \
# "padding not supported if output passed in"
# assert output.dtype == out_dtype
# if scale is None:
# if use_per_token_if_dynamic:
# scale = torch.empty((shape[0], 1),
# device=input.device,
# dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input, scale, scale_ub)
# else:
# scale = torch.empty(1, device=input.device, dtype=torch.float32)
# torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
# else:
# assert scale.numel() == 1, f"{scale.shape}"
# torch.ops._C.static_scaled_fp8_quant(output, input, scale)
# return output, scale
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape: Union[tuple[int, int], torch.Size] = input.shape
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
if scale is None:
if use_per_token_if_dynamic:
scale = torch.empty((shape[0], 1),
device=input.device,
dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
output, scale = per_token_quant_fp8(input.contiguous())
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
assert scale.numel() == 1, f"{scale.shape}"
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale
# gptq allspark # gptq allspark
def allspark_repack_weight( def allspark_repack_weight(
qweight: torch.Tensor, qweight: torch.Tensor,
......
...@@ -171,6 +171,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -171,6 +171,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
**_,
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
......
...@@ -139,7 +139,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -139,7 +139,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def apply_weights(self, def apply_weights(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None,input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None, **_,
) -> torch.Tensor:
if layer.weight_block_size is not None: if layer.weight_block_size is not None:
return apply_fp8_block_linear( return apply_fp8_block_linear(
......
...@@ -18,6 +18,7 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer ...@@ -18,6 +18,7 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
try: try:
from lmslim.layers.gemm.int8_utils import per_token_quant_int8 from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from lmslim.quantize import quant_ops
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
...@@ -291,7 +292,37 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, ...@@ -291,7 +292,37 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
output = output.view(*output_shape) output = output.view(*output_shape)
return output return output
def hipblaslt_w8a8_channelwise_scaled_mm(
qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
) -> torch.Tensor:
assert qinput.is_contiguous() and weight.is_contiguous()
assert qinput.shape[-1] == weight.shape[-1]
assert qinput.dtype == weight.dtype
m = qinput.shape[0]
k = qinput.shape[1]
n = weight.shape[0]
success, output = quant_ops.hipblaslt_w8a8_channelwise_gemm(
a = qinput,
b = weight,
scale_a = scale_a,
scale_b = scale_b,
m = m,
n = n,
k = k,
transpose_flag = "NT",
out_dtype = out_dtype,
bias = bias,
)
return output.view(m, n)
def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
out_dtype: torch.dtype, out_dtype: torch.dtype,
...@@ -336,11 +367,9 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, ...@@ -336,11 +367,9 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
output = output + bias output = output + bias
return output.to(out_dtype).view(*output_shape) return output.to(out_dtype).view(*output_shape)
def dispatch_w8a8_scaled_mm( def dispatch_w8a8_scaled_mm(
preferred_backend: str, per_tensor_weights: bool, preferred_backend: str, per_tensor_weights: bool,
per_tensor_activations: bool) -> Callable[..., torch.Tensor]: per_tensor_activations: bool) -> Callable[..., torch.Tensor]:
if per_tensor_weights and per_tensor_activations: if per_tensor_weights and per_tensor_activations:
if preferred_backend == "rocm": if preferred_backend == "rocm":
return rocm_per_tensor_w8a8_scaled_mm return rocm_per_tensor_w8a8_scaled_mm
...@@ -353,6 +382,9 @@ def dispatch_w8a8_scaled_mm( ...@@ -353,6 +382,9 @@ def dispatch_w8a8_scaled_mm(
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if preferred_backend == "cutlass" or preferred_backend == "flashinfer": if preferred_backend == "cutlass" or preferred_backend == "flashinfer":
return cutlass_w8a8_scaled_mm return cutlass_w8a8_scaled_mm
if preferred_backend == "blaslt":
return hipblaslt_w8a8_channelwise_scaled_mm
# If torch.scaled_mm supports per-channel (weights) per-token (inputs) # If torch.scaled_mm supports per-channel (weights) per-token (inputs)
if not per_tensor_weights and not per_tensor_activations \ if not per_tensor_weights and not per_tensor_activations \
...@@ -378,7 +410,11 @@ class Fp8LinearOp: ...@@ -378,7 +410,11 @@ class Fp8LinearOp:
act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
pad_output: Optional[bool] = None): pad_output: Optional[bool] = None):
if current_platform.is_rocm(): if current_platform.is_rocm():
self.preferred_backend = "rocm" if envs.VLLM_W8A8_BACKEND == 3:
self.preferred_backend = "blaslt"
else:
self.preferred_backend = "rocm"
elif current_platform.is_cuda() and cutlass_fp8_supported(): elif current_platform.is_cuda() and cutlass_fp8_supported():
if has_flashinfer() and current_platform.has_device_capability( if has_flashinfer() and current_platform.has_device_capability(
100): 100):
...@@ -429,11 +465,12 @@ class Fp8LinearOp: ...@@ -429,11 +465,12 @@ class Fp8LinearOp:
# If input not quantized # If input not quantized
# TODO(luka) remove this path if not used anymore # TODO(luka) remove this path if not used anymore
if input.dtype != current_platform.fp8_dtype(): if input.dtype != current_platform.fp8_dtype():
qinput, x_scale = self.quant_fp8( qinput, x_scale = ops.scaled_fp8_quant(
input_2d, input = input_2d,
input_scale, scale = input_scale,
input_scale_ub, num_token_padding = self.output_padding,
) scale_ub = input_scale_ub,
use_per_token_if_dynamic = True)
else: else:
qinput, x_scale = input_2d, input_scale qinput, x_scale = input_2d, input_scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment