Commit 7b2122d9 authored by jujl1's avatar jujl1
Browse files

feat: w4a8

parent 76ec56bd
...@@ -1760,12 +1760,6 @@ def fused_experts_impl( ...@@ -1760,12 +1760,6 @@ def fused_experts_impl(
cache13=cache13, cache13=cache13,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
use_int4_w4a8=True,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=w1_scale, w1_scale=w1_scale,
......
...@@ -20,7 +20,7 @@ from lmslim.layers.gemm.int8_utils import ( ...@@ -20,7 +20,7 @@ from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_group_quant_int8,
per_token_quant_int8) per_token_quant_int8)
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from vllm.model_executor.layers.quantization.utils.w8a8_utils import apply_int8_linear
import os import os
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs from vllm import envs
...@@ -94,7 +94,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -94,7 +94,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: SlimQuantW4A8Int8Config): def __init__(self, quantization_config: SlimQuantW4A8Int8Config):
self.quantization_config = quantization_config self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0] n=layer.weight.shape[0]
...@@ -112,7 +112,9 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -112,7 +112,9 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
for key, value in configs_dict.items(): for key, value in configs_dict.items():
m=int(key.split('_')[0]) m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value) ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
else: elif self.w8a8_strategy == 3:
layer.weight.data = layer.weight.data.T
else:
weight_data=layer.weight.data weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1) _weight=weight_data.T.contiguous().reshape(n,-1)
layer.weight.data=_weight layer.weight.data=_weight
...@@ -159,68 +161,14 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -159,68 +161,14 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
input_quant_args: Optional[list[torch.Tensor]] = None, input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None silu_quant_args: Optional[list[torch.Tensor]] = None
): ):
x_q, x_scale = per_token_quant_int8(x) return apply_int8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
bias=bias,
w8a8_strategy=self.w8a8_strategy,
input_quant_args=input_quant_args,
silu_quant_args=silu_quant_args)
if self.w8a8_strategy==1:
m=x_q.shape[0]
k=x_q.shape[1]
n=layer.weight.shape[1]
if len(W8A8_TRITONJSON.triton_json_dict)==0:
best_config=None
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict:
if m<=16:
m_=m
elif m<=64:
m_= (m + 3) & -4 #取值到最近的4的倍数
elif m<=160:
m_ = (m // 8) * 8
elif m<200: #256
m_=160
elif m<480: #512
m_=256
elif m<960: #1024
m_=512
elif m<2048:
m_=1024
elif m<4096:
m_=2048
elif m<6000:
m_=4096
else:
m_=8192
best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"]
else:
best_config=None
#if best_config==None:
# print("m:{},n:{},k:{}".format(m,n,k))
# print("config not found!")
return ops.triton_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias,best_config=best_config)
elif self.w8a8_strategy==2:
return ops.cutlass_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
else:
return ops.rocblas_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
class SlimQuantW4A8Int8MoEMethod: class SlimQuantW4A8Int8MoEMethod:
...@@ -256,8 +204,7 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -256,8 +204,7 @@ class SlimQuantW4A8Int8MoEMethod:
self.quant_config = quant_config self.quant_config = quant_config
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton= W8a8GetCacheJSON()
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.fused_experts: Optional[FusedMoEModularKernel] = None self.moe_mk: Optional[FusedMoEModularKernel] = None
self.topk_indices_dtype = None
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module)-> Optional[FusedMoEQuantConfig]: self, layer: torch.nn.Module)-> Optional[FusedMoEQuantConfig]:
...@@ -270,9 +217,8 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -270,9 +217,8 @@ class SlimQuantW4A8Int8MoEMethod:
per_act_token_quant=True, per_act_token_quant=True,
per_out_ch_quant=False, per_out_ch_quant=False,
block_shape=None, block_shape=None,
weight_dtype='int4'
) )
self.moe_quant_config._w1.dtype="int4"
self.moe_quant_config._w1.dtype="int4"
return self.moe_quant_config return self.moe_quant_config
def create_weights( def create_weights(
...@@ -354,49 +300,15 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -354,49 +300,15 @@ class SlimQuantW4A8Int8MoEMethod:
) )
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
top_k: int, topk_ids: torch.Tensor,
renormalize: bool, use_nn_moe: bool | None = False,
use_grouped_topk: bool = False, use_fused_gate: bool | None = False,
topk_group: Optional[int] = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
**_
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.")
# Expert selection
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
)
return fused_experts( return fused_experts(
x, x,
layer.w13_weight, layer.w13_weight,
...@@ -404,10 +316,10 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -404,10 +316,10 @@ class SlimQuantW4A8Int8MoEMethod:
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
expert_map=expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
) )
...@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig ...@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod
from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_cache
try: try:
from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin
except Exception: except Exception:
...@@ -147,8 +147,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -147,8 +147,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
self.moe = moe self.moe = moe
self.quant_config = quant_config self.quant_config = quant_config
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.fused_experts: Optional[FusedMoEModularKernel] = None self.moe_mk: Optional[FusedMoEModularKernel] = None
self.topk_indices_dtype = None
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module) : self, layer: torch.nn.Module) :
...@@ -218,46 +217,15 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -218,46 +217,15 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False) layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
top_k: int, topk_ids: torch.Tensor,
renormalize: bool, use_nn_moe: bool | None = False,
use_grouped_topk: bool = False, use_fused_gate: bool | None = False,
topk_group: Optional[int] = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
**_
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.")
# Expert selection
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
)
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers() workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin( return fused_experts_impl_w4a8_marlin(
x, x,
...@@ -268,15 +236,13 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -268,15 +236,13 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
workspace=workspace, workspace=workspace,
global_reduce_buffer=global_reduce_buffer, global_reduce_buffer=global_reduce_buffer,
inplace=True, inplace=True,
use_int4_w4a8=True, activation=layer.activation,
per_channel_quant=True, expert_map=layer.expert_map,
activation=activation, apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=expert_map, global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input, moe_cache_getter=get_moe_cache if envs.VLLM_USE_GLOBAL_CACHE13 else None,
global_num_experts=global_num_experts,
w1_scale=(layer.w13_weight_scale), w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale), w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale
use_nn_moe=use_nn_moe,
) )
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment