Commit 82aee745 authored by zhuwenwen's avatar zhuwenwen
Browse files

optimize qwen2-moe layout

parent 01df9361
...@@ -10,20 +10,21 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention ...@@ -10,20 +10,21 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
## 支持模型结构列表 ## 支持模型结构列表
| 结构 | 模型 | 模型并行 | FP16 | | 结构 | 模型 | 模型并行 | FP16 |
| :------: | :------: | :------: | :------: | | :------: | :------: | :------: | :------: |
| LlamaForCausalLM | Llama 3.1,Llama 3,Llama 2,Llama,Yi,Codellama、deepseek | Yes | Yes | | LlamaForCausalLM | Llama 3.1,Llama 3,Llama 2,Llama,Yi,Codellama、deepseek | Yes | Yes |
| QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes | | QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes |
| Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5 | Yes | Yes | | Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5 | Yes | Yes |
| ChatGLMModel | glm-4v-9b,chatglm3,chatglm2 | Yes | Yes | | ChatGLMModel | glm-4v-9b,chatglm3,chatglm2 | Yes | Yes |
| DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | Yes | | DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | Yes |
| BaiChuanForCausalLM | Baichuan2,Baichuan | Yes | Yes | | BaiChuanForCausalLM | Baichuan2,Baichuan | Yes | Yes |
| BloomForCausalLM | BLOOM | Yes | Yes | | BloomForCausalLM | BLOOM | Yes | Yes |
| InternLMForCausalLM | InternLM | Yes | Yes | | InternLMForCausalLM | InternLM | Yes | Yes |
| InternLM2ForCausalLM | InternLM2 | Yes | Yes | | InternLM2ForCausalLM | InternLM2 | Yes | Yes |
| MiniCPMForCausalLM | MiniCPM | Yes | Yes | | MiniCPMForCausalLM | MiniCPM | Yes | Yes |
| MiniCPM3ForCausalLM | MiniCPM3 | Yes | Yes | | MiniCPM3ForCausalLM | MiniCPM3 | Yes | Yes |
| MixtralForCausalLM | Mixtral-8x7B,Mixtral-8x7B-Instruct | Yes | Yes | | MixtralForCausalLM | Mixtral-8x7B,Mixtral-8x7B-Instruct | Yes | Yes |
| TeleChat12BForCausalLM (#TelechatForCausalLM) | TeleChat-12B | Yes | Yes | | TeleChat12BForCausalLM (#TelechatForCausalLM) | TeleChat-12B | Yes | Yes |
| LlavaForConditionalGeneration | LLaMA,LLaMA-2,LLaMA-3 | Yes | Yes | | Qwen2MoeForCausalLM | Qwen2-57B-A14B,Qwen2-57B-A14B-Instruct | Yes | Yes |
| LlavaForConditionalGeneration | LLaMA,LLaMA-2,LLaMA-3 | Yes | Yes |
| Qwen2VLForConditionalGeneration | Qwen2-VL | Yes | Yes | | Qwen2VLForConditionalGeneration | Qwen2-VL | Yes | Yes |
| MiniCPMV | MiniCPM-V | Yes | Yes | | MiniCPMV | MiniCPM-V | Yes | Yes |
| Phi3VForCausalLM | Phi-3.5-vision | Yes | Yes | | Phi3VForCausalLM | Phi-3.5-vision | Yes | Yes |
......
...@@ -23,7 +23,7 @@ def get_model_architecture( ...@@ -23,7 +23,7 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", []) visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2VLForConditionalGeneration', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel'] support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'Qwen2VLForConditionalGeneration', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []: if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
......
...@@ -55,6 +55,11 @@ from vllm.utils import print_warning_once ...@@ -55,6 +55,11 @@ from vllm.utils import print_warning_once
from .utils import is_pp_missing_parameter, make_layers from .utils import is_pp_missing_parameter, make_layers
import os
import re
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
class Qwen2MoeMLP(nn.Module): class Qwen2MoeMLP(nn.Module):
...@@ -389,6 +394,16 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -389,6 +394,16 @@ class Qwen2MoeForCausalLM(nn.Module):
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
def forward( def forward(
self, self,
...@@ -529,3 +544,41 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -529,3 +544,41 @@ class Qwen2MoeForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn and self.quant_method is None:
lay_key_words = [
"gate_up_proj.weight",
"down_proj.weight",
"mlp.gate.weight",
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"lm_head.weight",
]
combined_words = "|".join(lay_key_words)
lay_qkv_words = ["self_attn.qkv_proj.weight"]
qkv_words = "|".join(lay_qkv_words)
lay_qkv_bias_words = ["self_attn.qkv_proj.bias"]
qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items():
# if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
# weight.data = pad_weight(weight.data, 32)
matches = re.findall(combined_words, layername)
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
# if self.use_fa_pad and (re.findall(qkv_words, layername)):
# if not gemm_bank_conf(weight.data.shape[0]):
# weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
\ No newline at end of file
...@@ -1158,17 +1158,17 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1158,17 +1158,17 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
qkv_bias_words = "|".join(lay_qkv_bias_words) qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
if self.use_fa_pad and (re.findall(qkv_bias_words, layername)): # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
weight.data = pad_weight(weight.data, 32) # weight.data = pad_weight(weight.data, 32)
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]): if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32) weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and (re.findall(qkv_words, layername)): # if self.use_fa_pad and (re.findall(qkv_words, layername)):
if not gemm_bank_conf(weight.data.shape[0]): # if not gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32) # weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data) _weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape ori_shape =_weight.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment