sampler.py 5.41 KB
Newer Older
1
import dataclasses
2
import logging
3
from typing import Union
4
5
6
7
8
9
10
11
12
13

import torch
from flashinfer.sampling import (
    min_p_sampling_from_probs,
    top_k_renorm_prob,
    top_k_top_p_sampling_from_probs,
    top_p_renorm_prob,
)
from vllm.model_executor.custom_op import CustomOp

14
15
from sglang.srt.layers.logits_processor import LogitsProcessorOutput

16
17
18
19
20
21
22
# TODO: move this dict to another place
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo

logger = logging.getLogger(__name__)


23
24
25
26
27
28
29
@dataclasses.dataclass
class SampleOutput:
    success: torch.Tensor
    probs: torch.Tensor
    batch_next_token_ids: torch.Tensor


30
31
32
33
class Sampler(CustomOp):
    def __init__(self):
        super().__init__()

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
        # min-token, presence, frequency
        if sampling_info.linear_penalties is not None:
            logits += sampling_info.linear_penalties

        # repetition
        if sampling_info.scaling_penalties is not None:
            logits = torch.where(
                logits > 0,
                logits / sampling_info.scaling_penalties,
                logits * sampling_info.scaling_penalties,
            )

        return logits

    def _get_probs(
        self,
        logits: torch.Tensor,
        sampling_info: SamplingBatchInfo,
        is_torch_compile: bool = False,
    ):
55
56
57
        # Post process logits
        logits = logits.contiguous()
        logits.div_(sampling_info.temperatures)
58
59
60
61
        if is_torch_compile:
            # FIXME: Temporary workaround for unknown bugs in torch.compile
            logits.add_(0)

62
63
64
65
66
67
        if sampling_info.logit_bias is not None:
            logits.add_(sampling_info.logit_bias)

        if sampling_info.vocab_mask is not None:
            logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))

68
        logits = self._apply_penalties(logits, sampling_info)
69

70
71
72
73
74
75
76
77
78
79
80
        return torch.softmax(logits, dim=-1)

    def forward_cuda(
        self,
        logits: Union[torch.Tensor, LogitsProcessorOutput],
        sampling_info: SamplingBatchInfo,
    ):
        if isinstance(logits, LogitsProcessorOutput):
            logits = logits.next_token_logits

        probs = self._get_probs(logits, sampling_info)
81
82
83
84
85
86

        if not global_server_args_dict["disable_flashinfer_sampling"]:
            max_top_k_round, batch_size = 32, probs.shape[0]
            uniform_samples = torch.rand(
                (max_top_k_round, batch_size), device=probs.device
            )
87
            if sampling_info.need_min_p_sampling:
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
                probs = top_k_renorm_prob(probs, sampling_info.top_ks)
                probs = top_p_renorm_prob(probs, sampling_info.top_ps)
                batch_next_token_ids, success = min_p_sampling_from_probs(
                    probs, uniform_samples, sampling_info.min_ps
                )
            else:
                batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
                    probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
                )
        else:
            # Here we provide a slower fallback implementation.
            batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
                probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
            )

103
        return SampleOutput(success, probs, batch_next_token_ids)
104

105
106
107
108
109
110
111
112
113
114
115
116
117
    def forward_native(
        self,
        logits: Union[torch.Tensor, LogitsProcessorOutput],
        sampling_info: SamplingBatchInfo,
    ):
        if isinstance(logits, LogitsProcessorOutput):
            logits = logits.next_token_logits

        probs = self._get_probs(logits, sampling_info, is_torch_compile=True)

        batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
            probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
        )
118

119
        return SampleOutput(success, probs, batch_next_token_ids)
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139


def top_k_top_p_min_p_sampling_from_probs_torch(
    probs: torch.Tensor,
    top_ks: torch.Tensor,
    top_ps: torch.Tensor,
    min_ps: torch.Tensor,
):
    """A top-k, top-p and min-p sampling implementation with native pytorch operations."""
    probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    min_p_thresholds = probs_sort[:, 0] * min_ps
    probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
    probs_sort[
        torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
        >= top_ks.view(-1, 1)
    ] = 0.0
    probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
    probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
    try:
140
141
142
143
        # FIXME: torch.multiomial does not support num_samples = 1
        sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
            :, :1
        ]
144
145
146
147
148
149
150
151
152
153
154
    except RuntimeError as e:
        logger.warning(f"Sampling error: {e}")
        batch_next_token_ids = torch.zeros(
            (probs_sort.shape[0],), dtype=torch.int32, device=probs.device
        )
        success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
        return batch_next_token_ids, success

    batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
    success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
    return batch_next_token_ids, success