word_splitter.py 2.28 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


class SubwordSplitter(object):
    def process_line(self, string):
        raise NotImplementedError

    def split(self, string):
        raise NotImplementedError


class NoneWordSplitter(object):
    def __init__(self, model):
        pass

    def split(self, string):
        return [string]

    def process_line(self, string):
        return [string]

    def finished_word(self, string):
        return True

    def merge(self, list_of_string):
        return "".join(list_of_string)

    def last_full_word_step(self, tokens, step):
        return len(tokens)

    def end_idx_last_full_word(self, tokens):
        return len(tokens)


class BPEWordSplitter(object):
    # TODO: lock back here
    def __init__(self, model_path):
        super().__init__()
        from subword_nmt.apply_bpe import BPE

        with open(model_path) as f:
            self.model = BPE(f)

    def split(self, string):
        return self.model.process_line(string).split()

    def end_idx_last_full_word(self, tokens):
        # Begin of word indices
        bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != "@@"]

        if len(bow_indices) < 2:
            return 0
        else:
            return bow_indices[-1]

    def merge(self, list_of_string):
        return " ".join([item.replace("@@", "") for item in list_of_string])


class SentencePieceModelWordSplitter(object):
    def __init__(self, model_path):
        super().__init__()
        import sentencepiece as spm

        self.model = spm.SentencePieceProcessor()
        self.model.Load(model_path)

    def split(self, string):
        return self.model.EncodeAsPieces(string)

    def end_idx_last_full_word(self, tokens):
        # Begin of word indices
        bow_indices = [i for i, t in enumerate(tokens) if t[0] == "\u2581"]

        if len(bow_indices) < 2:
            return 0
        else:
            return bow_indices[-1]

    def merge(self, list_of_string):
        return self.model.DecodePieces(list_of_string)


SPLITTER_DICT = {
    None: NoneWordSplitter,
    "BPE": BPEWordSplitter,
    "SentencePieceModel": SentencePieceModelWordSplitter,
}