tokens.py 23.1 KB
Newer Older
1
import re
2
from typing import List, Optional, Tuple, Set, Union
3

drbh's avatar
drbh committed
4
import math
Nicolas Patry's avatar
Nicolas Patry committed
5
import torch
6
from text_generation_server.pb import generate_pb2
drbh's avatar
drbh committed
7
from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
8
from text_generation_server.utils.logits_process import (
9
    FrequencyPenaltyLogitsProcessor,
drbh's avatar
drbh committed
10
    GrammarLogitProcessor,
Nicolas Patry's avatar
Nicolas Patry committed
11
    HeterogeneousProcessorWrapper,
12
    HeterogeneousRepetitionPenaltyLogitsProcessor,
13
    HeterogeneousFrequencyPenaltyLogitsProcessor,
14
15
16
17
    HeterogeneousTemperatureLogitsWarper,
    HeterogeneousTopKLogitsWarper,
    HeterogeneousTopPLogitsWarper,
    HeterogeneousTypicalLogitsWarper,
drbh's avatar
drbh committed
18
    HeterogeneousGrammarLogitProcessor,
Nicolas Patry's avatar
Nicolas Patry committed
19
    static_warper,
20
)
Nicolas Patry's avatar
Nicolas Patry committed
21
22
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
23

OlivierDehaene's avatar
OlivierDehaene committed
24

25
26
27
class NextTokenChooser:
    def __init__(
        self,
drbh's avatar
drbh committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        watermark: bool = False,
        temperature: float = 1.0,
        repetition_penalty: float = 1.0,
        frequency_penalty: float = 0.0,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        typical_p: Optional[float] = None,
        do_sample: bool = False,
        seed: int = 0,
        device: str = "cpu",
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        grammar: str = "",
        grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
        fsm_grammar_state: int = 0,
42
43
44
45
46
47
    ):
        self.watermark_processor = (
            WatermarkLogitsProcessor(device=device) if watermark else None
        )
        self.repetition_processor = (
            RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
48
49
50
51
52
53
            if repetition_penalty and repetition_penalty != 1.0
            else None
        )
        self.frequency_processor = (
            FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty)
            if frequency_penalty and frequency_penalty != 0.0
54
55
            else None
        )
drbh's avatar
drbh committed
56
57
58
59
60
61
        self.grammar_processor = (
            GrammarLogitProcessor(tokenizer, device, grammar, grammar_type)
            if grammar != ""
            else None
        )
        self.tokenizer = tokenizer
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

        has_warpers = (
            (temperature is not None and temperature != 1.0)
            or (top_k is not None and top_k != 0)
            or (top_p is not None and top_p < 1.0)
            or (typical_p is not None and typical_p < 1.0)
        )
        if has_warpers:
            self.static_warper = static_warper(
                temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
            )
        else:
            self.static_warper = None

        sampling = do_sample or has_warpers
drbh's avatar
drbh committed
77

78
        self.choice = Sampling(seed, device) if sampling else Greedy()
drbh's avatar
drbh committed
79
80
        self.fsm_grammar_state = fsm_grammar_state
        self.grammar = grammar
81
82

    def __call__(self, input_ids, scores):
83
        if self.watermark_processor is not None:
84
            scores = self.watermark_processor(input_ids, scores)
85
        if self.repetition_processor is not None:
86
            scores = self.repetition_processor(input_ids, scores)
87
88
        if self.frequency_processor is not None:
            scores = self.frequency_processor(input_ids, scores)
drbh's avatar
drbh committed
89
90
        if self.grammar_processor is not None:
            scores = self.grammar_processor(scores, self.fsm_grammar_state)
91

92
93
94
95
        if self.static_warper is None:
            next_logprob = torch.log_softmax(scores, -1)
        else:
            scores, next_logprob = self.static_warper(scores)
96

97
        next_id = self.choice(scores[-1]).view(1, 1)
98

99
        return next_id, next_logprob
100

drbh's avatar
drbh committed
101
102
103
104
105
106
107
    def advance_grammar(self, next_id: int):
        if self.grammar_processor is not None:
            self.fsm_grammar_state = self.grammar_processor.advance(
                next_id, self.fsm_grammar_state
            )
        return self

108
109
    @classmethod
    def from_pb(
110
111
112
        cls,
        pb: generate_pb2.NextTokenChooserParameters,
        device: torch.device,
drbh's avatar
drbh committed
113
        tokenizer: PreTrainedTokenizerBase,
114
115
    ) -> "NextTokenChooser":
        return NextTokenChooser(
116
            watermark=pb.watermark,
117
118
            temperature=pb.temperature,
            repetition_penalty=pb.repetition_penalty,
119
            frequency_penalty=pb.frequency_penalty,
120
121
            top_k=pb.top_k,
            top_p=pb.top_p,
122
            typical_p=pb.typical_p,
123
124
125
            do_sample=pb.do_sample,
            seed=pb.seed,
            device=device,
drbh's avatar
drbh committed
126
127
128
            tokenizer=tokenizer,
            grammar=pb.grammar,
            grammar_type=pb.grammar_type,
129
130
131
132
133
        )


class StopSequenceCriteria:
    def __init__(self, stop_sequence: str):
134
        stop_sequence = re.escape(stop_sequence)
135
        self.regex = re.compile(f"{stop_sequence}$")
136
137
138
139
140
141
142
143
144
145

    def __call__(self, output: str) -> bool:
        if self.regex.findall(output):
            return True
        return False


class StoppingCriteria:
    def __init__(
        self,
146
        eos_token_ids: Optional[Union[Set[int], int]],
147
        stop_sequence_criterias: List[StopSequenceCriteria],
148
149
        max_new_tokens: int = 20,
        ignore_eos_token: bool = False,
150
    ):
151
152
153
154
155
156
157
158
159
160
161
        if eos_token_ids is None:
            eos_token_ids = set()
        elif isinstance(eos_token_ids, int):
            eos_token_ids = set([eos_token_ids])
        elif isinstance(eos_token_ids, set):
            eos_token_ids = eos_token_ids
        else:
            raise RuntimeError(
                f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]"
            )
        self.eos_token_ids = eos_token_ids
162
163
164
        self.stop_sequence_criterias = stop_sequence_criterias
        self.max_new_tokens = max_new_tokens
        self.current_tokens = 0
165
        self.current_output = ""
166
        self.ignore_eos_token = ignore_eos_token
167
168
169
170
171
172

    def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
        self.current_tokens += 1
        if self.current_tokens >= self.max_new_tokens:
            return True, FinishReason.FINISH_REASON_LENGTH

173
174
175
176
        if isinstance(last_token, torch.Tensor):
            last_token = last_token.item()

        if not self.ignore_eos_token and last_token in self.eos_token_ids:
177
178
            return True, FinishReason.FINISH_REASON_EOS_TOKEN

179
180
181
182
183
184
185
186
187
        if self.stop_sequence_criterias:
            self.current_output += last_output
            # There is no need to keep an output that is too long
            if len(self.current_output) > 300:
                # Slice to -200 to avoid doing it all the time
                self.current_output = self.current_output[-200:]
            for stop_sequence_criteria in self.stop_sequence_criterias:
                if stop_sequence_criteria(self.current_output):
                    return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
188
189
190
191
192
193
194
195
196
197
198
199

        return False, None

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.StoppingCriteriaParameters,
        tokenizer: PreTrainedTokenizerBase,
    ) -> "StoppingCriteria":
        stop_sequence_criterias = [
            StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
        ]
200
201
        # TODO Hack because eos_token_id cannot be what we want.
        eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id)
202
        return StoppingCriteria(
203
            eos_token_id,
204
205
206
            stop_sequence_criterias,
            pb.max_new_tokens,
            pb.ignore_eos_token,
207
        )
208

OlivierDehaene's avatar
OlivierDehaene committed
209
210
211
212
213
214
215
216

def create_n_gram_speculation(
    input_ids: torch.Tensor,
    next_ids: torch.Tensor,
    accepted_ids: torch.Tensor,
    speculate: int,
    verbose: bool,
):
Nicolas Patry's avatar
Nicolas Patry committed
217
218
219
220
221
222
    # Very trivial approach, find first match in the string.
    # This is much less refined than actual n-gram but seems to work
    # relatively OK in grounded mode and is by far much faster with
    # much less worst case complexity as everything happens on device.
    B = accepted_ids.shape[0]
    device = input_ids.device
OlivierDehaene's avatar
OlivierDehaene committed
223
    seeds = next_ids[accepted_ids.cumsum(dim=-1) - 1]
Nicolas Patry's avatar
Nicolas Patry committed
224
    indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
OlivierDehaene's avatar
OlivierDehaene committed
225
226
227
    all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(
        speculate, device=device
    )
Nicolas Patry's avatar
Nicolas Patry committed
228
229
230
231
    all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)

    speculative_ids = input_ids.gather(dim=-1, index=all_indices)
    return speculative_ids
232

OlivierDehaene's avatar
OlivierDehaene committed
233

234
235
236
237
238
239
240
241
class HeterogeneousNextTokenChooser:
    def __init__(
        self,
        dtype: torch.dtype,
        device: torch.device,
        watermark: List[bool],
        temperature: List[float],
        repetition_penalty: List[float],
242
        frequency_penalty: List[float],
243
244
245
246
247
        top_k: List[int],
        top_p: List[float],
        typical_p: List[float],
        do_sample: List[bool],
        seeds: List[int],
drbh's avatar
drbh committed
248
249
250
251
        tokenizer: PreTrainedTokenizerBase,
        grammars: List[str],
        grammar_types: List[int],
        fsm_grammar_states=List[int],
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
    ):
        warpers = []

        self.watermark_processor = (
            HeterogeneousProcessorWrapper(
                {
                    i: WatermarkLogitsProcessor(device=device)
                    for i, do_watermark in enumerate(watermark)
                    if do_watermark
                }
            )
            if any(watermark)
            else None
        )

        self.repetition_processor = (
            HeterogeneousRepetitionPenaltyLogitsProcessor(
                repetition_penalty, dtype, device
            )
            if any([x != 1.0 for x in repetition_penalty])
            else None
        )

275
276
277
278
279
280
281
282
        self.frequency_processor = (
            HeterogeneousFrequencyPenaltyLogitsProcessor(
                frequency_penalty, dtype, device
            )
            if any([x != 0.0 for x in frequency_penalty])
            else None
        )

drbh's avatar
drbh committed
283
284
285
286
287
288
289
290
        self.grammar_processor = (
            HeterogeneousGrammarLogitProcessor(
                tokenizer, device, grammars, grammar_types
            )
            if any([grammar != "" for grammar in grammars])
            else None
        )

291
        if any(x != 1.0 for x in temperature):
292
293
294
295
296
297
298
            do_sample = [
                sample or x != 1.0 for x, sample in zip(temperature, do_sample)
            ]
            warpers.append(
                HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
            )

299
        if any(x != 0 for x in top_k):
300
301
302
            do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
            warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))

303
        if any(x < 1.0 for x in top_p):
304
305
306
            do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
            warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))

307
        if any(x < 1.0 for x in typical_p):
308
309
310
311
312
313
314
315
316
317
318
319
            do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
            warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))

        self.warpers = warpers

        if any(do_sample):
            self.choice = HeterogeneousSampling(do_sample, seeds, device)
        else:
            self.choice = Greedy()

        self.seeds = seeds
        self.do_sample = do_sample
320
321
        self.dtype = dtype
        self.device = device
drbh's avatar
drbh committed
322
323
324
325
        self.tokenizer = tokenizer
        self.fsm_grammar_states = fsm_grammar_states
        self.grammars = grammars
        self.grammar_types = grammar_types
326

OlivierDehaene's avatar
OlivierDehaene committed
327
328
329
330
331
332
333
334
335
    def __call__(
        self,
        input_ids: torch.Tensor,
        scores: torch.Tensor,
        speculate: int,
        speculated_ids: Optional[torch.Tensor] = None,
        speculative_scores: Optional[torch.Tensor] = None,
        verbose=False,
    ):
Nicolas Patry's avatar
Nicolas Patry committed
336
337
338
339
340
341
342
343
344
345
        if speculated_ids is not None:
            B = scores.shape[0] // (speculated_ids.shape[1] + 1)
            S = speculated_ids.shape[1] + 1
            scores = scores.view(B, S, -1)
        else:
            B = scores.shape[0]
            S = 1
            scores = scores.view(B, S, -1)

        next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
drbh's avatar
drbh committed
346

Nicolas Patry's avatar
Nicolas Patry committed
347
348
349
350
351
352
        for j in range(S):
            _scores = scores[:, j]
            if self.watermark_processor is not None:
                _scores = self.watermark_processor(input_ids, _scores)
            if self.repetition_processor is not None:
                _scores = self.repetition_processor(input_ids, _scores)
353
354
            if self.frequency_processor is not None:
                _scores = self.frequency_processor(input_ids, _scores)
drbh's avatar
drbh committed
355
            if self.grammar_processor is not None:
OlivierDehaene's avatar
OlivierDehaene committed
356
                _scores = self.grammar_processor(_scores, self.fsm_grammar_states)
357
358
            for warper in self.warpers:
                _scores = warper(input_ids, _scores)
Nicolas Patry's avatar
Nicolas Patry committed
359
360
361
            _next_ids = self.choice(_scores)
            scores[:, j] = _scores
            next_ids[:, j] = _next_ids
OlivierDehaene's avatar
OlivierDehaene committed
362
        next_ids = next_ids.view(B * S)
Nicolas Patry's avatar
Nicolas Patry committed
363
364
        allscores = scores.view(B * S, -1)
        alllogprobs = torch.log_softmax(allscores, -1)
Nicolas Patry's avatar
Nicolas Patry committed
365
366
367
368
369
370
371

        if speculated_ids is not None:
            accepted_ids = []
            B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
            S = speculated_ids.shape[1] + 1
            indices = []
            for i in range(B):
OlivierDehaene's avatar
OlivierDehaene committed
372
                _next_ids = next_ids[i * S : (i + 1) * S]
Nicolas Patry's avatar
Nicolas Patry committed
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
                _speculated_ids = speculated_ids[i]
                validate_speculative = _next_ids[:-1] == _speculated_ids
                index = i * S
                accepted = 1
                # First is always valid
                indices.append(index)
                for valid in validate_speculative.tolist():
                    if valid:
                        index += 1
                        accepted += 1
                        indices.append(index)
                    else:
                        break
                accepted_ids.append(accepted)

OlivierDehaene's avatar
OlivierDehaene committed
388
389
390
            accepted_ids = torch.tensor(
                accepted_ids, device=input_ids.device, dtype=input_ids.dtype
            )
Nicolas Patry's avatar
Nicolas Patry committed
391
            next_ids = next_ids[indices]
Nicolas Patry's avatar
Nicolas Patry committed
392
            logprobs = alllogprobs[indices]
Nicolas Patry's avatar
Nicolas Patry committed
393
394
395
396
397
            indices = torch.arange(B, device=input_ids.device) * S
            if speculative_scores is not None:
                speculative_scores = speculative_scores[indices + accepted_ids - 1]
        else:
            accepted_ids = torch.ones_like(next_ids)
Nicolas Patry's avatar
Nicolas Patry committed
398
            logprobs = alllogprobs
399

Nicolas Patry's avatar
Nicolas Patry committed
400
        next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
401

Nicolas Patry's avatar
Nicolas Patry committed
402
403
404
405
406
407
        if speculate > 0:
            if speculative_scores is not None:
                # Medusa provided some scores
                speculative_ids = Greedy()(speculative_scores)
            else:
                # n-gram
OlivierDehaene's avatar
OlivierDehaene committed
408
409
410
                speculative_ids = create_n_gram_speculation(
                    input_ids, next_ids, accepted_ids, speculate, verbose
                )
Nicolas Patry's avatar
Nicolas Patry committed
411
412
413
        else:
            speculative_ids = None

Nicolas Patry's avatar
Nicolas Patry committed
414
        return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
415

drbh's avatar
drbh committed
416
417
418
    def advance_grammar(self, next_ids: List[int]):
        if self.grammar_processor is not None:
            other_new_states = self.grammar_processor.advance_batch(
OlivierDehaene's avatar
OlivierDehaene committed
419
                next_ids, self.fsm_grammar_states
drbh's avatar
drbh committed
420
421
422
423
424
425
            )
            self.fsm_grammar_states = other_new_states
        return self

    def advance_grammar_single(self, grammar_state_index: int, next_id: int):
        if self.grammar_processor is not None:
OlivierDehaene's avatar
OlivierDehaene committed
426
427
428
429
430
431
            self.fsm_grammar_states[grammar_state_index] = (
                self.grammar_processor.advance_at_index(
                    next_id,
                    self.fsm_grammar_states[grammar_state_index],
                    grammar_state_index,
                )
drbh's avatar
drbh committed
432
433
434
            )
        return self

435
436
437
438
439
440
441
    def filter(self, indices):
        if self.watermark_processor is not None:
            self.watermark_processor = self.watermark_processor.filter(indices)

        if self.repetition_processor is not None:
            self.repetition_processor = self.repetition_processor.filter(indices)

442
443
444
        if self.frequency_processor is not None:
            self.frequency_processor = self.frequency_processor.filter(indices)

drbh's avatar
drbh committed
445
446
447
        if self.grammar_processor is not None:
            self.grammar_processor = self.grammar_processor.filter(indices)

448
449
450
451
452
453
454
455
456
457
        filtered_warpers = []
        for warper in self.warpers:
            filtered_warper = warper.filter(indices)
            if filtered_warper is not None:
                filtered_warpers.append(filtered_warper)
        self.warpers = filtered_warpers

        self.seeds = [self.seeds[i] for i in indices]
        self.do_sample = [self.do_sample[i] for i in indices]

drbh's avatar
drbh committed
458
459
460
461
462
463
464
465
466
467
468
469
        new_grammars = []
        new_fsm_grammar_states = []
        new_grammar_types = []
        for i in indices:
            new_grammars.append(self.grammars[i])
            new_fsm_grammar_states.append(self.fsm_grammar_states[i])
            new_grammar_types.append(self.grammar_types[i])

        self.grammars = new_grammars
        self.fsm_grammar_states = new_fsm_grammar_states
        self.grammar_types = new_grammar_types

470
471
472
473
474
475
476
477
478
479
480
481
482
        if any(self.do_sample):
            self.choice.filter(indices)
        else:
            self.choice = Greedy()

        return self

    @classmethod
    def from_pb(
        cls,
        pb: List[generate_pb2.NextTokenChooserParameters],
        dtype: torch.dtype,
        device: torch.device,
drbh's avatar
drbh committed
483
        tokenizer: PreTrainedTokenizerBase,
484
        fsm_grammar_states: Optional[List[int]] = None,
485
486
487
488
489
    ) -> "HeterogeneousNextTokenChooser":
        return HeterogeneousNextTokenChooser(
            watermark=[pb_.watermark for pb_ in pb],
            temperature=[pb_.temperature for pb_ in pb],
            repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
490
            frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
491
492
493
494
495
496
497
            top_k=[pb_.top_k for pb_ in pb],
            top_p=[pb_.top_p for pb_ in pb],
            typical_p=[pb_.typical_p for pb_ in pb],
            do_sample=[pb_.do_sample for pb_ in pb],
            seeds=[pb_.seed for pb_ in pb],
            device=device,
            dtype=dtype,
drbh's avatar
drbh committed
498
499
500
            tokenizer=tokenizer,
            grammars=[pb_.grammar for pb_ in pb],
            grammar_types=[pb_.grammar_type for pb_ in pb],
501
502
503
            fsm_grammar_states=(
                fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
            ),
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
        )


class Sampling:
    def __init__(self, seed: int, device: str = "cpu"):
        self.generator = torch.Generator(device)
        self.generator.manual_seed(seed)
        self.seed = seed

    def __call__(self, logits):
        probs = torch.nn.functional.softmax(logits, -1)
        # Avoid GPU<->CPU sync done by torch multinomial
        # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
        q = torch.empty_like(probs).exponential_(1, generator=self.generator)
        return probs.div_(q).argmax()


class Greedy:
    def __call__(self, logits):
        return logits.argmax(dim=-1)


class HeterogeneousSampling:
    r"""
    Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
    """

    def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
        self.seeds = seeds

        self.greedy_indices = []
        self.sampling_mapping = {}
        for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
            if sample:
                self.sampling_mapping[i] = Sampling(seed, device)
            else:
                self.greedy_indices.append(i)

        self.greedy = Greedy()

    def __call__(self, logits):
        out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
        if self.greedy_indices:
            # Computing for all indices is faster than slicing
            torch.argmax(logits, -1, out=out)

        for i, sampling in self.sampling_mapping.items():
            out[i] = sampling(logits[i])
        return out

    def filter(self, indices):
        new_greedy_indices = []
        new_sampling_mapping = {}
        for i, idx in enumerate(indices):
            if idx in self.sampling_mapping:
                new_sampling_mapping[i] = self.sampling_mapping[idx]
            else:
                new_greedy_indices.append(i)

        self.greedy_indices = new_greedy_indices
        self.sampling_mapping = new_sampling_mapping
        return self
Nicolas Patry's avatar
Nicolas Patry committed
566
567
568


def batch_top_tokens(
569
570
571
572
    top_n_tokens: List[int],
    top_n_tokens_tensor: torch.Tensor,
    logprobs: torch.Tensor,
    accepted_ids: torch.Tensor,
Nicolas Patry's avatar
Nicolas Patry committed
573
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
Nicolas Patry's avatar
Nicolas Patry committed
574
575
576
577
578
579
580
581
    """Find the top n most likely tokens for a batch of generations.

    When multiple tokens have equal probabilities and they don't all fit, the
    remaining tokens are also returned.
    """
    max_top_n = max(top_n_tokens)
    # Early exit when top_n_tokens is not used
    if max_top_n == 0:
Nicolas Patry's avatar
Nicolas Patry committed
582
583
584
585
586
        return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)

    batch_size = accepted_ids.shape[0]
    speculate_size = logprobs.shape[0] // batch_size
    top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
Nicolas Patry's avatar
Nicolas Patry committed
587
    # Ensure top_n doesn't exceed vocab size
588
589
590
591
592
    top_n_tokens = [
        min(tok, logprobs.size(-1))
        for tok in top_n_tokens
        for _ in range(speculate_size)
    ]
Nicolas Patry's avatar
Nicolas Patry committed
593
594
595

    # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
    # Sorted topk is faster than torch.sort() since we only need a small subset
Nicolas Patry's avatar
Nicolas Patry committed
596
597
    sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values

Nicolas Patry's avatar
Nicolas Patry committed
598
599
600
601
602
603
604
605
    nth_highest = torch.gather(
        sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
    )
    nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min

    # Find the new "fuzzy" top n values
    top_n_indices = (logprobs >= nth_highest).nonzero()
    _, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
OlivierDehaene's avatar
OlivierDehaene committed
606

607
    k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max()
Nicolas Patry's avatar
Nicolas Patry committed
608
    # Take a new topk for these new max n values
609
    top_k = torch.topk(logprobs, k=k, dim=1, sorted=True)
Nicolas Patry's avatar
Nicolas Patry committed
610
611
612
613
614

    top_n_ishes = top_n_ishes.tolist()
    top_indices = top_k.indices.tolist()
    top_values = top_k.values.tolist()

Nicolas Patry's avatar
Nicolas Patry committed
615
616
617
618
619
620
    batch_top_token_ids = []
    batch_top_token_logprobs = []
    accepted_ids_list = accepted_ids.tolist()
    for i, n_accepted_ids in enumerate(accepted_ids_list):
        start = speculate_size * i
        stop = speculate_size * (i + 1)
621
622
623
624
        _top_indices = top_indices[start:stop]
        _top_values = top_values[start:stop]
        _top_n_ishes = top_n_ishes[start:stop]
        _top_n_tokens = top_n_tokens[start:stop]
Nicolas Patry's avatar
Nicolas Patry committed
625
626
627
628
629
630
631
632
633

        _top_indices = _top_indices[:n_accepted_ids]
        _top_values = _top_values[:n_accepted_ids]
        _top_n_ishes = _top_n_ishes[:n_accepted_ids]
        _top_n_tokens = _top_n_tokens[:n_accepted_ids]

        row_top_token_ids = []
        row_top_token_logprobs = []

634
635
636
        for idxs, vals, n, req_n in zip(
            _top_indices, _top_values, _top_n_ishes, _top_n_tokens
        ):
Nicolas Patry's avatar
Nicolas Patry committed
637
638
639
640
641
642
643
644
645
646
            indices = idxs[:n] if req_n > 0 else []
            values = vals[:n] if req_n > 0 else []

            row_top_token_ids.append(indices)
            row_top_token_logprobs.append(values)

        batch_top_token_ids.append(row_top_token_ids)
        batch_top_token_logprobs.append(row_top_token_logprobs)

    return batch_top_token_ids, batch_top_token_logprobs