Unverified Commit e483ab6d authored by Enrique Shockwave's avatar Enrique Shockwave Committed by GitHub
Browse files

enable marlin fp8 blockwise (#8990)

parent 720cd308
...@@ -49,6 +49,7 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ...@@ -49,6 +49,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
) )
from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear, apply_fp8_linear,
can_auto_enable_marlin_fp8,
cutlass_fp8_supported, cutlass_fp8_supported,
dispatch_w8a8_block_fp8_linear, dispatch_w8a8_block_fp8_linear,
input_to_float8, input_to_float8,
...@@ -209,17 +210,13 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -209,17 +210,13 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin # For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization # kernel for fast weight-only FP8 quantization
self.use_marlin = ( self.use_marlin = False
get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") and MARLIN_FP8_AVAILABLE if _is_cuda and MARLIN_FP8_AVAILABLE:
) force_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
# Disable marlin for ROCm auto_enable = can_auto_enable_marlin_fp8()
if _is_hip: self.use_marlin = force_marlin or auto_enable
self.use_marlin = False
self.block_quant = self.quant_config.weight_block_size is not None self.block_quant = self.quant_config.weight_block_size is not None
if self.block_quant:
# Marlin doesn't support block-wise fp8
self.use_marlin = False
self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear() self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear()
...@@ -332,7 +329,6 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -332,7 +329,6 @@ class Fp8LinearMethod(LinearMethodBase):
layer.register_parameter("input_scale", None) layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant: if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz # If ROCm, normalize the weights and scales to e4m3fnuz
if _is_fp8_fnuz: if _is_fp8_fnuz:
...@@ -342,7 +338,6 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -342,7 +338,6 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale=layer.weight_scale_inv, weight_scale=layer.weight_scale_inv,
input_scale=None, input_scale=None,
) )
layer.input_scale = None layer.input_scale = None
elif _is_cpu: elif _is_cpu:
assert ( assert (
...@@ -352,90 +347,94 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -352,90 +347,94 @@ class Fp8LinearMethod(LinearMethodBase):
return return
else: else:
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
layer.weight = torch.nn.Parameter(weight, requires_grad=False) layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = torch.nn.Parameter( layer.weight_scale_inv = Parameter(weight_scale, requires_grad=False)
weight_scale, requires_grad=False else:
) layer.weight = Parameter(layer.weight.data, requires_grad=False)
return
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) # If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
if self.cutlass_fp8_supported or self.use_marlin:
# apply per-channel quantization default as
# cutlass sgl-kernel and marlin only support per-channel scale
qweight, weight_scale = per_token_group_quant_fp8(
layer.weight, layer.weight.shape[-1]
)
weight_scale = weight_scale.t().contiguous()
else:
# per-tensor quantization
qweight, weight_scale = input_to_float8(layer.weight)
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.input_scale = None
# If checkpoint not serialized fp8, quantize the weights. # If checkpoint is fp8, handle that there are N scales for N
if not self.quant_config.is_checkpoint_fp8_serialized: # shards in a fused module
if self.cutlass_fp8_supported or self.use_marlin:
# apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
qweight, weight_scale = per_token_group_quant_fp8(
layer.weight, layer.weight.shape[-1]
)
weight_scale = weight_scale.t().contiguous()
else: else:
# per-tensor quantization layer.weight_scale = Parameter(
qweight, weight_scale = input_to_float8(layer.weight) layer.weight_scale.data, requires_grad=False
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.input_scale = None
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
else:
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False
)
if (
hasattr(self.quant_config, "activation_scheme")
and self.quant_config.activation_scheme == "static"
) or (
hasattr(self.quant_config, "linear_activation_scheme")
and self.quant_config.linear_activation_scheme == "static"
):
layer.input_scale = torch.nn.Parameter(
layer.input_scale.data, requires_grad=False
) )
if (
hasattr(self.quant_config, "activation_scheme")
and self.quant_config.activation_scheme == "static"
) or (
hasattr(self.quant_config, "linear_activation_scheme")
and self.quant_config.linear_activation_scheme == "static"
):
layer.input_scale = Parameter(
layer.input_scale.data, requires_grad=False
)
# cutlass sgl-kernel and marlin only support per-channel scale # cutlass sgl-kernel and marlin only support per-channel scale
if self.cutlass_fp8_supported or self.use_marlin: if self.cutlass_fp8_supported or self.use_marlin:
weight = layer.weight weight = layer.weight
weight_scale = convert_to_channelwise( weight_scale = convert_to_channelwise(
layer.weight_scale, layer.logical_widths layer.weight_scale, layer.logical_widths
) )
else: else:
# Dequant -> Quant with max scale so we can run per tensor. # Dequant -> Quant with max scale so we can run per tensor.
weight = layer.weight weight = layer.weight
weight_scale = layer.weight_scale weight_scale = layer.weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz # If ROCm, normalize the weights and scales to e4m3fnuz
if _is_fp8_fnuz: if _is_fp8_fnuz:
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight, weight_scale, input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=weight_scale,
input_scale=layer.input_scale,
)
)
if input_scale is not None:
layer.input_scale = Parameter(
input_scale, requires_grad=False
)
weight_scale, weight = requantize_with_max_scale(
weight=weight, weight=weight,
weight_scale=weight_scale, weight_scale=weight_scale,
input_scale=layer.input_scale, logical_widths=layer.logical_widths,
) )
if input_scale is not None:
layer.input_scale = Parameter(input_scale, requires_grad=False)
weight_scale, weight = requantize_with_max_scale(
weight=weight,
weight_scale=weight_scale,
logical_widths=layer.logical_widths,
)
# Update layer with new values. # Update layer with new values.
layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False)
if ( if (
hasattr(self.quant_config, "activation_scheme") hasattr(self.quant_config, "activation_scheme")
and self.quant_config.activation_scheme == "static" and self.quant_config.activation_scheme == "static"
) or ( ) or (
hasattr(self.quant_config, "linear_activation_scheme") hasattr(self.quant_config, "linear_activation_scheme")
and self.quant_config.linear_activation_scheme == "static" and self.quant_config.linear_activation_scheme == "static"
): ):
layer.input_scale = Parameter( layer.input_scale = Parameter(
layer.input_scale.max(), requires_grad=False layer.input_scale.max(), requires_grad=False
) )
if self.use_marlin: if self.use_marlin:
prepare_fp8_layer_for_marlin(layer) if self.block_quant:
layer.weight_block_size = self.quant_config.weight_block_size
prepare_fp8_layer_for_marlin(layer, not self.block_quant)
# Activations not quantized for marlin. # Activations not quantized for marlin.
del layer.input_scale del layer.input_scale
......
...@@ -789,3 +789,12 @@ def apply_fp8_linear( ...@@ -789,3 +789,12 @@ def apply_fp8_linear(
bias, bias,
input.dtype, input.dtype,
) )
def can_auto_enable_marlin_fp8() -> bool:
try:
major, minor = get_device_capability()
sm = major * 10 + minor
return 80 <= sm < 89
except Exception:
return False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment