Commit b3312eec authored by yangql's avatar yangql
Browse files

分离fuse moe awq算子到lmslim上

parent 483acdc4
...@@ -17,7 +17,10 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -17,7 +17,10 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
from lmslim.layers.fused_moe.fuse_moe_int4 import fused_experts_w4a16
os.environ['W4A16_MOE_CUDA'] = os.environ.get('W4A16_MOE_CUDA', '0') os.environ['W4A16_MOE_CUDA'] = os.environ.get('W4A16_MOE_CUDA', '0')
os.environ['W4A16_MOE_LMSLIM'] = os.environ.get('W4A16_MOE_LMSLIM', '1')
if os.environ['W4A16_MOE_CUDA'] == '1': if os.environ['W4A16_MOE_CUDA'] == '1':
from vllm.model_executor.layers.quantization.utils.fused_moe_cuda import fused_experts_cuda from vllm.model_executor.layers.quantization.utils.fused_moe_cuda import fused_experts_cuda
...@@ -180,7 +183,11 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -180,7 +183,11 @@ class MoeWNA16Method(FusedMoEMethodBase):
def __init__(self, quant_config: MoeWNA16Config): def __init__(self, quant_config: MoeWNA16Config):
self.quant_config = quant_config self.quant_config = quant_config
self.use_w4a16_moe_sz = os.environ.get('AWQ_MOE_SZ') == '1' self.use_w4a16_moe_sz = os.environ.get('AWQ_MOE_SZ') == '1'
self.use_w4a16_cuda = os.environ['W4A16_MOE_CUDA'] == '1' self.use_w4a16_cuda = 0
self.use_moe_lmslim = 0
if self.use_w4a16_moe_sz:
self.use_w4a16_cuda = os.environ['W4A16_MOE_CUDA'] == '1'
self.use_moe_lmslim = os.environ['W4A16_MOE_LMSLIM'] == "1"
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
...@@ -352,6 +359,24 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -352,6 +359,24 @@ class MoeWNA16Method(FusedMoEMethodBase):
weight_bits = self.quant_config.weight_bits weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp has_zp = self.quant_config.has_zp
if self.use_moe_lmslim:
return fused_experts_w4a16(
x,
layer.w13_qweight,
layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int4_w4a16=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
block_shape=[0, layer.group_size])
if self.use_w4a16_cuda: if self.use_w4a16_cuda:
m = topk_ids.shape[0] m = topk_ids.shape[0]
if m <= 512: if m <= 512:
...@@ -380,6 +405,7 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -380,6 +405,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation,
use_int4_w4a16=weight_bits == 4, use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8, use_int8_w8a16=weight_bits == 8,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
......
...@@ -17,8 +17,8 @@ from vllm.platforms import current_platform ...@@ -17,8 +17,8 @@ from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size) moe_align_block_size)
from grouped_gemm import moe_gemm_w4a16 from grouped_gemm_int4 import moe_gemm_w4a16
from grouped_gemm.ops import permute as permute_topK, unpermute as unpermute_topK from grouped_gemm_int4.ops import permute as permute_topK, unpermute as unpermute_topK
import torch.nn.functional as F import torch.nn.functional as F
logger = init_logger(__name__) logger = init_logger(__name__)
device_name = current_platform.get_device_name() device_name = current_platform.get_device_name()
...@@ -315,7 +315,7 @@ def fused_experts_impl_cuda(hidden_states: torch.Tensor, ...@@ -315,7 +315,7 @@ def fused_experts_impl_cuda(hidden_states: torch.Tensor,
num_tokens_post_padded, # 实际专家数 num_tokens_post_padded, # 实际专家数
expert_ids, # expert_id_vec expert_ids, # expert_id_vec
w1_scale, # scale_zero w1_scale, # scale_zero
64, # group_size block_shape[1], # group_size
topk=topk, # topk topk=topk, # topk
mode=mode_1) # mode=gemm1_mode mode=mode_1) # mode=gemm1_mode
...@@ -329,10 +329,12 @@ def fused_experts_impl_cuda(hidden_states: torch.Tensor, ...@@ -329,10 +329,12 @@ def fused_experts_impl_cuda(hidden_states: torch.Tensor,
expert_ids, # expert_id_vec expert_ids, # expert_id_vec
w2_scale, # scale_zero w2_scale, # scale_zero
topk_weights, # topk_weights topk_weights, # topk_weights
64, # group_size block_shape[1], # group_size
topk=topk, # topk topk=topk, # topk
mode=mode_2) # mode=gemm2_mode mode=mode_2) # mode=gemm2_mode
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states) ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states)
return out_hidden_states return out_hidden_states
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment