Commit 94c8a620 authored by wanglong3's avatar wanglong3 Committed by zhuwenwen
Browse files

feat: Supprot fp8 channle-wise matmul.

parent 4dcfd0ae
...@@ -16,6 +16,7 @@ from vllm.utils import direct_register_custom_op ...@@ -16,6 +16,7 @@ from vllm.utils import direct_register_custom_op
try: try:
from lmslim import quant_ops from lmslim import quant_ops
from lmslim import quant_tools from lmslim import quant_tools
from lmslim.layers.gemm.fp8_utils import per_token_quant_fp8
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n") print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
try: try:
...@@ -1350,70 +1351,67 @@ def scaled_fp4_experts_quant( ...@@ -1350,70 +1351,67 @@ def scaled_fp4_experts_quant(
output_scales = output_scales.view(torch.float8_e4m3fn) output_scales = output_scales.view(torch.float8_e4m3fn)
return output, output_scales return output, output_scales
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensors for downstream kernels that
will benefit from padding.
# fp8 Args:
# def scaled_fp8_quant( input: The input tensor to be quantized to FP8
# input: torch.Tensor, scale: Optional scaling factor for the FP8 quantization
# scale: Optional[torch.Tensor] = None, scale_ub: Optional upper bound for scaling factor in dynamic
# num_token_padding: Optional[int] = None, per token case
# scale_ub: Optional[torch.Tensor] = None, num_token_padding: If specified, pad the first dimension
# use_per_token_if_dynamic: bool = False, of the output to at least this value.
# output: Optional[torch.Tensor] = None, use_per_token_if_dynamic: Whether to do per_tensor or per_token
# ) -> tuple[torch.Tensor, torch.Tensor]: in the dynamic quantization case.
# """
# Quantize input tensor to FP8 and return quantized tensor and scale.
# This function supports both static and dynamic quantization: If you
# provide the scale, it will use static scaling and if you omit it,
# the scale will be determined dynamically. The function also allows
# optional padding of the output tensors for downstream kernels that
# will benefit from padding.
# Args:
# input: The input tensor to be quantized to FP8
# scale: Optional scaling factor for the FP8 quantization
# scale_ub: Optional upper bound for scaling factor in dynamic
# per token case
# num_token_padding: If specified, pad the first dimension
# of the output to at least this value.
# use_per_token_if_dynamic: Whether to do per_tensor or per_token
# in the dynamic quantization case.
# Returns:
# tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
# scaling factor.
# """
# # This code assumes batch_dim and num_tokens are flattened
# assert (input.ndim == 2)
# shape: Union[tuple[int, int], torch.Size] = input.shape
# # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
# out_dtype: torch.dtype = current_platform.fp8_dtype()
# if num_token_padding:
# shape = (max(num_token_padding, input.shape[0]), shape[1])
# if output is None:
# output = torch.empty(shape, device=input.device, dtype=out_dtype)
# else:
# assert num_token_padding is None, \
# "padding not supported if output passed in"
# assert output.dtype == out_dtype
# if scale is None:
# if use_per_token_if_dynamic:
# scale = torch.empty((shape[0], 1),
# device=input.device,
# dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input, scale, scale_ub)
# else:
# scale = torch.empty(1, device=input.device, dtype=torch.float32)
# torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
# else:
# assert scale.numel() == 1, f"{scale.shape}"
# torch.ops._C.static_scaled_fp8_quant(output, input, scale)
# return output, scale
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape: Union[tuple[int, int], torch.Size] = input.shape
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
if scale is None:
if use_per_token_if_dynamic:
scale = torch.empty((shape[0], 1),
device=input.device,
dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
output, scale = per_token_quant_fp8(input.contiguous())
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
assert scale.numel() == 1, f"{scale.shape}"
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale
# gptq allspark # gptq allspark
def allspark_repack_weight( def allspark_repack_weight(
qweight: torch.Tensor, qweight: torch.Tensor,
......
...@@ -171,6 +171,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -171,6 +171,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
**_,
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
......
...@@ -139,7 +139,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -139,7 +139,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def apply_weights(self, def apply_weights(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None,input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None, **_,
) -> torch.Tensor:
if layer.weight_block_size is not None: if layer.weight_block_size is not None:
return apply_fp8_block_linear( return apply_fp8_block_linear(
......
...@@ -18,6 +18,7 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer ...@@ -18,6 +18,7 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
try: try:
from lmslim.layers.gemm.int8_utils import per_token_quant_int8 from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from lmslim.quantize import quant_ops
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
...@@ -291,6 +292,36 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, ...@@ -291,6 +292,36 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
output = output.view(*output_shape) output = output.view(*output_shape)
return output return output
def hipblaslt_w8a8_channelwise_scaled_mm(
qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
) -> torch.Tensor:
assert qinput.is_contiguous() and weight.is_contiguous()
assert qinput.shape[-1] == weight.shape[-1]
assert qinput.dtype == weight.dtype
m = qinput.shape[0]
k = qinput.shape[1]
n = weight.shape[0]
success, output = quant_ops.hipblaslt_w8a8_channelwise_gemm(
a = qinput,
b = weight,
scale_a = scale_a,
scale_b = scale_b,
m = m,
n = n,
k = k,
transpose_flag = "NT",
out_dtype = out_dtype,
bias = bias,
)
return output.view(m, n)
def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
...@@ -336,11 +367,9 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, ...@@ -336,11 +367,9 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
output = output + bias output = output + bias
return output.to(out_dtype).view(*output_shape) return output.to(out_dtype).view(*output_shape)
def dispatch_w8a8_scaled_mm( def dispatch_w8a8_scaled_mm(
preferred_backend: str, per_tensor_weights: bool, preferred_backend: str, per_tensor_weights: bool,
per_tensor_activations: bool) -> Callable[..., torch.Tensor]: per_tensor_activations: bool) -> Callable[..., torch.Tensor]:
if per_tensor_weights and per_tensor_activations: if per_tensor_weights and per_tensor_activations:
if preferred_backend == "rocm": if preferred_backend == "rocm":
return rocm_per_tensor_w8a8_scaled_mm return rocm_per_tensor_w8a8_scaled_mm
...@@ -354,6 +383,9 @@ def dispatch_w8a8_scaled_mm( ...@@ -354,6 +383,9 @@ def dispatch_w8a8_scaled_mm(
if preferred_backend == "cutlass" or preferred_backend == "flashinfer": if preferred_backend == "cutlass" or preferred_backend == "flashinfer":
return cutlass_w8a8_scaled_mm return cutlass_w8a8_scaled_mm
if preferred_backend == "blaslt":
return hipblaslt_w8a8_channelwise_scaled_mm
# If torch.scaled_mm supports per-channel (weights) per-token (inputs) # If torch.scaled_mm supports per-channel (weights) per-token (inputs)
if not per_tensor_weights and not per_tensor_activations \ if not per_tensor_weights and not per_tensor_activations \
and USE_ROWWISE_TORCH_SCALED_MM: and USE_ROWWISE_TORCH_SCALED_MM:
...@@ -378,7 +410,11 @@ class Fp8LinearOp: ...@@ -378,7 +410,11 @@ class Fp8LinearOp:
act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
pad_output: Optional[bool] = None): pad_output: Optional[bool] = None):
if current_platform.is_rocm(): if current_platform.is_rocm():
if envs.VLLM_W8A8_BACKEND == 3:
self.preferred_backend = "blaslt"
else:
self.preferred_backend = "rocm" self.preferred_backend = "rocm"
elif current_platform.is_cuda() and cutlass_fp8_supported(): elif current_platform.is_cuda() and cutlass_fp8_supported():
if has_flashinfer() and current_platform.has_device_capability( if has_flashinfer() and current_platform.has_device_capability(
100): 100):
...@@ -429,11 +465,12 @@ class Fp8LinearOp: ...@@ -429,11 +465,12 @@ class Fp8LinearOp:
# If input not quantized # If input not quantized
# TODO(luka) remove this path if not used anymore # TODO(luka) remove this path if not used anymore
if input.dtype != current_platform.fp8_dtype(): if input.dtype != current_platform.fp8_dtype():
qinput, x_scale = self.quant_fp8( qinput, x_scale = ops.scaled_fp8_quant(
input_2d, input = input_2d,
input_scale, scale = input_scale,
input_scale_ub, num_token_padding = self.output_padding,
) scale_ub = input_scale_ub,
use_per_token_if_dynamic = True)
else: else:
qinput, x_scale = input_2d, input_scale qinput, x_scale = input_2d, input_scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment