Commit f519ea41 authored by liuyunfei's avatar liuyunfei
Browse files

实现modelopt的w8a16量化算化

parent 49a30c70
......@@ -6,13 +6,14 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from torch.nn import functional as F, init
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
FusedMoEConfig, FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,int8_w8a16_moe_quant_config,
nvfp4_moe_quant_config)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe)
......@@ -40,7 +41,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, cutlass_fp4_supported, is_layer_skipped, swizzle_blockscale)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, requantize_with_max_scale)
from vllm.model_executor.parameter import (ModelWeightParameter,
from vllm.model_executor.parameter import (ModelWeightParameter,ChannelQuantScaleParameter,
PerTensorScaleParameter)
from vllm.scalar_type import scalar_types
from vllm.utils import next_power_of_2
......@@ -52,7 +53,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
QUANT_ALGOS = ["FP8", "NVFP4"]
QUANT_ALGOS = ["FP8", "NVFP4", "W8A16"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
......@@ -145,6 +146,8 @@ class ModelOptFp8Config(QuantizationConfig):
quant_method = config.get("quant_algo", "")
kv_cache_quant_method = config.get("kv_cache_quant_algo")
exclude_modules = config.get("exclude_modules")
if not exclude_modules:
exclude_modules = config.get("ignore")
if quant_method not in QUANT_ALGOS:
raise ValueError(
......@@ -152,7 +155,7 @@ class ModelOptFp8Config(QuantizationConfig):
"quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration.")
is_checkpoint_fp8_serialized = ("FP8" in quant_method)
is_checkpoint_fp8_serialized = ("W8A16" in quant_method)
return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method,
exclude_modules)
......@@ -234,7 +237,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
weight_dtype = (torch.float8_e4m3fn
weight_dtype = (torch.int8
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)
weight = ModelWeightParameter(data=torch.empty(
......@@ -248,29 +251,29 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_scale = ChannelQuantScaleParameter(output_dim=0, data=torch.empty(
output_size_per_partition, dtype=torch.float16),
weight_loader=weight_loader)
weight_scale[:] = torch.finfo(torch.float32).min
weight_scale[:] = torch.finfo(torch.float16).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
# # INPUT SCALE
# scale = PerTensorScaleParameter(data=torch.empty(
# len(output_partition_sizes), dtype=torch.float32),
# weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale)
# scale[:] = torch.finfo(torch.float32).min
# layer.register_parameter("input_scale", scale)
def process_weights_after_loading(self, layer: Module) -> None:
weight = layer.weight
max_w_scale = layer.weight_scale.max()
if not (layer.weight_scale == layer.weight_scale[0]).all():
max_w_scale, weight = requantize_with_max_scale(
layer.weight, layer.weight_scale, layer.logical_widths)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
# weight = layer.weight
# max_w_scale = layer.weight_scale.max()
# if not (layer.weight_scale == layer.weight_scale[0]).all():
# max_w_scale, weight = requantize_with_max_scale(
# layer.weight, layer.weight_scale, layer.logical_widths)
layer.weight = Parameter(layer.weight.detach().clone(), requires_grad=False)
layer.weight_scale = Parameter(layer.weight_scale.detach().clone(), requires_grad=False)
# layer.input_scale = Parameter(layer.input_scale.max(),
# requires_grad=False)
def apply(
self,
......@@ -278,11 +281,14 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias)
# return self.fp8_linear.apply(input=x,
# weight=layer.weight,
# weight_scale=layer.weight_scale,
# input_scale=layer.input_scale,
# bias=bias)
weight_scale = layer.weight_scale.unsqueeze(1)
weights = layer.weight.view(torch.int8).to(x.dtype)*weight_scale.to(x.dtype)
return F.linear(x, weights, bias)
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
......@@ -348,7 +354,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
):
# Use FP8 dtype if checkpoint is serialized
weight_dtype = (torch.float8_e4m3fn
weight_dtype = (torch.int8
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)
weight_loader = extra_weight_attrs.get("weight_loader")
......@@ -381,14 +387,14 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# They will be combined to a single scale after weight loading.
w13_weight_scale = PerTensorScaleParameter(
data=torch.full(
(num_experts, 2),
(num_experts, 2, intermediate_size_per_partition),
1.0,
dtype=torch.float32,
dtype=torch.float16,
),
weight_loader=weight_loader,
)
w2_weight_scale = PerTensorScaleParameter(
data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
data=torch.full((num_experts, hidden_size), 1.0, dtype=torch.float16),
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
......@@ -399,16 +405,16 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
# INPUT SCALES - Per-tensor scaling for ModelOpt
w13_input_scale = PerTensorScaleParameter(
data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
w2_input_scale = PerTensorScaleParameter(
data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
layer.register_parameter("w2_input_scale", w2_input_scale)
# w13_input_scale = PerTensorScaleParameter(
# data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
# weight_loader=weight_loader,
# )
# w2_input_scale = PerTensorScaleParameter(
# data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
# weight_loader=weight_loader,
# )
# layer.register_parameter("w13_input_scale", w13_input_scale)
# layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Process FP8 MoE weights after loading from serialized checkpoint.
......@@ -462,7 +468,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight_scale = Parameter(max_w13_scales,
requires_grad=False)
else:
layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data,
layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data.flatten(start_dim=1),
requires_grad=False)
if hasattr(layer,
......@@ -491,13 +497,19 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
return None
return fp8_w8a8_moe_quant_config(
return int8_w8a16_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=False,
w1_zp= None,
w2_zp = None
)
# return fp8_w8a8_moe_quant_config(
# w1_scale=layer.w13_weight_scale,
# w2_scale=layer.w2_weight_scale,
# a1_scale=layer.w13_input_scale,
# a2_scale=layer.w2_input_scale,
# per_act_token_quant=False,
# )
def apply(
self,
......@@ -521,6 +533,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
use_fused_gate: Optional[bool] = False,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb:
raise NotImplementedError(
......@@ -594,7 +608,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)
assert self.moe_quant_config is not None
return fused_experts(
x,
layer.w13_weight,
......
......@@ -190,7 +190,7 @@ class RocmPlatform(Platform):
supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
"quark", "ptpc_fp8", "mxfp4", "petit_nvfp4", "torchao",
"moe_wna16", "slimquant_w4a8", "w8a8_int8", "awq_marlin", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin"
"moe_wna16", "slimquant_w4a8", "w8a8_int8", "awq_marlin", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "modelopt"
]
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment