sampler.py 5.22 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
import numpy as np
5
6
import torch

7
import vllm.envs as envs
8
from vllm.config.model import LogprobsMode
9
from vllm.sampling_params import SamplingParams
10
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
11
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
12
from vllm.v1.worker.gpu.sample.gumbel import apply_temperature, gumbel_sample
13
from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState
14
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
15
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
16
from vllm.v1.worker.gpu.sample.output import SamplerOutput
17
18
from vllm.v1.worker.gpu.sample.penalties import PenaltiesState
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates
19
20
21
22
23


class Sampler:
    def __init__(
        self,
24
25
26
        max_num_reqs: int,
        vocab_size: int,
        device: torch.device,
27
28
29
30
31
        logprobs_mode: LogprobsMode = "raw_logprobs",
    ):
        if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
            raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
        self.logprobs_mode = logprobs_mode
32
        self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS  # False by default.
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        self.sampling_states = SamplingStates(max_num_reqs, vocab_size)
        self.penalties_state = PenaltiesState(max_num_reqs, vocab_size, device)
        self.logit_bias_state = LogitBiasState(max_num_reqs, device)

    def add_request(
        self,
        req_idx: int,
        prompt_len: int,
        sampling_params: SamplingParams,
    ) -> None:
        self.sampling_states.add_request(req_idx, sampling_params)
        self.penalties_state.add_request(req_idx, sampling_params)
        self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params)

    def apply_staged_writes(
        self,
        prefill_token_ids: torch.Tensor,
        prefill_lens: np.ndarray,
        prompt_lens: np.ndarray,
    ) -> None:
        self.sampling_states.apply_staged_writes()
        self.penalties_state.apply_staged_writes(
            prefill_token_ids, prefill_lens, prompt_lens
        )
        self.logit_bias_state.apply_staged_writes()

60
61
62
    def __call__(
        self,
        logits: torch.Tensor,
63
64
65
        idx_mapping: torch.Tensor,
        idx_mapping_np: np.ndarray,
        pos: torch.Tensor,
66
    ) -> SamplerOutput:
67
68
69
        # NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
        # that num_nans is computed before applying penalties and temperature.
        num_nans = get_num_nans(logits) if self.compute_nans else None
70
71
72
73
74
75
        sampled, processed_logits = self.sample(
            logits, idx_mapping, idx_mapping_np, pos
        )

        max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
        if max_num_logprobs != NO_LOGPROBS:
76
77
78
79
80
            logits = (
                processed_logits
                if self.logprobs_mode == "processed_logprobs"
                else logits
            )
81
            logprobs_tensors = compute_topk_logprobs(logits, max_num_logprobs, sampled)
82
83
84
85
86
87
88
89
90
91
        else:
            logprobs_tensors = None

        # These are GPU tensors.
        sampler_output = SamplerOutput(
            # The sampled tokens are expanded to 2D tensor with shape
            # [num_requests, 1], where each row represents one generated
            # token per request.
            sampled_token_ids=sampled.view(-1, 1),
            logprobs_tensors=logprobs_tensors,
92
            num_nans=num_nans,
93
94
95
96
97
98
        )
        return sampler_output

    def sample(
        self,
        logits: torch.Tensor,
99
100
101
        idx_mapping: torch.Tensor,
        idx_mapping_np: np.ndarray,
        pos: torch.Tensor,
102
103
104
105
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Copy logits to a new FP32 tensor.
        logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)

106
107
108
        # Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
        self.logit_bias_state.apply_logit_bias(logits, idx_mapping, pos)

109
110
111
112
113
        # Apply penalties in place.
        self.penalties_state.apply_penalties(logits, idx_mapping)

        # Apply temperature in place.
        apply_temperature(logits, idx_mapping, self.sampling_states.temperature.gpu)
114

115
116
117
118
119
120
121
122
123
124
125
126
127
128
        # Apply min_p in place if any request has a non-zero min_p.
        do_min_p = self.sampling_states.do_min_p(idx_mapping_np)
        if do_min_p:
            apply_min_p(logits, idx_mapping, self.sampling_states.min_p.gpu)

        # Apply top_k and/or top_p. This might return a new tensor.
        do_top_k = self.sampling_states.do_top_k(idx_mapping_np)
        top_k = self.sampling_states.top_k.gpu[idx_mapping] if do_top_k else None
        do_top_p = self.sampling_states.do_top_p(idx_mapping_np)
        top_p = self.sampling_states.top_p.gpu[idx_mapping] if do_top_p else None
        if do_top_k or do_top_p:
            logits = apply_top_k_top_p(logits, top_k, top_p)

        # Sample the next token.
129
130
        sampled = gumbel_sample(
            logits,
131
132
133
134
            idx_mapping,
            self.sampling_states.temperature.gpu,
            self.sampling_states.seeds.gpu,
            pos,
135
136
            apply_temperature=False,
        )
137
        return sampled, logits