tokens.py 6.88 KB
Newer Older
1
2
3
import re
import torch

4
from functools import lru_cache
5
6
7
8
from transformers import (
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
9
    TypicalLogitsWarper,
10
11
12
13
14
    RepetitionPenaltyLogitsProcessor,
    PreTrainedTokenizerBase,
)
from typing import List, Tuple, Optional

15
16
17
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
18
19
20
21
22
23
24
25
26


class Sampling:
    def __init__(self, seed: int, device: str = "cpu"):
        self.generator = torch.Generator(device)
        self.generator.manual_seed(seed)
        self.seed = seed

    def __call__(self, logits):
27
        probs = torch.nn.functional.softmax(logits, -1)
28
        # Avoid GPU<->CPU sync done by torch multinomial
29
        # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
30
31
        q = torch.empty_like(probs).exponential_(1, generator=self.generator)
        return probs.div_(q).argmax()
32
33
34
35
36
37
38


class Greedy:
    def __call__(self, logits):
        return logits.argmax()


39
class StaticWarper:
40
41
42
43
44
    def __init__(
        self,
        temperature=1.0,
        top_k=None,
        top_p=None,
45
        typical_p=None,
46
    ):
47
48
        self.warpers = []

49
50
        if temperature is not None and temperature != 1.0:
            temperature = float(temperature)
51
            self.warpers.append(TemperatureLogitsWarper(temperature))
52
        if top_k is not None and top_k != 0:
53
            self.warpers.append(TopKLogitsWarper(top_k=top_k))
54
        if top_p is not None and top_p < 1.0:
55
            self.warpers.append(TopPLogitsWarper(top_p=top_p))
56
        if typical_p is not None and typical_p < 1.0:
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
            self.warpers.append(TypicalLogitsWarper(mass=typical_p))

        self.cuda_graph = None
        self.static_scores = None
        self.static_warped_scores = None
        self.static_next_logprob = None

    def __call__(self, scores):
        if self.cuda_graph is None:
            self.static_scores = scores
            self.cuda_graph = torch.cuda.CUDAGraph()

            with torch.cuda.graph(self.cuda_graph):
                for warper in self.warpers:
                    self.static_warped_scores = warper(None, self.static_scores)

                # Compute logprobs
                self.static_next_logprob = torch.log_softmax(
                    self.static_warped_scores, -1
                )

        self.static_scores.copy_(scores)
        self.cuda_graph.replay()

        return self.static_warped_scores, self.static_next_logprob

83

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
@lru_cache(10)
def static_warper(
    temperature: Optional[float],
    top_k: Optional[int],
    top_p: Optional[float],
    typical_p: Optional[float],
) -> StaticWarper:
    return StaticWarper(
        temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
    )


class NextTokenChooser:
    def __init__(
        self,
        watermark=False,
        temperature=1.0,
        repetition_penalty=1.0,
        top_k=None,
        top_p=None,
        typical_p=None,
        do_sample=False,
        seed=0,
        device="cpu",
    ):
        self.watermark_processor = (
            WatermarkLogitsProcessor(device=device) if watermark else None
        )
        self.repetition_processor = (
            RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
            if repetition_penalty
            else None
        )

        has_warpers = (
            (temperature is not None and temperature != 1.0)
            or (top_k is not None and top_k != 0)
            or (top_p is not None and top_p < 1.0)
            or (typical_p is not None and typical_p < 1.0)
        )
        if has_warpers:
            self.static_warper = static_warper(
                temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
            )
        else:
            self.static_warper = None

        sampling = do_sample or has_warpers
132
133
134
        self.choice = Sampling(seed, device) if sampling else Greedy()

    def __call__(self, input_ids, scores):
135
136
137
138
        if self.watermark_processor:
            scores = self.watermark_processor(input_ids, scores)
        if self.repetition_processor:
            scores = self.repetition_processor(input_ids, scores)
139

140
141
142
143
        if self.static_warper is None:
            next_logprob = torch.log_softmax(scores, -1)
        else:
            scores, next_logprob = self.static_warper(scores)
144

145
        next_id = self.choice(scores[-1]).view(1, 1)
146

147
        return next_id, next_logprob
148
149
150

    @classmethod
    def from_pb(
151
152
153
        cls,
        pb: generate_pb2.NextTokenChooserParameters,
        device: torch.device,
154
155
    ) -> "NextTokenChooser":
        return NextTokenChooser(
156
            watermark=pb.watermark,
157
158
159
160
            temperature=pb.temperature,
            repetition_penalty=pb.repetition_penalty,
            top_k=pb.top_k,
            top_p=pb.top_p,
161
            typical_p=pb.typical_p,
162
163
164
165
166
167
168
169
            do_sample=pb.do_sample,
            seed=pb.seed,
            device=device,
        )


class StopSequenceCriteria:
    def __init__(self, stop_sequence: str):
170
        stop_sequence = re.escape(stop_sequence)
171
172
173
174
175
176
177
178
179
180
181
182
183
        self.regex = re.compile(f".*{stop_sequence}$")

    def __call__(self, output: str) -> bool:
        if self.regex.findall(output):
            return True
        return False


class StoppingCriteria:
    def __init__(
        self,
        eos_token_id: int,
        stop_sequence_criterias: List[StopSequenceCriteria],
184
185
        max_new_tokens: int = 20,
        ignore_eos_token: bool = False,
186
187
188
189
190
191
    ):
        self.eos_token_id = eos_token_id
        self.stop_sequence_criterias = stop_sequence_criterias
        self.max_new_tokens = max_new_tokens
        self.current_tokens = 0
        self.current_output = ""
192
        self.ignore_eos_token = ignore_eos_token
193
194
195
196
197
198

    def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
        self.current_tokens += 1
        if self.current_tokens >= self.max_new_tokens:
            return True, FinishReason.FINISH_REASON_LENGTH

199
        if not self.ignore_eos_token and last_token == self.eos_token_id:
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
            return True, FinishReason.FINISH_REASON_EOS_TOKEN

        self.current_output += last_output
        for stop_sequence_criteria in self.stop_sequence_criterias:
            if stop_sequence_criteria(self.current_output):
                return True, FinishReason.FINISH_REASON_STOP_SEQUENCE

        return False, None

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.StoppingCriteriaParameters,
        tokenizer: PreTrainedTokenizerBase,
    ) -> "StoppingCriteria":
        stop_sequence_criterias = [
            StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
        ]
        return StoppingCriteria(
219
220
221
222
            tokenizer.eos_token_id,
            stop_sequence_criterias,
            pb.max_new_tokens,
            pb.ignore_eos_token,
223
        )