Commit 96dcaff9 authored by 王敏's avatar 王敏
Browse files

[fix]修复单测test_mlp_correctness.py运行时的崩溃问题

parent 215f33b0
...@@ -38,7 +38,7 @@ SPEC_MODEL = "ibm-fms/llama-160m-accelerator" ...@@ -38,7 +38,7 @@ SPEC_MODEL = "ibm-fms/llama-160m-accelerator"
MAX_SPEC_TOKENS = 3 MAX_SPEC_TOKENS = 3
# precision # precision
PRECISION = "float32" PRECISION = "float16"
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -23,7 +23,7 @@ def get_model_architecture( ...@@ -23,7 +23,7 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", []) visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'Qwen2VLForConditionalGeneration', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM'] support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'Qwen2VLForConditionalGeneration', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []: if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
......
import os
import math import math
from typing import Iterable, List, Tuple from typing import Iterable, List, Tuple
...@@ -11,6 +12,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -11,6 +12,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs import MLPSpeculatorConfig from vllm.transformers_utils.configs import MLPSpeculatorConfig
from vllm import _custom_ops as ops
SQRT2 = 2**0.5 SQRT2 = 2**0.5
...@@ -67,6 +69,9 @@ class MLPSpeculator(nn.Module): ...@@ -67,6 +69,9 @@ class MLPSpeculator(nn.Module):
def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None: def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None:
super().__init__() super().__init__()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.n_predict = config.n_predict self.n_predict = config.n_predict
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim self.emb_dim = config.emb_dim
...@@ -195,3 +200,12 @@ class MLPSpeculator(nn.Module): ...@@ -195,3 +200,12 @@ class MLPSpeculator(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn and "head" in name:
_weight = torch.zeros_like(param.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, param.data, _weight.shape[0], _weight.shape[1])
param.data.copy_(_weight)
param.data=param.data.reshape(ori_shape[1],-1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment