Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
57b7be0e
Unverified
Commit
57b7be0e
authored
Aug 08, 2024
by
William Lin
Committed by
GitHub
Aug 09, 2024
Browse files
[Speculative decoding] [Multi-Step] decouple should_modify_greedy_probs_inplace (#6971)
parent
99b4cf5f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
52 additions
and
3 deletions
+52
-3
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+26
-1
vllm/lora/layers.py
vllm/lora/layers.py
+4
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+2
-2
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+3
-0
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+4
-0
vllm/spec_decode/proposer_worker_base.py
vllm/spec_decode/proposer_worker_base.py
+4
-0
vllm/spec_decode/smaller_tp_proposer_worker.py
vllm/spec_decode/smaller_tp_proposer_worker.py
+6
-0
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+3
-0
No files found.
tests/samplers/test_sampler.py
View file @
57b7be0e
import
itertools
import
itertools
import
random
import
random
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
unittest.mock
import
patch
from
unittest.mock
import
Mock
,
patch
import
pytest
import
pytest
import
torch
import
torch
...
@@ -703,3 +703,28 @@ def test_sampler_repetition_penalty_mixed(device: str):
...
@@ -703,3 +703,28 @@ def test_sampler_repetition_penalty_mixed(device: str):
assert
tokens1
[
0
]
==
tokens2
[
1
]
assert
tokens1
[
0
]
==
tokens2
[
1
]
assert
tokens1
[
1
]
==
tokens2
[
0
]
assert
tokens1
[
1
]
==
tokens2
[
0
]
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_include_gpu_probs_tensor
(
device
:
str
):
set_random_seed
(
42
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
_
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
sampler
.
include_gpu_probs_tensor
=
True
sampler
.
should_modify_greedy_probs_inplace
=
False
sampling_params
=
SamplingParams
(
temperature
=
0
)
mock_inplace
=
Mock
()
with
patch
(
"vllm.model_executor.layers.sampler._modify_greedy_probs_inplace"
,
mock_inplace
):
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
mock_inplace
.
assert_not_called
()
assert
sampler_output
.
sampled_token_probs
is
not
None
assert
sampler_output
.
logprobs
is
not
None
assert
sampler_output
.
sampled_token_ids
is
not
None
vllm/lora/layers.py
View file @
57b7be0e
...
@@ -1067,6 +1067,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
...
@@ -1067,6 +1067,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def
include_gpu_probs_tensor
(
self
):
def
include_gpu_probs_tensor
(
self
):
return
self
.
base_layer
.
include_gpu_probs_tensor
return
self
.
base_layer
.
include_gpu_probs_tensor
@
property
def
should_modify_greedy_probs_inplace
(
self
):
return
self
.
base_layer
.
should_modify_greedy_probs_inplace
def
create_lora_weights
(
def
create_lora_weights
(
self
,
self
,
max_loras
:
int
,
max_loras
:
int
,
...
...
vllm/model_executor/layers/sampler.py
View file @
57b7be0e
...
@@ -51,6 +51,7 @@ class Sampler(nn.Module):
...
@@ -51,6 +51,7 @@ class Sampler(nn.Module):
# containing the sampled token ids and probabilities. This is used by
# containing the sampled token ids and probabilities. This is used by
# speculative decoding.
# speculative decoding.
self
.
include_gpu_probs_tensor
=
False
self
.
include_gpu_probs_tensor
=
False
self
.
should_modify_greedy_probs_inplace
=
False
def
_init_sampling_tensors
(
def
_init_sampling_tensors
(
self
,
self
,
...
@@ -177,8 +178,7 @@ class Sampler(nn.Module):
...
@@ -177,8 +178,7 @@ class Sampler(nn.Module):
This is used by speculative decoding, which requires that the sampling
This is used by speculative decoding, which requires that the sampling
method be encoded into the probability distribution.
method be encoded into the probability distribution.
"""
"""
# Modify greedy probs if include_gpu_probs_tensor is set.
return
self
.
should_modify_greedy_probs_inplace
return
self
.
include_gpu_probs_tensor
def
_get_bin_counts_and_mask
(
def
_get_bin_counts_and_mask
(
...
...
vllm/spec_decode/medusa_worker.py
View file @
57b7be0e
...
@@ -35,6 +35,9 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
...
@@ -35,6 +35,9 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
def
set_include_gpu_probs_tensor
(
self
):
def
set_include_gpu_probs_tensor
(
self
):
pass
pass
def
set_should_modify_greedy_probs_inplace
(
self
):
pass
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
sampler_output
(
def
sampler_output
(
self
,
self
,
...
...
vllm/spec_decode/multi_step_worker.py
View file @
57b7be0e
...
@@ -46,6 +46,10 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
...
@@ -46,6 +46,10 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
# Need include_gpu_probs_tensor for MultiStepWorker
# Need include_gpu_probs_tensor for MultiStepWorker
self
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
=
True
self
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
=
True
def
set_should_modify_greedy_probs_inplace
(
self
)
->
None
:
self
.
model_runner
.
model
.
sampler
.
should_modify_greedy_probs_inplace
=
(
True
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
sampler_output
(
def
sampler_output
(
self
,
self
,
...
...
vllm/spec_decode/proposer_worker_base.py
View file @
57b7be0e
...
@@ -28,6 +28,10 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
...
@@ -28,6 +28,10 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
"""Implementation optional"""
"""Implementation optional"""
pass
pass
def
set_should_modify_greedy_probs_inplace
(
self
)
->
None
:
"""Implementation optional"""
pass
class
NonLLMProposerWorkerBase
(
ProposerWorkerBase
,
ABC
):
class
NonLLMProposerWorkerBase
(
ProposerWorkerBase
,
ABC
):
"""Proposer worker which does not use a model with kvcache"""
"""Proposer worker which does not use a model with kvcache"""
...
...
vllm/spec_decode/smaller_tp_proposer_worker.py
View file @
57b7be0e
...
@@ -83,6 +83,12 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
...
@@ -83,6 +83,12 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
# Need include_gpu_probs_tensor for multi_step_worker
# Need include_gpu_probs_tensor for multi_step_worker
self
.
_worker
.
set_include_gpu_probs_tensor
()
self
.
_worker
.
set_include_gpu_probs_tensor
()
def
set_should_modify_greedy_probs_inplace
(
self
)
->
None
:
if
self
.
_is_dummy
:
return
self
.
_worker
.
set_should_modify_greedy_probs_inplace
()
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
if
self
.
_is_dummy
:
if
self
.
_is_dummy
:
return
return
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
57b7be0e
...
@@ -295,7 +295,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -295,7 +295,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
"""
(
self
.
scorer_worker
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
(
self
.
scorer_worker
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
)
=
True
)
=
True
(
self
.
scorer_worker
.
model_runner
.
model
.
sampler
.
should_modify_greedy_probs_inplace
)
=
True
self
.
proposer_worker
.
set_include_gpu_probs_tensor
()
self
.
proposer_worker
.
set_include_gpu_probs_tensor
()
self
.
proposer_worker
.
set_should_modify_greedy_probs_inplace
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of cache blocks to use.
"""Determine the number of cache blocks to use.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment