train.py 5.31 KB
Newer Older
1
2
3
4
5
import os
import sys
import time

import numpy as np
6
7
import torch
import torch.nn as nn
8
import torch.nn.functional as F
9
10
from graphwriter import *
from opts import *
11
12
from tqdm import tqdm
from utlis import *
13

14
sys.path.append("./pycocoevalcap")
15
16
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
17
from pycocoevalcap.rouge.rouge import Rouge
18
19
20
21


def train_one_epoch(model, dataloader, optimizer, args, epoch):
    model.train()
22
23
    tloss = 0.0
    tcnt = 0.0
24
    st_time = time.time()
25
    with tqdm(dataloader, desc="Train Ep " + str(epoch), mininterval=60) as tq:
26
27
        for batch in tq:
            pred = model(batch)
28
29
30
31
32
            nll_loss = F.nll_loss(
                pred.view(-1, pred.shape[-1]),
                batch["tgt_text"].view(-1),
                ignore_index=0,
            )
33
34
35
36
37
38
            loss = nll_loss
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()
            loss = loss.item()
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
            if loss != loss:
                raise ValueError("NaN appear")
            tloss += loss * len(batch["tgt_text"])
            tcnt += len(batch["tgt_text"])
            tq.set_postfix({"loss": tloss / tcnt}, refresh=False)
    print(
        "Train Ep ",
        str(epoch),
        "AVG Loss ",
        tloss / tcnt,
        "Steps ",
        tcnt,
        "Time ",
        time.time() - st_time,
        "GPU",
        torch.cuda.max_memory_cached() / 1024.0 / 1024.0 / 1024.0,
    )
    torch.save(model, args.save_model + str(epoch % 100))

58
59

val_loss = 2**31
60
61


62
63
64
def eval_it(model, dataloader, args, epoch):
    global val_loss
    model.eval()
65
66
    tloss = 0.0
    tcnt = 0.0
67
    st_time = time.time()
68
    with tqdm(dataloader, desc="Eval Ep " + str(epoch), mininterval=60) as tq:
69
70
71
        for batch in tq:
            with torch.no_grad():
                pred = model(batch)
72
73
74
75
76
                nll_loss = F.nll_loss(
                    pred.view(-1, pred.shape[-1]),
                    batch["tgt_text"].view(-1),
                    ignore_index=0,
                )
77
78
            loss = nll_loss
            loss = loss.item()
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
            tloss += loss * len(batch["tgt_text"])
            tcnt += len(batch["tgt_text"])
            tq.set_postfix({"loss": tloss / tcnt}, refresh=False)
    print(
        "Eval Ep ",
        str(epoch),
        "AVG Loss ",
        tloss / tcnt,
        "Steps ",
        tcnt,
        "Time ",
        time.time() - st_time,
    )
    if tloss / tcnt < val_loss:
        print("Saving best model ", "Ep ", epoch, " loss ", tloss / tcnt)
        torch.save(model, args.save_model + "best")
        val_loss = tloss / tcnt
96
97
98
99
100
101
102
103
104


def test(model, dataloader, args):
    scorer = Bleu(4)
    m_scorer = Meteor()
    r_scorer = Rouge()
    hyp = []
    ref = []
    model.eval()
105
106
107
    gold_file = open("tmp_gold.txt", "w")
    pred_file = open("tmp_pred.txt", "w")
    with tqdm(dataloader, desc="Test ", mininterval=1) as tq:
108
109
110
        for batch in tq:
            with torch.no_grad():
                seq = model(batch, beam_size=args.beam_size)
111
            r = write_txt(batch, batch["tgt_text"], gold_file, args)
112
113
114
115
116
117
            h = write_txt(batch, seq, pred_file, args)
            hyp.extend(h)
            ref.extend(r)
    hyp = dict(zip(range(len(hyp)), hyp))
    ref = dict(zip(range(len(ref)), ref))
    print(hyp[0], ref[0])
118
119
120
121
    print("BLEU INP", len(hyp), len(ref))
    print("BLEU", scorer.compute_score(ref, hyp)[0])
    print("METEOR", m_scorer.compute_score(ref, hyp)[0])
    print("ROUGE_L", r_scorer.compute_score(ref, hyp)[0])
122
123
124
125
126
127
    gold_file.close()
    pred_file.close()


def main(args):
    if os.path.exists(args.save_dataset):
128
129
130
        train_dataset, valid_dataset, test_dataset = pickle.load(
            open(args.save_dataset, "rb")
        )
131
    else:
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        train_dataset, valid_dataset, test_dataset = get_datasets(
            args.fnames, device=args.device, save=args.save_dataset
        )
    args = vocab_config(
        args,
        train_dataset.ent_vocab,
        train_dataset.rel_vocab,
        train_dataset.text_vocab,
        train_dataset.ent_text_vocab,
        train_dataset.title_vocab,
    )
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_sampler=BucketSampler(train_dataset, batch_size=args.batch_size),
        collate_fn=train_dataset.batch_fn,
    )
    valid_dataloader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=train_dataset.batch_fn,
    )
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=train_dataset.batch_fn,
    )
160
161
162
163
164
165
166
167
168

    model = GraphWriter(args)
    model.to(args.device)
    if args.test:
        model = torch.load(args.save_model)
        model.args = args
        print(model)
        test(model, test_dataloader, args)
    else:
169
170
171
172
173
174
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay,
            momentum=0.9,
        )
175
176
177
178
179
        print(model)
        for epoch in range(args.epoch):
            train_one_epoch(model, train_dataloader, optimizer, args, epoch)
            eval_it(model, valid_dataloader, args, epoch)

180
181

if __name__ == "__main__":
182
183
    args = get_args()
    main(args)