sampler.py 5.37 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
import numpy as np
5
6
import torch

7
import vllm.envs as envs
8
from vllm.config.model import LogprobsMode
9
from vllm.sampling_params import SamplingParams
10
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
11
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
12
from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState
13
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
14
from vllm.v1.worker.gpu.sample.output import SamplerOutput
15
16
from vllm.v1.worker.gpu.sample.penalties import PenaltiesState
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates
17
18
19
20
21


class Sampler:
    def __init__(
        self,
22
23
24
        max_num_reqs: int,
        vocab_size: int,
        device: torch.device,
25
        logprobs_mode: LogprobsMode = "raw_logprobs",
26
        num_speculative_tokens: int = 1,
27
    ):
28
        if logprobs_mode not in ("processed_logprobs", "raw_logprobs"):
29
30
            raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
        self.logprobs_mode = logprobs_mode
31
        self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS  # False by default.
32

33
34
35
        self.sampling_states = SamplingStates(max_num_reqs, vocab_size)
        self.penalties_state = PenaltiesState(max_num_reqs, vocab_size, device)
        self.logit_bias_state = LogitBiasState(max_num_reqs, device)
36
        self.num_speculative_tokens = num_speculative_tokens
37
38

    def add_request(
39
        self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    ) -> None:
        self.sampling_states.add_request(req_idx, sampling_params)
        self.penalties_state.add_request(req_idx, sampling_params)
        self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params)

    def apply_staged_writes(
        self,
        prefill_token_ids: torch.Tensor,
        prefill_lens: np.ndarray,
        prompt_lens: np.ndarray,
    ) -> None:
        self.sampling_states.apply_staged_writes()
        self.penalties_state.apply_staged_writes(
            prefill_token_ids, prefill_lens, prompt_lens
        )
        self.logit_bias_state.apply_staged_writes()

57
58
59
    def __call__(
        self,
        logits: torch.Tensor,
60
61
        idx_mapping: torch.Tensor,
        idx_mapping_np: np.ndarray,
62
        cu_num_logits_np: np.ndarray,
63
        pos: torch.Tensor,
64
65
        input_ids: torch.Tensor,
        expanded_local_pos: torch.Tensor,
66
    ) -> SamplerOutput:
67
68
69
        # NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
        # that num_nans is computed before applying penalties and temperature.
        num_nans = get_num_nans(logits) if self.compute_nans else None
70
        sampled, processed_logits = self.sample(
71
72
73
74
75
76
            logits,
            idx_mapping,
            idx_mapping_np,
            pos,
            input_ids,
            expanded_local_pos,
77
78
79
80
        )

        max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
        if max_num_logprobs != NO_LOGPROBS:
81
82
            if self.logprobs_mode == "processed_logprobs":
                logits = processed_logits
83
84
85
86
87
            expanded_logits = logits.shape[0] != idx_mapping_np.shape[0]
            cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None
            logprobs_tensors = compute_topk_logprobs(
                logits, max_num_logprobs, sampled, cu_num_logits
            )
88
89
90
91
92
93
94
95
96
97
        else:
            logprobs_tensors = None

        # These are GPU tensors.
        sampler_output = SamplerOutput(
            # The sampled tokens are expanded to 2D tensor with shape
            # [num_requests, 1], where each row represents one generated
            # token per request.
            sampled_token_ids=sampled.view(-1, 1),
            logprobs_tensors=logprobs_tensors,
98
            num_nans=num_nans,
99
100
101
102
103
104
        )
        return sampler_output

    def sample(
        self,
        logits: torch.Tensor,
105
106
107
        idx_mapping: torch.Tensor,
        idx_mapping_np: np.ndarray,
        pos: torch.Tensor,
108
109
        input_ids: torch.Tensor,
        expanded_local_pos: torch.Tensor,
110
111
112
113
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Copy logits to a new FP32 tensor.
        logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)

114
        # Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
115
        self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos)
116

117
        # Apply penalties in place.
118
119
120
121
122
123
124
125
        self.penalties_state.apply_penalties(
            logits,
            idx_mapping,
            idx_mapping_np,
            input_ids,
            expanded_local_pos,
            self.num_speculative_tokens,
        )
126
127

        # Apply temperature in place.
128
129
130
131
132
133
134
135
136
        self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)

        # Apply min_p in place.
        self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np)

        # Apply top_k and/or top_p. This might or might not return a new tensor.
        logits = self.sampling_states.apply_top_k_top_p(
            logits, idx_mapping, idx_mapping_np
        )
137
138

        # Sample the next token.
139
140
        sampled = gumbel_sample(
            logits,
141
142
143
144
            idx_mapping,
            self.sampling_states.temperature.gpu,
            self.sampling_states.seeds.gpu,
            pos,
145
146
            apply_temperature=False,
        )
147
        return sampled, logits