Unverified Commit e483ab6d authored by Enrique Shockwave's avatar Enrique Shockwave Committed by GitHub
Browse files

enable marlin fp8 blockwise (#8990)

parent 720cd308
......@@ -49,6 +49,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
can_auto_enable_marlin_fp8,
cutlass_fp8_supported,
dispatch_w8a8_block_fp8_linear,
input_to_float8,
......@@ -209,17 +210,13 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (
get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") and MARLIN_FP8_AVAILABLE
)
# Disable marlin for ROCm
if _is_hip:
self.use_marlin = False
self.use_marlin = False
if _is_cuda and MARLIN_FP8_AVAILABLE:
force_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
auto_enable = can_auto_enable_marlin_fp8()
self.use_marlin = force_marlin or auto_enable
self.block_quant = self.quant_config.weight_block_size is not None
if self.block_quant:
# Marlin doesn't support block-wise fp8
self.use_marlin = False
self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear()
......@@ -332,7 +329,6 @@ class Fp8LinearMethod(LinearMethodBase):
layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz
if _is_fp8_fnuz:
......@@ -342,7 +338,6 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale=layer.weight_scale_inv,
input_scale=None,
)
layer.input_scale = None
elif _is_cpu:
assert (
......@@ -352,90 +347,94 @@ class Fp8LinearMethod(LinearMethodBase):
return
else:
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
layer.weight_scale_inv = torch.nn.Parameter(
weight_scale, requires_grad=False
)
return
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale, requires_grad=False)
else:
layer.weight = Parameter(layer.weight.data, requires_grad=False)
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
if self.cutlass_fp8_supported or self.use_marlin:
# apply per-channel quantization default as
# cutlass sgl-kernel and marlin only support per-channel scale
qweight, weight_scale = per_token_group_quant_fp8(
layer.weight, layer.weight.shape[-1]
)
weight_scale = weight_scale.t().contiguous()
else:
# per-tensor quantization
qweight, weight_scale = input_to_float8(layer.weight)
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.input_scale = None
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
if self.cutlass_fp8_supported or self.use_marlin:
# apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
qweight, weight_scale = per_token_group_quant_fp8(
layer.weight, layer.weight.shape[-1]
)
weight_scale = weight_scale.t().contiguous()
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
else:
# per-tensor quantization
qweight, weight_scale = input_to_float8(layer.weight)
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.input_scale = None
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
else:
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False
)
if (
hasattr(self.quant_config, "activation_scheme")
and self.quant_config.activation_scheme == "static"
) or (
hasattr(self.quant_config, "linear_activation_scheme")
and self.quant_config.linear_activation_scheme == "static"
):
layer.input_scale = torch.nn.Parameter(
layer.input_scale.data, requires_grad=False
layer.weight_scale = Parameter(
layer.weight_scale.data, requires_grad=False
)
if (
hasattr(self.quant_config, "activation_scheme")
and self.quant_config.activation_scheme == "static"
) or (
hasattr(self.quant_config, "linear_activation_scheme")
and self.quant_config.linear_activation_scheme == "static"
):
layer.input_scale = Parameter(
layer.input_scale.data, requires_grad=False
)
# cutlass sgl-kernel and marlin only support per-channel scale
if self.cutlass_fp8_supported or self.use_marlin:
weight = layer.weight
weight_scale = convert_to_channelwise(
layer.weight_scale, layer.logical_widths
)
else:
# Dequant -> Quant with max scale so we can run per tensor.
weight = layer.weight
weight_scale = layer.weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
if _is_fp8_fnuz:
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
# cutlass sgl-kernel and marlin only support per-channel scale
if self.cutlass_fp8_supported or self.use_marlin:
weight = layer.weight
weight_scale = convert_to_channelwise(
layer.weight_scale, layer.logical_widths
)
else:
# Dequant -> Quant with max scale so we can run per tensor.
weight = layer.weight
weight_scale = layer.weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
if _is_fp8_fnuz:
weight, weight_scale, input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=weight_scale,
input_scale=layer.input_scale,
)
)
if input_scale is not None:
layer.input_scale = Parameter(
input_scale, requires_grad=False
)
weight_scale, weight = requantize_with_max_scale(
weight=weight,
weight_scale=weight_scale,
input_scale=layer.input_scale,
logical_widths=layer.logical_widths,
)
if input_scale is not None:
layer.input_scale = Parameter(input_scale, requires_grad=False)
weight_scale, weight = requantize_with_max_scale(
weight=weight,
weight_scale=weight_scale,
logical_widths=layer.logical_widths,
)
# Update layer with new values.
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
if (
hasattr(self.quant_config, "activation_scheme")
and self.quant_config.activation_scheme == "static"
) or (
hasattr(self.quant_config, "linear_activation_scheme")
and self.quant_config.linear_activation_scheme == "static"
):
layer.input_scale = Parameter(
layer.input_scale.max(), requires_grad=False
)
# Update layer with new values.
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
if (
hasattr(self.quant_config, "activation_scheme")
and self.quant_config.activation_scheme == "static"
) or (
hasattr(self.quant_config, "linear_activation_scheme")
and self.quant_config.linear_activation_scheme == "static"
):
layer.input_scale = Parameter(
layer.input_scale.max(), requires_grad=False
)
if self.use_marlin:
prepare_fp8_layer_for_marlin(layer)
if self.block_quant:
layer.weight_block_size = self.quant_config.weight_block_size
prepare_fp8_layer_for_marlin(layer, not self.block_quant)
# Activations not quantized for marlin.
del layer.input_scale
......
......@@ -789,3 +789,12 @@ def apply_fp8_linear(
bias,
input.dtype,
)
def can_auto_enable_marlin_fp8() -> bool:
try:
major, minor = get_device_capability()
sm = major * 10 + minor
return 80 <= sm < 89
except Exception:
return False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment