Unverified Commit 82a1b1a8 authored by Cade Daniel's avatar Cade Daniel Committed by GitHub
Browse files

[Speculative decoding] Add periodic log with time spent in proposal/scoring/verification (#6963)

parent c0d8f163
...@@ -34,8 +34,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int, ...@@ -34,8 +34,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
target_worker = mock_worker() target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker( worker = SpecDecodeWorker(
draft_worker, target_worker, draft_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector)
exception_secret = 'artificial stop' exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
...@@ -74,8 +77,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int, ...@@ -74,8 +77,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker( worker = SpecDecodeWorker(
draft_worker, target_worker, draft_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device() worker.init_device()
vocab_size = 32_000 vocab_size = 32_000
...@@ -159,8 +165,11 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int, ...@@ -159,8 +165,11 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, worker = SpecDecodeWorker(draft_worker,
metrics_collector) target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device() worker.init_device()
proposal_token_ids = torch.randint(low=0, proposal_token_ids = torch.randint(low=0,
...@@ -249,8 +258,11 @@ def test_correctly_formats_output(k: int, batch_size: int, ...@@ -249,8 +258,11 @@ def test_correctly_formats_output(k: int, batch_size: int,
set_random_seed(1) set_random_seed(1)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, worker = SpecDecodeWorker(draft_worker,
metrics_collector) target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device() worker.init_device()
proposal_token_ids = torch.randint(low=0, proposal_token_ids = torch.randint(low=0,
...@@ -479,9 +491,13 @@ def test_k_equals_zero(k: int, batch_size: int, ...@@ -479,9 +491,13 @@ def test_k_equals_zero(k: int, batch_size: int,
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker( worker = SpecDecodeWorker(
draft_worker, target_worker, proposer_worker=draft_worker,
mock_spec_decode_sampler(acceptance_sampler_method), False, scorer_worker=target_worker,
metrics_collector) spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
)
seq_group_metadata_list, _, _ = create_batch(batch_size, seq_group_metadata_list, _, _ = create_batch(batch_size,
k, k,
...@@ -526,9 +542,13 @@ def test_empty_input_batch(k: int, batch_size: int, ...@@ -526,9 +542,13 @@ def test_empty_input_batch(k: int, batch_size: int,
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker( worker = SpecDecodeWorker(
draft_worker, target_worker, proposer_worker=draft_worker,
mock_spec_decode_sampler(acceptance_sampler_method), False, scorer_worker=target_worker,
metrics_collector) spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
)
seq_group_metadata_list, _, _ = create_batch(batch_size, seq_group_metadata_list, _, _ = create_batch(batch_size,
k, k,
...@@ -560,8 +580,13 @@ def test_init_device(acceptance_sampler_method: str): ...@@ -560,8 +580,13 @@ def test_init_device(acceptance_sampler_method: str):
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, worker = SpecDecodeWorker(
False, metrics_collector) proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector,
)
worker.init_device() worker.init_device()
draft_worker.init_device.assert_called_once() draft_worker.init_device.assert_called_once()
...@@ -583,9 +608,11 @@ def test_initialize_cache(acceptance_sampler_method): ...@@ -583,9 +608,11 @@ def test_initialize_cache(acceptance_sampler_method):
target_worker = mock_worker() target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker( worker = SpecDecodeWorker(proposer_worker=draft_worker,
draft_worker, target_worker, scorer_worker=target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
metrics_collector=metrics_collector)
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
worker.initialize_cache(**kwargs) worker.initialize_cache(**kwargs)
...@@ -725,7 +752,8 @@ def test_populate_seq_ids_with_bonus_tokens(): ...@@ -725,7 +752,8 @@ def test_populate_seq_ids_with_bonus_tokens():
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
accepted_token_ids=accepted_token_ids, accepted_token_ids=accepted_token_ids,
target_logprobs=target_token_logprobs, target_logprobs=target_token_logprobs,
k=k) k=k,
stage_times=(0, 0, 0))
# Verify that _seq_with_bonus_token_in_last_step contains the following: # Verify that _seq_with_bonus_token_in_last_step contains the following:
# 1. Sequence IDs that were already present in # 1. Sequence IDs that were already present in
# _seq_with_bonus_token_in_last_step but were not part of the current # _seq_with_bonus_token_in_last_step but were not part of the current
......
...@@ -907,6 +907,7 @@ class SpeculativeConfig: ...@@ -907,6 +907,7 @@ class SpeculativeConfig:
speculative_max_model_len: Optional[int], speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool, enable_chunked_prefill: bool,
use_v2_block_manager: bool, use_v2_block_manager: bool,
disable_log_stats: bool,
speculative_disable_by_batch_size: Optional[int], speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int], ngram_prompt_lookup_min: Optional[int],
...@@ -1095,7 +1096,8 @@ class SpeculativeConfig: ...@@ -1095,7 +1096,8 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=\ typical_acceptance_sampler_posterior_alpha=\
typical_acceptance_sampler_posterior_alpha, typical_acceptance_sampler_posterior_alpha,
disable_logprobs=disable_logprobs disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
) )
@staticmethod @staticmethod
...@@ -1189,6 +1191,7 @@ class SpeculativeConfig: ...@@ -1189,6 +1191,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float, typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool, disable_logprobs: bool,
disable_log_stats: bool,
): ):
"""Create a SpeculativeConfig object. """Create a SpeculativeConfig object.
...@@ -1221,6 +1224,8 @@ class SpeculativeConfig: ...@@ -1221,6 +1224,8 @@ class SpeculativeConfig:
sampling, target sampling, and after accepted tokens are sampling, target sampling, and after accepted tokens are
determined. If set to False, log probabilities will be determined. If set to False, log probabilities will be
returned. returned.
disable_log_stats: Whether to disable periodic printing of stage
times in speculative decoding.
""" """
self.draft_model_config = draft_model_config self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config self.draft_parallel_config = draft_parallel_config
...@@ -1235,6 +1240,7 @@ class SpeculativeConfig: ...@@ -1235,6 +1240,7 @@ class SpeculativeConfig:
self.typical_acceptance_sampler_posterior_alpha = \ self.typical_acceptance_sampler_posterior_alpha = \
typical_acceptance_sampler_posterior_alpha typical_acceptance_sampler_posterior_alpha
self.disable_logprobs = disable_logprobs self.disable_logprobs = disable_logprobs
self.disable_log_stats = disable_log_stats
self._verify_args() self._verify_args()
......
...@@ -792,6 +792,7 @@ class EngineArgs: ...@@ -792,6 +792,7 @@ class EngineArgs:
speculative_max_model_len=self.speculative_max_model_len, speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill, enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager, use_v2_block_manager=self.use_v2_block_manager,
disable_log_stats=self.disable_log_stats,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
draft_token_acceptance_method=\ draft_token_acceptance_method=\
......
...@@ -27,7 +27,7 @@ from vllm.spec_decode.ngram_worker import NGramWorker ...@@ -27,7 +27,7 @@ from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
from vllm.spec_decode.target_model_runner import TargetModelRunner from vllm.spec_decode.target_model_runner import TargetModelRunner
from vllm.spec_decode.util import (create_sequence_group_output, from vllm.spec_decode.util import (Timer, create_sequence_group_output,
get_all_num_logprobs, get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range, get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len) split_batch_by_proposal_len)
...@@ -75,7 +75,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": ...@@ -75,7 +75,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=speculative_config. typical_acceptance_sampler_posterior_alpha=speculative_config.
typical_acceptance_sampler_posterior_alpha, typical_acceptance_sampler_posterior_alpha,
disable_logprobs=speculative_config.disable_logprobs) disable_logprobs=speculative_config.disable_logprobs,
disable_log_stats=speculative_config.disable_log_stats,
)
return spec_decode_worker return spec_decode_worker
...@@ -116,6 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -116,6 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float, typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool, disable_logprobs: bool,
disable_log_stats: bool,
) -> "SpecDecodeWorker": ) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True allow_zero_draft_token_step = True
...@@ -171,6 +174,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -171,6 +174,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker, proposer_worker,
scorer_worker, scorer_worker,
disable_logprobs=disable_logprobs, disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size, disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler, spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step) allow_zero_draft_token_step=allow_zero_draft_token_step)
...@@ -180,7 +184,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -180,7 +184,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker: ProposerWorkerBase, proposer_worker: ProposerWorkerBase,
scorer_worker: WorkerBase, scorer_worker: WorkerBase,
spec_decode_sampler: SpecDecodeBaseSampler, spec_decode_sampler: SpecDecodeBaseSampler,
disable_logprobs: bool, disable_logprobs: bool = False,
disable_log_stats: bool = False,
metrics_collector: Optional[AsyncMetricsCollector] = None, metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None, disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True, allow_zero_draft_token_step: Optional[bool] = True,
...@@ -203,6 +208,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -203,6 +208,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable_logprobs: If set to True, token log probabilities will disable_logprobs: If set to True, token log probabilities will
not be output in both the draft worker and the target worker. not be output in both the draft worker and the target worker.
If set to False, log probabilities will be output by both. If set to False, log probabilities will be output by both.
disable_log_stats: If set to True, disable periodic printing of
speculative stage times.
disable_by_batch_size: If the batch size is larger than this, disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests. disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set metrics_collector: Helper class for collecting metrics; can be set
...@@ -240,6 +247,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -240,6 +247,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# in the subsequent step. # in the subsequent step.
self.previous_hidden_states: Optional[HiddenStates] = None self.previous_hidden_states: Optional[HiddenStates] = None
self._disable_logprobs = disable_logprobs self._disable_logprobs = disable_logprobs
self._disable_log_stats = disable_log_stats
def init_device(self) -> None: def init_device(self) -> None:
"""Initialize both scorer and proposer models. """Initialize both scorer and proposer models.
...@@ -525,28 +533,37 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -525,28 +533,37 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
execute_model_req.previous_hidden_states = self.previous_hidden_states execute_model_req.previous_hidden_states = self.previous_hidden_states
self.previous_hidden_states = None self.previous_hidden_states = None
# Generate proposals using draft worker. with Timer() as proposal_timer:
proposals = self.proposer_worker.get_spec_proposals( # Generate proposals using draft worker.
execute_model_req, self._seq_with_bonus_token_in_last_step) proposals = self.proposer_worker.get_spec_proposals(
execute_model_req, self._seq_with_bonus_token_in_last_step)
if not self._allow_zero_draft_token_step and proposals.no_proposals: if not self._allow_zero_draft_token_step and proposals.no_proposals:
#TODO: Fix it #5814 #TODO: Fix it #5814
raise RuntimeError("Cannot handle cases where distributed draft " raise RuntimeError("Cannot handle cases where distributed draft "
"workers generate no tokens") "workers generate no tokens")
proposal_scores = self.scorer.score_proposals( with Timer() as scoring_timer:
execute_model_req, proposal_scores = self.scorer.score_proposals(
proposals, execute_model_req,
) proposals,
accepted_token_ids, target_logprobs = self._verify_tokens( )
execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots) with Timer() as verification_timer:
accepted_token_ids, target_logprobs = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots)
stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
scoring_timer.elapsed_time_ms,
verification_timer.elapsed_time_ms)
return self._create_output_sampler_list( return self._create_output_sampler_list(
execute_model_req.seq_group_metadata_list, execute_model_req.seq_group_metadata_list,
accepted_token_ids, accepted_token_ids,
target_logprobs=target_logprobs, target_logprobs=target_logprobs,
k=execute_model_req.num_lookahead_slots) k=execute_model_req.num_lookahead_slots,
stage_times=stage_times)
@nvtx_range("spec_decode_worker._verify_tokens") @nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens( def _verify_tokens(
...@@ -645,6 +662,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -645,6 +662,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size] target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
k: int, k: int,
stage_times: Tuple[float, float, float],
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
"""Given the accepted token ids, create a list of SamplerOutput. """Given the accepted token ids, create a list of SamplerOutput.
...@@ -722,8 +740,30 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -722,8 +740,30 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if maybe_rejsample_metrics is not None: if maybe_rejsample_metrics is not None:
sampler_output_list[ sampler_output_list[
0].spec_decode_worker_metrics = maybe_rejsample_metrics 0].spec_decode_worker_metrics = maybe_rejsample_metrics
# Log time spent in each stage periodically.
# This is periodic because the rejection sampler emits metrics
# periodically.
self._maybe_log_stage_times(*stage_times)
return sampler_output_list return sampler_output_list
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
scoring_time_ms: float,
verification_time_ms: float) -> None:
"""Log the speculative stage times. If stat logging is disabled, do
nothing.
"""
if self._disable_log_stats:
return
logger.info(
"SpecDecodeWorker stage times: "
"average_time_per_proposal_tok_ms=%.02f "
"scoring_time_ms=%.02f verification_time_ms=%.02f",
average_time_per_proposal_tok_ms, scoring_time_ms,
verification_time_ms)
def _create_dummy_logprob_lists( def _create_dummy_logprob_lists(
self, self,
batch_size: int, batch_size: int,
......
import time
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
...@@ -214,3 +215,17 @@ def nvtx_range(msg, *args, **kwargs): ...@@ -214,3 +215,17 @@ def nvtx_range(msg, *args, **kwargs):
yield yield
finally: finally:
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
class Timer:
"""Basic timer context manager for measuring CPU time.
"""
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time_s = self.end_time - self.start_time
self.elapsed_time_ms = self.elapsed_time_s * 1000
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment