tokenizer.py 8.86 KB
Newer Older
q.yao's avatar
q.yao committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import json
q.yao's avatar
q.yao committed
3
import os.path as osp
4
from typing import Optional, Sequence, Union
5

q.yao's avatar
q.yao committed
6
import torch
7

q.yao's avatar
q.yao committed
8

q.yao's avatar
q.yao committed
9
10
class SentencePieceTokenizer:
    """Tokenizer of sentencepiece.
lvhan028's avatar
lvhan028 committed
11
12
13
14

    Args:
        model_file (str): the path of the tokenizer model
    """
q.yao's avatar
q.yao committed
15
16

    def __init__(self, model_file: str):
q.yao's avatar
q.yao committed
17
18
        from sentencepiece import SentencePieceProcessor
        self.model = SentencePieceProcessor(model_file=model_file)
19
        self._prefix_space_tokens = None
q.yao's avatar
q.yao committed
20

q.yao's avatar
q.yao committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    @property
    def vocab_size(self):
        """vocabulary size."""
        return self.model.vocab_size()

    @property
    def bos_token_id(self):
        """begine of the sentence token id."""
        return self.model.bos_id()

    @property
    def eos_token_id(self):
        """end of the sentence token id."""
        return self.model.eos_id()
q.yao's avatar
q.yao committed
35

36
    @property
37
    def prefix_space_tokens(self):
38
        """tokens without prefix space."""
39
        if self._prefix_space_tokens is None:
40
            vocab = self.model.IdToPiece(list(range(self.vocab_size)))
41
            self._prefix_space_tokens = {
42
                i
43
                for i, tok in enumerate(vocab) if tok.startswith('▁')
44
            }
45
        return self._prefix_space_tokens
46
47
48

    def _maybe_add_prefix_space(self, tokens, decoded):
        """maybe add prefix space for incremental decoding."""
49
50
        if len(tokens) and not decoded.startswith(' ') and\
                tokens[0] in self.prefix_space_tokens:
51
52
53
54
            return ' ' + decoded
        else:
            return decoded

q.yao's avatar
q.yao committed
55
    def encode(self, s: str):
lvhan028's avatar
lvhan028 committed
56
57
58
59
60
61
62
        """Tokenize a prompt.

        Args:
            s (str): a prompt
        Returns:
            list[int]: token ids
        """
q.yao's avatar
q.yao committed
63
64
65
66
67
68
69
70
71
        add_bos = False
        add_eos = False
        if s.find('<BOS>') != -1:
            s = s.replace('<BOS>', '')
            add_bos = True
        if s == '<EOS>':
            s = ''
            add_eos = True
        return self.model.Encode(s, add_bos=add_bos, add_eos=add_eos)
q.yao's avatar
q.yao committed
72

73
    def decode(self, t: Sequence[int], offset: Optional[int] = None):
lvhan028's avatar
lvhan028 committed
74
75
76
77
        """De-tokenize.

        Args:
            t (List[int]): a list of token ids
78
79
            offset (int): for incrementally decoding. Default to None, which
                means not applied.
lvhan028's avatar
lvhan028 committed
80
81
82
        Returns:
            str: text of decoding tokens
        """
q.yao's avatar
q.yao committed
83
84
        if isinstance(t, torch.Tensor):
            t = t.tolist()
85
86
87
88
89
        t = t[offset:]
        out_string = self.model.Decode(t)
        if offset:
            out_string = self._maybe_add_prefix_space(t, out_string)
        return out_string
q.yao's avatar
q.yao committed
90

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    def __call__(self, s: Union[str, Sequence[str]]):
        """Tokenize prompts.

        Args:
            s (str): prompts
        Returns:
            list[int]: token ids
        """
        import addict
        add_bos = False
        add_eos = False

        input_ids = self.model.Encode(s, add_bos=add_bos, add_eos=add_eos)
        return addict.Addict(input_ids=input_ids)

106

q.yao's avatar
q.yao committed
107
108
class HuggingFaceTokenizer:
    """Tokenizer of sentencepiece.
lvhan028's avatar
lvhan028 committed
109
110

    Args:
q.yao's avatar
q.yao committed
111
        model_dir (str): the directory of the tokenizer model
lvhan028's avatar
lvhan028 committed
112
    """
113

q.yao's avatar
q.yao committed
114
    def __init__(self, model_dir: str):
115
        from transformers import AutoTokenizer
q.yao's avatar
q.yao committed
116
117
118
119
120
121
122
123
        model_file = osp.join(model_dir, 'tokenizer.model')
        backend_tokenizer_file = osp.join(model_dir, 'tokenizer.json')
        model_file_exists = osp.exists(model_file)
        if not osp.exists(backend_tokenizer_file) and model_file_exists:
            print('WARNING: Can not find tokenizer.json. '
                  'It may take long time to initialize the tokenizer.')
        self.model = AutoTokenizer.from_pretrained(model_dir,
                                                   trust_remote_code=True)
124
        self._prefix_space_tokens = None
q.yao's avatar
q.yao committed
125
126
127
128
        # save tokenizer.json to reuse
        if not osp.exists(backend_tokenizer_file) and model_file_exists:
            if hasattr(self.model, 'backend_tokenizer'):
                self.model.backend_tokenizer.save(backend_tokenizer_file)
q.yao's avatar
q.yao committed
129

130
131
132
        if self.model.eos_token_id is None:
            generation_config_file = osp.join(model_dir,
                                              'generation_config.json')
133
134
135
136
137
138
            if osp.exists(generation_config_file):
                with open(generation_config_file, 'r') as f:
                    cfg = json.load(f)
                    self.model.eos_token_id = cfg['eos_token_id']
            elif hasattr(self.model, 'eod_id'):  # Qwen remote
                self.model.eos_token_id = self.model.eod_id
139

q.yao's avatar
q.yao committed
140
141
142
143
    @property
    def vocab_size(self):
        """vocabulary size."""
        return self.model.vocab_size
q.yao's avatar
q.yao committed
144

q.yao's avatar
q.yao committed
145
146
147
148
    @property
    def bos_token_id(self):
        """begine of the sentence token id."""
        return self.model.bos_token_id
q.yao's avatar
q.yao committed
149

q.yao's avatar
q.yao committed
150
151
152
153
154
    @property
    def eos_token_id(self):
        """end of the sentence token id."""
        return self.model.eos_token_id

155
    @property
156
    def prefix_space_tokens(self):
157
        """tokens without prefix space."""
158
        if self._prefix_space_tokens is None:
159
160
            vocab = self.model.convert_ids_to_tokens(
                list(range(self.vocab_size)))
161
            self._prefix_space_tokens = {
162
                i
163
164
                for i, tok in enumerate(vocab)
                if tok.startswith('▁' if isinstance(tok, str) else b' ')
165
            }
166
        return self._prefix_space_tokens
167
168
169

    def _maybe_add_prefix_space(self, tokens, decoded):
        """maybe add prefix space for incremental decoding."""
170
171
        if len(tokens) and not decoded.startswith(' ') and\
                tokens[0] in self.prefix_space_tokens:
172
173
174
175
            return ' ' + decoded
        else:
            return decoded

q.yao's avatar
q.yao committed
176
177
    def encode(self, s: str):
        """Tokenize a prompt.
q.yao's avatar
q.yao committed
178

q.yao's avatar
q.yao committed
179
180
        Args:
            s (str): a prompt
q.yao's avatar
q.yao committed
181
        Returns:
q.yao's avatar
q.yao committed
182
            list[int]: token ids
q.yao's avatar
q.yao committed
183
        """
q.yao's avatar
q.yao committed
184
185
186
187
188
189
190
191
        add_special_tokens = False
        if s.find('<BOS>') != -1:
            s = s.replace('<BOS>', '<s>')
        if s == '<EOS>':
            s = '</s>'
        if len(s) == 0:
            add_special_tokens = True
        return self.model.encode(s, add_special_tokens=add_special_tokens)
192

193
    def decode(self, t: Sequence[int], offset: Optional[int] = None):
q.yao's avatar
q.yao committed
194
195
196
197
        """De-tokenize.

        Args:
            t (List[int]): a list of token ids
198
199
            offset (int): for incrementally decoding. Default to None, which
                means not applied.
q.yao's avatar
q.yao committed
200
201
202
203
        Returns:
            str: text of decoding tokens
        """
        skip_special_tokens = True
204
205
206
207
208
209
        t = t[offset:]
        out_string = self.model.decode(t,
                                       skip_special_tokens=skip_special_tokens)
        if offset:
            out_string = self._maybe_add_prefix_space(t, out_string)
        return out_string
q.yao's avatar
q.yao committed
210

211
212
213
214
215
216
217
218
219
220
221
    def __call__(self, s: Union[str, Sequence[str]]):
        """Tokenize prompts.

        Args:
            s (str): prompts
        Returns:
            list[int]: token ids
        """
        add_special_tokens = False
        return self.model(s, add_special_tokens=add_special_tokens)

q.yao's avatar
q.yao committed
222

q.yao's avatar
q.yao committed
223
224
class Tokenizer:
    """Tokenize prompts or de-tokenize tokens into texts.
lvhan028's avatar
lvhan028 committed
225
226

    Args:
q.yao's avatar
q.yao committed
227
        model_file (str): the path of the tokenizer model
lvhan028's avatar
lvhan028 committed
228
    """
229

q.yao's avatar
q.yao committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    def __init__(self, model_file: str):
        if model_file.endswith('.model'):
            model_folder = osp.split(model_file)[0]
        else:
            model_folder = model_file
            model_file = osp.join(model_folder, 'tokenizer.model')
        tokenizer_config_file = osp.join(model_folder, 'tokenizer_config.json')

        model_file_exists = osp.exists(model_file)
        config_exists = osp.exists(tokenizer_config_file)
        use_hf_model = config_exists or not model_file_exists

        if not use_hf_model:
            self.model = SentencePieceTokenizer(model_file)
        else:
            self.model = HuggingFaceTokenizer(model_folder)

    @property
    def vocab_size(self):
        """vocabulary size."""
        return self.model.vocab_size
q.yao's avatar
q.yao committed
251

q.yao's avatar
q.yao committed
252
253
254
255
    @property
    def bos_token_id(self):
        """begine of the sentence token id."""
        return self.model.bos_token_id
q.yao's avatar
q.yao committed
256

q.yao's avatar
q.yao committed
257
258
259
260
261
262
263
    @property
    def eos_token_id(self):
        """end of the sentence token id."""
        return self.model.eos_token_id

    def encode(self, s: str):
        """Tokenize a prompt.
q.yao's avatar
q.yao committed
264
265

        Args:
q.yao's avatar
q.yao committed
266
267
268
269
270
271
            s (str): a prompt
        Returns:
            list[int]: token ids
        """
        return self.model.encode(s)

272
    def decode(self, t: Sequence[int], offset: Optional[int] = None):
q.yao's avatar
q.yao committed
273
        """De-tokenize.
q.yao's avatar
q.yao committed
274

q.yao's avatar
q.yao committed
275
276
        Args:
            t (List[int]): a list of token ids
277
278
            offset (int): for incrementally decoding. Default to None, which
                means not applied.
q.yao's avatar
q.yao committed
279
        Returns:
q.yao's avatar
q.yao committed
280
            str: text of decoding tokens
q.yao's avatar
q.yao committed
281
        """
282
        return self.model.decode(t, offset)
283
284
285
286
287
288
289
290
291
292

    def __call__(self, s: Union[str, Sequence[str]]):
        """Tokenize prompts.

        Args:
            s (str): prompts
        Returns:
            list[int]: token ids
        """
        return self.model(s)