Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
82a1b1a8
Unverified
Commit
82a1b1a8
authored
Aug 05, 2024
by
Cade Daniel
Committed by
GitHub
Aug 05, 2024
Browse files
[Speculative decoding] Add periodic log with time spent in proposal/scoring/verification (#6963)
parent
c0d8f163
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
125 additions
and
35 deletions
+125
-35
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+48
-20
vllm/config.py
vllm/config.py
+7
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-0
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+54
-14
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+15
-0
No files found.
tests/spec_decode/test_spec_decode_worker.py
View file @
82a1b1a8
...
@@ -34,8 +34,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
...
@@ -34,8 +34,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
target_worker
=
mock_worker
()
target_worker
=
mock_worker
()
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
draft_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
)
exception_secret
=
'artificial stop'
exception_secret
=
'artificial stop'
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
...
@@ -74,8 +77,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
...
@@ -74,8 +77,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
set_random_seed
(
1
)
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
draft_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
)
worker
.
init_device
()
worker
.
init_device
()
vocab_size
=
32_000
vocab_size
=
32_000
...
@@ -159,8 +165,11 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
...
@@ -159,8 +165,11 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
set_random_seed
(
1
)
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
spec_decode_sampler
,
worker
=
SpecDecodeWorker
(
draft_worker
,
metrics_collector
)
target_worker
,
spec_decode_sampler
,
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
)
worker
.
init_device
()
worker
.
init_device
()
proposal_token_ids
=
torch
.
randint
(
low
=
0
,
proposal_token_ids
=
torch
.
randint
(
low
=
0
,
...
@@ -249,8 +258,11 @@ def test_correctly_formats_output(k: int, batch_size: int,
...
@@ -249,8 +258,11 @@ def test_correctly_formats_output(k: int, batch_size: int,
set_random_seed
(
1
)
set_random_seed
(
1
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
spec_decode_sampler
,
worker
=
SpecDecodeWorker
(
draft_worker
,
metrics_collector
)
target_worker
,
spec_decode_sampler
,
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
)
worker
.
init_device
()
worker
.
init_device
()
proposal_token_ids
=
torch
.
randint
(
low
=
0
,
proposal_token_ids
=
torch
.
randint
(
low
=
0
,
...
@@ -479,9 +491,13 @@ def test_k_equals_zero(k: int, batch_size: int,
...
@@ -479,9 +491,13 @@ def test_k_equals_zero(k: int, batch_size: int,
set_random_seed
(
1
)
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
proposer_worker
=
draft_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
False
,
scorer_worker
=
target_worker
,
metrics_collector
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
),
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
,
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
k
,
...
@@ -526,9 +542,13 @@ def test_empty_input_batch(k: int, batch_size: int,
...
@@ -526,9 +542,13 @@ def test_empty_input_batch(k: int, batch_size: int,
set_random_seed
(
1
)
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
proposer_worker
=
draft_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
False
,
scorer_worker
=
target_worker
,
metrics_collector
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
),
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
,
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
k
,
...
@@ -560,8 +580,13 @@ def test_init_device(acceptance_sampler_method: str):
...
@@ -560,8 +580,13 @@ def test_init_device(acceptance_sampler_method: str):
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
spec_decode_sampler
,
worker
=
SpecDecodeWorker
(
False
,
metrics_collector
)
proposer_worker
=
draft_worker
,
scorer_worker
=
target_worker
,
spec_decode_sampler
=
spec_decode_sampler
,
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
,
)
worker
.
init_device
()
worker
.
init_device
()
draft_worker
.
init_device
.
assert_called_once
()
draft_worker
.
init_device
.
assert_called_once
()
...
@@ -583,9 +608,11 @@ def test_initialize_cache(acceptance_sampler_method):
...
@@ -583,9 +608,11 @@ def test_initialize_cache(acceptance_sampler_method):
target_worker
=
mock_worker
()
target_worker
=
mock_worker
()
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
worker
=
SpecDecodeWorker
(
proposer_worker
=
draft_worker
,
draft_worker
,
target_worker
,
scorer_worker
=
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
=
metrics_collector
)
kwargs
=
{
"num_gpu_blocks"
:
1024
,
"num_cpu_blocks"
:
1023
}
kwargs
=
{
"num_gpu_blocks"
:
1024
,
"num_cpu_blocks"
:
1023
}
worker
.
initialize_cache
(
**
kwargs
)
worker
.
initialize_cache
(
**
kwargs
)
...
@@ -725,7 +752,8 @@ def test_populate_seq_ids_with_bonus_tokens():
...
@@ -725,7 +752,8 @@ def test_populate_seq_ids_with_bonus_tokens():
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
accepted_token_ids
=
accepted_token_ids
,
accepted_token_ids
=
accepted_token_ids
,
target_logprobs
=
target_token_logprobs
,
target_logprobs
=
target_token_logprobs
,
k
=
k
)
k
=
k
,
stage_times
=
(
0
,
0
,
0
))
# Verify that _seq_with_bonus_token_in_last_step contains the following:
# Verify that _seq_with_bonus_token_in_last_step contains the following:
# 1. Sequence IDs that were already present in
# 1. Sequence IDs that were already present in
# _seq_with_bonus_token_in_last_step but were not part of the current
# _seq_with_bonus_token_in_last_step but were not part of the current
...
...
vllm/config.py
View file @
82a1b1a8
...
@@ -907,6 +907,7 @@ class SpeculativeConfig:
...
@@ -907,6 +907,7 @@ class SpeculativeConfig:
speculative_max_model_len
:
Optional
[
int
],
speculative_max_model_len
:
Optional
[
int
],
enable_chunked_prefill
:
bool
,
enable_chunked_prefill
:
bool
,
use_v2_block_manager
:
bool
,
use_v2_block_manager
:
bool
,
disable_log_stats
:
bool
,
speculative_disable_by_batch_size
:
Optional
[
int
],
speculative_disable_by_batch_size
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
...
@@ -1095,7 +1096,8 @@ class SpeculativeConfig:
...
@@ -1095,7 +1096,8 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold
,
typical_acceptance_sampler_posterior_threshold
,
typical_acceptance_sampler_posterior_alpha
=
\
typical_acceptance_sampler_posterior_alpha
=
\
typical_acceptance_sampler_posterior_alpha
,
typical_acceptance_sampler_posterior_alpha
,
disable_logprobs
=
disable_logprobs
disable_logprobs
=
disable_logprobs
,
disable_log_stats
=
disable_log_stats
,
)
)
@
staticmethod
@
staticmethod
...
@@ -1189,6 +1191,7 @@ class SpeculativeConfig:
...
@@ -1189,6 +1191,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold
:
float
,
typical_acceptance_sampler_posterior_threshold
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
disable_logprobs
:
bool
,
disable_logprobs
:
bool
,
disable_log_stats
:
bool
,
):
):
"""Create a SpeculativeConfig object.
"""Create a SpeculativeConfig object.
...
@@ -1221,6 +1224,8 @@ class SpeculativeConfig:
...
@@ -1221,6 +1224,8 @@ class SpeculativeConfig:
sampling, target sampling, and after accepted tokens are
sampling, target sampling, and after accepted tokens are
determined. If set to False, log probabilities will be
determined. If set to False, log probabilities will be
returned.
returned.
disable_log_stats: Whether to disable periodic printing of stage
times in speculative decoding.
"""
"""
self
.
draft_model_config
=
draft_model_config
self
.
draft_model_config
=
draft_model_config
self
.
draft_parallel_config
=
draft_parallel_config
self
.
draft_parallel_config
=
draft_parallel_config
...
@@ -1235,6 +1240,7 @@ class SpeculativeConfig:
...
@@ -1235,6 +1240,7 @@ class SpeculativeConfig:
self
.
typical_acceptance_sampler_posterior_alpha
=
\
self
.
typical_acceptance_sampler_posterior_alpha
=
\
typical_acceptance_sampler_posterior_alpha
typical_acceptance_sampler_posterior_alpha
self
.
disable_logprobs
=
disable_logprobs
self
.
disable_logprobs
=
disable_logprobs
self
.
disable_log_stats
=
disable_log_stats
self
.
_verify_args
()
self
.
_verify_args
()
...
...
vllm/engine/arg_utils.py
View file @
82a1b1a8
...
@@ -792,6 +792,7 @@ class EngineArgs:
...
@@ -792,6 +792,7 @@ class EngineArgs:
speculative_max_model_len
=
self
.
speculative_max_model_len
,
speculative_max_model_len
=
self
.
speculative_max_model_len
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
use_v2_block_manager
=
self
.
use_v2_block_manager
,
use_v2_block_manager
=
self
.
use_v2_block_manager
,
disable_log_stats
=
self
.
disable_log_stats
,
ngram_prompt_lookup_max
=
self
.
ngram_prompt_lookup_max
,
ngram_prompt_lookup_max
=
self
.
ngram_prompt_lookup_max
,
ngram_prompt_lookup_min
=
self
.
ngram_prompt_lookup_min
,
ngram_prompt_lookup_min
=
self
.
ngram_prompt_lookup_min
,
draft_token_acceptance_method
=
\
draft_token_acceptance_method
=
\
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
82a1b1a8
...
@@ -27,7 +27,7 @@ from vllm.spec_decode.ngram_worker import NGramWorker
...
@@ -27,7 +27,7 @@ from vllm.spec_decode.ngram_worker import NGramWorker
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.smaller_tp_proposer_worker
import
SmallerTpProposerWorker
from
vllm.spec_decode.smaller_tp_proposer_worker
import
SmallerTpProposerWorker
from
vllm.spec_decode.target_model_runner
import
TargetModelRunner
from
vllm.spec_decode.target_model_runner
import
TargetModelRunner
from
vllm.spec_decode.util
import
(
create_sequence_group_output
,
from
vllm.spec_decode.util
import
(
Timer
,
create_sequence_group_output
,
get_all_num_logprobs
,
get_all_num_logprobs
,
get_sampled_token_logprobs
,
nvtx_range
,
get_sampled_token_logprobs
,
nvtx_range
,
split_batch_by_proposal_len
)
split_batch_by_proposal_len
)
...
@@ -75,7 +75,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
...
@@ -75,7 +75,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_threshold
,
typical_acceptance_sampler_posterior_threshold
,
typical_acceptance_sampler_posterior_alpha
=
speculative_config
.
typical_acceptance_sampler_posterior_alpha
=
speculative_config
.
typical_acceptance_sampler_posterior_alpha
,
typical_acceptance_sampler_posterior_alpha
,
disable_logprobs
=
speculative_config
.
disable_logprobs
)
disable_logprobs
=
speculative_config
.
disable_logprobs
,
disable_log_stats
=
speculative_config
.
disable_log_stats
,
)
return
spec_decode_worker
return
spec_decode_worker
...
@@ -116,6 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -116,6 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_threshold
:
float
,
typical_acceptance_sampler_posterior_threshold
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
disable_logprobs
:
bool
,
disable_logprobs
:
bool
,
disable_log_stats
:
bool
,
)
->
"SpecDecodeWorker"
:
)
->
"SpecDecodeWorker"
:
allow_zero_draft_token_step
=
True
allow_zero_draft_token_step
=
True
...
@@ -171,6 +174,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -171,6 +174,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker
,
proposer_worker
,
scorer_worker
,
scorer_worker
,
disable_logprobs
=
disable_logprobs
,
disable_logprobs
=
disable_logprobs
,
disable_log_stats
=
disable_log_stats
,
disable_by_batch_size
=
disable_by_batch_size
,
disable_by_batch_size
=
disable_by_batch_size
,
spec_decode_sampler
=
spec_decode_sampler
,
spec_decode_sampler
=
spec_decode_sampler
,
allow_zero_draft_token_step
=
allow_zero_draft_token_step
)
allow_zero_draft_token_step
=
allow_zero_draft_token_step
)
...
@@ -180,7 +184,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -180,7 +184,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker
:
ProposerWorkerBase
,
proposer_worker
:
ProposerWorkerBase
,
scorer_worker
:
WorkerBase
,
scorer_worker
:
WorkerBase
,
spec_decode_sampler
:
SpecDecodeBaseSampler
,
spec_decode_sampler
:
SpecDecodeBaseSampler
,
disable_logprobs
:
bool
,
disable_logprobs
:
bool
=
False
,
disable_log_stats
:
bool
=
False
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
disable_by_batch_size
:
Optional
[
int
]
=
None
,
disable_by_batch_size
:
Optional
[
int
]
=
None
,
allow_zero_draft_token_step
:
Optional
[
bool
]
=
True
,
allow_zero_draft_token_step
:
Optional
[
bool
]
=
True
,
...
@@ -203,6 +208,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -203,6 +208,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable_logprobs: If set to True, token log probabilities will
disable_logprobs: If set to True, token log probabilities will
not be output in both the draft worker and the target worker.
not be output in both the draft worker and the target worker.
If set to False, log probabilities will be output by both.
If set to False, log probabilities will be output by both.
disable_log_stats: If set to True, disable periodic printing of
speculative stage times.
disable_by_batch_size: If the batch size is larger than this,
disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests.
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
metrics_collector: Helper class for collecting metrics; can be set
...
@@ -240,6 +247,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -240,6 +247,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# in the subsequent step.
# in the subsequent step.
self
.
previous_hidden_states
:
Optional
[
HiddenStates
]
=
None
self
.
previous_hidden_states
:
Optional
[
HiddenStates
]
=
None
self
.
_disable_logprobs
=
disable_logprobs
self
.
_disable_logprobs
=
disable_logprobs
self
.
_disable_log_stats
=
disable_log_stats
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
"""Initialize both scorer and proposer models.
"""Initialize both scorer and proposer models.
...
@@ -525,28 +533,37 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -525,28 +533,37 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
execute_model_req
.
previous_hidden_states
=
self
.
previous_hidden_states
execute_model_req
.
previous_hidden_states
=
self
.
previous_hidden_states
self
.
previous_hidden_states
=
None
self
.
previous_hidden_states
=
None
# Generate proposals using draft worker.
with
Timer
()
as
proposal_timer
:
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
# Generate proposals using draft worker.
execute_model_req
,
self
.
_seq_with_bonus_token_in_last_step
)
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
execute_model_req
,
self
.
_seq_with_bonus_token_in_last_step
)
if
not
self
.
_allow_zero_draft_token_step
and
proposals
.
no_proposals
:
if
not
self
.
_allow_zero_draft_token_step
and
proposals
.
no_proposals
:
#TODO: Fix it #5814
#TODO: Fix it #5814
raise
RuntimeError
(
"Cannot handle cases where distributed draft "
raise
RuntimeError
(
"Cannot handle cases where distributed draft "
"workers generate no tokens"
)
"workers generate no tokens"
)
proposal_scores
=
self
.
scorer
.
score_proposals
(
with
Timer
()
as
scoring_timer
:
execute_model_req
,
proposal_scores
=
self
.
scorer
.
score_proposals
(
proposals
,
execute_model_req
,
)
proposals
,
accepted_token_ids
,
target_logprobs
=
self
.
_verify_tokens
(
)
execute_model_req
.
seq_group_metadata_list
,
proposal_scores
,
proposals
,
execute_model_req
.
num_lookahead_slots
)
with
Timer
()
as
verification_timer
:
accepted_token_ids
,
target_logprobs
=
self
.
_verify_tokens
(
execute_model_req
.
seq_group_metadata_list
,
proposal_scores
,
proposals
,
execute_model_req
.
num_lookahead_slots
)
stage_times
=
(
proposal_timer
.
elapsed_time_ms
/
num_lookahead_slots
,
scoring_timer
.
elapsed_time_ms
,
verification_timer
.
elapsed_time_ms
)
return
self
.
_create_output_sampler_list
(
return
self
.
_create_output_sampler_list
(
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
seq_group_metadata_list
,
accepted_token_ids
,
accepted_token_ids
,
target_logprobs
=
target_logprobs
,
target_logprobs
=
target_logprobs
,
k
=
execute_model_req
.
num_lookahead_slots
)
k
=
execute_model_req
.
num_lookahead_slots
,
stage_times
=
stage_times
)
@
nvtx_range
(
"spec_decode_worker._verify_tokens"
)
@
nvtx_range
(
"spec_decode_worker._verify_tokens"
)
def
_verify_tokens
(
def
_verify_tokens
(
...
@@ -645,6 +662,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -645,6 +662,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
accepted_token_ids
:
torch
.
Tensor
,
# shape: [batch_size, k+1]
accepted_token_ids
:
torch
.
Tensor
,
# shape: [batch_size, k+1]
target_logprobs
:
torch
.
Tensor
,
# shape: [batch_size, k+1, vocab_size]
target_logprobs
:
torch
.
Tensor
,
# shape: [batch_size, k+1, vocab_size]
k
:
int
,
k
:
int
,
stage_times
:
Tuple
[
float
,
float
,
float
],
)
->
List
[
SamplerOutput
]:
)
->
List
[
SamplerOutput
]:
"""Given the accepted token ids, create a list of SamplerOutput.
"""Given the accepted token ids, create a list of SamplerOutput.
...
@@ -722,8 +740,30 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -722,8 +740,30 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if
maybe_rejsample_metrics
is
not
None
:
if
maybe_rejsample_metrics
is
not
None
:
sampler_output_list
[
sampler_output_list
[
0
].
spec_decode_worker_metrics
=
maybe_rejsample_metrics
0
].
spec_decode_worker_metrics
=
maybe_rejsample_metrics
# Log time spent in each stage periodically.
# This is periodic because the rejection sampler emits metrics
# periodically.
self
.
_maybe_log_stage_times
(
*
stage_times
)
return
sampler_output_list
return
sampler_output_list
def
_maybe_log_stage_times
(
self
,
average_time_per_proposal_tok_ms
:
float
,
scoring_time_ms
:
float
,
verification_time_ms
:
float
)
->
None
:
"""Log the speculative stage times. If stat logging is disabled, do
nothing.
"""
if
self
.
_disable_log_stats
:
return
logger
.
info
(
"SpecDecodeWorker stage times: "
"average_time_per_proposal_tok_ms=%.02f "
"scoring_time_ms=%.02f verification_time_ms=%.02f"
,
average_time_per_proposal_tok_ms
,
scoring_time_ms
,
verification_time_ms
)
def
_create_dummy_logprob_lists
(
def
_create_dummy_logprob_lists
(
self
,
self
,
batch_size
:
int
,
batch_size
:
int
,
...
...
vllm/spec_decode/util.py
View file @
82a1b1a8
import
time
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
...
@@ -214,3 +215,17 @@ def nvtx_range(msg, *args, **kwargs):
...
@@ -214,3 +215,17 @@ def nvtx_range(msg, *args, **kwargs):
yield
yield
finally
:
finally
:
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_pop
()
class
Timer
:
"""Basic timer context manager for measuring CPU time.
"""
def
__enter__
(
self
):
self
.
start_time
=
time
.
time
()
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
self
.
end_time
=
time
.
time
()
self
.
elapsed_time_s
=
self
.
end_time
-
self
.
start_time
self
.
elapsed_time_ms
=
self
.
elapsed_time_s
*
1000
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment