translation_test.py 2.18 KB
Newer Older
Zihao Ye's avatar
Zihao Ye committed
1
2
# Beam Search Module

3
4
5
import argparse

import numpy as n
Zihao Ye's avatar
Zihao Ye committed
6
from dataset import *
7
from modules import *
Zihao Ye's avatar
Zihao Ye committed
8
9
from tqdm import tqdm

10
k = 5  # Beam size
Zihao Ye's avatar
Zihao Ye committed
11

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
if __name__ == "__main__":
    argparser = argparse.ArgumentParser("testing translation model")
    argparser.add_argument("--gpu", default=-1, help="gpu id")
    argparser.add_argument("--N", default=6, type=int, help="num of layers")
    argparser.add_argument("--dataset", default="multi30k", help="dataset")
    argparser.add_argument("--batch", default=64, help="batch size")
    argparser.add_argument(
        "--universal", action="store_true", help="use universal transformer"
    )
    argparser.add_argument(
        "--checkpoint", type=int, help="checkpoint: you must specify it"
    )
    argparser.add_argument(
        "--print", action="store_true", help="whether to print translated text"
    )
Zihao Ye's avatar
Zihao Ye committed
27
    args = argparser.parse_args()
28
29
30
31
32
    args_filter = ["batch", "gpu", "print"]
    exp_setting = "-".join(
        "{}".format(v) for k, v in vars(args).items() if k not in args_filter
    )
    device = "cpu" if args.gpu == -1 else "cuda:{}".format(args.gpu)
Zihao Ye's avatar
Zihao Ye committed
33
34
35
36
37

    dataset = get_dataset(args.dataset)
    V = dataset.vocab_size
    dim_model = 512

38
39
    fpred = open("pred.txt", "w")
    fref = open("ref.txt", "w")
Zihao Ye's avatar
Zihao Ye committed
40
41
42

    graph_pool = GraphPool()
    model = make_model(V, V, N=args.N, dim_model=dim_model)
43
44
45
46
    with open("checkpoints/{}.pkl".format(exp_setting), "rb") as f:
        model.load_state_dict(
            th.load(f, map_location=lambda storage, loc: storage)
        )
Zihao Ye's avatar
Zihao Ye committed
47
48
    model = model.to(device)
    model.eval()
49
50
51
    test_iter = dataset(
        graph_pool, mode="test", batch_size=args.batch, device=device, k=k
    )
Zihao Ye's avatar
Zihao Ye committed
52
53
    for i, g in enumerate(test_iter):
        with th.no_grad():
54
55
56
            output = model.infer(
                g, dataset.MAX_LENGTH, dataset.eos_id, k, alpha=0.6
            )
Zihao Ye's avatar
Zihao Ye committed
57
58
59
60
        for line in dataset.get_sequence(output):
            if args.print:
                print(line)
            print(line, file=fpred)
61
        for line in dataset.tgt["test"]:
Zihao Ye's avatar
Zihao Ye committed
62
63
64
            print(line.strip(), file=fref)
    fpred.close()
    fref.close()
65
66
67
    os.system(r"bash scripts/bleu.sh pred.txt ref.txt")
    os.remove("pred.txt")
    os.remove("ref.txt")