Unverified Commit 57b7be0e authored by William Lin's avatar William Lin Committed by GitHub
Browse files

[Speculative decoding] [Multi-Step] decouple should_modify_greedy_probs_inplace (#6971)

parent 99b4cf5f
import itertools import itertools
import random import random
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from unittest.mock import patch from unittest.mock import Mock, patch
import pytest import pytest
import torch import torch
...@@ -703,3 +703,28 @@ def test_sampler_repetition_penalty_mixed(device: str): ...@@ -703,3 +703,28 @@ def test_sampler_repetition_penalty_mixed(device: str):
assert tokens1[0] == tokens2[1] assert tokens1[0] == tokens2[1]
assert tokens1[1] == tokens2[0] assert tokens1[1] == tokens2[0]
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_include_gpu_probs_tensor(device: str):
set_random_seed(42)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)
sampler.include_gpu_probs_tensor = True
sampler.should_modify_greedy_probs_inplace = False
sampling_params = SamplingParams(temperature=0)
mock_inplace = Mock()
with patch(
"vllm.model_executor.layers.sampler._modify_greedy_probs_inplace",
mock_inplace):
sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
mock_inplace.assert_not_called()
assert sampler_output.sampled_token_probs is not None
assert sampler_output.logprobs is not None
assert sampler_output.sampled_token_ids is not None
...@@ -1067,6 +1067,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1067,6 +1067,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def include_gpu_probs_tensor(self): def include_gpu_probs_tensor(self):
return self.base_layer.include_gpu_probs_tensor return self.base_layer.include_gpu_probs_tensor
@property
def should_modify_greedy_probs_inplace(self):
return self.base_layer.should_modify_greedy_probs_inplace
def create_lora_weights( def create_lora_weights(
self, self,
max_loras: int, max_loras: int,
......
...@@ -51,6 +51,7 @@ class Sampler(nn.Module): ...@@ -51,6 +51,7 @@ class Sampler(nn.Module):
# containing the sampled token ids and probabilities. This is used by # containing the sampled token ids and probabilities. This is used by
# speculative decoding. # speculative decoding.
self.include_gpu_probs_tensor = False self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False
def _init_sampling_tensors( def _init_sampling_tensors(
self, self,
...@@ -177,8 +178,7 @@ class Sampler(nn.Module): ...@@ -177,8 +178,7 @@ class Sampler(nn.Module):
This is used by speculative decoding, which requires that the sampling This is used by speculative decoding, which requires that the sampling
method be encoded into the probability distribution. method be encoded into the probability distribution.
""" """
# Modify greedy probs if include_gpu_probs_tensor is set. return self.should_modify_greedy_probs_inplace
return self.include_gpu_probs_tensor
def _get_bin_counts_and_mask( def _get_bin_counts_and_mask(
......
...@@ -35,6 +35,9 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker): ...@@ -35,6 +35,9 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
def set_include_gpu_probs_tensor(self): def set_include_gpu_probs_tensor(self):
pass pass
def set_should_modify_greedy_probs_inplace(self):
pass
@torch.inference_mode() @torch.inference_mode()
def sampler_output( def sampler_output(
self, self,
......
...@@ -46,6 +46,10 @@ class MultiStepWorker(Worker, ProposerWorkerBase): ...@@ -46,6 +46,10 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
# Need include_gpu_probs_tensor for MultiStepWorker # Need include_gpu_probs_tensor for MultiStepWorker
self.model_runner.model.sampler.include_gpu_probs_tensor = True self.model_runner.model.sampler.include_gpu_probs_tensor = True
def set_should_modify_greedy_probs_inplace(self) -> None:
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
True)
@torch.inference_mode() @torch.inference_mode()
def sampler_output( def sampler_output(
self, self,
......
...@@ -28,6 +28,10 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer): ...@@ -28,6 +28,10 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
"""Implementation optional""" """Implementation optional"""
pass pass
def set_should_modify_greedy_probs_inplace(self) -> None:
"""Implementation optional"""
pass
class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC): class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
"""Proposer worker which does not use a model with kvcache""" """Proposer worker which does not use a model with kvcache"""
......
...@@ -83,6 +83,12 @@ class SmallerTpProposerWorker(ProposerWorkerBase): ...@@ -83,6 +83,12 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
# Need include_gpu_probs_tensor for multi_step_worker # Need include_gpu_probs_tensor for multi_step_worker
self._worker.set_include_gpu_probs_tensor() self._worker.set_include_gpu_probs_tensor()
def set_should_modify_greedy_probs_inplace(self) -> None:
if self._is_dummy:
return
self._worker.set_should_modify_greedy_probs_inplace()
def load_model(self) -> None: def load_model(self) -> None:
if self._is_dummy: if self._is_dummy:
return return
......
...@@ -295,7 +295,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -295,7 +295,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
""" """
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
) = True ) = True
(self.scorer_worker.model_runner.model.sampler.
should_modify_greedy_probs_inplace) = True
self.proposer_worker.set_include_gpu_probs_tensor() self.proposer_worker.set_include_gpu_probs_tensor()
self.proposer_worker.set_should_modify_greedy_probs_inplace()
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of cache blocks to use. """Determine the number of cache blocks to use.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment