nnp_training.py 5.29 KB
Newer Older
1
2
3
4
import torch
import ignite
import torchani
import model
5
6
7
8
import tqdm
import timeit
import tensorboardX
import math
9
10
11
12
13
import argparse
import json

# parse command line arguments
parser = argparse.ArgumentParser()
14
15
16
17
18
parser.add_argument('training_path',
                    help='Path of the training set, can be a hdf5 file \
                          or a directory containing hdf5 files')
parser.add_argument('validation_path',
                    help='Path of the validation set, can be a hdf5 file \
19
20
21
22
23
24
                          or a directory containing hdf5 files')
parser.add_argument('--model_checkpoint',
                    help='Checkpoint file for model',
                    default='model.pt')
parser.add_argument('-m', '--max_epochs',
                    help='Maximum number of epoches',
25
                    default=300, type=int)
26
27
28
parser.add_argument('--training_rmse_every',
                    help='Compute training RMSE every epoches',
                    default=20, type=int)
29
30
31
parser.add_argument('-d', '--device',
                    help='Device of modules and tensors',
                    default=('cuda' if torch.cuda.is_available() else 'cpu'))
32
33
34
parser.add_argument('--batch_size',
                    help='Number of conformations of each batch',
                    default=1024, type=int)
35
36
37
38
39
40
41
42
43
parser.add_argument('--log',
                    help='Log directory for tensorboardX',
                    default=None)
parser.add_argument('--optimizer',
                    help='Optimizer used to train the model',
                    default='Adam')
parser.add_argument('--optim_args',
                    help='Arguments to optimizers, in the format of json',
                    default='{}')
44
45
46
parser.add_argument('--early_stopping',
                    help='Stop after epoches of no improvements',
                    default=math.inf, type=int)
47
48
49
50
51
parser = parser.parse_args()

# set up the training
device = torch.device(parser.device)
writer = tensorboardX.SummaryWriter(log_dir=parser.log)
52
start = timeit.default_timer()
53

Gao, Xiang's avatar
Gao, Xiang committed
54
nnp = model.get_or_create_model(parser.model_checkpoint, device=device)
55
56
57
training = torchani.data.BatchedANIDataset(
    parser.training_path, model.consts.species_to_tensor,
    parser.batch_size, device=device,
Gao, Xiang's avatar
Gao, Xiang committed
58
    transform=[model.shift_energy.subtract_from_dataset])
59
60
61
validation = torchani.data.BatchedANIDataset(
    parser.validation_path, model.consts.species_to_tensor,
    parser.batch_size, device=device,
Gao, Xiang's avatar
Gao, Xiang committed
62
    transform=[model.shift_energy.subtract_from_dataset])
63
container = torchani.ignite.Container({'energies': nnp})
64
65
66
67

parser.optim_args = json.loads(parser.optim_args)
optimizer = getattr(torch.optim, parser.optimizer)
optimizer = optimizer(nnp.parameters(), **parser.optim_args)
68

69
trainer = ignite.engine.create_supervised_trainer(
70
    container, optimizer, torchani.ignite.MSELoss('energies'))
71
evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
72
        'RMSE': torchani.ignite.RMSEMetric('energies')
73
74
75
    })


76
77
78
79
def hartree2kcal(x):
    return 627.509 * x


80
81
82
@trainer.on(ignite.engine.Events.STARTED)
def initialize(trainer):
    trainer.state.best_validation_rmse = math.inf
83
    trainer.state.no_improve_count = 0
84
85


86
87
88
89
90
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer):
    trainer.state.tqdm = tqdm.tqdm(total=len(training), desc='epoch')


91
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
92
93
def update_tqdm(trainer):
    trainer.state.tqdm.update(1)
94
95
96


@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
97
98
def finalize_tqdm(trainer):
    trainer.state.tqdm.close()
99
100


101
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
102
def validation_and_checkpoint(trainer):
103
104
105
106
107
108
109
110
111
112
113
114
115
    def evaluate(dataset, name):
        evaluator = ignite.engine.create_supervised_evaluator(
            container,
            metrics={
                'RMSE': torchani.ignite.RMSEMetric('energies')
            }
        )
        evaluator.run(dataset)
        metrics = evaluator.state.metrics
        rmse = hartree2kcal(metrics['RMSE'])
        writer.add_scalar(name, rmse, trainer.state.epoch)
        return rmse

116
    # compute validation RMSE
117
    rmse = evaluate(validation, 'validation_rmse_vs_epoch')
118
119

    # compute training RMSE
120
121
    if trainer.state.epoch % parser.training_rmse_every == 1:
        evaluate(training, 'training_rmse_vs_epoch')
122
123

    # handle best validation RMSE
124
    if rmse < trainer.state.best_validation_rmse:
125
        trainer.state.no_improve_count = 0
126
        trainer.state.best_validation_rmse = rmse
127
128
        writer.add_scalar('best_validation_rmse_vs_epoch', rmse,
                          trainer.state.epoch)
129
        torch.save(nnp.state_dict(), parser.model_checkpoint)
130
131
    else:
        trainer.state.no_improve_count += 1
132
133
134
    writer.add_scalar('no_improve_count_vs_epoch',
                      trainer.state.no_improve_count,
                      trainer.state.epoch)
135
136

    if trainer.state.no_improve_count > parser.early_stopping:
137
            trainer.terminate()
138
139
140
141
142
143
144
145
146
147
148


@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def log_time(trainer):
    elapsed = round(timeit.default_timer() - start, 2)
    writer.add_scalar('time_vs_epoch', elapsed, trainer.state.epoch)


@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def log_loss_and_time(trainer):
    iteration = trainer.state.iteration
149
    writer.add_scalar('loss_vs_iteration', trainer.state.output, iteration)
150
151


152
trainer.run(training, max_epochs=parser.max_epochs)