sampler.py 5.17 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
import numpy as np
5
6
import torch

7
import vllm.envs as envs
8
from vllm.config.model import LogprobsMode
9
from vllm.sampling_params import SamplingParams
10
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
11
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
12
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
13
from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState
14
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
15
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
16
from vllm.v1.worker.gpu.sample.output import SamplerOutput
17
18
from vllm.v1.worker.gpu.sample.penalties import PenaltiesState
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates
19
20
21
22
23


class Sampler:
    def __init__(
        self,
24
25
26
        max_num_reqs: int,
        vocab_size: int,
        device: torch.device,
27
28
29
30
31
        logprobs_mode: LogprobsMode = "raw_logprobs",
    ):
        if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
            raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
        self.logprobs_mode = logprobs_mode
32
        self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS  # False by default.
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        self.sampling_states = SamplingStates(max_num_reqs, vocab_size)
        self.penalties_state = PenaltiesState(max_num_reqs, vocab_size, device)
        self.logit_bias_state = LogitBiasState(max_num_reqs, device)

    def add_request(
        self,
        req_idx: int,
        prompt_len: int,
        sampling_params: SamplingParams,
    ) -> None:
        self.sampling_states.add_request(req_idx, sampling_params)
        self.penalties_state.add_request(req_idx, sampling_params)
        self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params)

    def apply_staged_writes(
        self,
        prefill_token_ids: torch.Tensor,
        prefill_lens: np.ndarray,
        prompt_lens: np.ndarray,
    ) -> None:
        self.sampling_states.apply_staged_writes()
        self.penalties_state.apply_staged_writes(
            prefill_token_ids, prefill_lens, prompt_lens
        )
        self.logit_bias_state.apply_staged_writes()

60
61
62
    def __call__(
        self,
        logits: torch.Tensor,
63
64
65
        idx_mapping: torch.Tensor,
        idx_mapping_np: np.ndarray,
        pos: torch.Tensor,
66
    ) -> SamplerOutput:
67
68
69
        # NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
        # that num_nans is computed before applying penalties and temperature.
        num_nans = get_num_nans(logits) if self.compute_nans else None
70
71
72
73
74
75
        sampled, processed_logits = self.sample(
            logits, idx_mapping, idx_mapping_np, pos
        )

        max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
        if max_num_logprobs != NO_LOGPROBS:
76
77
78
79
80
            logits = (
                processed_logits
                if self.logprobs_mode == "processed_logprobs"
                else logits
            )
81
            logprobs_tensors = compute_topk_logprobs(logits, max_num_logprobs, sampled)
82
83
84
85
86
87
88
89
90
91
        else:
            logprobs_tensors = None

        # These are GPU tensors.
        sampler_output = SamplerOutput(
            # The sampled tokens are expanded to 2D tensor with shape
            # [num_requests, 1], where each row represents one generated
            # token per request.
            sampled_token_ids=sampled.view(-1, 1),
            logprobs_tensors=logprobs_tensors,
92
            num_nans=num_nans,
93
94
95
96
97
98
        )
        return sampler_output

    def sample(
        self,
        logits: torch.Tensor,
99
100
101
        idx_mapping: torch.Tensor,
        idx_mapping_np: np.ndarray,
        pos: torch.Tensor,
102
103
104
105
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Copy logits to a new FP32 tensor.
        logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)

106
107
108
        # Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
        self.logit_bias_state.apply_logit_bias(logits, idx_mapping, pos)

109
        # Apply penalties and temperature in place.
110
111
        self.penalties_state.apply_penalties_and_temperature(
            logits, idx_mapping, self.sampling_states.temperature.gpu
112
113
        )

114
115
116
117
118
119
120
121
122
123
124
125
126
127
        # Apply min_p in place if any request has a non-zero min_p.
        do_min_p = self.sampling_states.do_min_p(idx_mapping_np)
        if do_min_p:
            apply_min_p(logits, idx_mapping, self.sampling_states.min_p.gpu)

        # Apply top_k and/or top_p. This might return a new tensor.
        do_top_k = self.sampling_states.do_top_k(idx_mapping_np)
        top_k = self.sampling_states.top_k.gpu[idx_mapping] if do_top_k else None
        do_top_p = self.sampling_states.do_top_p(idx_mapping_np)
        top_p = self.sampling_states.top_p.gpu[idx_mapping] if do_top_p else None
        if do_top_k or do_top_p:
            logits = apply_top_k_top_p(logits, top_k, top_p)

        # Sample the next token.
128
129
        sampled = gumbel_sample(
            logits,
130
131
132
133
            idx_mapping,
            self.sampling_states.temperature.gpu,
            self.sampling_states.seeds.gpu,
            pos,
134
135
            apply_temperature=False,
        )
136
        return sampled, logits