spm_encode.py 3.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import absolute_import, division, print_function, unicode_literals

import argparse
import contextlib
import sys

import sentencepiece as spm


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", required=True,
                        help="sentencepiece model to use for encoding")
    parser.add_argument("--inputs", nargs="+", default=['-'],
                        help="input files to filter/encode")
    parser.add_argument("--outputs", nargs="+", default=['-'],
                        help="path to save encoded outputs")
    parser.add_argument("--output_format", choices=["piece", "id"], default="piece")
    parser.add_argument("--min-len", type=int, metavar="N",
                        help="filter sentence pairs with fewer than N tokens")
    parser.add_argument("--max-len", type=int, metavar="N",
                        help="filter sentence pairs with more than N tokens")
    args = parser.parse_args()

    assert len(args.inputs) == len(args.outputs), \
Myle Ott's avatar
Myle Ott committed
33
        "number of input and output paths should match"
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

    sp = spm.SentencePieceProcessor()
    sp.Load(args.model)

    if args.output_format == "piece":
        def encode(l):
            return sp.EncodeAsPieces(l)
    elif args.output_format == "id":
        def encode(l):
            return list(map(str, sp.EncodeAsIds(l)))
    else:
        raise NotImplementedError

    if args.min_len is not None or args.max_len is not None:
        def valid(line):
            return (
                (args.min_len is None or len(line) >= args.min_len)
                and (args.max_len is None or len(line) <= args.max_len)
            )
    else:
        def valid(lines):
            return True

    with contextlib.ExitStack() as stack:
        inputs = [
            stack.enter_context(open(input, "r", encoding="utf-8")) \
                if input != "-" else sys.stdin
            for input in args.inputs
        ]
        outputs = [
            stack.enter_context(open(output, "w", encoding="utf-8")) \
                if output != "-" else sys.stdout
            for output in args.outputs
        ]

        stats = {
            "num_empty": 0,
            "num_filtered": 0,
        }

        def encode_line(line):
            line = line.strip()
            if len(line) > 0:
                line = encode(line)
                if valid(line):
                    return line
                else:
                    stats["num_filtered"] += 1
            else:
                stats["num_empty"] += 1
            return None

        for i, lines in enumerate(zip(*inputs), start=1):
            enc_lines = list(map(encode_line, lines))
            if not any(enc_line is None for enc_line in enc_lines):
                for enc_line, output_h in zip(enc_lines, outputs):
                    print(" ".join(enc_line), file=output_h)
            if i % 10000 == 0:
                print("processed {} lines".format(i), file=sys.stderr)

        print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr)
        print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr)


if __name__ == "__main__":
    main()