tokenizer.py 5.4 KB
Newer Older
q.yao's avatar
q.yao committed
1
2
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
3
4
from typing import Sequence, Union

q.yao's avatar
q.yao committed
5
import torch
6
7
from torch.nn.utils.rnn import pad_sequence

q.yao's avatar
q.yao committed
8
9

class Tokenizer:
lvhan028's avatar
lvhan028 committed
10
11
12
13
14
    """Tokenize prompts or de-tokenize tokens into texts.

    Args:
        model_file (str): the path of the tokenizer model
    """
q.yao's avatar
q.yao committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

    def __init__(self, model_file: str):
        if model_file.endswith('.model'):
            model_folder = osp.split(model_file)[0]
        else:
            model_folder = model_file
        tokenizer_config_file = osp.join(model_folder, 'tokenizer_config.json')

        model_file_exists = osp.exists(model_file)
        config_exists = osp.exists(tokenizer_config_file)
        use_hf_model = not config_exists or not model_file_exists

        self.use_hf_model = use_hf_model
        if not self.use_hf_model:
            from sentencepiece import SentencePieceProcessor
            self.model = SentencePieceProcessor(model_file=model_file)
            self.vocab_size = self.model.vocab_size()
            self.bos_token_id = self.model.bos_id()
            self.eos_token_id = self.model.eos_id()
        else:
            from transformers import AutoTokenizer
            backend_tokenizer_file = osp.join(model_folder, 'tokenizer.json')
            if not osp.exists(backend_tokenizer_file) and model_file_exists:
                print('WARNING: Can not find tokenizer.json. '
                      'It may take long time to initialize the tokenizer.')
            self.model = AutoTokenizer.from_pretrained(model_folder)
            self.vocab_size = self.model.vocab_size
            self.bos_token_id = self.model.bos_token_id
            self.eos_token_id = self.model.eos_token_id
            # save tokenizer.json to reuse
            if not osp.exists(backend_tokenizer_file) and model_file_exists:
                self.model.backend_tokenizer.save(backend_tokenizer_file)

    def encode(self, s: str):
lvhan028's avatar
lvhan028 committed
49
50
51
52
53
54
55
        """Tokenize a prompt.

        Args:
            s (str): a prompt
        Returns:
            list[int]: token ids
        """
q.yao's avatar
q.yao committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        if not self.use_hf_model:
            add_bos = False
            add_eos = False
            if s.find('<BOS>') != -1:
                s = s.replace('<BOS>', '')
                add_bos = True
            if s == '<EOS>':
                s = ''
                add_eos = True
            return self.model.Encode(s, add_bos=add_bos, add_eos=add_eos)
        else:
            add_special_tokens = False
            if s.find('<BOS>') != -1:
                s = s.replace('<BOS>', '<s>')
            if s == '<EOS>':
                s = '</s>'
            if len(s) == 0:
                add_special_tokens = True
            return self.model.encode(s, add_special_tokens=add_special_tokens)

    def decode(self, t: Sequence[int]):
lvhan028's avatar
lvhan028 committed
77
78
79
80
81
82
83
        """De-tokenize.

        Args:
            t (List[int]): a list of token ids
        Returns:
            str: text of decoding tokens
        """
q.yao's avatar
q.yao committed
84
85
86
87
88
89
90
        if not self.use_hf_model:
            return self.model.Decode(t)
        else:
            skip_special_tokens = False
            return self.model.decode(t,
                                     skip_special_tokens=skip_special_tokens)

91

q.yao's avatar
q.yao committed
92
class Preprocessor:
lvhan028's avatar
lvhan028 committed
93
94
95
96
97
    """Tokenize prompts.

    Args:
        tokenizer (Tokenizer): an instance of tokenizer
    """
98
99

    def __init__(self, tokenizer: Tokenizer):
q.yao's avatar
q.yao committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        self.tokenizer = tokenizer
        self.bos_token_id = tokenizer.bos_token_id
        self.eos_token_id = tokenizer.eos_token_id

    def __call__(self, *args, **kwargs):
        return self.infer(*args, **kwargs)

    def infer(self, prompts: Union[str, Sequence[str]]) -> tuple:
        """Tokenize the input prompts.

        Args:
            prompts(str | Sequence[str]): user's prompt, or a batch prompts

        Returns:
            Tuple(torch.Tensor, torch.Tensor): prompt's token
            ids, ids' length and requested output length
        """
        if isinstance(prompts, str):
118
            _ = [[prompts]]
q.yao's avatar
q.yao committed
119
        elif isinstance(prompts, Sequence):
120
            _ = [[prompt] for prompt in prompts]
q.yao's avatar
q.yao committed
121
122
123
        else:
            assert 0, f'str or Sequence[str] prompts are expected but got ' \
                      f'{type(prompts)}'
124

q.yao's avatar
q.yao committed
125
126
127
128
129
130
131
132
133
134
135
136
        start_ids = [
            torch.IntTensor(self.tokenizer.encode(prompt))
            for prompt in prompts
        ]
        start_lengths = torch.IntTensor([[len(ids)] for ids in start_ids])
        start_ids = pad_sequence(start_ids,
                                 batch_first=True,
                                 padding_value=self.eos_token_id)
        return start_ids, start_lengths


class Postprocessor:
lvhan028's avatar
lvhan028 committed
137
138
139
140
141
    """De-tokenize token ids.

    Args:
        tokenizer (Tokenizer): an instance of tokenizer
    """
142
143

    def __init__(self, tokenizer: Tokenizer):
q.yao's avatar
q.yao committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        self.tokenizer = tokenizer
        self.bos_token_id = tokenizer.bos_token_id
        self.eos_token_id = tokenizer.eos_token_id

    def __call__(self, *args, **kwargs):
        return self.infer(*args, **kwargs)

    def infer(self, output_ids: torch.Tensor, seqlen: torch.Tensor):
        """De-tokenize tokens for text.

        Args:
            output_ids(torch.Tensor): tokens' id
            seqlen(torch.Tensor): sequence length

        Returns:
            str: decoded tokens
        """
        outputs = []
        for tokens, _len in zip(output_ids, seqlen):
            output = self.tokenizer.decode(tokens[:_len])
            outputs.append(output)
165
        return outputs