sampler.py 3.28 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

6
import vllm.envs as envs
7
8
from vllm.config.model import LogprobsMode
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
9
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
10
11
12
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
13
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
14
from vllm.v1.worker.gpu.sample.output import SamplerOutput
15
from vllm.v1.worker.gpu.sample.penalties import apply_penalties_and_temperature
16
17
18
19
20
21
22
23
24
25


class Sampler:
    def __init__(
        self,
        logprobs_mode: LogprobsMode = "raw_logprobs",
    ):
        if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
            raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
        self.logprobs_mode = logprobs_mode
26
        self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS  # False by default.
27
28
29
30
31
32

    def __call__(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> SamplerOutput:
33
34
35
        # NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
        # that num_nans is computed before applying penalties and temperature.
        num_nans = get_num_nans(logits) if self.compute_nans else None
36
        sampled, processed_logits = self.sample(logits, sampling_metadata)
37
        if sampling_metadata.max_num_logprobs is not None:
38
39
40
41
42
            logits = (
                processed_logits
                if self.logprobs_mode == "processed_logprobs"
                else logits
            )
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
            logprobs_tensors = compute_topk_logprobs(
                logits,
                sampling_metadata.max_num_logprobs,
                sampled,
            )
        else:
            logprobs_tensors = None

        # These are GPU tensors.
        sampler_output = SamplerOutput(
            # The sampled tokens are expanded to 2D tensor with shape
            # [num_requests, 1], where each row represents one generated
            # token per request.
            sampled_token_ids=sampled.view(-1, 1),
            logprobs_tensors=logprobs_tensors,
58
            num_nans=num_nans,
59
60
61
62
63
64
65
        )
        return sampler_output

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
66
67
68
69
70
71
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Copy logits to a new FP32 tensor.
        logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)

        # Apply penalties and temperature in place.
        apply_penalties_and_temperature(logits, sampling_metadata)
72
        # Apply min_p in place.
73
74
        if sampling_metadata.min_p is not None:
            apply_min_p(logits, sampling_metadata.min_p)
75
        # Apply top_k and/or top_p. This might return a new tensor.
76
77
78
79
80
81
82
83
84
85
86
        logits = apply_top_k_top_p(
            logits, sampling_metadata.top_k, sampling_metadata.top_p
        )

        sampled = gumbel_sample(
            logits,
            sampling_metadata.temperature,
            sampling_metadata.seeds,
            sampling_metadata.pos,
            apply_temperature=False,
        )
87
        return sampled, logits