rejection_sampler.py 8.92 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2

3
4
5
6
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence

7
from vllm import envs
8
from vllm.logger import init_logger
9
from vllm.platforms import current_platform
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata

try:
    import flashinfer.sampling as fs
    is_flashinfer_available = True
except ImportError:
    is_flashinfer_available = False

logger = init_logger(__name__)
INVALID_TOKEN_ID = -1


class RejectionSampler(nn.Module):

25
26
    def __init__(self):
        super().__init__()
27
        if current_platform.is_cuda():
28
29
            if is_flashinfer_available:
                if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
30
31
32
33
34
35
36
37
38
                    # FIXME(woosuk): Currently, we have errors when using
                    # FlashInfer for rejection sampling. As a workaround, we
                    # disable FlashInfer for rejection sampling by default.
                    logger.info("Currently, FlashInfer rejection sampler is "
                                "disabled because of a bug. Falling back to "
                                "the PyTorch-native implementation of "
                                "rejection sampling.")
                    self.forward_method = self.forward_native

39
40
41
42
43
44
45
46
                    # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
                    # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
                    # default it is unused). For backward compatibility, we set
                    # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
                    # interpret it differently in V0 and V1 samplers: In V0,
                    # None means False, while in V1, None means True. This is
                    # why we use the condition
                    # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
47
48
                    # logger.info("Using FlashInfer for rejection sampling.")
                    # self.forward_method = self.flashinfer_sample
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
                else:
                    logger.warning(
                        "FlashInfer is available, but it is not enabled. "
                        "Falling back to the PyTorch-native implementation of "
                        "rejection sampling. For the best performance, "
                        "please set VLLM_USE_FLASHINFER_SAMPLER=1.")
                    self.forward_method = self.forward_native
            else:
                logger.warning(
                    "FlashInfer is not available. Falling back to the PyTorch-"
                    "native implementation of rejection sampling. For the "
                    "best performance, please install FlashInfer.")
                self.forward_method = self.forward_native
        else:
            self.forward_method = self.forward_native

65
    def forward(self, draft_token_ids: list[list[int]],
66
                target_probs: torch.Tensor,
67
68
69
                sampling_metadata: SamplingMetadata) -> SamplerOutput:
        if not sampling_metadata.all_greedy:
            raise NotImplementedError(
70
71
                "Currently, only greedy sampling is supported by "
                "rejection sampler.")
72
73
        return self.forward_method(draft_token_ids, target_probs,
                                   sampling_metadata)
74
75

    def flashinfer_sample(
76
        self,
77
        draft_token_ids: list[list[int]],
78
        target_probs: torch.Tensor,
79
80
        sampling_metadata: SamplingMetadata,
    ) -> SamplerOutput:
81
82
83
        # NOTE: The following input preparationg can be moved
        # to the model runner with a persistent manner for better
        # performance.
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        sample_lens = [len(x) + 1 for x in draft_token_ids]
        # Convert draft token IDs to a tensor, split by sample_lens, then pad.
        draft_token_ids = [
            torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
        ]
        draft_token_ids_tensor = pad_sequence(draft_token_ids,
                                              batch_first=True,
                                              padding_value=INVALID_TOKEN_ID)

        if sampling_metadata.all_greedy:
            target_token_ids = target_probs.argmax(dim=-1).view(-1)
            target_token_ids = target_token_ids.split(sample_lens)
            target_token_ids = pad_sequence(target_token_ids,
                                            batch_first=True,
                                            padding_value=INVALID_TOKEN_ID)

            vocab_size = target_probs.size(-1)
            # NOTE: CPU <-> GPU synchronization happens here.
            draft_token_ids_tensor = draft_token_ids_tensor.to(
                target_probs.device)
            draft_probs = _create_greedy_token_probs(draft_token_ids_tensor,
                                                     vocab_size,
                                                     target_probs.device)
            target_probs = _create_greedy_token_probs(target_token_ids,
                                                      vocab_size,
                                                      target_probs.device)
            uniform_samples = torch.zeros(draft_token_ids_tensor.size(0),
                                          draft_token_ids_tensor.size(1) + 1,
                                          device=target_probs.device)
        else:
            raise NotImplementedError(
                "Currently, only greedy sampling is supported by "
                "rejection sampler.")
117
118
119

        sampled_token_ids, _, _ = fs.chain_speculative_sampling(
            draft_probs,
120
            draft_token_ids_tensor,
121
122
123
124
125
126
127
            uniform_samples,
            target_probs,
        )
        return SamplerOutput(sampled_token_ids=sampled_token_ids,
                             logprobs_tensors=None)

    # TODO: The following method can be optimized for better performance.
128
129
    def forward_native(
        self,
130
        draft_token_ids: list[list[int]],
131
        target_probs: torch.Tensor,
132
133
        sampling_metadata: SamplingMetadata,
    ) -> SamplerOutput:
134
135
136
137
        sample_lens = [len(x) + 1 for x in draft_token_ids]
        # Convert draft token IDs to a tensor, split by sample_lens, then pad.
        draft_token_ids = [
            torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
138
        ]
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        draft_token_ids_tensor = pad_sequence(draft_token_ids,
                                              batch_first=True,
                                              padding_value=INVALID_TOKEN_ID)
        draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device)
        # Add 1 to include the 'bonus' token.
        if sampling_metadata.all_greedy:
            output_token_ids = target_probs.argmax(dim=-1).view(-1)
            output_token_ids = output_token_ids.split(sample_lens)
            output_token_ids = pad_sequence(output_token_ids,
                                            batch_first=True,
                                            padding_value=INVALID_TOKEN_ID)
            # Produce a mask that remains 1 (True) until the first
            # mismatch (cumprod turns 0 after a mismatch).
            accept_mask = (
                output_token_ids[:, :-1] == draft_token_ids_tensor).cumprod(
                    dim=1)
        else:
            raise NotImplementedError(
                "Currently, only greedy sampling is supported by "
                "rejection sampler.")
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        # Identify valid positions (non-padding).
        valid_mask = output_token_ids != INVALID_TOKEN_ID
        # Generate mask with bonus token.
        generate_mask = torch.cat([
            accept_mask,
            torch.zeros(accept_mask.size(0), 1, device=accept_mask.device)
        ],
                                  dim=1).to(torch.bool) & valid_mask
        zeros_mask = (generate_mask == 0)
        first_zero_idx = zeros_mask.float().argmax(dim=1)
        # Figure out which rows actually contain at least one zero.
        rows_with_zero = zeros_mask.any(dim=1)
        # Use indexing to set the first zero in each of those rows to 1.
        generate_mask[rows_with_zero, first_zero_idx[rows_with_zero]] = 1

        output_token_ids[~generate_mask] = INVALID_TOKEN_ID
        return SamplerOutput(sampled_token_ids=output_token_ids,
                             logprobs_tensors=None)


179
180
181
182
183
184
185
186
187
188
189
190
def _create_greedy_token_probs(
    token_ids: torch.Tensor,
    vocab_size: int,
    out_device: torch.device,
) -> torch.Tensor:
    batch_size, num_tokens = token_ids.shape

    token_probs = torch.zeros(batch_size,
                              num_tokens,
                              vocab_size,
                              dtype=torch.float,
                              device=out_device)
191

192
193
194
195
    # Ignore INVALID_TOKEN_ID.
    valid_mask = (token_ids != INVALID_TOKEN_ID)
    valid_indices = token_ids.clone()
    valid_indices[~valid_mask] = 0
196

197
198
199
    token_probs.scatter_(dim=2,
                         index=valid_indices.unsqueeze(-1),
                         src=valid_mask.unsqueeze(-1).float())
200

201
    return token_probs