Commit 47bd229c authored by yangql's avatar yangql
Browse files

适配deepseekv3\v2 moe awq的推理支持

parent 4a734b9d
{
"7168_2048": {
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 4,
"SPLIT_K": 8,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 0
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 4,
"SPLIT_K": 8,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 0
},
"3": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 4,
"SPLIT_K": 8,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 4,
"SPLIT_K": 8,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 0
},
"5": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 4,
"SPLIT_K": 8,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 0
},
"6": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 4,
"SPLIT_K": 8,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 0
},
"7": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 4,
"SPLIT_K": 8,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 0
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 4,
"SPLIT_K": 8,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 0
},
"9": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"SPLIT_K": 8,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 0
},
"10": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 4,
"SPLIT_K": 8,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 0
},
"11": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 4,
"SPLIT_K": 8,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 0
},
"12": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 4,
"SPLIT_K": 8,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 0
},
"13": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 4,
"SPLIT_K": 8,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 0
},
"14": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 4,
"SPLIT_K": 8,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 0
},
"15": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 4,
"SPLIT_K": 8,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 0
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 4,
"SPLIT_K": 8,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 0
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
"SPLIT_K": 4,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 0
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
"SPLIT_K": 1,
"num_stages": 0,
"num_warps": 4,
"num_ldmatrixes": 1
},
"128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"SPLIT_K": 1,
"num_stages": 0,
"num_warps": 4,
"num_ldmatrixes": 1
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
"SPLIT_K": 1,
"num_stages": 1,
"num_warps": 8,
"num_ldmatrixes": 1
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
"SPLIT_K": 1,
"num_stages": 1,
"num_warps": 8,
"num_ldmatrixes": 1
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"SPLIT_K": 1,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 1
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"SPLIT_K": 1,
"num_stages": 1,
"num_warps": 8,
"num_ldmatrixes": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"SPLIT_K": 1,
"num_stages": 1,
"num_warps": 4,
"num_ldmatrixes": 1
}
}
}
\ No newline at end of file
...@@ -277,6 +277,10 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -277,6 +277,10 @@ class MoeWNA16Method(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = None,
start_expert: Optional[int] = None,
end_expert: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -307,7 +311,9 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -307,7 +311,9 @@ class MoeWNA16Method(FusedMoEMethodBase):
w2_scale=layer.w2_scales, w2_scale=layer.w2_scales,
w1_zp=layer.w13_qzeros if has_zp else None, w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None, w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size]) block_shape=[0, layer.group_size],
use_nn_moe=False,
)
@staticmethod @staticmethod
def get_weight_loader(layer, weight_loader): def get_weight_loader(layer, weight_loader):
......
...@@ -27,7 +27,7 @@ from torch import nn ...@@ -27,7 +27,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
import os import os
import re import re
import vllm.envs as envs
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
...@@ -517,9 +517,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -517,9 +517,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight.data.copy_(_weight) weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1) weight.data=weight.data.reshape(ori_shape[1], -1)
else:
if self.quant_method == "awq":
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
os.environ['LLAMA_NN'] = '0'
if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ:
lay_key_words = [ lay_key_words = [
"self_attn.W_pack.qweight", "self_attn.W_pack.qweight",
"self_attn.o_proj.qweight", "self_attn.o_proj.qweight",
......
This diff is collapsed.
...@@ -29,7 +29,7 @@ from torch import nn ...@@ -29,7 +29,7 @@ from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
import os import os
import re import re
import vllm.envs as envs
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
...@@ -505,9 +505,11 @@ class LlamaModel(nn.Module): ...@@ -505,9 +505,11 @@ class LlamaModel(nn.Module):
weight.data.copy_(_weight) weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1) weight.data=weight.data.reshape(ori_shape[1], -1)
else:
if self.quant_method == "awq":
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
os.environ['LLAMA_NN'] = '0'
if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ:
lay_key_words = [ lay_key_words = [
"self_attn.qkv_proj.qweight", "self_attn.qkv_proj.qweight",
"self_attn.o_proj.qweight", "self_attn.o_proj.qweight",
...@@ -551,7 +553,6 @@ class LlamaModel(nn.Module): ...@@ -551,7 +553,6 @@ class LlamaModel(nn.Module):
#当为triton支持推理的时候不能进行处理 #当为triton支持推理的时候不能进行处理
if self.quant_method == "compressed_tensors": if self.quant_method == "compressed_tensors":
os.environ['LM_NN'] = '0'
lay_key_words = [ lay_key_words = [
"self_attn.qkv_proj.weight", "self_attn.qkv_proj.weight",
"self_attn.o_proj.weight", "self_attn.o_proj.weight",
......
...@@ -30,7 +30,7 @@ from torch import nn ...@@ -30,7 +30,7 @@ from torch import nn
from transformers import Qwen2Config from transformers import Qwen2Config
import os import os
import re import re
import vllm.envs as envs
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
...@@ -483,9 +483,11 @@ class Qwen2Model(nn.Module): ...@@ -483,9 +483,11 @@ class Qwen2Model(nn.Module):
weight.data.copy_(_weight) weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1) weight.data=weight.data.reshape(ori_shape[1],-1)
else:
if self.quant_method == "awq":
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
os.environ['LLAMA_NN'] = '0'
if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ:
lay_key_words = [ lay_key_words = [
"self_attn.qkv_proj.qweight", "self_attn.qkv_proj.qweight",
"self_attn.o_proj.qweight", "self_attn.o_proj.qweight",
...@@ -528,7 +530,6 @@ class Qwen2Model(nn.Module): ...@@ -528,7 +530,6 @@ class Qwen2Model(nn.Module):
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous() qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
if self.quant_method == "compressed_tensors": if self.quant_method == "compressed_tensors":
os.environ['LM_NN'] = '0'
lay_key_words = [ lay_key_words = [
"self_attn.qkv_proj.weight", "self_attn.qkv_proj.weight",
"self_attn.o_proj.weight", "self_attn.o_proj.weight",
......
...@@ -72,7 +72,7 @@ class RocmPlatform(Platform): ...@@ -72,7 +72,7 @@ class RocmPlatform(Platform):
supported_quantization: list[str] = [ supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
"fbgemm_fp8", "gguf", "quark" "fbgemm_fp8", "gguf", "quark", "moe_wna16"
] ]
@classmethod @classmethod
...@@ -157,8 +157,8 @@ class RocmPlatform(Platform): ...@@ -157,8 +157,8 @@ class RocmPlatform(Platform):
if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ: if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
logger.warning( logger.warning(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.") " is not set, disabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True envs.VLLM_USE_TRITON_AWQ = False
@classmethod @classmethod
def get_punica_wrapper(cls) -> str: def get_punica_wrapper(cls) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment