sample.py 10.4 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
from typing import Dict, List, Tuple

import torch
import torch.nn as nn

from cacheflow.models import InputMetadata
7
8
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceOutputs
Zhuohan Li's avatar
Zhuohan Li committed
9
from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_parallel_region
Woosuk Kwon's avatar
Woosuk Kwon committed
10

Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
11

Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
class Sampler(nn.Module):

Woosuk Kwon's avatar
Woosuk Kwon committed
14
    def __init__(self, vocab_size: int) -> None:
15
        super().__init__()
Woosuk Kwon's avatar
Woosuk Kwon committed
16
        self.vocab_size = vocab_size
Woosuk Kwon's avatar
Woosuk Kwon committed
17
18
19

    def forward(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
20
        embedding: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
        hidden_states: torch.Tensor,
        input_metadata: InputMetadata,
23
24
25
    ) -> Dict[int, SequenceOutputs]:
        # Get the hidden states that we use for sampling.
        hidden_states = _prune_hidden_states(hidden_states, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27

        # Get the logits for the next tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
28
        logits = torch.matmul(hidden_states, embedding.t())
Zhuohan Li's avatar
Zhuohan Li committed
29
        logits = gather_from_tensor_model_parallel_region(logits)
Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
        # Remove paddings in vocab.
        logits = logits[:, :self.vocab_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
32

33
34
35
36
37
38
39
40
41
        # Apply temperature scaling.
        temperatures = _get_temperatures(input_metadata)
        assert len(temperatures) == logits.shape[0]
        if any(t != 1.0 for t in temperatures):
            t = torch.tensor(
                temperatures, dtype=logits.dtype, device=logits.device)
            # Use in-place division to avoid creating a new tensor.
            logits.div_(t.unsqueeze(dim=1))

42
        # We use float32 for probabilities and log probabilities.
43
44
45
        # Compute the probabilities.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
        # Compute the log probabilities (before applying top-p).
46
        logprobs = torch.log(probs)
47
48
49
50
51
52
53
54

        # Apply top-p truncation.
        top_ps = _get_top_ps(input_metadata)
        assert len(top_ps) == probs.shape[0]
        if any(p < 1.0 for p in top_ps):
            p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
            probs = _apply_top_p(probs, p)

Woosuk Kwon's avatar
Woosuk Kwon committed
55
        # Sample the next tokens.
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        return _sample(probs, logprobs, input_metadata)


def _prune_hidden_states(
    hidden_states: torch.Tensor,
    input_metadata: InputMetadata,
) -> torch.Tensor:
    start_idx = 0
    last_token_indicies: List[int] = []
    for prompt_len in input_metadata.prompt_lens:
        last_token_indicies.append(start_idx + prompt_len - 1)
        start_idx += prompt_len
    last_token_indicies.extend(
        range(start_idx, start_idx + input_metadata.num_generation_tokens))
    return hidden_states[last_token_indicies]


def _get_temperatures(
    input_metadata: InputMetadata,
) -> List[float]:
    # Collect the temperatures for the logits.
    temperatures: List[float] = []
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, sampling_params = seq_group
        temperature = sampling_params.temperature
        if temperature == 0.0:
            # NOTE: Zero temperature means deterministic sampling
            # (i.e., greedy sampling or beam search).
            # Set the temperature to 1 to avoid division by zero.
            temperature = 1.0

        if i < input_metadata.num_prompts:
            # A prompt input.
            temperatures.append(temperature)
        else:
            # A generation token.
            temperatures += [temperature] * len(seq_ids)
    return temperatures


def _get_top_ps(
    input_metadata: InputMetadata,
) -> List[float]:
    top_ps: List[float] = []
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, sampling_params = seq_group
        if i < input_metadata.num_prompts:
            # A prompt input.
            top_ps.append(sampling_params.top_p)
        else:
            # A generation token.
            top_ps += [sampling_params.top_p] * len(seq_ids)
    return top_ps


def _apply_top_p(
    probs: torch.Tensor,
    p: torch.Tensor,
) -> torch.Tensor:
    # TODO(woosuk): Optimize.
    probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    probs = torch.gather(
        probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
    return probs


def _get_topk_logprobs(
    logprobs: torch.Tensor,
    num_logprobs: int,
) -> Dict[int, float]:
    if num_logprobs == 0:
        return {}

    topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
    if num_logprobs == 1:
        topk_logprobs = [topk_logprobs.item()]
        topk_ids = [topk_ids.item()]
    else:
        topk_logprobs = topk_logprobs.tolist()
        topk_ids = topk_ids.tolist()

    token_to_logprob: Dict[int, float] = {}
    for token_id, logprob in zip(topk_ids, topk_logprobs):
        token_to_logprob[token_id] = logprob
    return token_to_logprob


def _sample_from_prompt(
    prob: torch.Tensor,
    sampling_params: SamplingParams,
) -> List[int]:
    if sampling_params.use_beam_search:
        # Beam search.
        beam_width = sampling_params.n
        _, next_token_ids = torch.topk(prob, beam_width)
        next_token_ids = next_token_ids.tolist()
    elif sampling_params.temperature == 0.0:
        # Greedy sampling.
        assert sampling_params.n == 1
        next_token_id = torch.argmax(prob)
        next_token_ids = [next_token_id.item()]
    else:
        # Neucleus sampling.
        # Sample n tokens for the prompt.
        n = sampling_params.n
        next_token_ids = torch.multinomial(
            prob, num_samples=n, replacement=True)
Woosuk Kwon's avatar
Woosuk Kwon committed
167
        next_token_ids = next_token_ids.tolist()
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    return next_token_ids


def _sample_from_generation_tokens(
    seq_ids: List[int],
    probs: torch.Tensor,
    logprobs: torch.Tensor,
    seq_logprobs: List[float],
    sampling_params: SamplingParams,
) -> Tuple[List[int], List[int]]:
    # NOTE(woosuk): sampling_params.n can be greater than
    # len(seq_ids) because some sequences in the group might have
    # been already terminated.
    if sampling_params.use_beam_search:
        # Beam search.
        # Add cumulative logprobs for the sequences in the group.
        seq_logprobs = torch.tensor(
            seq_logprobs, dtype=torch.float, device=logprobs.device)
        logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)

        vocab_size = logprobs.size(-1)
        beam_width = len(seq_ids)
        _, topk_ids = torch.topk(logprobs.flatten(), beam_width)
191
192
        topk_ids = topk_ids.tolist()
        seq_idx = [i // vocab_size for i in topk_ids]
193
        beam_seq_ids = [seq_ids[i] for i in seq_idx]
194
        token_ids = [i % vocab_size for i in topk_ids]
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288

        beam_outputs: Dict[int, Tuple[int, int]] = {}
        outstanding_beams: List[Tuple[int, int]] = []
        # If a beam survives, continue with it.
        for seq_id, token_id in zip(beam_seq_ids, token_ids):
            if seq_id not in beam_outputs:
                beam_outputs[seq_id] = (seq_id, token_id)
            else:
                outstanding_beams.append((seq_id, token_id))

        # If a beam is discarded, fork another beam.
        for seq_id in seq_ids:
            if seq_id not in beam_outputs:
                beam_outputs[seq_id] = outstanding_beams.pop()
        assert not outstanding_beams

        parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
        next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
    elif sampling_params.temperature == 0.0:
        # Greedy sampling.
        assert len(seq_ids) == 1
        next_token_id = torch.argmax(probs, dim=-1)
        next_token_ids = [next_token_id.item()]
        parent_seq_ids = seq_ids
    else:
        # Neucleus sampling.
        # Sample 1 token for each sequence in the group.
        next_token_ids = torch.multinomial(
            probs, num_samples=1, replacement=True)
        next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
        parent_seq_ids = seq_ids
    return parent_seq_ids, next_token_ids


def _sample(
    probs: torch.Tensor,
    logprobs: torch.Tensor,
    input_metadata: InputMetadata,
) -> Dict[int, SequenceOutputs]:
    seq_outputs: Dict[int, SequenceOutputs] = {}

    # TODO(woosuk): Optimize.
    idx = 0
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, sampling_params = seq_group
        if i < input_metadata.num_prompts:
            # Generate the next tokens for a prompt input.
            assert len(seq_ids) == sampling_params.n
            prob = probs[idx]
            logprob = logprobs[idx]
            idx += 1

            # Sample the next tokens.
            next_token_ids = _sample_from_prompt(prob, sampling_params)
            # Get top-k log probabilities for the next tokens.
            next_logprobs = _get_topk_logprobs(
                logprob, sampling_params.num_logprobs)

            # Build the output.
            for seq_id, next_token_id in zip(seq_ids, next_token_ids):
                output_logprobs = next_logprobs.copy()
                output_logprobs[next_token_id] = logprob[next_token_id].item()
                seq_outputs[seq_id] = SequenceOutputs(
                    seq_id, seq_id, next_token_id, output_logprobs)
        else:
            # Generate the next tokens for generation tokens.
            prob = probs[idx:idx + len(seq_ids)]
            logprob = logprobs[idx:idx + len(seq_ids)]
            idx += len(seq_ids)

            # Sample the next tokens.
            seq_logprobs = [
                input_metadata.seq_logprobs[seq_id] for seq_id in seq_ids]
            parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
                seq_ids, prob, logprob, seq_logprobs, sampling_params)

            # Get top-k log probabilities for the next tokens.
            next_logprobs: Dict[int, Dict[int, float]] = {}
            for i, seq_id in enumerate(seq_ids):
                next_logprobs[seq_id] = _get_topk_logprobs(
                    logprob[i], sampling_params.num_logprobs)

            # Build the output.
            for seq_id, parent_seq_id, next_token_id in zip(
                seq_ids, parent_seq_ids, next_token_ids):
                i = seq_ids.index(parent_seq_id)
                output_logprobs = next_logprobs[parent_seq_id].copy()
                output_logprobs[next_token_id] = logprob[i, next_token_id].item()
                seq_outputs[seq_id] = SequenceOutputs(
                    seq_id,
                    parent_seq_id,
                    next_token_id,
                    output_logprobs,
                )
Woosuk Kwon's avatar
Woosuk Kwon committed
289

290
    return seq_outputs