import torch import torch.nn.functional as F import torch.nn as nn import numpy as np import time from tqdm import tqdm from graphwriter import * from utlis import * from opts import * import os import sys sys.path.append('./pycocoevalcap') from pycocoevalcap.bleu.bleu import Bleu from pycocoevalcap.rouge.rouge import Rouge from pycocoevalcap.meteor.meteor import Meteor def train_one_epoch(model, dataloader, optimizer, args, epoch): model.train() tloss = 0. tcnt = 0. st_time = time.time() with tqdm(dataloader, desc='Train Ep '+str(epoch), mininterval=60) as tq: for batch in tq: pred = model(batch) nll_loss = F.nll_loss(pred.view(-1, pred.shape[-1]), batch['tgt_text'].view(-1), ignore_index=0) loss = nll_loss optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() loss = loss.item() if loss!=loss: raise ValueError('NaN appear') tloss += loss * len(batch['tgt_text']) tcnt += len(batch['tgt_text']) tq.set_postfix({'loss': tloss/tcnt}, refresh=False) print('Train Ep ', str(epoch), 'AVG Loss ', tloss/tcnt, 'Steps ', tcnt, 'Time ', time.time()-st_time, 'GPU', torch.cuda.max_memory_cached()/1024.0/1024.0/1024.0) torch.save(model, args.save_model+str(epoch%100)) val_loss = 2**31 def eval_it(model, dataloader, args, epoch): global val_loss model.eval() tloss = 0. tcnt = 0. st_time = time.time() with tqdm(dataloader, desc='Eval Ep '+str(epoch), mininterval=60) as tq: for batch in tq: with torch.no_grad(): pred = model(batch) nll_loss = F.nll_loss(pred.view(-1, pred.shape[-1]), batch['tgt_text'].view(-1), ignore_index=0) loss = nll_loss loss = loss.item() tloss += loss * len(batch['tgt_text']) tcnt += len(batch['tgt_text']) tq.set_postfix({'loss': tloss/tcnt}, refresh=False) print('Eval Ep ', str(epoch), 'AVG Loss ', tloss/tcnt, 'Steps ', tcnt, 'Time ', time.time()-st_time) if tloss/tcnt < val_loss: print('Saving best model ', 'Ep ', epoch, ' loss ', tloss/tcnt) torch.save(model, args.save_model+'best') val_loss = tloss/tcnt def test(model, dataloader, args): scorer = Bleu(4) m_scorer = Meteor() r_scorer = Rouge() hyp = [] ref = [] model.eval() gold_file = open('tmp_gold.txt', 'w') pred_file = open('tmp_pred.txt', 'w') with tqdm(dataloader, desc='Test ', mininterval=1) as tq: for batch in tq: with torch.no_grad(): seq = model(batch, beam_size=args.beam_size) r = write_txt(batch, batch['tgt_text'], gold_file, args) h = write_txt(batch, seq, pred_file, args) hyp.extend(h) ref.extend(r) hyp = dict(zip(range(len(hyp)), hyp)) ref = dict(zip(range(len(ref)), ref)) print(hyp[0], ref[0]) print('BLEU INP', len(hyp), len(ref)) print('BLEU', scorer.compute_score(ref, hyp)[0]) print('METEOR', m_scorer.compute_score(ref, hyp)[0]) print('ROUGE_L', r_scorer.compute_score(ref, hyp)[0]) gold_file.close() pred_file.close() def main(args): if os.path.exists(args.save_dataset): train_dataset, valid_dataset, test_dataset = pickle.load(open(args.save_dataset, 'rb')) else: train_dataset, valid_dataset, test_dataset = get_datasets(args.fnames, device=args.device, save=args.save_dataset) args = vocab_config(args, train_dataset.ent_vocab, train_dataset.rel_vocab, train_dataset.text_vocab, train_dataset.ent_text_vocab, train_dataset.title_vocab) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_sampler = BucketSampler(train_dataset, batch_size=args.batch_size), \ collate_fn=train_dataset.batch_fn) valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, \ shuffle=False, collate_fn=train_dataset.batch_fn) test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, \ shuffle=False, collate_fn=train_dataset.batch_fn) model = GraphWriter(args) model.to(args.device) if args.test: model = torch.load(args.save_model) model.args = args print(model) test(model, test_dataloader, args) else: optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9) print(model) for epoch in range(args.epoch): train_one_epoch(model, train_dataloader, optimizer, args, epoch) eval_it(model, valid_dataloader, args, epoch) if __name__ == '__main__': args = get_args() main(args)