Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1fe1a4b6
Commit
1fe1a4b6
authored
Dec 06, 2024
by
zhuwenwen
Browse files
Merge branch 'v0.6.2-dev_wm' into 'v0.6.2-dev'
[feat]并行解码支持多卡推理 See merge request dcutoolkit/deeplearing/vllm!48
parents
a1592b87
4a4e3601
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
165 additions
and
67 deletions
+165
-67
vllm/model_executor/models/mlp_speculator.py
vllm/model_executor/models/mlp_speculator.py
+50
-42
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+1
-1
vllm/spec_decode/mlp_speculator_worker.py
vllm/spec_decode/mlp_speculator_worker.py
+91
-20
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+3
-0
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+20
-4
No files found.
vllm/model_executor/models/mlp_speculator.py
View file @
1fe1a4b6
import
os
import
math
from
typing
import
Iterable
,
List
,
Tuple
from
typing
import
Iterable
,
List
,
Tuple
,
Optional
import
torch
import
torch.nn
as
nn
...
...
@@ -10,9 +10,11 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.transformers_utils.configs
import
MLPSpeculatorConfig
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed
import
tensor_model_parallel_all_gather
,
tensor_model_parallel_gather
SQRT2
=
2
**
0.5
...
...
@@ -95,8 +97,16 @@ class MLPSpeculator(nn.Module):
# the initial projection from the base model may
# have a different size, so that stays separate.
proj_first
=
nn
.
Linear
(
self
.
emb_dim
,
self
.
inner_dim
,
bias
=
False
)
proj_tied
=
nn
.
Linear
(
self
.
inner_dim
,
self
.
inner_dim
,
bias
=
False
)
# proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
# proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
proj_first
=
ColumnParallelLinear
(
input_size
=
self
.
emb_dim
,
output_size
=
self
.
inner_dim
,
bias
=
False
,
gather_output
=
True
)
proj_tied
=
ColumnParallelLinear
(
input_size
=
self
.
inner_dim
,
output_size
=
self
.
inner_dim
,
bias
=
False
,
gather_output
=
True
)
self
.
proj
=
nn
.
ModuleList
([
proj_first
]
+
[
proj_tied
]
*
(
self
.
max_speculative_tokens
-
1
))
...
...
@@ -116,9 +126,10 @@ class MLPSpeculator(nn.Module):
])
self
.
proj
=
nn
.
ModuleList
([
nn
.
Linear
((
self
.
emb_dim
if
i
==
0
else
self
.
inner_dim
),
self
.
inner_dim
,
bias
=
False
)
ColumnParallelLinear
(
input_size
=
(
self
.
emb_dim
if
i
==
0
else
self
.
inner_dim
),
output_size
=
self
.
inner_dim
,
bias
=
False
,
gather_output
=
True
)
for
i
in
range
(
self
.
max_speculative_tokens
)
])
...
...
@@ -150,28 +161,19 @@ class MLPSpeculator(nn.Module):
previous_hidden_states
:
torch
.
Tensor
,
num_predict_tokens
:
int
,
sampling_metadata
:
SamplingMetadata
,
)
->
List
[
SamplerOutput
]:
head_index
:
int
)
->
Tuple
[
Optional
[
SamplerOutput
],
Optional
[
torch
.
Tensor
]]:
if
num_predict_tokens
>
self
.
max_speculative_tokens
:
raise
ValueError
(
f
"Max speculative tokens for model is "
f
"
{
self
.
max_speculative_tokens
}
, but "
f
"
{
num_predict_tokens
}
were requested"
)
# b x 1 x d
previous_hidden_states
=
previous_hidden_states
.
unsqueeze
(
1
)
if
self
.
scale_input
:
previous_hidden_states
=
self
.
ln0
(
previous_hidden_states
)
/
SQRT2
# b x 1
last_tokens
=
input_ids
.
unsqueeze
(
1
)
next_tokens
=
[]
for
head_index
in
range
(
num_predict_tokens
):
# Project and predict
z
=
self
.
emb
[
head_index
](
last_token
s
)
# b k d
states
=
self
.
proj
[
head_index
](
previous_hidden_states
)
z
=
self
.
emb
[
head_index
](
input_id
s
)
# b k d
states
,
_
=
self
.
proj
[
head_index
](
previous_hidden_states
)
# Weighted add of state_weight*state and emb_weight*z
# Let subsequent LN take care of denominator
...
...
@@ -183,14 +185,19 @@ class MLPSpeculator(nn.Module):
# TODO: not yet supporting top_k_tokens_per_head
states
=
states
.
flatten
(
0
,
1
)
# sampling_metadata is not None indicates that driver card is running
if
sampling_metadata
is
not
None
:
logits
=
self
.
logits_processor
(
self
.
head
[
head_index
],
states
,
sampling_metadata
)
output
=
self
.
sampler
(
logits
,
sampling_metadata
)
last_tokens
=
output
.
sampled_token_ids
next_tokens
.
append
(
output
)
return
next_tokens
return
output
,
previous_hidden_states
else
:
logits
=
self
.
head
[
head_index
].
linear_method
.
apply
(
self
.
head
[
head_index
],
states
,
bias
=
None
)
logits
=
tensor_model_parallel_gather
(
logits
)
return
None
,
None
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
())
...
...
@@ -201,7 +208,8 @@ class MLPSpeculator(nn.Module):
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
use_llama_nn
and
os
.
environ
[
'LM_NN'
]
==
'1'
and
"head"
in
name
:
if
self
.
use_llama_nn
:
if
(
os
.
environ
[
'LM_NN'
]
==
'1'
and
"head"
in
name
)
or
"proj"
in
name
:
_weight
=
torch
.
zeros_like
(
param
.
data
)
ori_shape
=
_weight
.
shape
...
...
vllm/spec_decode/medusa_worker.py
View file @
1fe1a4b6
...
...
@@ -144,7 +144,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
medusa_buffers
=
self
.
medusa_buffers
)
# create tree attn masks
if
self
.
medusa_buffers
is
not
None
:
if
self
.
is_driver_worker
and
self
.
medusa_buffers
is
not
None
:
seq_lens
=
tensor_dict
[
"seq_lens"
]
max_context_len
=
max
(
seq_lens
)
for
sampler_output
,
seq_len
in
zip
(
model_outputs
,
seq_lens
):
...
...
vllm/spec_decode/mlp_speculator_worker.py
View file @
1fe1a4b6
from
typing
import
List
,
Optional
,
Set
,
Tuple
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Dict
import
torch
...
...
@@ -7,6 +7,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.distributed
import
broadcast_tensor_dict
class
MLPSpeculatorWorker
(
NonLLMProposerWorkerBase
,
MultiStepWorker
):
...
...
@@ -15,6 +16,58 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
Not currently compatible with LoRA or chunked prefill.
"""
def
_get_driver_input_and_broadcast
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
index
:
int
,
last_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
previous_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
sampling_metadata
:
Optional
[
SamplingMetadata
]
=
None
)
->
Dict
[
str
,
torch
.
Tensor
]:
if
sampling_metadata
is
None
and
execute_model_req
is
not
None
:
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
(
input_tokens
,
seq_lens
,
query_lens
)
=
self
.
_prepare_input_tensors
(
seq_group_metadata_list
)
# b x 1
last_tokens
=
input_tokens
.
unsqueeze
(
1
)
generators
=
self
.
model_runner
.
get_generators
(
execute_model_req
.
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
self
.
model_runner
.
pin_memory
,
generators
)
previous_hidden_states
=
execute_model_req
.
previous_hidden_states
.
hidden_states
# b x 1 x d
previous_hidden_states
=
previous_hidden_states
.
unsqueeze
(
1
)
tensor_dict
=
{
"input_tokens"
:
last_tokens
,
"previous_hidden_states"
:
previous_hidden_states
,
"sample_len"
:
sample_len
,
"head_index"
:
index
}
if
self
.
do_metadata_broadcast
:
broadcast_tensor_dict
(
tensor_dict
,
src
=
0
)
return
tensor_dict
,
sampling_metadata
def
_get_worker_input_from_broadcast
(
self
)
->
Optional
[
Dict
[
str
,
torch
.
Tensor
]]:
""" Get the worker input from the broadcasted tensor dict. """
assert
self
.
do_metadata_broadcast
assert
not
self
.
is_driver_worker
broadcast_data
=
broadcast_tensor_dict
(
src
=
0
)
return
broadcast_data
@
torch
.
inference_mode
()
def
sampler_output
(
self
,
...
...
@@ -33,24 +86,42 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
"""
self
.
_raise_if_unsupported
(
execute_model_req
)
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
(
input_tokens
,
seq_lens
,
query_lens
)
=
self
.
_prepare_input_tensors
(
seq_group_metadata_list
)
generators
=
self
.
model_runner
.
get_generators
(
execute_model_req
.
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
self
.
model_runner
.
pin_memory
,
generators
)
model_outputs
=
self
.
model_runner
.
model
.
generate_proposals
(
input_ids
=
input_tokens
,
previous_hidden_states
=
execute_model_req
.
previous_hidden_states
.
hidden_states
,
num_predict_tokens
=
sample_len
,
sampling_metadata
=
sampling_metadata
)
model_outputs
=
[]
last_tokens
=
None
previous_hidden_states
=
None
sampling_metadata
=
None
for
index
in
range
(
sample_len
):
if
self
.
is_driver_worker
:
tensor_dict
,
sampling_metadata
=
self
.
_get_driver_input_and_broadcast
(
execute_model_req
,
sample_len
,
index
,
last_tokens
,
previous_hidden_states
,
sampling_metadata
)
assert
sampling_metadata
is
not
None
output
,
previous_hidden_states
=
self
.
model_runner
.
model
.
generate_proposals
(
input_ids
=
tensor_dict
[
"input_tokens"
],
previous_hidden_states
=
tensor_dict
[
"previous_hidden_states"
],
num_predict_tokens
=
tensor_dict
[
"sample_len"
],
sampling_metadata
=
sampling_metadata
,
head_index
=
index
)
last_tokens
=
output
.
sampled_token_ids
model_outputs
.
append
(
output
)
else
:
tensor_dict
=
self
.
_get_worker_input_from_broadcast
()
if
tensor_dict
is
None
:
raise
ValueError
(
"Can not get inputs of mlp_speculator worker!!!"
)
self
.
model_runner
.
model
.
generate_proposals
(
input_ids
=
tensor_dict
[
"input_tokens"
],
previous_hidden_states
=
tensor_dict
[
"previous_hidden_states"
],
num_predict_tokens
=
tensor_dict
[
"sample_len"
],
sampling_metadata
=
None
,
head_index
=
tensor_dict
[
"head_index"
])
if
self
.
is_driver_worker
:
assert
len
(
model_outputs
)
==
sample_len
return
model_outputs
,
True
...
...
vllm/spec_decode/multi_step_worker.py
View file @
1fe1a4b6
...
...
@@ -350,6 +350,9 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
"""MultiStepWorker does not yet implement support for cache swap
operations or beam search.
"""
if
execute_model_req
is
None
:
return
None
if
any
([
execute_model_req
.
blocks_to_swap_in
,
execute_model_req
.
blocks_to_swap_out
,
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
1fe1a4b6
...
...
@@ -38,6 +38,7 @@ from vllm.worker.worker import Worker
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
,
WorkerBase
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
logger
=
init_logger
(
__name__
)
...
...
@@ -134,6 +135,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
ngram_prompt_lookup_min
=
(
draft_worker_kwargs
.
pop
(
"ngram_prompt_lookup_min"
))
if
ngram_prompt_lookup_max
>
0
:
draft_parallel_config
:
ParallelConfig
=
draft_worker_kwargs
[
'parallel_config'
]
assert
draft_parallel_config
.
tensor_parallel_size
==
1
proposer_worker
=
NGramWorker
(
**
draft_worker_kwargs
)
proposer_worker
.
set_ngram_window_size
(
ngram_prompt_lookup_min
,
ngram_prompt_lookup_max
)
...
...
@@ -608,7 +612,22 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
scorer_worker
.
execute_model
()
if
not
data
[
"disable_all_speculation"
]:
if
not
self
.
tree_style_spec_decoding
:
# if not self.tree_style_spec_decoding:
# # Even if num_lookahead_slots is zero, we want to run the
# # proposer model as it may have KV.
# #
# # We run the proposer once per lookahead slot. In the future we
# # should delegate how many times it runs to the proposer.
# for _ in range(max(num_lookahead_slots, 1)):
# self.proposer_worker.execute_model()
# else:
# if not data["no_spec"]:
# self.proposer_worker.sampler_output(None, None, None)
if
issubclass
(
type
(
self
.
proposer_worker
),
NonLLMProposerWorkerBase
):
if
not
data
[
"no_spec"
]:
self
.
proposer_worker
.
sampler_output
(
None
,
num_lookahead_slots
,
None
)
else
:
# Even if num_lookahead_slots is zero, we want to run the
# proposer model as it may have KV.
#
...
...
@@ -616,9 +635,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# should delegate how many times it runs to the proposer.
for
_
in
range
(
max
(
num_lookahead_slots
,
1
)):
self
.
proposer_worker
.
execute_model
()
else
:
if
not
data
[
"no_spec"
]:
self
.
proposer_worker
.
sampler_output
(
None
,
None
,
None
)
if
not
data
[
"no_spec"
]:
self
.
scorer_worker
.
execute_model
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment