fields.py 1.89 KB
Newer Older
Zihao Ye's avatar
Zihao Ye committed
1
class Vocab:
2
3
4
    def __init__(
        self, init_token=None, eos_token=None, pad_token=None, unk_token=None
    ):
Zihao Ye's avatar
Zihao Ye committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
        self.init_token = init_token
        self.eos_token = eos_token
        self.pad_token = pad_token
        self.unk_token = unk_token
        self.vocab_lst = []
        self.vocab_dict = None

    def load(self, path):
        if self.init_token is not None:
            self.vocab_lst.append(self.init_token)
        if self.eos_token is not None:
            self.vocab_lst.append(self.eos_token)
        if self.pad_token is not None:
            self.vocab_lst.append(self.pad_token)
        if self.unk_token is not None:
            self.vocab_lst.append(self.unk_token)
21
        with open(path, "r", encoding="utf-8") as f:
Zihao Ye's avatar
Zihao Ye committed
22
23
24
            for token in f.readlines():
                token = token.strip()
                self.vocab_lst.append(token)
25
        self.vocab_dict = {v: k for k, v in enumerate(self.vocab_lst)}
Zihao Ye's avatar
Zihao Ye committed
26
27
28
29
30
31
32
33
34
35
36
37
38

    def __len__(self):
        return len(self.vocab_lst)

    def __getitem__(self, key):
        if isinstance(key, str):
            if key in self.vocab_dict:
                return self.vocab_dict[key]
            else:
                return self.vocab_dict[self.unk_token]
        else:
            return self.vocab_lst[key]

39

Zihao Ye's avatar
Zihao Ye committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class Field:
    def __init__(self, vocab, preprocessing=None, postprocessing=None):
        self.vocab = vocab
        self.preprocessing = preprocessing
        self.postprocessing = postprocessing

    def preprocess(self, x):
        if self.preprocessing is not None:
            return self.preprocessing(x)
        return x

    def postprocess(self, x):
        if self.postprocessing is not None:
            return self.postprocessing(x)
        return x

    def numericalize(self, x):
        return [self.vocab[token] for token in x]

    def __call__(self, x):
60
        return self.postprocess(self.numericalize(self.preprocess(x)))