Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
691e29ec
Unverified
Commit
691e29ec
authored
Jun 27, 2024
by
Nick Hill
Committed by
GitHub
Jun 27, 2024
Browse files
[BugFix] Fix `MLPSpeculator` handling of `num_speculative_tokens` (#5876)
parent
3fd02bda
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
10 deletions
+18
-10
vllm/config.py
vllm/config.py
+7
-3
vllm/model_executor/models/mlp_speculator.py
vllm/model_executor/models/mlp_speculator.py
+8
-7
vllm/transformers_utils/configs/mlp_speculator.py
vllm/transformers_utils/configs/mlp_speculator.py
+3
-0
No files found.
vllm/config.py
View file @
691e29ec
...
@@ -920,15 +920,19 @@ class SpeculativeConfig:
...
@@ -920,15 +920,19 @@ class SpeculativeConfig:
max_logprobs
=
target_model_config
.
max_logprobs
,
max_logprobs
=
target_model_config
.
max_logprobs
,
)
)
if
(
draft_model_config
.
hf_config
.
model_type
==
"mlp_speculator"
draft_hf_config
=
draft_model_config
.
hf_config
if
(
draft_hf_config
.
model_type
==
"mlp_speculator"
and
target_parallel_config
.
world_size
!=
1
):
and
target_parallel_config
.
world_size
!=
1
):
# MLPSpeculator TP support will be added very soon
# MLPSpeculator TP support will be added very soon
raise
ValueError
(
raise
ValueError
(
"Speculative decoding with mlp_speculator models does not "
"Speculative decoding with mlp_speculator models does not "
"yet support distributed inferencing (TP > 1)."
)
"yet support distributed inferencing (TP > 1)."
)
n_predict
=
getattr
(
draft_model_config
.
hf_config
,
"n_predict"
,
if
(
num_speculative_tokens
is
not
None
None
)
and
hasattr
(
draft_hf_config
,
"num_lookahead_tokens"
)):
draft_hf_config
.
num_lookahead_tokens
=
num_speculative_tokens
n_predict
=
getattr
(
draft_hf_config
,
"n_predict"
,
None
)
if
n_predict
is
not
None
:
if
n_predict
is
not
None
:
if
num_speculative_tokens
is
None
:
if
num_speculative_tokens
is
None
:
# Default to max value defined in draft model config.
# Default to max value defined in draft model config.
...
...
vllm/model_executor/models/mlp_speculator.py
View file @
691e29ec
...
@@ -11,6 +11,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -11,6 +11,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs
import
MLPSpeculatorConfig
class
MLPSpeculatorLayerNorm
(
nn
.
Module
):
class
MLPSpeculatorLayerNorm
(
nn
.
Module
):
...
@@ -48,7 +49,7 @@ class MLPSpeculatorLayerNorm(nn.Module):
...
@@ -48,7 +49,7 @@ class MLPSpeculatorLayerNorm(nn.Module):
class
MLPSpeculator
(
nn
.
Module
):
class
MLPSpeculator
(
nn
.
Module
):
def
__init__
(
self
,
config
,
**
kwargs
)
->
None
:
def
__init__
(
self
,
config
:
MLPSpeculatorConfig
,
**
kwargs
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
n_predict
=
config
.
n_predict
self
.
n_predict
=
config
.
n_predict
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
...
@@ -56,8 +57,7 @@ class MLPSpeculator(nn.Module):
...
@@ -56,8 +57,7 @@ class MLPSpeculator(nn.Module):
self
.
inner_dim
=
config
.
inner_dim
if
config
.
inner_dim
!=
0
\
self
.
inner_dim
=
config
.
inner_dim
if
config
.
inner_dim
!=
0
\
else
config
.
emb_dim
else
config
.
emb_dim
self
.
max_speculative_tokens
=
getattr
(
config
,
"max_speculative_tokens"
,
self
.
max_speculative_tokens
=
config
.
num_lookahead_tokens
self
.
n_predict
)
self
.
emb
=
nn
.
ModuleList
([
self
.
emb
=
nn
.
ModuleList
([
VocabParallelEmbedding
(
config
.
vocab_size
,
VocabParallelEmbedding
(
config
.
vocab_size
,
...
@@ -137,7 +137,8 @@ class MLPSpeculator(nn.Module):
...
@@ -137,7 +137,8 @@ class MLPSpeculator(nn.Module):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
param
=
params_dict
[
name
.
replace
(
"speculator."
,
""
)]
param
=
params_dict
.
get
(
name
.
replace
(
"speculator."
,
""
))
weight_loader
=
getattr
(
param
,
"weight_loader"
,
if
param
is
not
None
:
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
(
param
,
loaded_weight
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/transformers_utils/configs/mlp_speculator.py
View file @
691e29ec
...
@@ -35,6 +35,7 @@ class MLPSpeculatorConfig(PretrainedConfig):
...
@@ -35,6 +35,7 @@ class MLPSpeculatorConfig(PretrainedConfig):
candidate tree.
candidate tree.
For each candidate branch in the tree, head n produces topk[n]
For each candidate branch in the tree, head n produces topk[n]
additional sub-branches.
additional sub-branches.
NOTE: This parameter is currently unused.
n_candidates: int
n_candidates: int
number of child candidates to create per sequence
number of child candidates to create per sequence
"""
"""
...
@@ -47,4 +48,6 @@ class MLPSpeculatorConfig(PretrainedConfig):
...
@@ -47,4 +48,6 @@ class MLPSpeculatorConfig(PretrainedConfig):
self
.
n_predict
=
n_predict
self
.
n_predict
=
n_predict
self
.
top_k_tokens_per_head
=
top_k_tokens_per_head
self
.
top_k_tokens_per_head
=
top_k_tokens_per_head
self
.
n_candidates
=
n_candidates
self
.
n_candidates
=
n_candidates
self
.
num_lookahead_tokens
=
n_predict
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment