Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2416b26e
Unverified
Commit
2416b26e
authored
Jul 10, 2024
by
Abhinav Goyal
Committed by
GitHub
Jul 09, 2024
Browse files
[Speculative Decoding] Medusa Implementation with Top-1 proposer (#4978)
parent
d3a24513
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
587 additions
and
4 deletions
+587
-4
tests/spec_decode/e2e/test_medusa_correctness.py
tests/spec_decode/e2e/test_medusa_correctness.py
+226
-0
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+1
-0
vllm/model_executor/models/medusa.py
vllm/model_executor/models/medusa.py
+159
-0
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+127
-0
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+5
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+4
-2
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/transformers_utils/configs/medusa.py
vllm/transformers_utils/configs/medusa.py
+60
-0
vllm/worker/worker.py
vllm/worker/worker.py
+3
-2
No files found.
tests/spec_decode/e2e/test_medusa_correctness.py
0 → 100644
View file @
2416b26e
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, Medusa would not break the
correctess for the target model outputs.
"""
import
pytest
from
.conftest
import
run_greedy_equality_correctness_test
# main model
# lmsys/vicuna-7b-v1.3 was to be used but it's causing
# OOM in CI pipeline, so using a smaller model.
MAIN_MODEL
=
"JackFram/llama-68m"
# speculative model
SPEC_MODEL
=
"abhigoyal/vllm-medusa-llama-68m-random"
# max. number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS
=
5
# precision
PRECISION
=
"float32"
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mlp_e2e_greedy_correctness
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality with different batch size."""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"block_size"
:
8
,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override"
:
2
+
256
//
8
,
"max_model_len"
:
(
2
+
256
//
8
)
*
8
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use small output len for fast test.
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mlp_e2e_greedy_correctness_with_preemption
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
k
,
}
# Try a range of num. speculative tokens
for
k
in
range
(
1
,
1
+
MAX_SPEC_TOKENS
)
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mlp_different_k
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_disable_by_batch_size"
:
4
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mlp_disable_queue
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
if
__name__
==
"__main__"
:
import
pytest
pytest
.
main
([
__file__
])
vllm/model_executor/models/__init__.py
View file @
2416b26e
...
@@ -64,6 +64,7 @@ _GENERATION_MODELS = {
...
@@ -64,6 +64,7 @@ _GENERATION_MODELS = {
"ArcticForCausalLM"
:
(
"arctic"
,
"ArcticForCausalLM"
),
"ArcticForCausalLM"
:
(
"arctic"
,
"ArcticForCausalLM"
),
"XverseForCausalLM"
:
(
"xverse"
,
"XverseForCausalLM"
),
"XverseForCausalLM"
:
(
"xverse"
,
"XverseForCausalLM"
),
"Phi3SmallForCausalLM"
:
(
"phi3_small"
,
"Phi3SmallForCausalLM"
),
"Phi3SmallForCausalLM"
:
(
"phi3_small"
,
"Phi3SmallForCausalLM"
),
"MedusaModel"
:
(
"medusa"
,
"Medusa"
),
"MLPSpeculatorPreTrainedModel"
:
(
"mlp_speculator"
,
"MLPSpeculator"
),
"MLPSpeculatorPreTrainedModel"
:
(
"mlp_speculator"
,
"MLPSpeculator"
),
"JambaForCausalLM"
:
(
"jamba"
,
"JambaForCausalLM"
)
"JambaForCausalLM"
:
(
"jamba"
,
"JambaForCausalLM"
)
}
}
...
...
vllm/model_executor/models/medusa.py
0 → 100644
View file @
2416b26e
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs.medusa
import
MedusaConfig
class
ResidualBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_layers
:
int
)
->
None
:
super
().
__init__
()
self
.
layers
=
nn
.
ModuleList
([
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
False
)
for
_
in
range
(
num_layers
)
])
self
.
act
=
nn
.
SiLU
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
layer
in
self
.
layers
:
x
=
x
+
self
.
act
(
layer
(
x
))
return
x
class
Medusa
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MedusaConfig
,
**
_
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
blocks
=
nn
.
ModuleList
([
ResidualBlock
(
hidden_size
=
self
.
config
.
hidden_size
,
num_layers
=
self
.
config
.
num_hidden_layers
)
for
_
in
range
(
self
.
config
.
num_heads
)
])
self
.
orig_vocab_size
=
config
.
vocab_size
self
.
truncated_vocab_size
=
config
.
truncated_vocab_size
self
.
unpadded_vocab_size
=
self
.
truncated_vocab_size
self
.
lm_heads
=
nn
.
ModuleList
([
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
self
.
truncated_vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
)
for
_
in
range
(
self
.
config
.
num_heads
)
])
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
truncated_vocab_size
,
logit_scale
)
self
.
token_map
=
None
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
List
[
torch
.
Tensor
]:
return
[
block
(
hidden_states
)
for
block
in
self
.
blocks
]
def
compute_logits
(
self
,
hidden_states
:
List
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
)
->
List
[
torch
.
Tensor
]:
logits
=
[]
for
hs
,
lm_head
in
zip
(
hidden_states
,
self
.
lm_heads
):
_logits
=
self
.
logits_processor
(
lm_head
,
hs
,
sampling_metadata
)
if
self
.
token_map
is
None
:
logits
.
append
(
_logits
)
else
:
logits
.
append
(
-
torch
.
inf
*
torch
.
ones
(
size
=
(
*
_logits
.
shape
[:
-
1
],
self
.
orig_vocab_size
),
device
=
_logits
.
device
,
dtype
=
_logits
.
dtype
))
logits
[
-
1
][...,
self
.
token_map
]
=
_logits
return
logits
def
sample
(
self
,
logits
:
List
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
List
[
SamplerOutput
]:
logits
=
torch
.
stack
(
logits
,
dim
=
0
).
float
()
logprobs
=
torch
.
log_softmax
(
logits
,
dim
=-
1
)
token_ids
=
logits
.
argmax
(
-
1
)
# support only top-1 for now
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
token_id_list
=
[]
token_prob_list
=
[]
token_logprob_list
=
[]
for
idx
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
token_id_list
.
append
(
token_ids
[:,
seq_group
.
sample_indices
])
token_prob_list
.
append
(
probs
[:,
seq_group
.
sample_indices
])
token_logprob_list
.
append
(
logprobs
[:,
seq_group
.
sample_indices
])
outputs
:
List
[
Optional
[
SamplerOutput
]]
=
[]
for
idx
in
range
(
len
(
sampling_metadata
.
seq_groups
)):
outputs
.
append
(
SamplerOutput
(
outputs
=
None
,
sampled_token_probs
=
token_prob_list
[
idx
].
squeeze
(
1
),
logprobs
=
token_logprob_list
[
idx
].
squeeze
(
1
),
sampled_token_ids
=
token_id_list
[
idx
].
squeeze
(
1
),
))
return
outputs
def
generate_proposals
(
self
,
previous_hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
List
[
SamplerOutput
]:
return
self
.
sample
(
logits
=
self
.
compute_logits
(
hidden_states
=
self
.
forward
(
previous_hidden_states
),
sampling_metadata
=
sampling_metadata
,
),
sampling_metadata
=
sampling_metadata
,
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
())
weights_map
=
{}
for
name
,
loaded_weight
in
weights
:
name
=
name
.
replace
(
"medusa_heads."
,
""
)
if
name
==
"token_map"
:
if
self
.
truncated_vocab_size
<
self
.
orig_vocab_size
:
self
.
token_map
=
nn
.
Parameter
(
loaded_weight
,
requires_grad
=
False
)
elif
name
in
params_dict
:
weights_map
[
name
]
=
loaded_weight
for
name
,
loaded_weight
in
weights_map
.
items
():
if
"lm_head"
in
name
and
self
.
token_map
is
not
None
and
\
loaded_weight
.
shape
[
0
]
>
self
.
token_map
.
shape
[
0
]:
loaded_weight
=
loaded_weight
[
self
.
token_map
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
token_map
is
not
None
:
self
.
token_map
.
to
(
device
=
self
.
lm_heads
[
0
].
weight
.
device
)
assert
(
self
.
truncated_vocab_size
==
self
.
orig_vocab_size
)
or
(
self
.
token_map
is
not
None
)
vllm/spec_decode/medusa_worker.py
0 → 100644
View file @
2416b26e
import
weakref
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
vllm.model_executor
import
SamplingMetadata
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceGroupMetadata
)
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker
import
Worker
class
MedusaWorker
(
NonLLMProposerWorkerBase
,
Worker
):
"""Worker for Medusa.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
# Lazy initialization list.
self
.
_proposer
:
Top1Proposer
def
init_device
(
self
):
super
().
init_device
()
self
.
_proposer
=
Top1Proposer
(
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
self
.
device
,
self
.
vocab_size
,
max_proposal_len
=
self
.
max_model_len
,
)
def
set_include_gpu_probs_tensor
(
self
):
pass
@
torch
.
inference_mode
()
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
)
->
Tuple
[
List
[
SamplerOutput
],
bool
]:
"""Run the model forward pass to generate sample_len future tokens.
Returns the list of sampler output, one per layer, along with indicator
of whether torch tensor in sampler output need to be transposed in
latter sampler_output_to_torch logic.
For medusa worker, this indicator shall be False.
"""
self
.
_raise_if_unsupported
(
execute_model_req
)
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
seq_lens
,
query_lens
=
self
.
_prepare_input_tensors
(
seq_group_metadata_list
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
self
.
model_runner
.
pin_memory
)
model_outputs
=
self
.
model_runner
.
model
.
generate_proposals
(
previous_hidden_states
=
execute_model_req
.
previous_hidden_states
.
hidden_states
,
sampling_metadata
=
sampling_metadata
)
return
model_outputs
,
False
def
_prepare_input_tensors
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
if
not
seq_group_metadata_list
:
return
[],
[]
seq_lens
:
List
[
int
]
=
[]
query_lens
:
List
[
int
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
is_prompt
=
seq_group_metadata
.
is_prompt
for
seq_data
in
seq_group_metadata
.
seq_data
.
values
():
seq_data_len
=
seq_data
.
get_len
()
if
is_prompt
:
context_len
=
seq_data
.
get_num_computed_tokens
()
seq_len
=
min
(
seq_data_len
,
context_len
+
seq_group_metadata
.
token_chunk_size
)
seq_lens
.
append
(
seq_len
)
query_lens
.
append
(
seq_len
-
context_len
)
else
:
seq_lens
.
append
(
seq_data_len
)
query_lens
.
append
(
1
)
return
seq_lens
,
query_lens
def
get_spec_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
)
->
SpeculativeProposals
:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return
self
.
_proposer
.
get_spec_proposals
(
execute_model_req
)
def
_raise_if_unsupported
(
self
,
execute_model_req
:
ExecuteModelRequest
,
)
->
None
:
"""MedusaWorker does not yet implement support for cache swap
operations or beam search.
"""
if
any
([
execute_model_req
.
blocks_to_swap_in
,
execute_model_req
.
blocks_to_swap_out
,
execute_model_req
.
blocks_to_copy
]):
raise
NotImplementedError
(
"MedusaWorker does not support cache operations"
)
if
any
(
len
(
seq_group_metadata
.
seq_data
.
keys
())
!=
1
for
seq_group_metadata
in
execute_model_req
.
seq_group_metadata_list
):
raise
NotImplementedError
(
"MedusaWorker does not support beam search."
)
vllm/spec_decode/spec_decode_worker.py
View file @
2416b26e
...
@@ -18,6 +18,7 @@ from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
...
@@ -18,6 +18,7 @@ from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.medusa_worker
import
MedusaWorker
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.mlp_speculator_worker
import
MLPSpeculatorWorker
from
vllm.spec_decode.mlp_speculator_worker
import
MLPSpeculatorWorker
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
...
@@ -129,6 +130,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -129,6 +130,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"model_config"
].
hf_config
.
model_type
==
"mlp_speculator"
:
"model_config"
].
hf_config
.
model_type
==
"mlp_speculator"
:
disable_bonus_tokens
=
False
disable_bonus_tokens
=
False
proposer_worker
=
MLPSpeculatorWorker
(
**
draft_worker_kwargs
)
proposer_worker
=
MLPSpeculatorWorker
(
**
draft_worker_kwargs
)
elif
draft_worker_kwargs
[
"model_config"
].
hf_config
.
model_type
==
"medusa"
:
disable_bonus_tokens
=
False
proposer_worker
=
MedusaWorker
(
**
draft_worker_kwargs
)
else
:
else
:
if
draft_tp
==
1
:
if
draft_tp
==
1
:
draft_worker_kwargs
[
draft_worker_kwargs
[
...
...
vllm/transformers_utils/config.py
View file @
2416b26e
...
@@ -6,8 +6,9 @@ from transformers import GenerationConfig, PretrainedConfig
...
@@ -6,8 +6,9 @@ from transformers import GenerationConfig, PretrainedConfig
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.configs
import
(
ChatGLMConfig
,
DbrxConfig
,
from
vllm.transformers_utils.configs
import
(
ChatGLMConfig
,
DbrxConfig
,
JAISConfig
,
MLPSpeculatorConfig
,
JAISConfig
,
MedusaConfig
,
MPTConfig
,
RWConfig
)
MLPSpeculatorConfig
,
MPTConfig
,
RWConfig
)
if
VLLM_USE_MODELSCOPE
:
if
VLLM_USE_MODELSCOPE
:
from
modelscope
import
AutoConfig
from
modelscope
import
AutoConfig
...
@@ -24,6 +25,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
...
@@ -24,6 +25,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"RefinedWebModel"
:
RWConfig
,
# For tiiuae/falcon-7b(-instruct)
"RefinedWebModel"
:
RWConfig
,
# For tiiuae/falcon-7b(-instruct)
"jais"
:
JAISConfig
,
"jais"
:
JAISConfig
,
"mlp_speculator"
:
MLPSpeculatorConfig
,
"mlp_speculator"
:
MLPSpeculatorConfig
,
"medusa"
:
MedusaConfig
,
}
}
for
name
,
cls
in
_CONFIG_REGISTRY
.
items
():
for
name
,
cls
in
_CONFIG_REGISTRY
.
items
():
...
...
vllm/transformers_utils/configs/__init__.py
View file @
2416b26e
...
@@ -5,6 +5,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
...
@@ -5,6 +5,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
# `FalconConfig` class from the official HuggingFace transformers library.
# `FalconConfig` class from the official HuggingFace transformers library.
from
vllm.transformers_utils.configs.falcon
import
RWConfig
from
vllm.transformers_utils.configs.falcon
import
RWConfig
from
vllm.transformers_utils.configs.jais
import
JAISConfig
from
vllm.transformers_utils.configs.jais
import
JAISConfig
from
vllm.transformers_utils.configs.medusa
import
MedusaConfig
from
vllm.transformers_utils.configs.mlp_speculator
import
MLPSpeculatorConfig
from
vllm.transformers_utils.configs.mlp_speculator
import
MLPSpeculatorConfig
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
...
@@ -14,5 +15,6 @@ __all__ = [
...
@@ -14,5 +15,6 @@ __all__ = [
"MPTConfig"
,
"MPTConfig"
,
"RWConfig"
,
"RWConfig"
,
"JAISConfig"
,
"JAISConfig"
,
"MedusaConfig"
,
"MLPSpeculatorConfig"
,
"MLPSpeculatorConfig"
,
]
]
vllm/transformers_utils/configs/medusa.py
0 → 100644
View file @
2416b26e
import
os
from
typing
import
Optional
,
Union
from
transformers
import
PretrainedConfig
class
MedusaConfig
(
PretrainedConfig
):
model_type
=
"medusa"
def
__init__
(
self
,
hidden_size
:
int
=
4096
,
vocab_size
:
int
=
32001
,
num_heads
:
int
=
5
,
num_hidden_layers
:
int
=
1
,
max_paths
:
int
=
64
,
topk
:
int
=
10
,
truncated_vocab_size
:
Optional
[
int
]
=
None
,
**
kwargs
):
self
.
hidden_size
=
hidden_size
self
.
vocab_size
=
vocab_size
self
.
num_heads
=
num_heads
self
.
num_hidden_layers
=
num_hidden_layers
self
.
max_paths
=
max_paths
self
.
topk
=
topk
self
.
max_seq_len
=
int
(
2
**
20
)
self
.
truncated_vocab_size
=
vocab_size
if
truncated_vocab_size
is
None
\
else
truncated_vocab_size
if
"architectures"
not
in
kwargs
:
kwargs
[
"architectures"
]
=
[
"MedusaModel"
]
super
().
__init__
(
**
kwargs
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
**
kwargs
,
)
->
"MedusaConfig"
:
config_dict
,
kwargs
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
,
**
kwargs
)
for
k
in
list
(
config_dict
.
keys
()):
if
'num'
in
k
:
if
'heads'
in
k
:
config_dict
[
"num_heads"
]
=
config_dict
.
pop
(
k
)
elif
'layers'
in
k
:
config_dict
[
"num_hidden_layers"
]
=
config_dict
.
pop
(
k
)
return
cls
.
from_dict
(
config_dict
,
**
kwargs
)
@
property
def
num_attention_heads
(
self
):
return
0
@
property
def
num_lookahead_tokens
(
self
):
return
self
.
num_heads
@
num_lookahead_tokens
.
setter
def
num_lookahead_tokens
(
self
,
num_lookahead_tokens
:
int
):
self
.
num_heads
=
num_lookahead_tokens
vllm/worker/worker.py
View file @
2416b26e
...
@@ -78,8 +78,9 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -78,8 +78,9 @@ class Worker(LocalOrDistributedWorkerBase):
speculative_args
=
{}
if
speculative_config
is
None
\
speculative_args
=
{}
if
speculative_config
is
None
\
or
(
speculative_config
.
draft_model_config
.
model
==
or
(
speculative_config
.
draft_model_config
.
model
==
model_config
.
model
)
\
model_config
.
model
)
\
or
(
speculative_config
.
draft_model_config
.
hf_config
.
model_type
!=
or
(
speculative_config
.
draft_model_config
.
hf_config
.
model_type
"mlp_speculator"
)
else
{
"return_hidden_states"
:
True
}
not
in
[
"medusa"
,
"mlp_speculator"
])
\
else
{
"return_hidden_states"
:
True
}
ModelRunnerClass
:
Type
[
GPUModelRunnerBase
]
=
ModelRunner
ModelRunnerClass
:
Type
[
GPUModelRunnerBase
]
=
ModelRunner
if
model_runner_cls
is
not
None
:
if
model_runner_cls
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment