translation_train.py 7.14 KB
Newer Older
Zihao Ye's avatar
Zihao Ye committed
1
import argparse
2
from functools import partial
3
4
5

import numpy as np
import torch
6
import torch.distributed as dist
7
8
9
10
11
from dataset import *
from loss import *
from modules import *
from modules.config import *
from optims import *
Zihao Ye's avatar
Zihao Ye committed
12

13
14
15
16

def run_epoch(
    epoch, data_iter, dev_rank, ndev, model, loss_compute, is_train=True
):
Zihao Ye's avatar
Zihao Ye committed
17
    universal = isinstance(model, UTransformer)
18
19
20
    with loss_compute:
        for i, g in enumerate(data_iter):
            with T.set_grad_enabled(is_train):
Zihao Ye's avatar
Zihao Ye committed
21
22
                if universal:
                    output, loss_act = model(g)
23
24
                    if is_train:
                        loss_act.backward(retain_graph=True)
Zihao Ye's avatar
Zihao Ye committed
25
26
27
28
                else:
                    output = model(g)
                tgt_y = g.tgt_y
                n_tokens = g.n_tokens
29
                loss = loss_compute(output, tgt_y, n_tokens)
Zihao Ye's avatar
Zihao Ye committed
30
31
32

    if universal:
        for step in range(1, model.MAX_DEPTH + 1):
33
34
35
36
37
            print(
                "nodes entering step {}: {:.2f}%".format(
                    step, (1.0 * model.stat[step] / model.stat[0])
                )
            )
Zihao Ye's avatar
Zihao Ye committed
38
        model.reset_stat()
39
40
41
42
43
44
45
46
47
48
    print(
        "Epoch {} {}: Dev {} average loss: {}, accuracy {}".format(
            epoch,
            "Training" if is_train else "Evaluating",
            dev_rank,
            loss_compute.avg_loss,
            loss_compute.accuracy,
        )
    )

Zihao Ye's avatar
Zihao Ye committed
49

50
def run(dev_id, args):
51
52
53
    dist_init_method = "tcp://{master_ip}:{master_port}".format(
        master_ip=args.master_ip, master_port=args.master_port
    )
54
    world_size = args.ngpu
55
56
57
58
59
60
    torch.distributed.init_process_group(
        backend="nccl",
        init_method=dist_init_method,
        world_size=world_size,
        rank=dev_id,
    )
61
62
63
    gpu_rank = torch.distributed.get_rank()
    assert gpu_rank == dev_id
    main(dev_id, args)
Zihao Ye's avatar
Zihao Ye committed
64

65

66
67
def main(dev_id, args):
    if dev_id == -1:
68
        device = torch.device("cpu")
69
    else:
70
        device = torch.device("cuda:{}".format(dev_id))
71
72
73
    # Set current device
    th.cuda.set_device(device)
    # Prepare dataset
Zihao Ye's avatar
Zihao Ye committed
74
75
76
77
    dataset = get_dataset(args.dataset)
    V = dataset.vocab_size
    criterion = LabelSmoothing(V, padding_idx=dataset.pad_id, smoothing=0.1)
    dim_model = 512
78
    # Build graph pool
Zihao Ye's avatar
Zihao Ye committed
79
    graph_pool = GraphPool()
80
    # Create model
81
82
83
    model = make_model(
        V, V, N=args.N, dim_model=dim_model, universal=args.universal
    )
Zihao Ye's avatar
Zihao Ye committed
84
85
86
    # Sharing weights between Encoder & Decoder
    model.src_embed.lut.weight = model.tgt_embed.lut.weight
    model.generator.proj.weight = model.tgt_embed.lut.weight
87
88
89
90
    # Move model to corresponding device
    model, criterion = model.to(device), criterion.to(device)
    # Loss function
    if args.ngpu > 1:
91
92
93
94
95
96
        dev_rank = dev_id  # current device id
        ndev = args.ngpu  # number of devices (including cpu)
        loss_compute = partial(
            MultiGPULossCompute, criterion, args.ngpu, args.grad_accum, model
        )
    else:  # cpu or single gpu case
97
98
99
        dev_rank = 0
        ndev = 1
        loss_compute = partial(SimpleLossCompute, criterion, args.grad_accum)
Zihao Ye's avatar
Zihao Ye committed
100

101
102
103
104
    if ndev > 1:
        for param in model.parameters():
            dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
            param.data /= ndev
Zihao Ye's avatar
Zihao Ye committed
105

106
    # Optimizer
107
108
109
110
111
112
    model_opt = NoamOpt(
        dim_model,
        0.1,
        4000,
        T.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9),
    )
113
114

    # Train & evaluate
Zihao Ye's avatar
Zihao Ye committed
115
    for epoch in range(100):
116
        start = time.time()
117
118
119
120
121
122
123
124
        train_iter = dataset(
            graph_pool,
            mode="train",
            batch_size=args.batch,
            device=device,
            dev_rank=dev_rank,
            ndev=ndev,
        )
Zihao Ye's avatar
Zihao Ye committed
125
        model.train(True)
126
127
128
129
130
131
132
133
134
        run_epoch(
            epoch,
            train_iter,
            dev_rank,
            ndev,
            model,
            loss_compute(opt=model_opt),
            is_train=True,
        )
135
136
137
        if dev_rank == 0:
            model.att_weight_map = None
            model.eval()
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
            valid_iter = dataset(
                graph_pool,
                mode="valid",
                batch_size=args.batch,
                device=device,
                dev_rank=dev_rank,
                ndev=1,
            )
            run_epoch(
                epoch,
                valid_iter,
                dev_rank,
                1,
                model,
                loss_compute(opt=None),
                is_train=False,
            )
155
156
157
158
159
            end = time.time()
            print("epoch time: {}".format(end - start))

            # Visualize attention
            if args.viz:
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
                src_seq = dataset.get_seq_by_id(
                    VIZ_IDX, mode="valid", field="src"
                )
                tgt_seq = dataset.get_seq_by_id(
                    VIZ_IDX, mode="valid", field="tgt"
                )[:-1]
                draw_atts(
                    model.att_weight_map,
                    src_seq,
                    tgt_seq,
                    exp_setting,
                    "epoch_{}".format(epoch),
                )
            args_filter = [
                "batch",
                "gpus",
                "viz",
                "master_ip",
                "master_port",
                "grad_accum",
                "ngpu",
            ]
            exp_setting = "-".join(
                "{}".format(v)
                for k, v in vars(args).items()
                if k not in args_filter
            )
            with open(
                "checkpoints/{}-{}.pkl".format(exp_setting, epoch), "wb"
            ) as f:
190
191
                torch.save(model.state_dict(), f)

192
193
194
195

if __name__ == "__main__":
    if not os.path.exists("checkpoints"):
        os.makedirs("checkpoints")
196
    np.random.seed(1111)
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    argparser = argparse.ArgumentParser("training translation model")
    argparser.add_argument("--gpus", default="-1", type=str, help="gpu id")
    argparser.add_argument("--N", default=6, type=int, help="enc/dec layers")
    argparser.add_argument("--dataset", default="multi30k", help="dataset")
    argparser.add_argument("--batch", default=128, type=int, help="batch size")
    argparser.add_argument(
        "--viz", action="store_true", help="visualize attention"
    )
    argparser.add_argument(
        "--universal", action="store_true", help="use universal transformer"
    )
    argparser.add_argument(
        "--master-ip", type=str, default="127.0.0.1", help="master ip address"
    )
    argparser.add_argument(
        "--master-port", type=str, default="12345", help="master port"
    )
    argparser.add_argument(
        "--grad-accum",
        type=int,
        default=1,
        help="accumulate gradients for this many times " "then update weights",
    )
220
221
    args = argparser.parse_args()
    print(args)
Zihao Ye's avatar
Zihao Ye committed
222

223
    devices = list(map(int, args.gpus.split(",")))
224
225
226
227
228
    if len(devices) == 1:
        args.ngpu = 0 if devices[0] < 0 else 1
        main(devices[0], args)
    else:
        args.ngpu = len(devices)
229
        mp = torch.multiprocessing.get_context("spawn")
230
231
        procs = []
        for dev_id in devices:
232
233
234
            procs.append(
                mp.Process(target=run, args=(dev_id, args), daemon=True)
            )
235
236
237
            procs[-1].start()
        for p in procs:
            p.join()