sampler.py 4.91 KB
Newer Older
1
"""A layer that samples the next tokens from the model's outputs."""
2
from typing import Tuple
3
4
5
6
7
8

import torch
import torch.nn as nn

from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
                                          apply_min_token_penalties)
11
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
12
13
14
15
16
17

_SAMPLING_EPS = 1e-5


class Sampler(nn.Module):

18
19
20
21
    def __init__(self):
        super().__init__()
        self.topk_topp_sampler = TopKTopPSampler()

22
23
24
25
26
    def forward(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> SamplerOutput:
27
28
29
30
31
32
33
34
35
36
37
        needs_logprobs = sampling_metadata.max_num_logprobs > 0
        if needs_logprobs:
            # NOTE(woosuk): Use the original logits (before any penalties or
            # temperature scaling) for the top-k logprobs.
            # This is different from the V0 sampler, which uses the logits that
            # is used for sampling (after penalties and temperature scaling).
            # NOTE: We compute logprobs first because the below ops may
            # modify the logits tensor in-place (and we don't want to clone
            # the logits tensor for memory efficiency).
            topk_logprobs, topk_indices = self.get_topk_logprobs(
                logits, sampling_metadata)
38
39
40
41
        else:
            topk_logprobs = None
            topk_indices = None

42
43
44
45
46
47
48
49
50
51
52
        # Use float32 for the logits.
        logits = logits.to(torch.float32)
        # Apply penalties (e.g., min_tokens, freq_penalties).
        logits = self.apply_penalties(logits, sampling_metadata)
        # Apply temperature.
        logits = self.apply_temperature(logits, sampling_metadata.temperature)
        # Sample the next token.
        sampled = self.sample(logits, sampling_metadata)
        # Use int32 to reduce the tensor size.
        sampled = sampled.to(torch.int32)

53
        sampler_output = SamplerOutput(
54
            sampled_token_ids=sampled,
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
            logprob_token_ids=topk_indices,
            logprobs=topk_logprobs,
            prompt_logprob_token_ids=None,
            prompt_logprobs=None,
        )
        return sampler_output

    def apply_temperature(
        self,
        logits: torch.Tensor,
        temp: torch.Tensor,
    ) -> torch.Tensor:
        # Avoid division by zero.
        temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
        # Use in-place division to avoid creating a new tensor.
        logits.div_(temp.unsqueeze(dim=1))
        return logits

73
74
75
76
    def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
        return logits.argmax(dim=-1).view(-1)

    def sample(
77
78
79
80
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> torch.Tensor:
81
82
83
84
85
86
        assert not (sampling_metadata.all_greedy
                    and sampling_metadata.all_random)
        if sampling_metadata.all_greedy:
            return self.greedy_sample(logits)

        random_sampled = self.topk_topp_sampler(
87
            logits,
88
            sampling_metadata.generators,
89
90
91
92
93
94
            sampling_metadata.no_top_k,
            sampling_metadata.top_k,
            sampling_metadata.no_top_p,
            sampling_metadata.top_p,
        )
        if sampling_metadata.all_random:
95
            return random_sampled
96

97
        greedy_sampled = self.greedy_sample(logits)
98
99
100
101
102
103
104
        sampled = torch.where(
            sampling_metadata.temperature < _SAMPLING_EPS,
            greedy_sampled,
            random_sampled,
        )
        return sampled

105
106
107
108
109
110
111
112
113
114
115
116
117
    def get_topk_logprobs(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        logprobs = logits.log_softmax(dim=-1, dtype=torch.float32)
        # FIXME: Mask the sampled token_id, get topk logprobs,
        # and concatenate the topk with the sampled token_id.
        topk_logprobs, topk_indices = torch.topk(
            logprobs, sampling_metadata.max_num_logprobs, dim=-1)
        # Use int32 to reduce the tensor size.
        topk_indices = topk_indices.to(torch.int32)
        return topk_logprobs, topk_indices
118

119
120
121
122
123
124
125
126
127
128
    def apply_penalties(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> torch.Tensor:
        apply_min_token_penalties(logits, sampling_metadata.output_token_ids,
                                  sampling_metadata.stop_token_ids,
                                  sampling_metadata.min_tokens)
        if not sampling_metadata.no_penalties:
            assert sampling_metadata.prompt_token_ids is not None
Woosuk Kwon's avatar
Woosuk Kwon committed
129
130
131
132
133
134
            logits = apply_all_penalties(
                logits, sampling_metadata.prompt_token_ids,
                sampling_metadata.presence_penalties,
                sampling_metadata.frequency_penalties,
                sampling_metadata.repetition_penalties,
                sampling_metadata.output_token_ids)
135
        return logits