Commit 7b2122d9 authored by jujl1's avatar jujl1
Browse files

feat: w4a8

parent 76ec56bd
......@@ -1760,12 +1760,6 @@ def fused_experts_impl(
cache13=cache13,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
use_int4_w4a8=True,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
......
......@@ -20,7 +20,7 @@ from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8,
per_token_quant_int8)
from vllm.utils import W8a8GetCacheJSON
from vllm.model_executor.layers.quantization.utils.w8a8_utils import apply_int8_linear
import os
from vllm import _custom_ops as ops
from vllm import envs
......@@ -94,7 +94,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: SlimQuantW4A8Int8Config):
self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0]
......@@ -112,6 +112,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
for key, value in configs_dict.items():
m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
elif self.w8a8_strategy == 3:
layer.weight.data = layer.weight.data.T
else:
weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1)
......@@ -159,68 +161,14 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None
):
x_q, x_scale = per_token_quant_int8(x)
if self.w8a8_strategy==1:
m=x_q.shape[0]
k=x_q.shape[1]
n=layer.weight.shape[1]
if len(W8A8_TRITONJSON.triton_json_dict)==0:
best_config=None
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict:
if m<=16:
m_=m
elif m<=64:
m_= (m + 3) & -4 #取值到最近的4的倍数
elif m<=160:
m_ = (m // 8) * 8
elif m<200: #256
m_=160
elif m<480: #512
m_=256
elif m<960: #1024
m_=512
elif m<2048:
m_=1024
elif m<4096:
m_=2048
elif m<6000:
m_=4096
else:
m_=8192
return apply_int8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
bias=bias,
w8a8_strategy=self.w8a8_strategy,
input_quant_args=input_quant_args,
silu_quant_args=silu_quant_args)
best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"]
else:
best_config=None
#if best_config==None:
# print("m:{},n:{},k:{}".format(m,n,k))
# print("config not found!")
return ops.triton_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias,best_config=best_config)
elif self.w8a8_strategy==2:
return ops.cutlass_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
else:
return ops.rocblas_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
class SlimQuantW4A8Int8MoEMethod:
......@@ -256,8 +204,7 @@ class SlimQuantW4A8Int8MoEMethod:
self.quant_config = quant_config
self.tritonsingleton= W8a8GetCacheJSON()
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.fused_experts: Optional[FusedMoEModularKernel] = None
self.topk_indices_dtype = None
self.moe_mk: Optional[FusedMoEModularKernel] = None
def get_fused_moe_quant_config(
self, layer: torch.nn.Module)-> Optional[FusedMoEQuantConfig]:
......@@ -270,9 +217,8 @@ class SlimQuantW4A8Int8MoEMethod:
per_act_token_quant=True,
per_out_ch_quant=False,
block_shape=None,
weight_dtype='int4'
)
self.moe_quant_config._w1.dtype="int4"
self.moe_quant_config._w1.dtype="int4"
return self.moe_quant_config
def create_weights(
......@@ -355,48 +301,14 @@ class SlimQuantW4A8Int8MoEMethod:
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
**_
) -> torch.Tensor:
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.")
# Expert selection
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
)
return fused_experts(
x,
layer.w13_weight,
......@@ -404,10 +316,10 @@ class SlimQuantW4A8Int8MoEMethod:
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=layer.activation,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
quant_config=self.moe_quant_config,
use_nn_moe=use_nn_moe,
)
......@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod
from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_cache
try:
from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin
except Exception:
......@@ -147,8 +147,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
self.moe = moe
self.quant_config = quant_config
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.fused_experts: Optional[FusedMoEModularKernel] = None
self.topk_indices_dtype = None
self.moe_mk: Optional[FusedMoEModularKernel] = None
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) :
......@@ -219,45 +218,14 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
**_
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.")
# Expert selection
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
)
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin(
x,
......@@ -268,15 +236,13 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=layer.activation,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
moe_cache_getter=get_moe_cache if envs.VLLM_USE_GLOBAL_CACHE13 else None,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe,
a2_scale=layer.w2_input_scale
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment