Commit 82aee745 authored by zhuwenwen's avatar zhuwenwen
Browse files

optimize qwen2-moe layout

parent 01df9361
......@@ -23,6 +23,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| MiniCPM3ForCausalLM | MiniCPM3 | Yes | Yes |
| MixtralForCausalLM | Mixtral-8x7B,Mixtral-8x7B-Instruct | Yes | Yes |
| TeleChat12BForCausalLM (#TelechatForCausalLM) | TeleChat-12B | Yes | Yes |
| Qwen2MoeForCausalLM | Qwen2-57B-A14B,Qwen2-57B-A14B-Instruct | Yes | Yes |
| LlavaForConditionalGeneration | LLaMA,LLaMA-2,LLaMA-3 | Yes | Yes |
| Qwen2VLForConditionalGeneration | Qwen2-VL | Yes | Yes |
| MiniCPMV | MiniCPM-V | Yes | Yes |
......
......@@ -23,7 +23,7 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2VLForConditionalGeneration', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel']
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'Qwen2VLForConditionalGeneration', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel']
if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
......
......@@ -55,6 +55,11 @@ from vllm.utils import print_warning_once
from .utils import is_pp_missing_parameter, make_layers
import os
import re
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
class Qwen2MoeMLP(nn.Module):
......@@ -390,6 +395,16 @@ class Qwen2MoeForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
def forward(
self,
input_ids: torch.Tensor,
......@@ -529,3 +544,41 @@ class Qwen2MoeForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.use_llama_nn and self.quant_method is None:
lay_key_words = [
"gate_up_proj.weight",
"down_proj.weight",
"mlp.gate.weight",
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"lm_head.weight",
]
combined_words = "|".join(lay_key_words)
lay_qkv_words = ["self_attn.qkv_proj.weight"]
qkv_words = "|".join(lay_qkv_words)
lay_qkv_bias_words = ["self_attn.qkv_proj.bias"]
qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items():
# if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
# weight.data = pad_weight(weight.data, 32)
matches = re.findall(combined_words, layername)
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
# if self.use_fa_pad and (re.findall(qkv_words, layername)):
# if not gemm_bank_conf(weight.data.shape[0]):
# weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
\ No newline at end of file
......@@ -1158,17 +1158,17 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items():
if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
weight.data = pad_weight(weight.data, 32)
# if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
# weight.data = pad_weight(weight.data, 32)
matches = re.findall(combined_words, layername)
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and (re.findall(qkv_words, layername)):
if not gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
# if self.use_fa_pad and (re.findall(qkv_words, layername)):
# if not gemm_bank_conf(weight.data.shape[0]):
# weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment