tokens.py 23.1 KB
Newer Older
1
import re
2
from typing import List, Optional, Tuple, Set, Union
3

Nicolas Patry's avatar
Nicolas Patry committed
4
import torch
5
from text_generation_server.pb import generate_pb2
drbh's avatar
drbh committed
6
from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
7
from text_generation_server.utils.logits_process import (
8
    FrequencyPenaltyLogitsProcessor,
drbh's avatar
drbh committed
9
    GrammarLogitProcessor,
Nicolas Patry's avatar
Nicolas Patry committed
10
    HeterogeneousProcessorWrapper,
11
    HeterogeneousRepetitionPenaltyLogitsProcessor,
12
    HeterogeneousFrequencyPenaltyLogitsProcessor,
13
14
15
16
    HeterogeneousTemperatureLogitsWarper,
    HeterogeneousTopKLogitsWarper,
    HeterogeneousTopPLogitsWarper,
    HeterogeneousTypicalLogitsWarper,
drbh's avatar
drbh committed
17
    HeterogeneousGrammarLogitProcessor,
Nicolas Patry's avatar
Nicolas Patry committed
18
    static_warper,
19
)
Nicolas Patry's avatar
Nicolas Patry committed
20
21
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
22

OlivierDehaene's avatar
OlivierDehaene committed
23

24
25
26
class NextTokenChooser:
    def __init__(
        self,
drbh's avatar
drbh committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
        watermark: bool = False,
        temperature: float = 1.0,
        repetition_penalty: float = 1.0,
        frequency_penalty: float = 0.0,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        typical_p: Optional[float] = None,
        do_sample: bool = False,
        seed: int = 0,
        device: str = "cpu",
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        grammar: str = "",
        grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
        fsm_grammar_state: int = 0,
41
42
43
44
45
46
    ):
        self.watermark_processor = (
            WatermarkLogitsProcessor(device=device) if watermark else None
        )
        self.repetition_processor = (
            RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
47
48
49
50
51
52
            if repetition_penalty and repetition_penalty != 1.0
            else None
        )
        self.frequency_processor = (
            FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty)
            if frequency_penalty and frequency_penalty != 0.0
53
54
            else None
        )
drbh's avatar
drbh committed
55
56
57
58
59
60
        self.grammar_processor = (
            GrammarLogitProcessor(tokenizer, device, grammar, grammar_type)
            if grammar != ""
            else None
        )
        self.tokenizer = tokenizer
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

        has_warpers = (
            (temperature is not None and temperature != 1.0)
            or (top_k is not None and top_k != 0)
            or (top_p is not None and top_p < 1.0)
            or (typical_p is not None and typical_p < 1.0)
        )
        if has_warpers:
            self.static_warper = static_warper(
                temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
            )
        else:
            self.static_warper = None

        sampling = do_sample or has_warpers
drbh's avatar
drbh committed
76

77
        self.choice = Sampling(seed, device) if sampling else Greedy()
drbh's avatar
drbh committed
78
79
        self.fsm_grammar_state = fsm_grammar_state
        self.grammar = grammar
80
81

    def __call__(self, input_ids, scores):
82
        if self.watermark_processor is not None:
83
            scores = self.watermark_processor(input_ids, scores)
84
        if self.repetition_processor is not None:
85
            scores = self.repetition_processor(input_ids, scores)
86
87
        if self.frequency_processor is not None:
            scores = self.frequency_processor(input_ids, scores)
drbh's avatar
drbh committed
88
89
        if self.grammar_processor is not None:
            scores = self.grammar_processor(scores, self.fsm_grammar_state)
90

91
92
93
94
        if self.static_warper is None:
            next_logprob = torch.log_softmax(scores, -1)
        else:
            scores, next_logprob = self.static_warper(scores)
95

96
        next_id = self.choice(scores[-1]).view(1, 1)
97

98
        return next_id, next_logprob
99

drbh's avatar
drbh committed
100
101
102
103
104
105
106
    def advance_grammar(self, next_id: int):
        if self.grammar_processor is not None:
            self.fsm_grammar_state = self.grammar_processor.advance(
                next_id, self.fsm_grammar_state
            )
        return self

107
108
    @classmethod
    def from_pb(
109
110
111
        cls,
        pb: generate_pb2.NextTokenChooserParameters,
        device: torch.device,
drbh's avatar
drbh committed
112
        tokenizer: PreTrainedTokenizerBase,
113
114
    ) -> "NextTokenChooser":
        return NextTokenChooser(
115
            watermark=pb.watermark,
116
117
            temperature=pb.temperature,
            repetition_penalty=pb.repetition_penalty,
118
            frequency_penalty=pb.frequency_penalty,
119
120
            top_k=pb.top_k,
            top_p=pb.top_p,
121
            typical_p=pb.typical_p,
122
123
124
            do_sample=pb.do_sample,
            seed=pb.seed,
            device=device,
drbh's avatar
drbh committed
125
126
127
            tokenizer=tokenizer,
            grammar=pb.grammar,
            grammar_type=pb.grammar_type,
128
129
130
131
132
        )


class StopSequenceCriteria:
    def __init__(self, stop_sequence: str):
133
        stop_sequence = re.escape(stop_sequence)
134
        self.regex = re.compile(f"{stop_sequence}$")
135
136
137
138
139
140
141
142
143
144

    def __call__(self, output: str) -> bool:
        if self.regex.findall(output):
            return True
        return False


class StoppingCriteria:
    def __init__(
        self,
145
        eos_token_ids: Optional[Union[Set[int], int]],
146
        stop_sequence_criterias: List[StopSequenceCriteria],
147
148
        max_new_tokens: int = 20,
        ignore_eos_token: bool = False,
149
    ):
150
151
152
153
154
155
156
157
158
159
160
        if eos_token_ids is None:
            eos_token_ids = set()
        elif isinstance(eos_token_ids, int):
            eos_token_ids = set([eos_token_ids])
        elif isinstance(eos_token_ids, set):
            eos_token_ids = eos_token_ids
        else:
            raise RuntimeError(
                f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]"
            )
        self.eos_token_ids = eos_token_ids
161
162
163
        self.stop_sequence_criterias = stop_sequence_criterias
        self.max_new_tokens = max_new_tokens
        self.current_tokens = 0
164
        self.current_output = ""
165
        self.ignore_eos_token = ignore_eos_token
166
167
168
169
170
171

    def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
        self.current_tokens += 1
        if self.current_tokens >= self.max_new_tokens:
            return True, FinishReason.FINISH_REASON_LENGTH

172
173
174
175
        if isinstance(last_token, torch.Tensor):
            last_token = last_token.item()

        if not self.ignore_eos_token and last_token in self.eos_token_ids:
176
177
            return True, FinishReason.FINISH_REASON_EOS_TOKEN

178
179
180
181
182
183
184
185
186
        if self.stop_sequence_criterias:
            self.current_output += last_output
            # There is no need to keep an output that is too long
            if len(self.current_output) > 300:
                # Slice to -200 to avoid doing it all the time
                self.current_output = self.current_output[-200:]
            for stop_sequence_criteria in self.stop_sequence_criterias:
                if stop_sequence_criteria(self.current_output):
                    return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
187
188
189
190
191
192
193
194
195
196
197
198

        return False, None

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.StoppingCriteriaParameters,
        tokenizer: PreTrainedTokenizerBase,
    ) -> "StoppingCriteria":
        stop_sequence_criterias = [
            StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
        ]
199
200
        # TODO Hack because eos_token_id cannot be what we want.
        eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id)
201
        return StoppingCriteria(
202
            eos_token_id,
203
204
205
            stop_sequence_criterias,
            pb.max_new_tokens,
            pb.ignore_eos_token,
206
        )
207

OlivierDehaene's avatar
OlivierDehaene committed
208
209
210
211
212
213
214
215

def create_n_gram_speculation(
    input_ids: torch.Tensor,
    next_ids: torch.Tensor,
    accepted_ids: torch.Tensor,
    speculate: int,
    verbose: bool,
):
Nicolas Patry's avatar
Nicolas Patry committed
216
217
218
219
220
221
    # Very trivial approach, find first match in the string.
    # This is much less refined than actual n-gram but seems to work
    # relatively OK in grounded mode and is by far much faster with
    # much less worst case complexity as everything happens on device.
    B = accepted_ids.shape[0]
    device = input_ids.device
OlivierDehaene's avatar
OlivierDehaene committed
222
    seeds = next_ids[accepted_ids.cumsum(dim=-1) - 1]
Nicolas Patry's avatar
Nicolas Patry committed
223
    indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
OlivierDehaene's avatar
OlivierDehaene committed
224
225
226
    all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(
        speculate, device=device
    )
Nicolas Patry's avatar
Nicolas Patry committed
227
228
229
230
    all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)

    speculative_ids = input_ids.gather(dim=-1, index=all_indices)
    return speculative_ids
231

OlivierDehaene's avatar
OlivierDehaene committed
232

233
234
235
236
237
238
239
240
class HeterogeneousNextTokenChooser:
    def __init__(
        self,
        dtype: torch.dtype,
        device: torch.device,
        watermark: List[bool],
        temperature: List[float],
        repetition_penalty: List[float],
241
        frequency_penalty: List[float],
242
243
244
245
246
        top_k: List[int],
        top_p: List[float],
        typical_p: List[float],
        do_sample: List[bool],
        seeds: List[int],
drbh's avatar
drbh committed
247
248
249
250
        tokenizer: PreTrainedTokenizerBase,
        grammars: List[str],
        grammar_types: List[int],
        fsm_grammar_states=List[int],
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
    ):
        warpers = []

        self.watermark_processor = (
            HeterogeneousProcessorWrapper(
                {
                    i: WatermarkLogitsProcessor(device=device)
                    for i, do_watermark in enumerate(watermark)
                    if do_watermark
                }
            )
            if any(watermark)
            else None
        )

        self.repetition_processor = (
            HeterogeneousRepetitionPenaltyLogitsProcessor(
                repetition_penalty, dtype, device
            )
            if any([x != 1.0 for x in repetition_penalty])
            else None
        )

274
275
276
277
278
279
280
281
        self.frequency_processor = (
            HeterogeneousFrequencyPenaltyLogitsProcessor(
                frequency_penalty, dtype, device
            )
            if any([x != 0.0 for x in frequency_penalty])
            else None
        )

drbh's avatar
drbh committed
282
283
284
285
286
287
288
289
        self.grammar_processor = (
            HeterogeneousGrammarLogitProcessor(
                tokenizer, device, grammars, grammar_types
            )
            if any([grammar != "" for grammar in grammars])
            else None
        )

290
        if any(x != 1.0 for x in temperature):
291
292
293
294
295
296
297
            do_sample = [
                sample or x != 1.0 for x, sample in zip(temperature, do_sample)
            ]
            warpers.append(
                HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
            )

298
        if any(x != 0 for x in top_k):
299
300
301
            do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
            warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))

302
        if any(x < 1.0 for x in top_p):
303
304
305
            do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
            warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))

306
        if any(x < 1.0 for x in typical_p):
307
308
309
310
311
312
313
314
315
316
317
318
            do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
            warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))

        self.warpers = warpers

        if any(do_sample):
            self.choice = HeterogeneousSampling(do_sample, seeds, device)
        else:
            self.choice = Greedy()

        self.seeds = seeds
        self.do_sample = do_sample
319
320
        self.dtype = dtype
        self.device = device
drbh's avatar
drbh committed
321
322
323
324
        self.tokenizer = tokenizer
        self.fsm_grammar_states = fsm_grammar_states
        self.grammars = grammars
        self.grammar_types = grammar_types
325

OlivierDehaene's avatar
OlivierDehaene committed
326
327
328
329
330
331
332
333
334
    def __call__(
        self,
        input_ids: torch.Tensor,
        scores: torch.Tensor,
        speculate: int,
        speculated_ids: Optional[torch.Tensor] = None,
        speculative_scores: Optional[torch.Tensor] = None,
        verbose=False,
    ):
Nicolas Patry's avatar
Nicolas Patry committed
335
336
337
338
339
340
341
342
343
344
        if speculated_ids is not None:
            B = scores.shape[0] // (speculated_ids.shape[1] + 1)
            S = speculated_ids.shape[1] + 1
            scores = scores.view(B, S, -1)
        else:
            B = scores.shape[0]
            S = 1
            scores = scores.view(B, S, -1)

        next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
drbh's avatar
drbh committed
345

Nicolas Patry's avatar
Nicolas Patry committed
346
347
348
349
350
351
        for j in range(S):
            _scores = scores[:, j]
            if self.watermark_processor is not None:
                _scores = self.watermark_processor(input_ids, _scores)
            if self.repetition_processor is not None:
                _scores = self.repetition_processor(input_ids, _scores)
352
353
            if self.frequency_processor is not None:
                _scores = self.frequency_processor(input_ids, _scores)
drbh's avatar
drbh committed
354
            if self.grammar_processor is not None:
OlivierDehaene's avatar
OlivierDehaene committed
355
                _scores = self.grammar_processor(_scores, self.fsm_grammar_states)
356
357
            for warper in self.warpers:
                _scores = warper(input_ids, _scores)
Nicolas Patry's avatar
Nicolas Patry committed
358
359
360
            _next_ids = self.choice(_scores)
            scores[:, j] = _scores
            next_ids[:, j] = _next_ids
OlivierDehaene's avatar
OlivierDehaene committed
361
        next_ids = next_ids.view(B * S)
Nicolas Patry's avatar
Nicolas Patry committed
362
363
        allscores = scores.view(B * S, -1)
        alllogprobs = torch.log_softmax(allscores, -1)
Nicolas Patry's avatar
Nicolas Patry committed
364
365
366
367
368
369
370

        if speculated_ids is not None:
            accepted_ids = []
            B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
            S = speculated_ids.shape[1] + 1
            indices = []
            for i in range(B):
OlivierDehaene's avatar
OlivierDehaene committed
371
                _next_ids = next_ids[i * S : (i + 1) * S]
Nicolas Patry's avatar
Nicolas Patry committed
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
                _speculated_ids = speculated_ids[i]
                validate_speculative = _next_ids[:-1] == _speculated_ids
                index = i * S
                accepted = 1
                # First is always valid
                indices.append(index)
                for valid in validate_speculative.tolist():
                    if valid:
                        index += 1
                        accepted += 1
                        indices.append(index)
                    else:
                        break
                accepted_ids.append(accepted)

OlivierDehaene's avatar
OlivierDehaene committed
387
388
389
            accepted_ids = torch.tensor(
                accepted_ids, device=input_ids.device, dtype=input_ids.dtype
            )
Nicolas Patry's avatar
Nicolas Patry committed
390
            next_ids = next_ids[indices]
Nicolas Patry's avatar
Nicolas Patry committed
391
            logprobs = alllogprobs[indices]
Nicolas Patry's avatar
Nicolas Patry committed
392
393
394
395
396
            indices = torch.arange(B, device=input_ids.device) * S
            if speculative_scores is not None:
                speculative_scores = speculative_scores[indices + accepted_ids - 1]
        else:
            accepted_ids = torch.ones_like(next_ids)
Nicolas Patry's avatar
Nicolas Patry committed
397
            logprobs = alllogprobs
398

Nicolas Patry's avatar
Nicolas Patry committed
399
        next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
400

Nicolas Patry's avatar
Nicolas Patry committed
401
402
403
404
405
406
        if speculate > 0:
            if speculative_scores is not None:
                # Medusa provided some scores
                speculative_ids = Greedy()(speculative_scores)
            else:
                # n-gram
OlivierDehaene's avatar
OlivierDehaene committed
407
408
409
                speculative_ids = create_n_gram_speculation(
                    input_ids, next_ids, accepted_ids, speculate, verbose
                )
Nicolas Patry's avatar
Nicolas Patry committed
410
411
412
        else:
            speculative_ids = None

Nicolas Patry's avatar
Nicolas Patry committed
413
        return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
414

drbh's avatar
drbh committed
415
416
417
    def advance_grammar(self, next_ids: List[int]):
        if self.grammar_processor is not None:
            other_new_states = self.grammar_processor.advance_batch(
OlivierDehaene's avatar
OlivierDehaene committed
418
                next_ids, self.fsm_grammar_states
drbh's avatar
drbh committed
419
420
421
422
423
424
            )
            self.fsm_grammar_states = other_new_states
        return self

    def advance_grammar_single(self, grammar_state_index: int, next_id: int):
        if self.grammar_processor is not None:
OlivierDehaene's avatar
OlivierDehaene committed
425
426
427
428
429
430
            self.fsm_grammar_states[grammar_state_index] = (
                self.grammar_processor.advance_at_index(
                    next_id,
                    self.fsm_grammar_states[grammar_state_index],
                    grammar_state_index,
                )
drbh's avatar
drbh committed
431
432
433
            )
        return self

434
435
436
437
438
439
440
    def filter(self, indices):
        if self.watermark_processor is not None:
            self.watermark_processor = self.watermark_processor.filter(indices)

        if self.repetition_processor is not None:
            self.repetition_processor = self.repetition_processor.filter(indices)

441
442
443
        if self.frequency_processor is not None:
            self.frequency_processor = self.frequency_processor.filter(indices)

drbh's avatar
drbh committed
444
445
446
        if self.grammar_processor is not None:
            self.grammar_processor = self.grammar_processor.filter(indices)

447
448
449
450
451
452
453
454
455
456
        filtered_warpers = []
        for warper in self.warpers:
            filtered_warper = warper.filter(indices)
            if filtered_warper is not None:
                filtered_warpers.append(filtered_warper)
        self.warpers = filtered_warpers

        self.seeds = [self.seeds[i] for i in indices]
        self.do_sample = [self.do_sample[i] for i in indices]

drbh's avatar
drbh committed
457
458
459
460
461
462
463
464
465
466
467
468
        new_grammars = []
        new_fsm_grammar_states = []
        new_grammar_types = []
        for i in indices:
            new_grammars.append(self.grammars[i])
            new_fsm_grammar_states.append(self.fsm_grammar_states[i])
            new_grammar_types.append(self.grammar_types[i])

        self.grammars = new_grammars
        self.fsm_grammar_states = new_fsm_grammar_states
        self.grammar_types = new_grammar_types

469
470
471
472
473
474
475
476
477
478
479
480
481
        if any(self.do_sample):
            self.choice.filter(indices)
        else:
            self.choice = Greedy()

        return self

    @classmethod
    def from_pb(
        cls,
        pb: List[generate_pb2.NextTokenChooserParameters],
        dtype: torch.dtype,
        device: torch.device,
drbh's avatar
drbh committed
482
        tokenizer: PreTrainedTokenizerBase,
483
        fsm_grammar_states: Optional[List[int]] = None,
484
485
486
487
488
    ) -> "HeterogeneousNextTokenChooser":
        return HeterogeneousNextTokenChooser(
            watermark=[pb_.watermark for pb_ in pb],
            temperature=[pb_.temperature for pb_ in pb],
            repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
489
            frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
490
491
492
493
494
495
496
            top_k=[pb_.top_k for pb_ in pb],
            top_p=[pb_.top_p for pb_ in pb],
            typical_p=[pb_.typical_p for pb_ in pb],
            do_sample=[pb_.do_sample for pb_ in pb],
            seeds=[pb_.seed for pb_ in pb],
            device=device,
            dtype=dtype,
drbh's avatar
drbh committed
497
498
499
            tokenizer=tokenizer,
            grammars=[pb_.grammar for pb_ in pb],
            grammar_types=[pb_.grammar_type for pb_ in pb],
500
501
502
            fsm_grammar_states=(
                fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
            ),
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
        )


class Sampling:
    def __init__(self, seed: int, device: str = "cpu"):
        self.generator = torch.Generator(device)
        self.generator.manual_seed(seed)
        self.seed = seed

    def __call__(self, logits):
        probs = torch.nn.functional.softmax(logits, -1)
        # Avoid GPU<->CPU sync done by torch multinomial
        # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
        q = torch.empty_like(probs).exponential_(1, generator=self.generator)
        return probs.div_(q).argmax()


class Greedy:
    def __call__(self, logits):
        return logits.argmax(dim=-1)


class HeterogeneousSampling:
    r"""
    Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
    """

    def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
        self.seeds = seeds

        self.greedy_indices = []
        self.sampling_mapping = {}
        for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
            if sample:
                self.sampling_mapping[i] = Sampling(seed, device)
            else:
                self.greedy_indices.append(i)

        self.greedy = Greedy()

    def __call__(self, logits):
        out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
        if self.greedy_indices:
            # Computing for all indices is faster than slicing
            torch.argmax(logits, -1, out=out)

        for i, sampling in self.sampling_mapping.items():
            out[i] = sampling(logits[i])
        return out

    def filter(self, indices):
        new_greedy_indices = []
        new_sampling_mapping = {}
        for i, idx in enumerate(indices):
            if idx in self.sampling_mapping:
                new_sampling_mapping[i] = self.sampling_mapping[idx]
            else:
                new_greedy_indices.append(i)

        self.greedy_indices = new_greedy_indices
        self.sampling_mapping = new_sampling_mapping
        return self
Nicolas Patry's avatar
Nicolas Patry committed
565
566
567


def batch_top_tokens(
568
569
570
571
    top_n_tokens: List[int],
    top_n_tokens_tensor: torch.Tensor,
    logprobs: torch.Tensor,
    accepted_ids: torch.Tensor,
Nicolas Patry's avatar
Nicolas Patry committed
572
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
Nicolas Patry's avatar
Nicolas Patry committed
573
574
575
576
577
578
579
580
    """Find the top n most likely tokens for a batch of generations.

    When multiple tokens have equal probabilities and they don't all fit, the
    remaining tokens are also returned.
    """
    max_top_n = max(top_n_tokens)
    # Early exit when top_n_tokens is not used
    if max_top_n == 0:
Nicolas Patry's avatar
Nicolas Patry committed
581
582
583
584
585
        return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)

    batch_size = accepted_ids.shape[0]
    speculate_size = logprobs.shape[0] // batch_size
    top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
Nicolas Patry's avatar
Nicolas Patry committed
586
    # Ensure top_n doesn't exceed vocab size
587
588
589
590
591
    top_n_tokens = [
        min(tok, logprobs.size(-1))
        for tok in top_n_tokens
        for _ in range(speculate_size)
    ]
Nicolas Patry's avatar
Nicolas Patry committed
592
593
594

    # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
    # Sorted topk is faster than torch.sort() since we only need a small subset
Nicolas Patry's avatar
Nicolas Patry committed
595
596
    sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values

Nicolas Patry's avatar
Nicolas Patry committed
597
598
599
600
601
602
603
604
    nth_highest = torch.gather(
        sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
    )
    nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min

    # Find the new "fuzzy" top n values
    top_n_indices = (logprobs >= nth_highest).nonzero()
    _, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
OlivierDehaene's avatar
OlivierDehaene committed
605

606
    k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max()
Nicolas Patry's avatar
Nicolas Patry committed
607
    # Take a new topk for these new max n values
608
    top_k = torch.topk(logprobs, k=k, dim=1, sorted=True)
Nicolas Patry's avatar
Nicolas Patry committed
609
610
611
612
613

    top_n_ishes = top_n_ishes.tolist()
    top_indices = top_k.indices.tolist()
    top_values = top_k.values.tolist()

Nicolas Patry's avatar
Nicolas Patry committed
614
615
616
617
618
619
    batch_top_token_ids = []
    batch_top_token_logprobs = []
    accepted_ids_list = accepted_ids.tolist()
    for i, n_accepted_ids in enumerate(accepted_ids_list):
        start = speculate_size * i
        stop = speculate_size * (i + 1)
620
621
622
623
        _top_indices = top_indices[start:stop]
        _top_values = top_values[start:stop]
        _top_n_ishes = top_n_ishes[start:stop]
        _top_n_tokens = top_n_tokens[start:stop]
Nicolas Patry's avatar
Nicolas Patry committed
624
625
626
627
628
629
630
631
632

        _top_indices = _top_indices[:n_accepted_ids]
        _top_values = _top_values[:n_accepted_ids]
        _top_n_ishes = _top_n_ishes[:n_accepted_ids]
        _top_n_tokens = _top_n_tokens[:n_accepted_ids]

        row_top_token_ids = []
        row_top_token_logprobs = []

633
634
635
        for idxs, vals, n, req_n in zip(
            _top_indices, _top_values, _top_n_ishes, _top_n_tokens
        ):
Nicolas Patry's avatar
Nicolas Patry committed
636
637
638
639
640
641
642
643
644
645
            indices = idxs[:n] if req_n > 0 else []
            values = vals[:n] if req_n > 0 else []

            row_top_token_ids.append(indices)
            row_top_token_logprobs.append(values)

        batch_top_token_ids.append(row_top_token_ids)
        batch_top_token_logprobs.append(row_top_token_logprobs)

    return batch_top_token_ids, batch_top_token_logprobs