tokens.py 5.17 KB
Newer Older
1
2
3
4
5
6
7
8
import re
import torch

from transformers import (
    LogitsProcessorList,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
9
    TypicalLogitsWarper,
10
11
12
13
14
    RepetitionPenaltyLogitsProcessor,
    PreTrainedTokenizerBase,
)
from typing import List, Tuple, Optional

15
16
17
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
18
19
20
21
22
23
24
25
26


class Sampling:
    def __init__(self, seed: int, device: str = "cpu"):
        self.generator = torch.Generator(device)
        self.generator.manual_seed(seed)
        self.seed = seed

    def __call__(self, logits):
27
        probs = torch.nn.functional.softmax(logits, -1)
28
29
30
31
32
33
34
35
36
37
38
39
        next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
        return next_tokens


class Greedy:
    def __call__(self, logits):
        return logits.argmax()


class NextTokenChooser:
    def __init__(
        self,
40
        watermark=False,
41
42
43
44
        temperature=1.0,
        repetition_penalty=1.0,
        top_k=None,
        top_p=None,
45
        typical_p=None,
46
47
48
49
50
51
52
53
        do_sample=False,
        seed=0,
        device="cpu",
    ):
        warpers = LogitsProcessorList()
        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
        sampling = do_sample
54
55

        if watermark:
56
            warpers.append(WatermarkLogitsProcessor(device=device))
57
58
        if repetition_penalty is not None and repetition_penalty != 1.0:
            warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
59
60
61
62
63
64
65
66
67
68
        if temperature is not None and temperature != 1.0:
            temperature = float(temperature)
            warpers.append(TemperatureLogitsWarper(temperature))
            sampling = True
        if top_k is not None and top_k != 0:
            warpers.append(TopKLogitsWarper(top_k=top_k))
            sampling = True
        if top_p is not None and top_p < 1.0:
            warpers.append(TopPLogitsWarper(top_p=top_p))
            sampling = True
69
70
71
        if typical_p is not None and typical_p < 1.0:
            warpers.append(TypicalLogitsWarper(mass=typical_p))
            sampling = True
72
73
74
75
76
77

        self.warpers = warpers
        self.choice = Sampling(seed, device) if sampling else Greedy()

    def __call__(self, input_ids, scores):
        # Warp logits
78
79
80
81
82
        if scores.shape[0] > 1:
            # only warp the last token logits
            scores[-1:, :] = self.warpers(input_ids, scores[-1:, :])
        else:
            scores = self.warpers(input_ids, scores)
83
84
85
86
87
88
89
90
91
92
93

        # Compute logprobs
        logprobs = torch.log_softmax(scores, -1)

        # Choose tokens
        next_id = self.choice(scores[-1])

        return next_id.view(1, 1), logprobs

    @classmethod
    def from_pb(
94
95
96
        cls,
        pb: generate_pb2.NextTokenChooserParameters,
        device: torch.device,
97
98
    ) -> "NextTokenChooser":
        return NextTokenChooser(
99
            watermark=pb.watermark,
100
101
102
103
            temperature=pb.temperature,
            repetition_penalty=pb.repetition_penalty,
            top_k=pb.top_k,
            top_p=pb.top_p,
104
            typical_p=pb.typical_p,
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
            do_sample=pb.do_sample,
            seed=pb.seed,
            device=device,
        )


class StopSequenceCriteria:
    def __init__(self, stop_sequence: str):
        self.regex = re.compile(f".*{stop_sequence}$")

    def __call__(self, output: str) -> bool:
        if self.regex.findall(output):
            return True
        return False


class StoppingCriteria:
    def __init__(
        self,
        eos_token_id: int,
        stop_sequence_criterias: List[StopSequenceCriteria],
126
127
        max_new_tokens: int = 20,
        ignore_eos_token: bool = False,
128
129
130
131
132
133
    ):
        self.eos_token_id = eos_token_id
        self.stop_sequence_criterias = stop_sequence_criterias
        self.max_new_tokens = max_new_tokens
        self.current_tokens = 0
        self.current_output = ""
134
        self.ignore_eos_token = ignore_eos_token
135
136
137
138
139
140

    def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
        self.current_tokens += 1
        if self.current_tokens >= self.max_new_tokens:
            return True, FinishReason.FINISH_REASON_LENGTH

141
        if not self.ignore_eos_token and last_token == self.eos_token_id:
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
            return True, FinishReason.FINISH_REASON_EOS_TOKEN

        self.current_output += last_output
        for stop_sequence_criteria in self.stop_sequence_criterias:
            if stop_sequence_criteria(self.current_output):
                return True, FinishReason.FINISH_REASON_STOP_SEQUENCE

        return False, None

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.StoppingCriteriaParameters,
        tokenizer: PreTrainedTokenizerBase,
    ) -> "StoppingCriteria":
        stop_sequence_criterias = [
            StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
        ]
        return StoppingCriteria(
161
162
163
164
            tokenizer.eos_token_id,
            stop_sequence_criterias,
            pb.max_new_tokens,
            pb.ignore_eos_token,
165
        )