tokenizer.py 2.96 KB
Newer Older
1
import os.path as osp
Li Zhang's avatar
Li Zhang committed
2
from typing import List
AllentDan's avatar
AllentDan committed
3

Li Zhang's avatar
Li Zhang committed
4
5
6
7
import fire


class Tokenizer:
q.yao's avatar
q.yao committed
8

Li Zhang's avatar
Li Zhang committed
9
    def __init__(self, model_file: str):
q.yao's avatar
q.yao committed
10
11
12
13
14
15
16
17
18
        if model_file.endswith('.model'):
            from sentencepiece import SentencePieceProcessor
            self.model = SentencePieceProcessor(model_file=model_file)
            self.vocab_size = self.model.vocab_size()
            self.start_id = self.model.bos_id()
            self.end_id = self.model.eos_id()
            self.pad_id = self.model.pad_id()
        else:
            from transformers import AutoTokenizer
19
20
            self.model = AutoTokenizer.from_pretrained(model_file,
                                                       trust_remote_code=True)
q.yao's avatar
q.yao committed
21
22
23
24
25
            self.vocab_size = self.model.vocab_size
            self.start_id = self.model.bos_token_id
            self.end_id = self.model.eos_token_id
            self.pad_id = self.model.pad_token_id

Li Zhang's avatar
Li Zhang committed
26
    def encode(self, s: str):
q.yao's avatar
q.yao committed
27
28
29
30
        if hasattr(self.model, 'Encode'):
            return self.model.Encode(s, add_bos=True)
        else:
            return self.model.encode(s, add_special_tokens=True)
Li Zhang's avatar
Li Zhang committed
31
32

    def decode(self, t: List[int]):
q.yao's avatar
q.yao committed
33
34
35
36
        if hasattr(self.model, 'Decode'):
            return self.model.Decode(t)
        else:
            return self.model.decode(t)
Li Zhang's avatar
Li Zhang committed
37
38
39


def main(model_file: str = '/data/llama/model/tokenizer.model',
q.yao's avatar
q.yao committed
40
         encode_file: str = None,
zhouxiang's avatar
zhouxiang committed
41
42
         decode_file: str = None,
         encode_line: str = None):
Li Zhang's avatar
Li Zhang committed
43
44
45
46
    tokenizer = Tokenizer(model_file)
    if encode_file:
        with open(encode_file, 'r') as f:
            xs = tokenizer.encode(f.read())
47
48
49
50
51
52
53
            xs = ','.join(map(str, xs))
            print(xs)

        output_dir = osp.dirname(osp.abspath(__file__))
        with open(osp.join(output_dir, 'start_ids.csv'), 'w') as f:
            f.write(xs)

Li Zhang's avatar
Li Zhang committed
54
55
    elif decode_file:
        with open(decode_file, 'r') as f:
56
57
58
59
60
61
62
            token_ids = f.read()
            token_ids = token_ids.splitlines()
            for _token_ids in token_ids:
                _token_ids = _token_ids.split(',')
                _token_ids = [int(token_id) for token_id in _token_ids]
                ys = tokenizer.decode(_token_ids)
                print(ys)
zhouxiang's avatar
zhouxiang committed
63
64
65
66
67
68
69
    elif encode_line:
        xs = tokenizer.encode(encode_line)
        xs = ','.join(map(str, xs))
        print(xs)
        output_dir = osp.dirname(osp.abspath(__file__))
        with open(osp.join(output_dir, 'start_ids.csv'), 'w') as f:
            f.write(xs)
Li Zhang's avatar
Li Zhang committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    else:
        first = True
        while True:
            try:
                s = input()
            except EOFError:
                break
            if not first:
                print('---------------------------------------------')
            first = False
            try:
                xs = map(int, s.strip().split(' '))
                s = tokenizer.decode(list(xs))
                print(s)
            except ValueError:
                xs = tokenizer.encode(s)
                print(' '.join(map(str, xs)))


if __name__ == '__main__':
q.yao's avatar
q.yao committed
90
    fire.Fire(main)