Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
57b7be0e
"vscode:/vscode.git/clone" did not exist on "73df49ef3a220c79abfffc36bdfb4e8dee61226b"
Unverified
Commit
57b7be0e
authored
Aug 08, 2024
by
William Lin
Committed by
GitHub
Aug 09, 2024
Browse files
[Speculative decoding] [Multi-Step] decouple should_modify_greedy_probs_inplace (#6971)
parent
99b4cf5f
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
52 additions
and
3 deletions
+52
-3
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+26
-1
vllm/lora/layers.py
vllm/lora/layers.py
+4
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+2
-2
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+3
-0
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+4
-0
vllm/spec_decode/proposer_worker_base.py
vllm/spec_decode/proposer_worker_base.py
+4
-0
vllm/spec_decode/smaller_tp_proposer_worker.py
vllm/spec_decode/smaller_tp_proposer_worker.py
+6
-0
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+3
-0
No files found.
tests/samplers/test_sampler.py
View file @
57b7be0e
import
itertools
import
random
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
unittest.mock
import
patch
from
unittest.mock
import
Mock
,
patch
import
pytest
import
torch
...
...
@@ -703,3 +703,28 @@ def test_sampler_repetition_penalty_mixed(device: str):
assert
tokens1
[
0
]
==
tokens2
[
1
]
assert
tokens1
[
1
]
==
tokens2
[
0
]
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_include_gpu_probs_tensor
(
device
:
str
):
set_random_seed
(
42
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
_
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
sampler
.
include_gpu_probs_tensor
=
True
sampler
.
should_modify_greedy_probs_inplace
=
False
sampling_params
=
SamplingParams
(
temperature
=
0
)
mock_inplace
=
Mock
()
with
patch
(
"vllm.model_executor.layers.sampler._modify_greedy_probs_inplace"
,
mock_inplace
):
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
mock_inplace
.
assert_not_called
()
assert
sampler_output
.
sampled_token_probs
is
not
None
assert
sampler_output
.
logprobs
is
not
None
assert
sampler_output
.
sampled_token_ids
is
not
None
vllm/lora/layers.py
View file @
57b7be0e
...
...
@@ -1067,6 +1067,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def
include_gpu_probs_tensor
(
self
):
return
self
.
base_layer
.
include_gpu_probs_tensor
@
property
def
should_modify_greedy_probs_inplace
(
self
):
return
self
.
base_layer
.
should_modify_greedy_probs_inplace
def
create_lora_weights
(
self
,
max_loras
:
int
,
...
...
vllm/model_executor/layers/sampler.py
View file @
57b7be0e
...
...
@@ -51,6 +51,7 @@ class Sampler(nn.Module):
# containing the sampled token ids and probabilities. This is used by
# speculative decoding.
self
.
include_gpu_probs_tensor
=
False
self
.
should_modify_greedy_probs_inplace
=
False
def
_init_sampling_tensors
(
self
,
...
...
@@ -177,8 +178,7 @@ class Sampler(nn.Module):
This is used by speculative decoding, which requires that the sampling
method be encoded into the probability distribution.
"""
# Modify greedy probs if include_gpu_probs_tensor is set.
return
self
.
include_gpu_probs_tensor
return
self
.
should_modify_greedy_probs_inplace
def
_get_bin_counts_and_mask
(
...
...
vllm/spec_decode/medusa_worker.py
View file @
57b7be0e
...
...
@@ -35,6 +35,9 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
def
set_include_gpu_probs_tensor
(
self
):
pass
def
set_should_modify_greedy_probs_inplace
(
self
):
pass
@
torch
.
inference_mode
()
def
sampler_output
(
self
,
...
...
vllm/spec_decode/multi_step_worker.py
View file @
57b7be0e
...
...
@@ -46,6 +46,10 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
# Need include_gpu_probs_tensor for MultiStepWorker
self
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
=
True
def
set_should_modify_greedy_probs_inplace
(
self
)
->
None
:
self
.
model_runner
.
model
.
sampler
.
should_modify_greedy_probs_inplace
=
(
True
)
@
torch
.
inference_mode
()
def
sampler_output
(
self
,
...
...
vllm/spec_decode/proposer_worker_base.py
View file @
57b7be0e
...
...
@@ -28,6 +28,10 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
"""Implementation optional"""
pass
def
set_should_modify_greedy_probs_inplace
(
self
)
->
None
:
"""Implementation optional"""
pass
class
NonLLMProposerWorkerBase
(
ProposerWorkerBase
,
ABC
):
"""Proposer worker which does not use a model with kvcache"""
...
...
vllm/spec_decode/smaller_tp_proposer_worker.py
View file @
57b7be0e
...
...
@@ -83,6 +83,12 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
# Need include_gpu_probs_tensor for multi_step_worker
self
.
_worker
.
set_include_gpu_probs_tensor
()
def
set_should_modify_greedy_probs_inplace
(
self
)
->
None
:
if
self
.
_is_dummy
:
return
self
.
_worker
.
set_should_modify_greedy_probs_inplace
()
def
load_model
(
self
)
->
None
:
if
self
.
_is_dummy
:
return
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
57b7be0e
...
...
@@ -295,7 +295,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
(
self
.
scorer_worker
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
)
=
True
(
self
.
scorer_worker
.
model_runner
.
model
.
sampler
.
should_modify_greedy_probs_inplace
)
=
True
self
.
proposer_worker
.
set_include_gpu_probs_tensor
()
self
.
proposer_worker
.
set_should_modify_greedy_probs_inplace
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of cache blocks to use.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment