beam_search.py 5.01 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from dataclasses import dataclass
5

6
7
8
9
10
11
12
13
from vllm.inputs import (
    DecoderOnlyEngineInput,
    EncoderDecoderInput,
    MultiModalInput,
    TokensInput,
    mm_input,
    tokens_input,
)
14
from vllm.logprobs import Logprob
15
from vllm.lora.request import LoRARequest
16

17
18
19
20
21
22
23
24

@dataclass
class BeamSearchSequence:
    """A sequence for beam search.
    It keeps track of the tokens and the log probability of the sequence.
    The text field is optional and will only be filled when the sequence is
    about to be returned to the user.
    """
25

26
    orig_prompt: TokensInput | MultiModalInput | EncoderDecoderInput
27

28
    # NOTE: Tokens represents decoder tokens in the encoder / decoder case
29
30
    tokens: list[int]
    logprobs: list[dict[int, Logprob]]
31
    lora_request: LoRARequest | None = None
32
    cum_logprob: float = 0.0
33
34
35
    text: str | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = None
36
37
38
39

    def get_prompt(self):
        prompt = self.orig_prompt

40
41
42
43
        if prompt["type"] == "enc_dec":
            return self._build_encoder_decoder_inputs(prompt)

        # Handle decoder-only inputs
44
45
46
47
        prompt_text = prompt.get("prompt")
        cache_salt = prompt.get("cache_salt")

        if prompt["type"] == "token":
48
            return tokens_input(
49
50
51
52
53
                self.tokens,
                prompt=prompt_text,
                cache_salt=cache_salt,
            )

54
        return mm_input(
55
56
57
58
59
60
61
            prompt_token_ids=self.tokens,
            mm_kwargs=prompt["mm_kwargs"],
            mm_hashes=prompt["mm_hashes"],
            mm_placeholders=prompt["mm_placeholders"],
            prompt=prompt_text,
            cache_salt=cache_salt,
        )
62

63
    def _build_encoder_decoder_inputs(
64
65
        self, prompt: EncoderDecoderInput
    ) -> EncoderDecoderInput:
66
67
68
69
70
71
72
73
74
75
76
77
        """Rebuild the encoder-decoder inputs with the current beam search
        sequence's tokens.

        FIXME (alex) - the encoder multimodal cache is not properly wired up
        yet, which means that currently we are running the encoder on every
        new beam because num_computed_tokens is 0 on each new request. This
        will be fixed once the cache is correctly implemented.
        """
        dec_prompt = prompt["decoder_prompt"]

        # Rebuild decoder prompt with updated tokens,
        # but keep everything else the same.
78
        new_dec_prompt: DecoderOnlyEngineInput
79
        if dec_prompt["type"] == "multimodal":
80
            new_dec_prompt = mm_input(
81
82
83
84
85
86
87
88
                self.tokens,
                mm_kwargs=dec_prompt["mm_kwargs"],
                mm_hashes=dec_prompt["mm_hashes"],
                mm_placeholders=dec_prompt["mm_placeholders"],
                prompt=dec_prompt.get("prompt"),
                cache_salt=dec_prompt.get("cache_salt"),
            )
        else:
89
            new_dec_prompt = tokens_input(
90
91
92
93
94
                self.tokens,
                prompt=dec_prompt.get("prompt"),
                cache_salt=dec_prompt.get("cache_salt"),
            )

95
        return EncoderDecoderInput(
96
97
98
99
100
            type="enc_dec",
            encoder_prompt=prompt["encoder_prompt"],
            decoder_prompt=new_dec_prompt,
        )

101
102
103
104
105
106
107

@dataclass
class BeamSearchOutput:
    """The output of beam search.
    It contains the list of the best beam search sequences.
    The length of the list is equal to the beam width.
    """
108

109
    sequences: list[BeamSearchSequence]
110
111
112


class BeamSearchInstance:
113
114
    def __init__(
        self,
115
        prompt: TokensInput | MultiModalInput | EncoderDecoderInput,
116
117
        lora_request: LoRARequest | None = None,
        logprobs: list[dict[int, Logprob]] | None = None,
118
119
        **kwargs,
    ):
120
121
122
123
124
        decoder_prompt = (
            prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
        )
        initial_tokens = decoder_prompt["prompt_token_ids"]

125
        self.beams: list[BeamSearchSequence] = [
126
            BeamSearchSequence(
127
                orig_prompt=prompt,
128
                tokens=initial_tokens,
129
                logprobs=[] if logprobs is None else list(logprobs),
130
                lora_request=lora_request,
131
132
                **kwargs,
            )
133
        ]
134
        self.completed: list[BeamSearchSequence] = []
135
136
137


def get_beam_search_score(
138
    tokens: list[int],
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    cumulative_logprob: float,
    eos_token_id: int,
    length_penalty: float = 1.0,
) -> float:
    """Calculate the beam search score with length penalty.

    Adapted from

    https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
    """
    seq_len = len(tokens)
    if tokens[-1] == eos_token_id:
        seq_len -= 1

    return cumulative_logprob / (seq_len**length_penalty)


def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):
    def sort_beams_key(x: BeamSearchSequence) -> float:
158
159
160
        return get_beam_search_score(
            x.tokens, x.cum_logprob, eos_token_id, length_penalty
        )
161
162

    return sort_beams_key