"vscode:/vscode.git/clone" did not exist on "59a50afa084dbd26b8a4f58b960ce337af6a4667"
Commit 520d727f authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.6.2-dev_wm' into 'v0.6.2-dev'

[fix]修复单测test_mlp_correctness.py运行时的崩溃问题

See merge request dcutoolkit/deeplearing/vllm!40
parents 2f7d31f1 96dcaff9
......@@ -38,7 +38,7 @@ SPEC_MODEL = "ibm-fms/llama-160m-accelerator"
MAX_SPEC_TOKENS = 3
# precision
PRECISION = "float32"
PRECISION = "float16"
@pytest.mark.parametrize(
......
......@@ -23,7 +23,7 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'Qwen2VLForConditionalGeneration', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM']
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'Qwen2VLForConditionalGeneration', 'ChatGLMModel', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel']
if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
......
import os
import math
from typing import Iterable, List, Tuple
......@@ -11,6 +12,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs import MLPSpeculatorConfig
from vllm import _custom_ops as ops
SQRT2 = 2**0.5
......@@ -67,6 +69,9 @@ class MLPSpeculator(nn.Module):
def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None:
super().__init__()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
......@@ -195,3 +200,12 @@ class MLPSpeculator(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.use_llama_nn and "head" in name:
_weight = torch.zeros_like(param.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, param.data, _weight.shape[0], _weight.shape[1])
param.data.copy_(_weight)
param.data=param.data.reshape(ori_shape[1],-1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment