translation_train.py 4.31 KB
Newer Older
Zihao Ye's avatar
Zihao Ye committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
In current version we use multi30k as the default training and validation set.
Multi-GPU support is required to train the model on WMT14.
"""
from modules import *
from parallel import *
from loss import * 
from optims import *
from dataset import *
from modules.config import *
from modules.viz import *
from tqdm import tqdm
import numpy as np
import argparse

def run_epoch(data_iter, model, loss_compute, is_train=True):
    universal = isinstance(model, UTransformer)
    for i, g in tqdm(enumerate(data_iter)):
        with T.set_grad_enabled(is_train):
            if isinstance(model, list):
                model = model[:len(gs)]
                output = parallel_apply(model, g)
                tgt_y = [g.tgt_y for g in gs]
                n_tokens = [g.n_tokens for g in gs]
            else:
                if universal:
                    output, loss_act = model(g)
                    if is_train: loss_act.backward(retain_graph=True)
                else:
                    output = model(g)
                tgt_y = g.tgt_y
                n_tokens = g.n_tokens
            loss = loss_compute(output, tgt_y, n_tokens)

    if universal:
        for step in range(1, model.MAX_DEPTH + 1):
            print("nodes entering step {}: {:.2f}%".format(step, (1.0 * model.stat[step] / model.stat[0])))
        model.reset_stat()
    print('average loss: {}'.format(loss_compute.avg_loss))
    print('accuracy: {}'.format(loss_compute.accuracy))

if __name__ == '__main__':
    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')
    np.random.seed(1111)
    argparser = argparse.ArgumentParser('training translation model')
    argparser.add_argument('--gpus', default='-1', type=str, help='gpu id')
    argparser.add_argument('--N', default=6, type=int, help='enc/dec layers')
    argparser.add_argument('--dataset', default='multi30k', help='dataset')
    argparser.add_argument('--batch', default=128, type=int, help='batch size')
    argparser.add_argument('--viz', action='store_true', help='visualize attention')
    argparser.add_argument('--universal', action='store_true', help='use universal transformer')
    args = argparser.parse_args()
    args_filter = ['batch', 'gpus', 'viz']
    exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter)
    devices = ['cpu'] if args.gpus == '-1' else [int(gpu_id) for gpu_id in args.gpus.split(',')]

    dataset = get_dataset(args.dataset)

    V = dataset.vocab_size
    criterion = LabelSmoothing(V, padding_idx=dataset.pad_id, smoothing=0.1)
    dim_model = 512

    graph_pool = GraphPool()
    model = make_model(V, V, N=args.N, dim_model=dim_model, universal=args.universal)

    # Sharing weights between Encoder & Decoder
    model.src_embed.lut.weight = model.tgt_embed.lut.weight
    model.generator.proj.weight = model.tgt_embed.lut.weight

    model, criterion = model.to(devices[0]), criterion.to(devices[0])
    model_opt = NoamOpt(dim_model, 1, 400,
                        T.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9))
    if len(devices) > 1:
        model, criterion = map(nn.parallel.replicate, [model, criterion], [devices, devices])
    loss_compute = SimpleLossCompute if len(devices) == 1 else MultiGPULossCompute

    for epoch in range(100):
        train_iter = dataset(graph_pool, mode='train', batch_size=args.batch, devices=devices)
        valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch, devices=devices)
        print('Epoch: {} Training...'.format(epoch))
        model.train(True)
        run_epoch(train_iter, model,
                      loss_compute(criterion, model_opt), is_train=True)
        print('Epoch: {} Evaluating...'.format(epoch))
        model.att_weight_map = None
        model.eval()
        run_epoch(valid_iter, model,
                      loss_compute(criterion, None), is_train=False)
        # Visualize attention
        if args.viz:
            src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src')
            tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1]
            draw_atts(model.att_weight_map, src_seq, tgt_seq, exp_setting, 'epoch_{}'.format(epoch))

        print('----------------------------------')
        with open('checkpoints/{}-{}.pkl'.format(exp_setting, epoch), 'wb') as f:
            th.save(model.state_dict(), f)