Commit f519ea41 authored by liuyunfei's avatar liuyunfei
Browse files

实现modelopt的w8a16量化算化

parent 49a30c70
...@@ -6,13 +6,14 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Union ...@@ -6,13 +6,14 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch import torch
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.nn import functional as F, init
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, FusedMoEConfig, FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,int8_w8a16_moe_quant_config,
nvfp4_moe_quant_config) nvfp4_moe_quant_config)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe) is_valid_flashinfer_cutlass_fused_moe)
...@@ -40,7 +41,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -40,7 +41,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, cutlass_fp4_supported, is_layer_skipped, swizzle_blockscale) GroupShape, cutlass_fp4_supported, is_layer_skipped, swizzle_blockscale)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, requantize_with_max_scale) Fp8LinearOp, requantize_with_max_scale)
from vllm.model_executor.parameter import (ModelWeightParameter, from vllm.model_executor.parameter import (ModelWeightParameter,ChannelQuantScaleParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import next_power_of_2 from vllm.utils import next_power_of_2
...@@ -52,7 +53,7 @@ if TYPE_CHECKING: ...@@ -52,7 +53,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
QUANT_ALGOS = ["FP8", "NVFP4"] QUANT_ALGOS = ["FP8", "NVFP4", "W8A16"]
KV_CACHE_QUANT_ALGOS = ["FP8"] KV_CACHE_QUANT_ALGOS = ["FP8"]
...@@ -145,6 +146,8 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -145,6 +146,8 @@ class ModelOptFp8Config(QuantizationConfig):
quant_method = config.get("quant_algo", "") quant_method = config.get("quant_algo", "")
kv_cache_quant_method = config.get("kv_cache_quant_algo") kv_cache_quant_method = config.get("kv_cache_quant_algo")
exclude_modules = config.get("exclude_modules") exclude_modules = config.get("exclude_modules")
if not exclude_modules:
exclude_modules = config.get("ignore")
if quant_method not in QUANT_ALGOS: if quant_method not in QUANT_ALGOS:
raise ValueError( raise ValueError(
...@@ -152,7 +155,7 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -152,7 +155,7 @@ class ModelOptFp8Config(QuantizationConfig):
"quantizations in vLLM. Please check the " "quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's " "`hf_quant_config.json` file for your model's "
"quant configuration.") "quant configuration.")
is_checkpoint_fp8_serialized = ("FP8" in quant_method) is_checkpoint_fp8_serialized = ("W8A16" in quant_method)
return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method,
exclude_modules) exclude_modules)
...@@ -234,7 +237,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase): ...@@ -234,7 +237,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition layer.output_size_per_partition = output_size_per_partition
weight_dtype = (torch.float8_e4m3fn weight_dtype = (torch.int8
if self.quant_config.is_checkpoint_fp8_serialized else if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype) params_dtype)
weight = ModelWeightParameter(data=torch.empty( weight = ModelWeightParameter(data=torch.empty(
...@@ -248,29 +251,29 @@ class ModelOptFp8LinearMethod(LinearMethodBase): ...@@ -248,29 +251,29 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE # WEIGHT SCALE
weight_scale = PerTensorScaleParameter(data=torch.empty( weight_scale = ChannelQuantScaleParameter(output_dim=0, data=torch.empty(
len(output_partition_sizes), dtype=torch.float32), output_size_per_partition, dtype=torch.float16),
weight_loader=weight_loader) weight_loader=weight_loader)
weight_scale[:] = torch.finfo(torch.float32).min weight_scale[:] = torch.finfo(torch.float16).min
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE # # INPUT SCALE
scale = PerTensorScaleParameter(data=torch.empty( # scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32), # len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader) # weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min # scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale) # layer.register_parameter("input_scale", scale)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
weight = layer.weight # weight = layer.weight
max_w_scale = layer.weight_scale.max() # max_w_scale = layer.weight_scale.max()
if not (layer.weight_scale == layer.weight_scale[0]).all(): # if not (layer.weight_scale == layer.weight_scale[0]).all():
max_w_scale, weight = requantize_with_max_scale( # max_w_scale, weight = requantize_with_max_scale(
layer.weight, layer.weight_scale, layer.logical_widths) # layer.weight, layer.weight_scale, layer.logical_widths)
layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight = Parameter(layer.weight.detach().clone(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False) layer.weight_scale = Parameter(layer.weight_scale.detach().clone(), requires_grad=False)
layer.input_scale = Parameter(layer.input_scale.max(), # layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False) # requires_grad=False)
def apply( def apply(
self, self,
...@@ -278,11 +281,14 @@ class ModelOptFp8LinearMethod(LinearMethodBase): ...@@ -278,11 +281,14 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.fp8_linear.apply(input=x, # return self.fp8_linear.apply(input=x,
weight=layer.weight, # weight=layer.weight,
weight_scale=layer.weight_scale, # weight_scale=layer.weight_scale,
input_scale=layer.input_scale, # input_scale=layer.input_scale,
bias=bias) # bias=bias)
weight_scale = layer.weight_scale.unsqueeze(1)
weights = layer.weight.view(torch.int8).to(x.dtype)*weight_scale.to(x.dtype)
return F.linear(x, weights, bias)
class ModelOptFp8MoEMethod(FusedMoEMethodBase): class ModelOptFp8MoEMethod(FusedMoEMethodBase):
...@@ -348,7 +354,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -348,7 +354,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
): ):
# Use FP8 dtype if checkpoint is serialized # Use FP8 dtype if checkpoint is serialized
weight_dtype = (torch.float8_e4m3fn weight_dtype = (torch.int8
if self.quant_config.is_checkpoint_fp8_serialized else if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype) params_dtype)
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
...@@ -381,14 +387,14 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -381,14 +387,14 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# They will be combined to a single scale after weight loading. # They will be combined to a single scale after weight loading.
w13_weight_scale = PerTensorScaleParameter( w13_weight_scale = PerTensorScaleParameter(
data=torch.full( data=torch.full(
(num_experts, 2), (num_experts, 2, intermediate_size_per_partition),
1.0, 1.0,
dtype=torch.float32, dtype=torch.float16,
), ),
weight_loader=weight_loader, weight_loader=weight_loader,
) )
w2_weight_scale = PerTensorScaleParameter( w2_weight_scale = PerTensorScaleParameter(
data=torch.full((num_experts, ), 1.0, dtype=torch.float32), data=torch.full((num_experts, hidden_size), 1.0, dtype=torch.float16),
weight_loader=weight_loader, weight_loader=weight_loader,
) )
layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w13_weight_scale", w13_weight_scale)
...@@ -399,16 +405,16 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -399,16 +405,16 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
# INPUT SCALES - Per-tensor scaling for ModelOpt # INPUT SCALES - Per-tensor scaling for ModelOpt
w13_input_scale = PerTensorScaleParameter( # w13_input_scale = PerTensorScaleParameter(
data=torch.full((num_experts, ), 1.0, dtype=torch.float32), # data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
weight_loader=weight_loader, # weight_loader=weight_loader,
) # )
w2_input_scale = PerTensorScaleParameter( # w2_input_scale = PerTensorScaleParameter(
data=torch.full((num_experts, ), 1.0, dtype=torch.float32), # data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
weight_loader=weight_loader, # weight_loader=weight_loader,
) # )
layer.register_parameter("w13_input_scale", w13_input_scale) # layer.register_parameter("w13_input_scale", w13_input_scale)
layer.register_parameter("w2_input_scale", w2_input_scale) # layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Process FP8 MoE weights after loading from serialized checkpoint. """Process FP8 MoE weights after loading from serialized checkpoint.
...@@ -462,7 +468,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -462,7 +468,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight_scale = Parameter(max_w13_scales, layer.w13_weight_scale = Parameter(max_w13_scales,
requires_grad=False) requires_grad=False)
else: else:
layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data, layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data.flatten(start_dim=1),
requires_grad=False) requires_grad=False)
if hasattr(layer, if hasattr(layer,
...@@ -491,13 +497,19 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -491,13 +497,19 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
return None return None
return fp8_w8a8_moe_quant_config( return int8_w8a16_moe_quant_config(
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, w1_zp= None,
a2_scale=layer.w2_input_scale, w2_zp = None
per_act_token_quant=False,
) )
# return fp8_w8a8_moe_quant_config(
# w1_scale=layer.w13_weight_scale,
# w2_scale=layer.w2_weight_scale,
# a1_scale=layer.w13_input_scale,
# a2_scale=layer.w2_input_scale,
# per_act_token_quant=False,
# )
def apply( def apply(
self, self,
...@@ -521,6 +533,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -521,6 +533,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
use_fused_gate: Optional[bool] = False,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
...@@ -594,7 +608,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -594,7 +608,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts) fused_experts)
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
return fused_experts( return fused_experts(
x, x,
layer.w13_weight, layer.w13_weight,
......
...@@ -190,7 +190,7 @@ class RocmPlatform(Platform): ...@@ -190,7 +190,7 @@ class RocmPlatform(Platform):
supported_quantization: list[str] = [ supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
"quark", "ptpc_fp8", "mxfp4", "petit_nvfp4", "torchao", "quark", "ptpc_fp8", "mxfp4", "petit_nvfp4", "torchao",
"moe_wna16", "slimquant_w4a8", "w8a8_int8", "awq_marlin", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin" "moe_wna16", "slimquant_w4a8", "w8a8_int8", "awq_marlin", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "modelopt"
] ]
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment