Unverified Commit 84d96b3a authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

Move FP8 to SGLang (#2370)


Co-authored-by: default avatarHaiShaw <hixiao@gmail.com>
parent 3d32e4a3
...@@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso ...@@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from vllm.model_executor.layers.quantization.gguf import GGUFConfig from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
...@@ -23,6 +22,7 @@ from vllm.model_executor.layers.quantization.qqq import QQQConfig ...@@ -23,6 +22,7 @@ from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
...@@ -100,13 +100,13 @@ def fp8_moe_apply( ...@@ -100,13 +100,13 @@ def fp8_moe_apply(
def fp8_get_quant_method(self, layer, prefix): def fp8_get_quant_method(self, layer, prefix):
"""Enhanced get_quant_method for FP8 config.""" """Enhanced get_quant_method for FP8 config."""
from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped, is_layer_skipped,
) )
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.linear import UnquantizedLinearMethod from sglang.srt.layers.linear import UnquantizedLinearMethod
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers): if is_layer_skipped(prefix, self.ignored_layers):
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
import logging
from typing import Any, Callable, Dict, List, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d,
apply_fp8_linear,
convert_to_channelwise,
cutlass_fp8_supported,
per_tensor_dequantize,
requantize_with_max_scale,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.fused_moe_triton import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.utils import (
get_bool_env_var,
is_hip,
print_warning_once,
set_weight_attrs,
)
ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = logging.getLogger(__name__)
class Fp8Config(QuantizationConfig):
"""Config class for FP8."""
def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
ignored_layers: Optional[List[str]] = None,
) -> None:
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning(
"Detected fp8 checkpoint. Please note that the "
"format is experimental and subject to change."
)
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme
self.ignored_layers = ignored_layers or []
@classmethod
def get_name(cls) -> str:
return "fp8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = "fp8" in quant_method
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
return cls(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self)
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
# Disable marlin for ROCm
if is_hip():
self.use_marlin = False
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT
weight_dtype = (
torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized
else params_dtype
)
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition, input_size_per_partition, dtype=weight_dtype
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", scale)
# INPUT ACTIVATION SCALE
if self.quant_config.activation_scheme == "static":
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
assert weight_scale.numel() == 1
weight_scale = convert_to_channelwise(
weight_scale.expand(len(layer.logical_widths)), layer.logical_widths
)
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.input_scale = None
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
else:
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False
)
if self.quant_config.activation_scheme == "static":
layer.input_scale = torch.nn.Parameter(
layer.input_scale.data, requires_grad=False
)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
weight = layer.weight
weight_scale = convert_to_channelwise(
layer.weight_scale, layer.logical_widths
)
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
else:
# Dequant -> Quant with max scale so we can run per tensor.
weight = layer.weight
weight_scale = layer.weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip():
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=weight_scale,
input_scale=layer.input_scale,
)
if input_scale is not None:
layer.input_scale = Parameter(input_scale, requires_grad=False)
weight_scale, weight = requantize_with_max_scale(
weight=weight,
weight_scale=weight_scale,
logical_widths=layer.logical_widths,
)
# Update layer with new values.
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
if self.quant_config.activation_scheme == "static":
layer.input_scale = Parameter(
layer.input_scale.max(), requires_grad=False
)
if self.use_marlin:
prepare_fp8_layer_for_marlin(layer)
# Activations not quantized for marlin.
del layer.input_scale
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.use_marlin:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
return apply_fp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=False,
)
class Fp8MoEMethod(FusedMoEMethodBase):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
def create_weights(
self,
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size, dtype=params_dtype
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer.w13_weight_scale = torch.nn.Parameter(
torch.ones(
layer.num_experts, dtype=torch.float32, device=w13_weight.device
),
requires_grad=False,
)
for expert in range(layer.num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if self.quant_config.activation_scheme == "static":
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
layer.w2_input_scale
):
print_warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False
)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False
)
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
)
)
w2_weight, w2_weight_scale, w2_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
)
)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False
)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
)
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
)
return
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_fp8_w8a8=True,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
class Fp8KVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def __init__(self, quant_config: Fp8Config):
super().__init__(quant_config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment