tokenizer.py 4.78 KB
Newer Older
q.yao's avatar
q.yao committed
1
2
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
3
4
from typing import Sequence, Union

q.yao's avatar
q.yao committed
5
import torch
6
7
from torch.nn.utils.rnn import pad_sequence

q.yao's avatar
q.yao committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

class Tokenizer:

    def __init__(self, model_file: str):
        if model_file.endswith('.model'):
            model_folder = osp.split(model_file)[0]
        else:
            model_folder = model_file
        tokenizer_config_file = osp.join(model_folder, 'tokenizer_config.json')

        model_file_exists = osp.exists(model_file)
        config_exists = osp.exists(tokenizer_config_file)
        use_hf_model = not config_exists or not model_file_exists

        self.use_hf_model = use_hf_model
        if not self.use_hf_model:
            from sentencepiece import SentencePieceProcessor
            self.model = SentencePieceProcessor(model_file=model_file)
            self.vocab_size = self.model.vocab_size()
            self.bos_token_id = self.model.bos_id()
            self.eos_token_id = self.model.eos_id()
        else:
            from transformers import AutoTokenizer
            backend_tokenizer_file = osp.join(model_folder, 'tokenizer.json')
            if not osp.exists(backend_tokenizer_file) and model_file_exists:
                print('WARNING: Can not find tokenizer.json. '
                      'It may take long time to initialize the tokenizer.')
            self.model = AutoTokenizer.from_pretrained(model_folder)
            self.vocab_size = self.model.vocab_size
            self.bos_token_id = self.model.bos_token_id
            self.eos_token_id = self.model.eos_token_id
            # save tokenizer.json to reuse
            if not osp.exists(backend_tokenizer_file) and model_file_exists:
                self.model.backend_tokenizer.save(backend_tokenizer_file)

    def encode(self, s: str):
        if not self.use_hf_model:
            add_bos = False
            add_eos = False
            if s.find('<BOS>') != -1:
                s = s.replace('<BOS>', '')
                add_bos = True
            if s == '<EOS>':
                s = ''
                add_eos = True
            return self.model.Encode(s, add_bos=add_bos, add_eos=add_eos)
        else:
            add_special_tokens = False
            if s.find('<BOS>') != -1:
                s = s.replace('<BOS>', '<s>')
            if s == '<EOS>':
                s = '</s>'
            if len(s) == 0:
                add_special_tokens = True
            return self.model.encode(s, add_special_tokens=add_special_tokens)

    def decode(self, t: Sequence[int]):
        if not self.use_hf_model:
            return self.model.Decode(t)
        else:
            skip_special_tokens = False
            return self.model.decode(t,
                                     skip_special_tokens=skip_special_tokens)

72

q.yao's avatar
q.yao committed
73
class Preprocessor:
74
75

    def __init__(self, tokenizer: Tokenizer):
q.yao's avatar
q.yao committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        self.tokenizer = tokenizer
        self.bos_token_id = tokenizer.bos_token_id
        self.eos_token_id = tokenizer.eos_token_id

    def __call__(self, *args, **kwargs):
        return self.infer(*args, **kwargs)

    def infer(self, prompts: Union[str, Sequence[str]]) -> tuple:
        """Tokenize the input prompts.

        Args:
            prompts(str | Sequence[str]): user's prompt, or a batch prompts

        Returns:
            Tuple(torch.Tensor, torch.Tensor): prompt's token
            ids, ids' length and requested output length
        """
        if isinstance(prompts, str):
94
            _ = [[prompts]]
q.yao's avatar
q.yao committed
95
        elif isinstance(prompts, Sequence):
96
            _ = [[prompt] for prompt in prompts]
q.yao's avatar
q.yao committed
97
98
99
        else:
            assert 0, f'str or Sequence[str] prompts are expected but got ' \
                      f'{type(prompts)}'
100

q.yao's avatar
q.yao committed
101
102
103
104
105
106
107
108
109
110
111
112
        start_ids = [
            torch.IntTensor(self.tokenizer.encode(prompt))
            for prompt in prompts
        ]
        start_lengths = torch.IntTensor([[len(ids)] for ids in start_ids])
        start_ids = pad_sequence(start_ids,
                                 batch_first=True,
                                 padding_value=self.eos_token_id)
        return start_ids, start_lengths


class Postprocessor:
113
114

    def __init__(self, tokenizer: Tokenizer):
q.yao's avatar
q.yao committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        self.tokenizer = tokenizer
        self.bos_token_id = tokenizer.bos_token_id
        self.eos_token_id = tokenizer.eos_token_id

    def __call__(self, *args, **kwargs):
        return self.infer(*args, **kwargs)

    def infer(self, output_ids: torch.Tensor, seqlen: torch.Tensor):
        """De-tokenize tokens for text.

        Args:
            output_ids(torch.Tensor): tokens' id
            seqlen(torch.Tensor): sequence length

        Returns:
            str: decoded tokens
        """
        outputs = []
        for tokens, _len in zip(output_ids, seqlen):
            output = self.tokenizer.decode(tokens[:_len])
            outputs.append(output)
136
        return outputs