tokenizer.py 2.33 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
from typing import List
import fire


class Tokenizer:
q.yao's avatar
q.yao committed
6

Li Zhang's avatar
Li Zhang committed
7
    def __init__(self, model_file: str):
q.yao's avatar
q.yao committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
        if model_file.endswith('.model'):
            from sentencepiece import SentencePieceProcessor
            self.model = SentencePieceProcessor(model_file=model_file)
            self.vocab_size = self.model.vocab_size()
            self.start_id = self.model.bos_id()
            self.end_id = self.model.eos_id()
            self.pad_id = self.model.pad_id()
        else:
            from transformers import AutoTokenizer
            self.model = AutoTokenizer.from_pretrained(model_file)
            self.vocab_size = self.model.vocab_size
            self.start_id = self.model.bos_token_id
            self.end_id = self.model.eos_token_id
            self.pad_id = self.model.pad_token_id

Li Zhang's avatar
Li Zhang committed
23
24
25
26
27
28
        print(f'vocab_size = {self.vocab_size}')
        print(f'start_id = {self.start_id}')
        print(f'end_id = {self.end_id}')
        print(f'pad_id = {self.pad_id}')

    def encode(self, s: str):
q.yao's avatar
q.yao committed
29
30
31
32
        if hasattr(self.model, 'Encode'):
            return self.model.Encode(s, add_bos=True)
        else:
            return self.model.encode(s, add_special_tokens=True)
Li Zhang's avatar
Li Zhang committed
33
34

    def decode(self, t: List[int]):
q.yao's avatar
q.yao committed
35
36
37
38
        if hasattr(self.model, 'Decode'):
            return self.model.Decode(t)
        else:
            return self.model.decode(t)
Li Zhang's avatar
Li Zhang committed
39
40
41


def main(model_file: str = '/data/llama/model/tokenizer.model',
q.yao's avatar
q.yao committed
42
43
         encode_file: str = None,
         decode_file: str = None):
Li Zhang's avatar
Li Zhang committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    tokenizer = Tokenizer(model_file)
    if encode_file:
        with open(encode_file, 'r') as f:
            xs = tokenizer.encode(f.read())
            print(','.join(map(str, xs)))
    elif decode_file:
        with open(decode_file, 'r') as f:
            ys = tokenizer.decode(f.read())
            print(ys)
    else:
        first = True
        while True:
            try:
                s = input()
            except EOFError:
                break
            if not first:
                print('---------------------------------------------')
            first = False
            try:
                xs = map(int, s.strip().split(' '))
                s = tokenizer.decode(list(xs))
                print(s)
            except ValueError:
                xs = tokenizer.encode(s)
                print(' '.join(map(str, xs)))


if __name__ == '__main__':
q.yao's avatar
q.yao committed
73
    fire.Fire(main)