target_model_runner.py 2.04 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import List, Optional
lizhigong's avatar
lizhigong committed
4
import torch
5
from vllm.sequence import SequenceGroupMetadata
6
7
8
from vllm.worker.model_runner_base import (ModelRunnerBase,
                                           ModelRunnerInputBase,
                                           ModelRunnerWrapperBase)
9
10


11
class TargetModelRunner(ModelRunnerWrapperBase):
12
13
14
15
16
17
18
19
20
21
22
    """Specialized model runner for speculative decoding target model.
    In speculative decoding, the log probabilities selected finally may not
    be the same ones as selected by the target model sampling. This means
    that the time spent in the log probability calculation of the target model
    is time wasted, since we calculate log probabilities after deciding which
    tokens are accepted. For this reason disabling log probabilities in the
    target model will make decode faster. The model runner sets the
    SamplingMetadata parameters according to whether log probabilities are
    requested or not. 
    """

23
    def __init__(self, model_runner: ModelRunnerBase):
24
25
        # An internal boolean member variable to indicate if token log
        # probabilities are needed or not.
26
        super().__init__(model_runner)
27
28
29
30
31
32
        self.disable_logprobs = True

    def prepare_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        virtual_engine: int = 0,
lizhigong's avatar
lizhigong committed
33
        finished_requests_ids: Optional[List[str]] = None
34
35
36
    ) -> ModelRunnerInputBase:
        model_input: ModelRunnerInputBase =\
            self.model_runner.prepare_model_input(
lizhigong's avatar
lizhigong committed
37
            seq_group_metadata_list, virtual_engine, finished_requests_ids)
38
39
40
41
42
43
44
        # If token log probabilities is disabled then skip generating sampler
        # CPU output. We directly serialize the GPU sampled_token_id tensors
        # as needed. If log probabilities is enabled then synchronize all the
        # sampling related tensors which includes the logprobs tensors.
        model_input.sampling_metadata.skip_sampler_cpu_output = (
            self.disable_logprobs)
        return model_input