sampler.py 5.49 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
import numpy as np
5
6
import torch

7
import vllm.envs as envs
8
from vllm.config.model import LogprobsMode
9
from vllm.sampling_params import SamplingParams
10
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
11
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
12
from vllm.v1.worker.gpu.sample.gumbel import apply_temperature, gumbel_sample
13
from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState
14
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
15
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
16
from vllm.v1.worker.gpu.sample.output import SamplerOutput
17
18
from vllm.v1.worker.gpu.sample.penalties import PenaltiesState
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates
19
20
21
22
23


class Sampler:
    def __init__(
        self,
24
25
26
        max_num_reqs: int,
        vocab_size: int,
        device: torch.device,
27
28
29
30
31
        logprobs_mode: LogprobsMode = "raw_logprobs",
    ):
        if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
            raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
        self.logprobs_mode = logprobs_mode
32
        self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS  # False by default.
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        self.sampling_states = SamplingStates(max_num_reqs, vocab_size)
        self.penalties_state = PenaltiesState(max_num_reqs, vocab_size, device)
        self.logit_bias_state = LogitBiasState(max_num_reqs, device)

    def add_request(
        self,
        req_idx: int,
        prompt_len: int,
        sampling_params: SamplingParams,
    ) -> None:
        self.sampling_states.add_request(req_idx, sampling_params)
        self.penalties_state.add_request(req_idx, sampling_params)
        self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params)

    def apply_staged_writes(
        self,
        prefill_token_ids: torch.Tensor,
        prefill_lens: np.ndarray,
        prompt_lens: np.ndarray,
    ) -> None:
        self.sampling_states.apply_staged_writes()
        self.penalties_state.apply_staged_writes(
            prefill_token_ids, prefill_lens, prompt_lens
        )
        self.logit_bias_state.apply_staged_writes()

60
61
62
    def __call__(
        self,
        logits: torch.Tensor,
63
64
        idx_mapping: torch.Tensor,
        idx_mapping_np: np.ndarray,
65
        cu_num_logits_np: np.ndarray,
66
        pos: torch.Tensor,
67
    ) -> SamplerOutput:
68
69
70
        # NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
        # that num_nans is computed before applying penalties and temperature.
        num_nans = get_num_nans(logits) if self.compute_nans else None
71
72
73
74
75
76
        sampled, processed_logits = self.sample(
            logits, idx_mapping, idx_mapping_np, pos
        )

        max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
        if max_num_logprobs != NO_LOGPROBS:
77
78
79
80
81
            logits = (
                processed_logits
                if self.logprobs_mode == "processed_logprobs"
                else logits
            )
82
83
84
85
86
            expanded_logits = logits.shape[0] != idx_mapping_np.shape[0]
            cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None
            logprobs_tensors = compute_topk_logprobs(
                logits, max_num_logprobs, sampled, cu_num_logits
            )
87
88
89
90
91
92
93
94
95
96
        else:
            logprobs_tensors = None

        # These are GPU tensors.
        sampler_output = SamplerOutput(
            # The sampled tokens are expanded to 2D tensor with shape
            # [num_requests, 1], where each row represents one generated
            # token per request.
            sampled_token_ids=sampled.view(-1, 1),
            logprobs_tensors=logprobs_tensors,
97
            num_nans=num_nans,
98
99
100
101
102
103
        )
        return sampler_output

    def sample(
        self,
        logits: torch.Tensor,
104
105
106
        idx_mapping: torch.Tensor,
        idx_mapping_np: np.ndarray,
        pos: torch.Tensor,
107
108
109
110
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Copy logits to a new FP32 tensor.
        logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)

111
        # Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
112
        self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos)
113

114
        # Apply penalties in place.
115
        self.penalties_state.apply_penalties(logits, idx_mapping, idx_mapping_np)
116
117
118

        # Apply temperature in place.
        apply_temperature(logits, idx_mapping, self.sampling_states.temperature.gpu)
119

120
121
122
123
124
125
126
127
128
129
130
131
132
133
        # Apply min_p in place if any request has a non-zero min_p.
        do_min_p = self.sampling_states.do_min_p(idx_mapping_np)
        if do_min_p:
            apply_min_p(logits, idx_mapping, self.sampling_states.min_p.gpu)

        # Apply top_k and/or top_p. This might return a new tensor.
        do_top_k = self.sampling_states.do_top_k(idx_mapping_np)
        top_k = self.sampling_states.top_k.gpu[idx_mapping] if do_top_k else None
        do_top_p = self.sampling_states.do_top_p(idx_mapping_np)
        top_p = self.sampling_states.top_p.gpu[idx_mapping] if do_top_p else None
        if do_top_k or do_top_p:
            logits = apply_top_k_top_p(logits, top_k, top_p)

        # Sample the next token.
134
135
        sampled = gumbel_sample(
            logits,
136
137
138
139
            idx_mapping,
            self.sampling_states.temperature.gpu,
            self.sampling_states.seeds.gpu,
            pos,
140
141
            apply_temperature=False,
        )
142
        return sampled, logits