tokenizer.py 2.33 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
from typing import List
AllentDan's avatar
AllentDan committed
2

Li Zhang's avatar
Li Zhang committed
3
4
5
6
import fire


class Tokenizer:
q.yao's avatar
q.yao committed
7

Li Zhang's avatar
Li Zhang committed
8
    def __init__(self, model_file: str):
q.yao's avatar
q.yao committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
        if model_file.endswith('.model'):
            from sentencepiece import SentencePieceProcessor
            self.model = SentencePieceProcessor(model_file=model_file)
            self.vocab_size = self.model.vocab_size()
            self.start_id = self.model.bos_id()
            self.end_id = self.model.eos_id()
            self.pad_id = self.model.pad_id()
        else:
            from transformers import AutoTokenizer
            self.model = AutoTokenizer.from_pretrained(model_file)
            self.vocab_size = self.model.vocab_size
            self.start_id = self.model.bos_token_id
            self.end_id = self.model.eos_token_id
            self.pad_id = self.model.pad_token_id

Li Zhang's avatar
Li Zhang committed
24
25
26
27
28
29
        print(f'vocab_size = {self.vocab_size}')
        print(f'start_id = {self.start_id}')
        print(f'end_id = {self.end_id}')
        print(f'pad_id = {self.pad_id}')

    def encode(self, s: str):
q.yao's avatar
q.yao committed
30
31
32
33
        if hasattr(self.model, 'Encode'):
            return self.model.Encode(s, add_bos=True)
        else:
            return self.model.encode(s, add_special_tokens=True)
Li Zhang's avatar
Li Zhang committed
34
35

    def decode(self, t: List[int]):
q.yao's avatar
q.yao committed
36
37
38
39
        if hasattr(self.model, 'Decode'):
            return self.model.Decode(t)
        else:
            return self.model.decode(t)
Li Zhang's avatar
Li Zhang committed
40
41
42


def main(model_file: str = '/data/llama/model/tokenizer.model',
q.yao's avatar
q.yao committed
43
44
         encode_file: str = None,
         decode_file: str = None):
Li Zhang's avatar
Li Zhang committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    tokenizer = Tokenizer(model_file)
    if encode_file:
        with open(encode_file, 'r') as f:
            xs = tokenizer.encode(f.read())
            print(','.join(map(str, xs)))
    elif decode_file:
        with open(decode_file, 'r') as f:
            ys = tokenizer.decode(f.read())
            print(ys)
    else:
        first = True
        while True:
            try:
                s = input()
            except EOFError:
                break
            if not first:
                print('---------------------------------------------')
            first = False
            try:
                xs = map(int, s.strip().split(' '))
                s = tokenizer.decode(list(xs))
                print(s)
            except ValueError:
                xs = tokenizer.encode(s)
                print(' '.join(map(str, xs)))


if __name__ == '__main__':
q.yao's avatar
q.yao committed
74
    fire.Fire(main)