Commit 313d629e authored by gaoqiong's avatar gaoqiong
Browse files

增加per-channel 量化策略支持

parent 779033b7
...@@ -552,6 +552,9 @@ class FusedMoE(torch.nn.Module): ...@@ -552,6 +552,9 @@ class FusedMoE(torch.nn.Module):
if (self.quant_method.__class__.__name__ in ("BlockInt8MoEMethod")): if (self.quant_method.__class__.__name__ in ("BlockInt8MoEMethod")):
moe_quant_params["intermediate_size"] = self.intermediate_size_per_partition moe_quant_params["intermediate_size"] = self.intermediate_size_per_partition
if (self.quant_method.__class__.__name__ in ("W8A8Int8MoEMethod")):
moe_quant_params["intermediate_size"] = self.intermediate_size_per_partition
self.quant_method.create_weights(layer=self, **moe_quant_params) self.quant_method.create_weights(layer=self, **moe_quant_params)
......
...@@ -33,7 +33,8 @@ QUANTIZATION_METHODS: List[str] = [ ...@@ -33,7 +33,8 @@ QUANTIZATION_METHODS: List[str] = [
"quark", "quark",
"moe_wna16", "moe_wna16",
"torchao", "torchao",
"blockwise_int8" "blockwise_int8",
"w8a8_int8"
] ]
# The customized quantization methods which will be added to this dict. # The customized quantization methods which will be added to this dict.
...@@ -108,6 +109,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: ...@@ -108,6 +109,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .torchao import TorchAOConfig from .torchao import TorchAOConfig
from .tpu_int8 import Int8TpuConfig from .tpu_int8 import Int8TpuConfig
from .blockwise_int8 import BlockInt8Config from .blockwise_int8 import BlockInt8Config
from .w8a8_int8 import W8A8Int8Config
method_to_config: Dict[str, Type[QuantizationConfig]] = { method_to_config: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
...@@ -138,6 +140,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: ...@@ -138,6 +140,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"moe_wna16": MoeWNA16Config, "moe_wna16": MoeWNA16Config,
"torchao": TorchAOConfig, "torchao": TorchAOConfig,
"blockwise_int8": BlockInt8Config, "blockwise_int8": BlockInt8Config,
"w8a8_int8":W8A8Int8Config,
} }
# Update the `method_to_config` with customized quantization methods. # Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
......
...@@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple ...@@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from triton.language.extra import libdevice
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -109,7 +110,7 @@ def _per_token_quant_int8( ...@@ -109,7 +110,7 @@ def _per_token_quant_int8(
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10) absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127 scale_x = absmax / 127
x_q = x * (127 / absmax) x_q = x * (127 / absmax)
x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8) x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask) tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x) tl.store(scale_ptr + row_id, scale_x)
......
from typing import Any, Callable, Dict, List, Optional
import torch
from vllm.model_executor.utils import set_weight_attrs
from vllm.distributed import get_tensor_model_parallel_world_size
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8,
per_token_quant_int8)
from vllm import _custom_ops as ops
def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
scales= scale_a* scale_b.T
gemmout= torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32))
output = (scales *gemmout).to(out_dtype)
if bias is not None:
output = output + bias
return output.to(out_dtype)
class W8A8Int8Config(QuantizationConfig):
"""Config class for W8A8 Int8 Quantization.
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
"""
def __init__(self):
pass
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 75
@classmethod
def get_name(self) -> str:
return "w8a8_int8"
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
return cls()
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return W8A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W8A8Int8MoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class W8A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: W8A8Int8Config):
self.quantization_config = quantization_config
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs.get("weight_loader")
self.logical_widths = output_partition_sizes
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
):
x_q, x_scale = per_token_quant_int8(x)
# return int8_scaled_mm(
# x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
# )
#return baseline_scaled_mm(x_q, layer.weight, x_scale, layer.weight_scale, x.dtype, bias)
best_config=None
return ops.triton_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias,best_config=best_config)
class W8A8Int8MoEMethod:
"""MoE method for INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def __new__(cls, *args, **kwargs):
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
tp_size = get_tensor_model_parallel_world_size()
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w13_input_scale = None
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = None
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias
)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_int8_w8a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe,
)
...@@ -130,7 +130,7 @@ class RocmPlatform(Platform): ...@@ -130,7 +130,7 @@ class RocmPlatform(Platform):
supported_quantization: list[str] = [ supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
"fbgemm_fp8", "gguf", "quark", "ptpc_fp8", "moe_wna16", "blockwise_int8" "fbgemm_fp8", "gguf", "quark", "ptpc_fp8", "moe_wna16", "blockwise_int8","w8a8_int8"
] ]
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment