Commit 587a5c60 authored by zhaosong1's avatar zhaosong1
Browse files

[feature] add scaled_fp8_quant_weight for online ptpc_fp8 quant.

parent f1eb27b8
...@@ -1386,6 +1386,68 @@ def scaled_fp8_quant( ...@@ -1386,6 +1386,68 @@ def scaled_fp8_quant(
optional padding of the output tensors for downstream kernels that optional padding of the output tensors for downstream kernels that
will benefit from padding. will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape: Union[tuple[int, int], torch.Size] = input.shape
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
if scale is None:
if use_per_token_if_dynamic:
scale = torch.empty((shape[0], 1),
device=input.device,
dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
output, scale = per_token_quant_fp8(input.contiguous())
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
assert scale.numel() == 1, f"{scale.shape}"
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale
def scaled_fp8_quant_weight(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args: Args:
input: The input tensor to be quantized to FP8 input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization scale: Optional scaling factor for the FP8 quantization
...@@ -1421,7 +1483,7 @@ def scaled_fp8_quant( ...@@ -1421,7 +1483,7 @@ def scaled_fp8_quant(
dtype=torch.float32) dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant( torch.ops._C.dynamic_per_token_scaled_fp8_quant(
output, input.contiguous(), scale, scale_ub) output, input.contiguous(), scale, scale_ub)
# per_token_quant_fp8 has precision problem. # per_token_quant_fp8 has precision problem for online weight quant.
# output, scale = per_token_quant_fp8(input.contiguous()) # output, scale = per_token_quant_fp8(input.contiguous())
else: else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32) scale = torch.zeros(1, device=input.device, dtype=torch.float32)
...@@ -1431,6 +1493,7 @@ def scaled_fp8_quant( ...@@ -1431,6 +1493,7 @@ def scaled_fp8_quant(
torch.ops._C.static_scaled_fp8_quant(output, input, scale) torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale return output, scale
# gptq allspark # gptq allspark
def allspark_repack_weight( def allspark_repack_weight(
qweight: torch.Tensor, qweight: torch.Tensor,
......
...@@ -107,7 +107,7 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): ...@@ -107,7 +107,7 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
assert layer.weight.data.dtype == torch.bfloat16, \ assert layer.weight.data.dtype == torch.bfloat16, \
f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501 f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501
# Quantize the weights. # Quantize the weights.
qweight, weight_scale = ops.scaled_fp8_quant( qweight, weight_scale = ops.scaled_fp8_quant_weight(
layer.weight, scale=None, use_per_token_if_dynamic=True) layer.weight, scale=None, use_per_token_if_dynamic=True)
# Update the layer with the new values. # Update the layer with the new values.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment