beam_search.py 5.03 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from dataclasses import dataclass
5

6
7
from vllm.inputs import EncoderDecoderInputs, TokenInputs, token_inputs
from vllm.inputs.data import DecoderInputs
8
from vllm.logprobs import Logprob
9
from vllm.lora.request import LoRARequest
10
from vllm.multimodal.inputs import MultiModalInputs, mm_inputs
11

12
13
14
15
16
17
18
19

@dataclass
class BeamSearchSequence:
    """A sequence for beam search.
    It keeps track of the tokens and the log probability of the sequence.
    The text field is optional and will only be filled when the sequence is
    about to be returned to the user.
    """
20

21
    orig_prompt: TokenInputs | MultiModalInputs | EncoderDecoderInputs
22

23
    # NOTE: Tokens represents decoder tokens in the encoder / decoder case
24
25
    tokens: list[int]
    logprobs: list[dict[int, Logprob]]
26
    lora_request: LoRARequest | None = None
27
    cum_logprob: float = 0.0
28
29
30
    text: str | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = None
31
32
33
34

    def get_prompt(self):
        prompt = self.orig_prompt

35
36
37
38
        if prompt["type"] == "enc_dec":
            return self._build_encoder_decoder_inputs(prompt)

        # Handle decoder-only inputs
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        prompt_text = prompt.get("prompt")
        cache_salt = prompt.get("cache_salt")

        if prompt["type"] == "token":
            return token_inputs(
                self.tokens,
                prompt=prompt_text,
                cache_salt=cache_salt,
            )

        return mm_inputs(
            prompt_token_ids=self.tokens,
            mm_kwargs=prompt["mm_kwargs"],
            mm_hashes=prompt["mm_hashes"],
            mm_placeholders=prompt["mm_placeholders"],
            prompt=prompt_text,
            cache_salt=cache_salt,
        )
57

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    def _build_encoder_decoder_inputs(
        self, prompt: EncoderDecoderInputs
    ) -> EncoderDecoderInputs:
        """Rebuild the encoder-decoder inputs with the current beam search
        sequence's tokens.

        FIXME (alex) - the encoder multimodal cache is not properly wired up
        yet, which means that currently we are running the encoder on every
        new beam because num_computed_tokens is 0 on each new request. This
        will be fixed once the cache is correctly implemented.
        """
        dec_prompt = prompt["decoder_prompt"]

        # Rebuild decoder prompt with updated tokens,
        # but keep everything else the same.
        new_dec_prompt: DecoderInputs
        if dec_prompt["type"] == "multimodal":
            new_dec_prompt = mm_inputs(
                self.tokens,
                mm_kwargs=dec_prompt["mm_kwargs"],
                mm_hashes=dec_prompt["mm_hashes"],
                mm_placeholders=dec_prompt["mm_placeholders"],
                prompt=dec_prompt.get("prompt"),
                cache_salt=dec_prompt.get("cache_salt"),
            )
        else:
            new_dec_prompt = token_inputs(
                self.tokens,
                prompt=dec_prompt.get("prompt"),
                cache_salt=dec_prompt.get("cache_salt"),
            )

        return EncoderDecoderInputs(
            type="enc_dec",
            encoder_prompt=prompt["encoder_prompt"],
            decoder_prompt=new_dec_prompt,
        )

96
97
98
99
100
101
102

@dataclass
class BeamSearchOutput:
    """The output of beam search.
    It contains the list of the best beam search sequences.
    The length of the list is equal to the beam width.
    """
103

104
    sequences: list[BeamSearchSequence]
105
106
107


class BeamSearchInstance:
108
109
    def __init__(
        self,
110
        prompt: TokenInputs | MultiModalInputs | EncoderDecoderInputs,
111
112
        lora_request: LoRARequest | None = None,
        logprobs: list[dict[int, Logprob]] | None = None,
113
114
        **kwargs,
    ):
115
116
117
118
119
        decoder_prompt = (
            prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
        )
        initial_tokens = decoder_prompt["prompt_token_ids"]

120
        self.beams: list[BeamSearchSequence] = [
121
            BeamSearchSequence(
122
                orig_prompt=prompt,
123
                tokens=initial_tokens,
124
                logprobs=[] if logprobs is None else list(logprobs),
125
                lora_request=lora_request,
126
127
                **kwargs,
            )
128
        ]
129
        self.completed: list[BeamSearchSequence] = []
130
131
132


def get_beam_search_score(
133
    tokens: list[int],
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    cumulative_logprob: float,
    eos_token_id: int,
    length_penalty: float = 1.0,
) -> float:
    """Calculate the beam search score with length penalty.

    Adapted from

    https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
    """
    seq_len = len(tokens)
    if tokens[-1] == eos_token_id:
        seq_len -= 1

    return cumulative_logprob / (seq_len**length_penalty)


def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):
    def sort_beams_key(x: BeamSearchSequence) -> float:
153
154
155
        return get_beam_search_score(
            x.tokens, x.cum_logprob, eos_token_id, length_penalty
        )
156
157

    return sort_beams_key