"vllm/executor/mp_distributed_executor.py" did not exist on "eb6d3c264d0cd8e44dec16bca7947fbe96415ce9"
Commit 53250530 authored by gaoqiong's avatar gaoqiong
Browse files

Update w8a8_int8.py

parent 40b94473
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase) from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
from vllm.model_executor.layers.quantization.utils.int8_utils import ( from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_group_quant_int8,
per_token_quant_int8) per_token_quant_int8)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
import os import os
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
W8A8_TRITONJSON=W8a8GetCacheJSON() W8A8_TRITONJSON=W8a8GetCacheJSON()
def baseline_scaled_mm(a: torch.Tensor, def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor, b: torch.Tensor,
scale_a: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype, out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
scales= scale_a* scale_b.T scales= scale_a* scale_b.T
gemmout= torch.mm( gemmout= torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32)) a.to(dtype=torch.float32), b.to(dtype=torch.float32))
output = (scales *gemmout).to(out_dtype) output = (scales *gemmout).to(out_dtype)
if bias is not None: if bias is not None:
output = output + bias output = output + bias
return output.to(out_dtype) return output.to(out_dtype)
class W8A8Int8Config(QuantizationConfig): class W8A8Int8Config(QuantizationConfig):
"""Config class for W8A8 Int8 Quantization. """Config class for W8A8 Int8 Quantization.
- Weight: static, per-channel, symmetric - Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric - Activation: dynamic, per-token, symmetric
""" """
def __init__(self): def __init__(self):
pass pass
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16] return [torch.float16, torch.bfloat16]
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 75 return 75
@classmethod @classmethod
def get_name(self) -> str: def get_name(self) -> str:
return "w8a8_int8" return "w8a8_int8"
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
return [] return []
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config": def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
return cls() return cls()
def get_quant_method( def get_quant_method(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
prefix: str, prefix: str,
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return W8A8Int8LinearMethod(self) return W8A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return W8A8Int8MoEMethod(self) return W8A8Int8MoEMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return [] return []
class W8A8Int8LinearMethod(LinearMethodBase): class W8A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: W8A8Int8Config): def __init__(self, quantization_config: W8A8Int8Config):
self.quantization_config = quantization_config self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0] n=layer.weight.shape[0]
k=layer.weight.shape[1] k=layer.weight.shape[1]
if self.w8a8_strategy==1: if self.w8a8_strategy==1:
if {n,k} not in self.tritonsingleton.weight_shapes: if {n,k} not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append({n,k}) self.tritonsingleton.weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k) json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k) configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict: if configs_dict:
self.tritonsingleton.triton_json_dict.update(configs_dict) self.tritonsingleton.triton_json_dict.update(configs_dict)
for key, value in configs_dict.items(): for key, value in configs_dict.items():
m=int(key.split('_')[0]) m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value) ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
else: else:
weight_data=layer.weight.data weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1) _weight=weight_data.T.contiguous().reshape(n,-1)
layer.weight.data=_weight layer.weight.data=_weight
layer.weight = Parameter(layer.weight.t(), requires_grad=False) layer.weight = Parameter(layer.weight.t(), requires_grad=False)
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], output_partition_sizes: List[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
self.logical_widths = output_partition_sizes self.logical_widths = output_partition_sizes
weight = ModelWeightParameter( weight = ModelWeightParameter(
data=torch.empty( data=torch.empty(
sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8 sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
), ),
input_dim=1, input_dim=1,
output_dim=0, output_dim=0,
weight_loader=weight_loader, weight_loader=weight_loader,
) )
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
weight_scale = ChannelQuantScaleParameter( weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0, output_dim=0,
weight_loader=weight_loader, weight_loader=weight_loader,
) )
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
): ):
x_q, x_scale = per_token_quant_int8(x) x_q, x_scale = per_token_quant_int8(x)
if self.w8a8_strategy==1: if self.w8a8_strategy==1:
m=x_q.shape[0] m=x_q.shape[0]
k=x_q.shape[1] k=x_q.shape[1]
n=layer.weight.shape[1] n=layer.weight.shape[1]
if len(W8A8_TRITONJSON.triton_json_dict)==0: if len(W8A8_TRITONJSON.triton_json_dict)==0:
best_config=None best_config=None
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict: elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict:
if m<=16: if m<=16:
m_=m m_=m
elif m<=64: elif m<=64:
m_= (m + 3) & -4 #取值到最近的4的倍数 m_= (m + 3) & -4 #取值到最近的4的倍数
elif m<=160: elif m<=160:
m_=(m + 7) & -8 m_=(m + 7) & -8
elif m<200: #256 elif m<200: #256
m_=160 m_=160
elif m<480: #512 elif m<480: #512
m_=256 m_=256
elif m<960: #1024 elif m<960: #1024
m_=512 m_=512
elif m<2048: elif m<2048:
m_=1024 m_=1024
elif m<4096: elif m<4096:
m_=2048 m_=2048
elif m<6000: elif m<6000:
m_=4096 m_=4096
else: else:
m_=8192 m_=8192
best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"] best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"]
else: else:
best_config=None best_config=None
if best_config==None: #if best_config==None:
print("m:{},n:{},k:{}".format(m,n,k)) # print("m:{},n:{},k:{}".format(m,n,k))
print("config not found!") # print("config not found!")
return ops.triton_scaled_mm(x_q, return ops.triton_scaled_mm(x_q,
layer.weight, layer.weight,
scale_a=x_scale, scale_a=x_scale,
scale_b=layer.weight_scale, scale_b=layer.weight_scale,
out_dtype=x.dtype, out_dtype=x.dtype,
bias=bias,best_config=best_config) bias=bias,best_config=best_config)
elif self.w8a8_strategy==2: elif self.w8a8_strategy==2:
return ops.cutlass_scaled_mm(x_q, return ops.cutlass_scaled_mm(x_q,
layer.weight, layer.weight,
scale_a=x_scale, scale_a=x_scale,
scale_b=layer.weight_scale, scale_b=layer.weight_scale,
out_dtype=x.dtype, out_dtype=x.dtype,
bias=bias) bias=bias)
else: else:
return ops.rocblas_scaled_mm(x_q, return ops.rocblas_scaled_mm(x_q,
layer.weight, layer.weight,
scale_a=x_scale, scale_a=x_scale,
scale_b=layer.weight_scale, scale_b=layer.weight_scale,
out_dtype=x.dtype, out_dtype=x.dtype,
bias=bias) bias=bias)
class W8A8Int8MoEMethod: class W8A8Int8MoEMethod:
"""MoE method for INT8. """MoE method for INT8.
Supports loading INT8 checkpoints with static weight scale and Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale. dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after activation scaling. The weight scaling factor will be initialized after
the model weights are loaded. the model weights are loaded.
Args: Args:
quant_config: The quantization config. quant_config: The quantization config.
""" """
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if not hasattr(cls, "_initialized"): if not hasattr(cls, "_initialized"):
original_init = cls.__init__ original_init = cls.__init__
new_cls = type( new_cls = type(
cls.__name__, cls.__name__,
(FusedMoEMethodBase,), (FusedMoEMethodBase,),
{ {
"__init__": original_init, "__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
}, },
) )
obj = super(new_cls, new_cls).__new__(new_cls) obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs) obj.__init__(*args, **kwargs)
return obj return obj
return super().__new__(cls) return super().__new__(cls)
def __init__(self, quant_config): def __init__(self, quant_config):
self.quant_config = quant_config self.quant_config = quant_config
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8 num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
), ),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w13_weight", w13_weight) layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs) set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter( w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8), torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter( w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
requires_grad=False, requires_grad=False,
) )
w2_weight_scale = torch.nn.Parameter( w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update( extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
) )
set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w13_input_scale = None w13_input_scale = None
layer.register_parameter("w13_input_scale", w13_input_scale) layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = None w2_input_scale = None
layer.register_parameter("w2_input_scale", w2_input_scale) layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False) layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter( layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False layer.w13_weight_scale.data, requires_grad=False
) )
layer.w2_weight_scale = Parameter( layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False layer.w2_weight_scale.data, requires_grad=False
) )
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
# Expert selection # Expert selection
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
top_k=top_k, top_k=top_k,
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate use_fused_gate=use_fused_gate
) )
return fused_experts( return fused_experts(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
use_int8_w8a8=True, use_int8_w8a8=True,
per_channel_quant=True, per_channel_quant=True,
activation=activation, activation=activation,
expert_map=expert_map, expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
w1_scale=(layer.w13_weight_scale), w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale), w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment