Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
96dcaff9
Commit
96dcaff9
authored
Nov 20, 2024
by
王敏
Browse files
[fix]修复单测test_mlp_correctness.py运行时的崩溃问题
parent
215f33b0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
2 deletions
+16
-2
tests/spec_decode/e2e/test_mlp_correctness.py
tests/spec_decode/e2e/test_mlp_correctness.py
+1
-1
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+1
-1
vllm/model_executor/models/mlp_speculator.py
vllm/model_executor/models/mlp_speculator.py
+14
-0
No files found.
tests/spec_decode/e2e/test_mlp_correctness.py
View file @
96dcaff9
...
...
@@ -38,7 +38,7 @@ SPEC_MODEL = "ibm-fms/llama-160m-accelerator"
MAX_SPEC_TOKENS
=
3
# precision
PRECISION
=
"float
32
"
PRECISION
=
"float
16
"
@
pytest
.
mark
.
parametrize
(
...
...
vllm/model_executor/model_loader/utils.py
View file @
96dcaff9
...
...
@@ -23,7 +23,7 @@ def get_model_architecture(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
visions
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
or
getattr
(
model_config
.
hf_config
,
"vision_config"
,
[])
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'Qwen2MoeForCausalLM'
,
'Qwen2VLForConditionalGeneration'
,
'ChatGLMModel'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MedusaModel'
,
'MixtralForCausalLM'
]
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'Qwen2MoeForCausalLM'
,
'Qwen2VLForConditionalGeneration'
,
'ChatGLMModel'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MedusaModel'
,
'MixtralForCausalLM'
,
'MLPSpeculatorPreTrainedModel'
]
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
(
architectures
==
[
'QWenLMHeadModel'
]
or
architectures
==
[
'ChatGLMModel'
]
)
and
visions
!=
[]:
...
...
vllm/model_executor/models/mlp_speculator.py
View file @
96dcaff9
import
os
import
math
from
typing
import
Iterable
,
List
,
Tuple
...
...
@@ -11,6 +12,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.transformers_utils.configs
import
MLPSpeculatorConfig
from
vllm
import
_custom_ops
as
ops
SQRT2
=
2
**
0.5
...
...
@@ -67,6 +69,9 @@ class MLPSpeculator(nn.Module):
def
__init__
(
self
,
config
:
MLPSpeculatorConfig
,
**
kwargs
)
->
None
:
super
().
__init__
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
n_predict
=
config
.
n_predict
self
.
vocab_size
=
config
.
vocab_size
self
.
emb_dim
=
config
.
emb_dim
...
...
@@ -195,3 +200,12 @@ class MLPSpeculator(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
use_llama_nn
and
"head"
in
name
:
_weight
=
torch
.
zeros_like
(
param
.
data
)
ori_shape
=
_weight
.
shape
ops
.
trans_w16_gemm
(
_weight
,
param
.
data
,
_weight
.
shape
[
0
],
_weight
.
shape
[
1
])
param
.
data
.
copy_
(
_weight
)
param
.
data
=
param
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment