Unverified Commit 0dd6cda2 authored by HandH1998's avatar HandH1998 Committed by GitHub
Browse files

Apply sgl w8a8 fp8 kernel (#3148)

parent 9fb48f95
......@@ -250,9 +250,11 @@ class ModelConfig:
"compressed-tensors",
"experts_int8",
"w8a8_int8",
"w8a8_fp8",
]
compatible_quantization_methods = {
"w8a8_int8": ["compressed-tensors", "compressed_tensors"]
"w8a8_int8": ["compressed-tensors", "compressed_tensors"],
"w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
}
if self.quantization is not None:
self.quantization = self.quantization.lower()
......
......@@ -18,6 +18,7 @@ from sglang.srt.distributed import (
)
from sglang.srt.layers.parameter import (
BasevLLMParameter,
BlockQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
PerTensorScaleParameter,
......@@ -27,7 +28,6 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
from sglang.srt.utils import set_weight_attrs
logger = logging.getLogger(__name__)
......
......@@ -16,6 +16,7 @@ __all__ = [
"ModelWeightParameter",
"ChannelQuantScaleParameter",
"GroupQuantScaleParameter",
"BlockQuantScaleParameter",
"PackedColumnParameter",
"RowvLLMParameter",
]
......@@ -221,6 +222,15 @@ class ChannelQuantScaleParameter(_ColumnvLLMParameter):
pass
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
Parameter class for weight scales loaded for weights with
block-wise quantization. Uses both column and row parallelism.
"""
pass
class PerTensorScaleParameter(BasevLLMParameter):
"""
Parameter class for scales where the number of scales is
......
......@@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
......@@ -50,6 +51,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"qqq": QQQConfig,
"experts_int8": ExpertsInt8Config,
"w8a8_int8": W8A8Int8Config,
"w8a8_fp8": W8A8Fp8Config,
}
......
......@@ -13,12 +13,11 @@ from sglang.srt.layers.linear import (
LinearMethodBase,
UnquantizedLinearMethod,
)
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
from sglang.srt.utils import set_weight_attrs
......
......@@ -16,9 +16,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d,
apply_fp8_linear,
convert_to_channelwise,
cutlass_fp8_supported,
per_tensor_dequantize,
requantize_with_max_scale,
)
......@@ -29,14 +27,21 @@ from sglang.srt.layers.linear import (
LinearMethodBase,
UnquantizedLinearMethod,
)
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.parameter import (
BlockQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.layers.quantization.fp8_utils import (
BlockQuantScaleParameter,
apply_fp8_linear,
apply_w8a8_block_fp8_linear,
cutlass_fp8_supported,
input_to_float8,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.utils import (
......@@ -305,15 +310,15 @@ class Fp8LinearMethod(LinearMethodBase):
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
assert weight_scale.numel() == 1
weight_scale = convert_to_channelwise(
weight_scale.expand(len(layer.logical_widths)), layer.logical_widths
if self.cutlass_fp8_supported or self.use_marlin:
# apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
qweight, weight_scale = per_token_group_quant_fp8(
layer.weight, layer.weight.shape[-1]
)
weight_scale = weight_scale.t().contiguous()
else:
# per-tensor quantization
qweight, weight_scale = input_to_float8(layer.weight)
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
......@@ -330,23 +335,19 @@ class Fp8LinearMethod(LinearMethodBase):
layer.input_scale = torch.nn.Parameter(
layer.input_scale.data, requires_grad=False
)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
# cutlass sgl-kernel and marlin only support per-channel scale
if self.cutlass_fp8_supported or self.use_marlin:
weight = layer.weight
weight_scale = convert_to_channelwise(
layer.weight_scale, layer.logical_widths
)
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
else:
# Dequant -> Quant with max scale so we can run per tensor.
weight = layer.weight
weight_scale = layer.weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip_:
if is_hip():
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=weight_scale,
......
......@@ -29,7 +29,7 @@ fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
_is_cuda = torch.cuda.is_available() and torch.version.cuda
if _is_cuda:
from sgl_kernel import sgl_per_token_group_quant_fp8
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
logger = logging.getLogger(__name__)
......@@ -70,7 +70,8 @@ def _per_token_group_quant_fp8(
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
y_s_inv = 1.0 / y_s
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
......@@ -140,7 +141,7 @@ def per_token_group_quant_fp8(
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
dtype: The dype of output tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
......@@ -241,6 +242,132 @@ def sglang_per_token_group_quant_fp8(
return x_q, x_s
def sglang_per_token_quant_fp8(
x: torch.Tensor,
dtype: torch.dtype = fp8_type_,
):
assert x.is_contiguous(), "`x` is not contiguous"
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
x_s = torch.empty(
x.shape[0],
1,
device=x.device,
dtype=torch.float32,
)
sgl_per_token_quant_fp8(x, x_q, x_s)
return x_q, x_s
@triton.jit
def _static_quant_fp8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
y_s_repeat_ptr,
# Stride of input
y_stride,
# Collums of input
N,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
REPEAT_SCALE: tl.constexpr,
):
"""A Triton-accelerated function to perform quantization using the given scale on a
tensor
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * y_stride
y_q_ptr += g_id * y_stride
if REPEAT_SCALE:
y_s_repeat_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < N
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
y_s = tl.load(y_s_ptr).to(tl.float32)
y_s_inv = 1.0 / y_s
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
if REPEAT_SCALE:
tl.store(y_s_repeat_ptr, y_s)
def static_quant_fp8(
x: torch.Tensor,
x_s: torch.Tensor,
repeat_scale: bool = False,
dtype: torch.dtype = fp8_type_,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform static quantization using the given scale on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
x_s: The quantization scale.
repeat_scale: Whether to broadcast per-tensor scale to per-channel scale.
dtype: The dype of output tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
assert x.is_contiguous(), "`x` is not contiguous"
assert x_s.numel() == 1, "only supports per-tensor scale"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if is_hip_:
fp8_max = 224.0
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // x.shape[-1]
N = x.shape[-1]
if repeat_scale:
x_s_repeat = torch.empty(
(M, 1),
device=x.device,
dtype=torch.float32,
)
else:
x_s_repeat = None
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_static_quant_fp8[(M,)](
x,
x_q,
x_s,
x_s_repeat,
N,
N,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
REPEAT_SCALE=repeat_scale,
num_warps=num_warps,
num_stages=num_stages,
)
x_s = x_s_repeat if repeat_scale else x_s
return x_q, x_s
@triton.jit
def _w8a8_block_fp8_matmul(
# Pointers to inputs and output
......
......@@ -2,13 +2,23 @@ import os
from typing import List, Optional, Tuple
import torch
from packaging.version import Version
from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
static_quant_fp8,
w8a8_block_fp8_matmul,
)
from sglang.srt.utils import get_bool_env_var, is_hip
from sglang.srt.utils import (
get_bool_env_var,
get_cuda_version,
get_device_capability,
is_hip,
)
use_vllm_cutlass_w8a8_fp8_kernel = os.environ.get(
"USE_VLLM_CUTLASS_W8A8_FP8_KERNEL", default=False
)
is_hip_ = is_hip()
if is_hip_ and get_bool_env_var("CK_MOE"):
......@@ -18,6 +28,25 @@ _is_cuda = torch.cuda.is_available() and torch.version.cuda
if _is_cuda:
from sgl_kernel import fp8_blockwise_scaled_mm
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
if use_vllm_cutlass_w8a8_fp8_kernel:
from vllm import _custom_ops as ops
else:
from sgl_kernel import fp8_scaled_mm
def cutlass_fp8_supported():
if not _is_cuda:
return False
major, minor = get_device_capability()
cuda_version = get_cuda_version()
if major >= 9:
return cuda_version >= (12, 0)
elif major == 8 and minor == 9:
return cuda_version >= (12, 4)
return False
def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
......@@ -158,10 +187,121 @@ def block_quant_to_tensor_quant(
return x_q_tensor, scale
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
Parameter class for weight scales loaded for weights with
block-wise quantization. Uses both column and row parallelism.
"""
def apply_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
input_scale_ub: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_fp8_supported: bool = True,
use_per_token_if_dynamic: bool = False,
) -> torch.Tensor:
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[1]]
# cutlass w8a8 fp8 sgl-kernel only supports per-token scale
if input_scale is not None:
assert input_scale.numel() == 1
# broadcast per-tensor scale to per-token scale when supporting cutlass
qinput, x_scale = static_quant_fp8(
input_2d, input_scale, repeat_scale=cutlass_fp8_supported
)
else:
# default use per-token quantization if dynamic
if _is_cuda:
qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
else:
qinput, x_scale = per_token_group_quant_fp8(
input_2d, group_size=input_2d.shape[1]
)
if cutlass_fp8_supported:
if use_vllm_cutlass_w8a8_fp8_kernel:
# Fall back to vllm cutlass w8a8 fp8 kernel
output = ops.cutlass_scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
)
else:
assert (
weight_scale.numel() == weight.shape[1]
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
output = fp8_scaled_mm(
qinput, weight, x_scale, weight_scale, out_dtype=input.dtype, bias=bias
)
return output.view(*output_shape)
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
else:
per_tensor_weights = weight_scale.numel() == 1
per_tensor_activations = x_scale.numel() == 1
if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ
output = torch._scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
else:
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(
qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32,
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
pass
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * weight_scale.t()
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
......@@ -7,7 +7,7 @@ import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear,
convert_to_channelwise,
cutlass_fp8_supported,
requantize_with_max_scale,
)
......@@ -19,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
# Initialize logger for the module
logger = logging.getLogger(__name__)
......@@ -161,6 +162,9 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
layer.weight, layer.weight_scale, layer.logical_widths
)
layer.weight = Parameter(quantized_weight.t(), requires_grad=False)
# cutlass sgl-kernel only supports per-channel scale
if self.cutlass_fp8_supported:
max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
......
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import LinearMethodBase
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.utils import is_hip
class W8A8Fp8Config(QuantizationConfig):
"""Config class for W8A8 FP8 Quantization.
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
"""
def __init__(self):
pass
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 89
@classmethod
def get_name(self) -> str:
return "w8a8_fp8"
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
return cls()
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.linear import LinearBase
if isinstance(layer, LinearBase):
return W8A8Fp8LinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class W8A8Fp8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: W8A8Fp8Config):
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.quantization_config = quantization_config
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = layer.weight
weight_scale = layer.weight_scale.detach()
if is_hip():
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale
)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs
):
weight_loader = extra_weight_attrs.get("weight_loader")
self.logical_widths = output_partition_sizes
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
):
return apply_fp8_linear(
x,
layer.weight,
layer.weight_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
)
......@@ -405,6 +405,7 @@ class ServerArgs:
"gguf",
"modelopt",
"w8a8_int8",
"w8a8_fp8",
],
help="The quantization method.",
)
......
......@@ -52,11 +52,13 @@ import triton
import zmq
from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version
from packaging.version import Version, parse
from starlette.routing import Mount
from torch import nn
from torch.func import functional_call
from torch.library import Library
from torch.profiler import ProfilerActivity, profile, record_function
from torch.utils.cpp_extension import CUDA_HOME
from triton.runtime.cache import (
FileCacheManager,
default_cache_dir,
......@@ -1431,6 +1433,12 @@ def rank0_print(msg: str):
print(msg, flush=True)
def get_cuda_version():
if torch.version.cuda:
return tuple(map(int, torch.version.cuda.split(".")))
return (0, 0)
def launch_dummy_health_check_server(host, port):
import uvicorn
from fastapi import FastAPI, Response
......
......@@ -7,6 +7,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
static_quant_fp8,
w8a8_block_fp8_matmul,
)
......@@ -63,7 +64,7 @@ class TestPerTokenGroupQuantFP8(unittest.TestCase):
out, scale = per_token_group_quant_fp8(x, group_size)
self.assertTrue(
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.20)
)
self.assertTrue(torch.allclose(scale, ref_scale))
......@@ -85,6 +86,71 @@ class TestPerTokenGroupQuantFP8(unittest.TestCase):
self._per_token_group_quant_fp8(*params)
# For test
def native_static_quant_fp8(x, x_s, dtype=torch.float8_e4m3fn):
"""Function to perform static quantization on an input tensor `x` using native torch.
It converts the tensor values into float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
"""
assert x.is_contiguous(), "`x` is not contiguous"
assert x_s.numel() == 1, "only supports per-tensor scale"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_ = x.reshape(x.numel() // x.shape[-1], x.shape[-1])
x_s_inv = 1.0 / x_s
x_q = (x_ * x_s_inv).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
return x_q, x_s
class TestStaticQuantFP8(unittest.TestCase):
DTYPES = [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _static_quant_fp8(self, num_tokens, d, dtype, seed):
torch.manual_seed(seed)
x = torch.rand(num_tokens, d, dtype=dtype)
fp8_max = torch.finfo(torch.float8_e4m3fn).max
x_s = x.max() / fp8_max
with torch.inference_mode():
ref_out, _ = native_static_quant_fp8(x, x_s)
out, _ = static_quant_fp8(x, x_s, repeat_scale=True)
self.assertTrue(
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.50)
)
def test_static_quant_fp8(self):
for params in itertools.product(
self.NUM_TOKENS,
self.D,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
num_tokens=params[0],
d=params[1],
dtype=params[2],
seed=params[3],
):
self._static_quant_fp8(*params)
# For test
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""This function performs matrix multiplication with block-wise quantization using native torch.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment