proposer_worker_base.py 2.11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from typing import List, Optional, Set, Tuple
6

7
8
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
9
from vllm.spec_decode.interfaces import SpeculativeProposer
10
from vllm.worker.worker_base import LoRANotSupportedWorkerBase
11
12


13
class ProposerWorkerBase(LoRANotSupportedWorkerBase, SpeculativeProposer):
14
15
16
17
18
19
20
    """Interface for proposer workers"""

    @abstractmethod
    def sampler_output(
        self,
        execute_model_req: ExecuteModelRequest,
        sample_len: int,
21
22
23
24
25
26
27
        # A set containing all sequence IDs that were assigned bonus tokens
        # in their last forward pass. This set is used to backfill the KV cache
        # with the key-value pairs of the penultimate token in the sequences.
        # This parameter is only used by the MultiStepWorker, which relies on
        # the KV cache for token generation. It is not used by workers that
        # do not utilize the KV cache.
        seq_ids_with_bonus_token_in_last_step: Set[int]
28
29
30
    ) -> Tuple[Optional[List[SamplerOutput]], bool]:
        raise NotImplementedError

31
    def set_include_gpu_probs_tensor(self) -> None:
32
33
34
        """Implementation optional"""
        pass

35
36
37
38
    def set_should_modify_greedy_probs_inplace(self) -> None:
        """Implementation optional"""
        pass

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
    """Proposer worker which does not use a model with kvcache"""

    def execute_model(
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
        """get_spec_proposals is used to get the proposals"""
        return []

    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """This is never called on the proposer, only the target model"""
        raise NotImplementedError

    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        pass

    def get_cache_block_size_bytes(self) -> int:
        return 0