sampling_metadata.py 22.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from array import array
5
from dataclasses import dataclass
6
from typing import Optional
7
8
9

import torch

10
from vllm.sampling_params import SamplingParams, SamplingType
11
12
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
                           SequenceGroupMetadata)
13
from vllm.utils import (PyObjectCache, async_tensor_h2d,
14
                        is_pin_memory_available, make_tensor_with_pad)
15
16

_SAMPLING_EPS = 1e-5
17
18


19
20
@dataclass
class SequenceGroupToSample:
21
22
23
24
25
26
27
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|

28
    # Sequence ids for the sequence group in a previous step.
29
    seq_ids: list[int]
30
31
    sampling_params: SamplingParams
    # seq_id -> sequence data.
32
    seq_data: dict[int, SequenceData]
33
34
    # The length of the sequence (all tokens seen in the past + new token to
    # compute attention) of the sequence group. None if it is in a decode
35
    # stage.
36
37
38
39
40
    seq_len: Optional[int]
    # The length of new query tokens to compute in the current step. None if it
    # is in a decode stage. The length of query_len <= seq_len if chunked
    # prefill is enabled.
    query_len: Optional[int]
41
42
43
44
45
46
47
    # A random number generator for sampling.
    generator: Optional[torch.Generator]
    # True if the sequence group is in prefill stage. False if it is in a
    # decode stage.
    is_prompt: bool
    # Query token indices from logits. to compute prompt logprob. Empty if
    # prompt logprob is not required.
48
    prompt_logprob_indices: list[int]
49
    # Sample token indices from logits. Empty if sampling is not required.
50
    sample_indices: list[int]
51
52
53
54
55
56
57
58
59

    @property
    def do_sample(self):
        return len(self.sample_indices) > 0

    def __post_init__(self):
        if len(self.prompt_logprob_indices) > 0:
            assert self.sampling_params.prompt_logprobs is not None
        if self.is_prompt:
60
61
            assert self.seq_len is not None
            assert self.query_len is not None
62
63


64
65
66
67
68
69
70
71
72
73
def gen_seq_group_to_sample_builder(num_seqs: int):
    return lambda: SequenceGroupToSample(
        seq_ids=[0] * num_seqs,
        sampling_params=None,
        seq_data=None,  # type: ignore
        seq_len=0,
        query_len=0,
        generator=None,
        is_prompt=True,
        prompt_logprob_indices=[],
74
75
        sample_indices=[],
    )
76
77
78


class SamplingMetadataCache:
79
    """Used to cache SamplingMetadata objects between scheduler iterations"""
80
81

    def __init__(self):
82
        self._seq_group_to_sample_cache: dict[int, PyObjectCache] = {}
83
84
85
86
87
88
89
90
91
92
93
94
95
96

    def get_cached_seq_group_to_sample(self, num_seqs):
        if num_seqs not in self._seq_group_to_sample_cache:
            self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
                gen_seq_group_to_sample_builder(num_seqs))

        obj = self._seq_group_to_sample_cache[num_seqs].get_object()
        return obj

    def reset(self):
        for cache in self._seq_group_to_sample_cache.values():
            cache.reset()


97
98
99
class SamplingMetadata:
    """Metadata for input sequences. Used in sampler.

100
101
102
103
104
105
106
107
108
109
    The usage is as follow;
    ```
    hidden_states = execute_model(...)
    logits = hidden_states[sampling_metadata.selected_token_indices]
    sample(logits)

    def sample(logits):
        # Use categorized_sample_indices for sampling....
    ```

110
    Args:
111
112
113
        seq_groups: List of batched sequence groups.
        selected_token_indices: (num_query_tokens_to_logprob). Indices to find
            logits from the initial model output hidden states.
114
        categorized_sample_indices: SamplingType -> token indices to sample.
115
116
117
118
119
120
121
122
123
            Each token indices is 2D tensor of (num_indices, num_indices) where
            the first item means the sample index within the returned logit
            (before pruning padding), and the second item means the sample
            index after pruning using selected_token_indices.
            For example, if the returned logit is [1, 2, 3], and we select
            [1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
            The first tuple is [1, 2] (sampled index within original logit),
            and the second tuple is [0, 1] (sampled index within pruned logit).
        num_prompts: Number of prompt sequence groups in seq_groups.
124
        skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
125
            serialization of token outputs.
126
        reuse_sampling_tensors: Indicates if we want to reuse sampling
127
128
            tensors that are part of the sampler forward pass. Currently,
            it is mainly used for multi-step decode.
129

130
131
132
133
    """

    def __init__(
        self,
134
        seq_groups: list[SequenceGroupToSample],
135
        selected_token_indices: torch.Tensor,
136
        categorized_sample_indices: dict[SamplingType, torch.Tensor],
137
        num_prompts: int,
138
139
        skip_sampler_cpu_output: bool = False,
        reuse_sampling_tensors: bool = False,
140
141
142
143
    ) -> None:
        self.seq_groups = seq_groups
        self.selected_token_indices = selected_token_indices
        self.categorized_sample_indices = categorized_sample_indices
144
        self.num_prompts = num_prompts
145
146
        self.skip_sampler_cpu_output = skip_sampler_cpu_output
        self.reuse_sampling_tensors = reuse_sampling_tensors
147

148
149
    @staticmethod
    def prepare(
150
151
152
        seq_group_metadata_list: list[SequenceGroupMetadata],
        seq_lens: list[int],
        query_lens: list[int],
153
154
        device: str,
        pin_memory: bool,
155
        generators: Optional[dict[str, torch.Generator]] = None,
156
        cache: Optional[SamplingMetadataCache] = None,
157
158
159
160
161
162
    ) -> "SamplingMetadata":
        (
            seq_groups,
            selected_token_indices,
            categorized_sample_indices,
            num_prompts,
163
        ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
164
                                device, generators, cache)
165
166
167
168
169
170
        selected_token_indices = async_tensor_h2d(
            selected_token_indices,
            dtype=torch.long,
            target_device=device,
            pin_memory=pin_memory,
        )
171
        categorized_sample_indices = {
172
173
            t:
            async_tensor_h2d(
174
175
176
177
178
                seq_ids,
                dtype=torch.int,
                target_device=device,
                pin_memory=pin_memory,
            )
179
180
181
182
183
184
185
186
187
188
            for t, seq_ids in categorized_sample_indices.items()
        }

        sampling_metadata = SamplingMetadata(
            seq_groups=seq_groups,
            selected_token_indices=selected_token_indices,
            categorized_sample_indices=categorized_sample_indices,
            num_prompts=num_prompts,
        )
        return sampling_metadata
189
190
191
192
193
194

    def __repr__(self) -> str:
        return (
            "SamplingMetadata("
            f"seq_groups={self.seq_groups}, "
            f"selected_token_indices={self.selected_token_indices}, "
Percy's avatar
Percy committed
195
            f"categorized_sample_indices={self.categorized_sample_indices})")
196
197
198


def _prepare_seq_groups(
199
200
201
    seq_group_metadata_list: list[SequenceGroupMetadata],
    seq_lens: list[int],
    query_lens: list[int],
202
    device: str,
203
    generators: Optional[dict[str, torch.Generator]] = None,
204
    cache: Optional[SamplingMetadataCache] = None,
205
206
207
208
) -> tuple[
        list[SequenceGroupToSample],
        list[int],
        dict[SamplingType, list[int]],
209
210
        int,
]:
211
212
213
214
    """Prepare sequence groups and indices for sampling.

    Args:
        seq_group_metadata_list: A list of sequence group to batch.
215
        seq_lens: A list of sequence lens per sequence group.
216
            Index of prompt len should match with seq_group_metadata_list.
217
        query_lens: A list of query lengths. Prompt lens include the length
218
            of entire prompt tokens, and it could be shorter.
219
        device: A device to use for random number generators,
220
            `SequenceGroupToSample.generator`.
221
222
        generators: A store of per-request random number generators used
            for seeded requests.
223
224
225
226
227
228
229
230

    Returns:
        seq_groups: A list of sequence group to sample.
        selected_token_indices: See the definition from `SamplingMetadata`.
        categorized_sample_indices: See the definition from `SamplingMetadata`.
        num_prompts: Total number of prompts from `seq_group_metadata_list`.
    """
    # Batched sequence groups for the current model forward stsep.
231
    seq_groups: list[SequenceGroupToSample] = []
232
233
    # A list of token indices to sample/compute logprob. It is used to
    # prune the outcome logits from the model for the performance.
234
    selected_token_indices: list[int] = []
235
236
237
238
239
240
    # Used for selected_token_indices.
    model_output_idx = 0

    # Sampling type -> (
    # indices to sample/prompt logprob within pruned output logits,
    # indices to sample within pruned logits)
241
    categorized_sample_indices: dict[SamplingType, list[int]] = {
242
243
244
245
246
247
248
249
250
251
        t: []
        for t in SamplingType
    }
    # Index of logits to compute logprob. Logits include both prompt logprob
    # and sample logprob indices.
    logit_idx = 0
    # Total number of prompts from given sequence groups.
    num_prompts = 0

    for i, seq_group_metadata in enumerate(seq_group_metadata_list):
252
253
254
255
256
257
258
259
260
261
262
        seq_ids = seq_group_metadata.seq_data.keys()

        if cache is not None:
            sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids))

            for j, seq_id in enumerate(seq_ids):
                sample_obj.seq_ids[j] = seq_id

            sample_obj.prompt_logprob_indices.clear()
            sample_obj.sample_indices.clear()

263
264
265
266
        sampling_params = seq_group_metadata.sampling_params
        is_prompt = seq_group_metadata.is_prompt
        generator: Optional[torch.Generator] = None
        # If the current seq group is in decode stage, it is None.
267
268
        seq_len: Optional[int] = None
        query_len: Optional[int] = None
269
        prompt_logprob_indices: list[int] = (sample_obj.prompt_logprob_indices
270
                                             if cache is not None else [])
271
        sample_indices: list[int] = (sample_obj.sample_indices
272
                                     if cache is not None else [])
273
274
275
276
        do_sample = seq_group_metadata.do_sample

        if seq_group_metadata.is_prompt:
            if sampling_params.seed is not None:
277
278
279
280
                generator = torch.Generator(device=device).manual_seed(
                    sampling_params.seed)
                if generators is not None:
                    generators[seq_group_metadata.request_id] = generator
281
282
283
284

            num_prompts += 1
            num_prefill_sample = len(seq_ids)
            assert num_prefill_sample == 1
285
286
            assert query_lens is not None and seq_lens is not None
            query_len, seq_len = query_lens[i], seq_lens[i]
287
288
            # If we need sampling, exclude num_prefill_sample tokens from
            # prompt logprob.
289
290
            prompt_logprob_len = (query_len - num_prefill_sample
                                  if do_sample else query_len)
291
292
293
294
            sample_len = num_prefill_sample if do_sample else 0
        else:
            # Decode
            prompt_logprob_len = 0
295
296
            query_len = query_lens[i] if query_lens is not None and len(
                query_lens) > 0 else 1
297
            sample_len = len(seq_ids) * query_len if do_sample else 0
298

299
300
301
            if sampling_params.seed is not None and generators is not None:
                generator = generators.get(seq_group_metadata.request_id)

302
303
304
305
306
307
308
309
310
        # Update indices to select from the model output.
        """
        This blocks computes selected_token_indices which is used in the
        following way.

        hidden_states = model(...)
        logits = hidden_states[selected_token_indices]
        """

311
        if sampling_params.prompt_logprobs is not None:
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
            selected_token_indices.extend(
                range(model_output_idx, model_output_idx + prompt_logprob_len))
        model_output_idx += prompt_logprob_len
        if do_sample:
            selected_token_indices.extend(
                range(model_output_idx, model_output_idx + sample_len))
        model_output_idx += sample_len

        # We now find indices for logprob computation and sampling.
        """
        This block computes categorized_sample_indices which is used in the
        following way.

        hidden_states = model(...)
        logits = hidden_states[selected_token_indices]
        def sample(logits):
           # Use categorized_sample_indices for sampling.
           # prompt_logprob_indices to find prompt logprob indices.
           # sample_indices to find sample indices.
        """

        if sampling_params.prompt_logprobs is not None:
            prompt_logprob_indices.extend(
                range(logit_idx, logit_idx + prompt_logprob_len))
            logit_idx += prompt_logprob_len
        if do_sample:
            sample_indices.extend(range(logit_idx, logit_idx + sample_len))
            categorized_sample_indices[sampling_params.sampling_type].extend(
340
                list(range(logit_idx, logit_idx + sample_len)))
341
342
            logit_idx += sample_len

343
344
345
346
347
348
349
350
351
352
        if cache is not None:
            sample_obj.sampling_params = sampling_params
            sample_obj.seq_data = seq_group_metadata.seq_data
            sample_obj.seq_len = seq_len
            sample_obj.query_len = query_len
            sample_obj.generator = generator
            sample_obj.is_prompt = is_prompt
        else:
            sample_obj = SequenceGroupToSample(
                seq_ids=list(seq_ids),
353
354
                sampling_params=sampling_params,
                seq_data=seq_group_metadata.seq_data,
355
356
                seq_len=seq_len,
                query_len=query_len,
357
358
359
                generator=generator,
                is_prompt=is_prompt,
                prompt_logprob_indices=list(prompt_logprob_indices),
360
361
                sample_indices=list(sample_indices),
            )
362
363
364
365
366
367

        seq_groups.append(sample_obj)

    if cache is not None:
        cache.reset()

368
369
    return (seq_groups, selected_token_indices, categorized_sample_indices,
            num_prompts)
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387


@dataclass
class SamplingTensors:
    """Tensors for sampling."""

    temperatures: torch.Tensor
    top_ps: torch.Tensor
    top_ks: torch.Tensor
    min_ps: torch.Tensor
    presence_penalties: torch.Tensor
    frequency_penalties: torch.Tensor
    repetition_penalties: torch.Tensor
    prompt_tokens: torch.Tensor
    output_tokens: torch.Tensor

    @classmethod
    def from_sampling_metadata(
388
389
390
391
392
        cls,
        sampling_metadata: "SamplingMetadata",
        vocab_size: int,
        device: torch.device,
        dtype: torch.dtype,
393
394
395
396
397
398
399
400
401
402
    ) -> tuple["SamplingTensors", bool, bool, bool]:
        prompt_tokens: list[array] = []
        output_tokens: list[array] = []
        top_ks: list[int] = []
        temperatures: list[float] = []
        top_ps: list[float] = []
        min_ps: list[float] = []
        presence_penalties: list[float] = []
        frequency_penalties: list[float] = []
        repetition_penalties: list[float] = []
403
404
405
        do_penalties = False
        do_top_p_top_k = False
        do_min_p = False
406

407
        assert sampling_metadata.seq_groups is not None
408
409
410
        for seq_group in sampling_metadata.seq_groups:
            seq_ids = seq_group.seq_ids
            sampling_params = seq_group.sampling_params
411
412
413
414
415
416
            temperature = sampling_params.temperature
            p = sampling_params.presence_penalty
            f = sampling_params.frequency_penalty
            r = sampling_params.repetition_penalty
            top_p = sampling_params.top_p
            min_p = sampling_params.min_p
417

418
419
            # k should not be greater than the vocab size.
            top_k = min(sampling_params.top_k, vocab_size)
420
            top_k = vocab_size if top_k < 1 else top_k
421
422
423
424
425
426
427
428
429
430
431
432
433
434
            if temperature < _SAMPLING_EPS:
                # NOTE: Zero temperature means deterministic sampling
                # (i.e., greedy sampling or beam search).
                # Set the temperature to 1 to avoid division by zero.
                temperature = 1.0
            if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
                                       or top_k != vocab_size):
                do_top_p_top_k = True
            if not do_min_p and min_p > _SAMPLING_EPS:
                do_min_p = True
            if not do_penalties and (abs(p) >= _SAMPLING_EPS
                                     or abs(f) >= _SAMPLING_EPS
                                     or abs(r - 1.0) >= _SAMPLING_EPS):
                do_penalties = True
435

436
            is_prompt = seq_group.is_prompt
437
            if is_prompt and sampling_params.prompt_logprobs is not None:
438
439
                # For tokens in the prompt that we only need to get
                # their logprobs
440
441
                query_len = seq_group.query_len
                assert query_len is not None
442
443
444
445
446
447
448
449
450
451
452
                prefill_len = len(seq_group.prompt_logprob_indices)
                temperatures += [temperature] * prefill_len
                top_ps += [top_p] * prefill_len
                top_ks += [top_k] * prefill_len
                min_ps += [min_p] * prefill_len
                presence_penalties += [0] * prefill_len
                frequency_penalties += [0] * prefill_len
                repetition_penalties += [1] * prefill_len

            if seq_group.do_sample:
                sample_lens = len(seq_group.sample_indices)
453
454
455
456
457
458
459
460
                assert sample_lens >= len(seq_ids)
                temperatures += [temperature] * sample_lens
                top_ps += [top_p] * sample_lens
                top_ks += [top_k] * sample_lens
                min_ps += [min_p] * sample_lens
                presence_penalties += [p] * sample_lens
                frequency_penalties += [f] * sample_lens
                repetition_penalties += [r] * sample_lens
461

462
463
464
        if do_penalties:
            for seq_group in sampling_metadata.seq_groups:
                seq_ids = seq_group.seq_ids
465
                sampling_params = seq_group.sampling_params
466
467
468
                if (seq_group.is_prompt
                        and sampling_params.prompt_logprobs is not None):
                    prefill_len = len(seq_group.prompt_logprob_indices)
469
                    prompt_tokens.extend(
470
471
                        array(VLLM_TOKEN_ID_ARRAY_TYPE)
                        for _ in range(prefill_len))
472
                    output_tokens.extend(
473
474
                        array(VLLM_TOKEN_ID_ARRAY_TYPE)
                        for _ in range(prefill_len))
475
476
477
                if seq_group.do_sample:
                    for seq_id in seq_ids:
                        seq_data = seq_group.seq_data[seq_id]
478
479
                        prompt_tokens.append(seq_data.prompt_token_ids_array)
                        output_tokens.append(seq_data.output_token_ids_array)
480

481
        sampling_tensors = SamplingTensors.from_lists(
482
483
484
485
486
487
488
489
490
491
492
493
494
            temperatures,
            top_ps,
            top_ks,
            min_ps,
            presence_penalties,
            frequency_penalties,
            repetition_penalties,
            prompt_tokens,
            output_tokens,
            vocab_size,
            device,
            dtype,
        )
495
496
497
        return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)

    @classmethod
498
499
    def from_lists(
        cls,
500
501
502
503
504
505
506
507
508
        temperatures: list[float],
        top_ps: list[float],
        top_ks: list[int],
        min_ps: list[float],
        presence_penalties: list[float],
        frequency_penalties: list[float],
        repetition_penalties: list[float],
        prompt_tokens: list[array],
        output_tokens: list[array],
509
510
511
512
        vocab_size: int,
        device: torch.device,
        dtype: torch.dtype,
    ) -> "SamplingTensors":
513
514
        # Note that the performance will be very bad without
        # pinned memory.
515
        pin_memory = is_pin_memory_available()
516
517
518
519

        do_penalties = prompt_tokens or output_tokens

        if do_penalties:
520
521
            prompt_t = make_tensor_with_pad(
                prompt_tokens,
522
                vocab_size,
523
524
525
526
527
528
                device="cpu",
                dtype=torch.int64,
                pin_memory=pin_memory,
            )
            output_t = make_tensor_with_pad(
                output_tokens,
529
                vocab_size,
530
531
532
533
534
535
536
537
                device="cpu",
                dtype=torch.int64,
                pin_memory=pin_memory,
            )
        else:
            empty_tensor = torch.empty(0, device=device, dtype=torch.long)
            prompt_t = empty_tensor
            output_t = empty_tensor
538
539
540
541

        temperatures_t = torch.tensor(
            temperatures,
            device="cpu",
542
            dtype=torch.float32,
543
544
545
546
547
            pin_memory=pin_memory,
        )
        top_ps_t = torch.tensor(
            top_ps,
            device="cpu",
548
            dtype=torch.float32,
549
550
551
552
553
            pin_memory=pin_memory,
        )
        min_ps_t = torch.tensor(
            min_ps,
            device="cpu",
554
            dtype=torch.float32,
555
556
557
558
559
            pin_memory=pin_memory,
        )
        presence_penalties_t = torch.tensor(
            presence_penalties,
            device="cpu",
560
            dtype=torch.float32,
561
562
563
564
565
            pin_memory=pin_memory,
        )
        frequency_penalties_t = torch.tensor(
            frequency_penalties,
            device="cpu",
566
            dtype=torch.float32,
567
568
569
570
571
            pin_memory=pin_memory,
        )
        repetition_penalties_t = torch.tensor(
            repetition_penalties,
            device="cpu",
572
            dtype=torch.float32,
573
574
575
576
577
578
579
580
581
582
            pin_memory=pin_memory,
        )
        top_ks_t = torch.tensor(
            top_ks,
            device="cpu",
            dtype=torch.int,
            pin_memory=pin_memory,
        )
        # Because the memory is pinned, we can do non-blocking
        # transfer to device.
583

584
585
586
587
588
589
590
591
592
593
594
        return cls(
            temperatures=temperatures_t.to(device=device, non_blocking=True),
            top_ps=top_ps_t.to(device=device, non_blocking=True),
            top_ks=top_ks_t.to(device=device, non_blocking=True),
            min_ps=min_ps_t.to(device=device, non_blocking=True),
            presence_penalties=presence_penalties_t.to(device=device,
                                                       non_blocking=True),
            frequency_penalties=frequency_penalties_t.to(device=device,
                                                         non_blocking=True),
            repetition_penalties=repetition_penalties_t.to(device=device,
                                                           non_blocking=True),
595
596
            prompt_tokens=prompt_t.to(device=device, non_blocking=True),
            output_tokens=output_t.to(device=device, non_blocking=True),
597
        )