Commit e7d022fd authored by 王敏's avatar 王敏
Browse files

1.Mixtral模型的gemm由TN改为NN

parent 82aee745
...@@ -23,7 +23,7 @@ def get_model_architecture( ...@@ -23,7 +23,7 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", []) visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'Qwen2VLForConditionalGeneration', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel'] support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'Qwen2VLForConditionalGeneration', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []: if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
import os
import re
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple
import torch import torch
...@@ -46,6 +48,8 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -46,6 +48,8 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers from .utils import is_pp_missing_parameter, make_layers
...@@ -366,6 +370,12 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -366,6 +370,12 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -483,3 +493,24 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -483,3 +493,24 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn and self.quant_method is None:
lay_key_words = [
"block_sparse_moe.gate.weight",
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"lm_head.weight",
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment