Unverified Commit 30828e71 authored by HAI's avatar HAI Committed by GitHub
Browse files

AMD: set weights and scaling numbers properly for block FP8 (#2637)

parent e0e09fce
......@@ -272,6 +272,19 @@ class Fp8LinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip():
# activation_scheme: dynamic
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=None,
)
layer.weight = torch.nn.Parameter(weight, require_grad=False)
layer.weight_scale_inv = torch.nn.Parameter(
weight_scale, require_grad=False
)
layer.input_scale = None
return
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights.
......@@ -369,7 +382,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight=layer.weight,
block_size=self.quant_config.weight_block_size,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
input_scale=None,
bias=bias,
)
......@@ -553,6 +566,30 @@ class Fp8MoEMethod:
# Block quant doesn't need to process weights after loading
if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip():
# activation_scheme: dynamic
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w13_weight,
weight_scale=layer.w13_weight_scale_inv,
input_scale=None,
)
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w2_weight,
weight_scale=layer.w2_weight_scale_inv,
input_scale=None,
)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale_inv = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
)
layer.w13_input_scale = None
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
)
layer.w2_input_scale = None
return
# If checkpoint is fp16 or bfloat16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
......
......@@ -22,7 +22,10 @@ import torch
import triton
import triton.language as tl
from sglang.srt.utils import get_device_name
from sglang.srt.utils import get_device_name, is_hip
is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
logger = logging.getLogger(__name__)
......@@ -73,7 +76,7 @@ def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = torch.float8_e4m3fn,
dtype: torch.dtype = fp8_type_,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
......@@ -95,9 +98,13 @@ def per_token_group_quant_fp8(
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
if is_hip_:
fp8_max = 224.0
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
......
......@@ -7,6 +7,9 @@ from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
)
from sglang.srt.utils import is_hip
is_hip_ = is_hip()
def normalize_e4m3fn_to_e4m3fnuz(
......@@ -63,8 +66,11 @@ def input_to_float8(
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
fp8_max = finfo.max
if is_hip_:
fp8_max = 224.0
scale = fp8_max / amax
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment