"megatron/vscode:/vscode.git/clone" did not exist on "f17a39337675ebd0d8a4a5b3246e7b86402f70ef"
nnp_training.py 3.88 KB
Newer Older
1
2
3
4
import torch
import ignite
import torchani
import model
5
6
7
8
import tqdm
import timeit
import tensorboardX
import math
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import argparse
import json

# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path',
                    help='Path of the dataset, can a hdf5 file \
                          or a directory containing hdf5 files')
parser.add_argument('--dataset_checkpoint',
                    help='Checkpoint file for datasets',
                    default='dataset-checkpoint.dat')
parser.add_argument('--model_checkpoint',
                    help='Checkpoint file for model',
                    default='model.pt')
parser.add_argument('-m', '--max_epochs',
                    help='Maximum number of epoches',
                    default=10, type=int)
parser.add_argument('-d', '--device',
                    help='Device of modules and tensors',
                    default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser.add_argument('--chunk_size',
                    help='Number of conformations of each chunk',
                    default=256, type=int)
parser.add_argument('--batch_chunks',
                    help='Number of chunks in each minibatch',
                    default=4, type=int)
parser.add_argument('--log',
                    help='Log directory for tensorboardX',
                    default=None)
parser.add_argument('--optimizer',
                    help='Optimizer used to train the model',
                    default='Adam')
parser.add_argument('--optim_args',
                    help='Arguments to optimizers, in the format of json',
                    default='{}')
parser = parser.parse_args()

# set up the training
device = torch.device(parser.device)
writer = tensorboardX.SummaryWriter(log_dir=parser.log)
49
start = timeit.default_timer()
50

Gao, Xiang's avatar
Gao, Xiang committed
51
52
nnp, shift_energy = model.get_or_create_model('/tmp/model.pt',
                                              True, device=device)
53
training, validation, testing = torchani.data.load_or_create(
54
    parser.dataset_checkpoint, parser.dataset_path, parser.chunk_size,
Gao, Xiang's avatar
Gao, Xiang committed
55
    device=device, transform=[shift_energy.subtract_from_dataset])
56
57
training = torchani.data.dataloader(training, parser.batch_chunks)
validation = torchani.data.dataloader(validation, parser.batch_chunks)
58
container = torchani.ignite.Container({'energies': nnp})
59
60
61
62

parser.optim_args = json.loads(parser.optim_args)
optimizer = getattr(torch.optim, parser.optimizer)
optimizer = optimizer(nnp.parameters(), **parser.optim_args)
63

64
trainer = ignite.engine.create_supervised_trainer(
65
    container, optimizer, torchani.ignite.MSELoss('energies'))
66
evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
67
        'RMSE': torchani.ignite.RMSEMetric('energies')
68
69
70
    })


71
72
73
74
75
76
77
78
79
def hartree2kcal(x):
    return 627.509 * x


@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer):
    trainer.state.tqdm = tqdm.tqdm(total=len(training), desc='epoch')


80
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
81
82
def update_tqdm(trainer):
    trainer.state.tqdm.update(1)
83
84
85


@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
86
87
def finalize_tqdm(trainer):
    trainer.state.tqdm.close()
88
89


90
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
91
92
93
def log_validation_results(trainer):
    evaluator.run(validation)
    metrics = evaluator.state.metrics
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    rmse = hartree2kcal(metrics['RMSE'])
    writer.add_scalar('validation_rmse_vs_epoch', rmse, trainer.state.epoch)


@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def log_time(trainer):
    elapsed = round(timeit.default_timer() - start, 2)
    writer.add_scalar('time_vs_epoch', elapsed, trainer.state.epoch)


@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def log_loss_and_time(trainer):
    iteration = trainer.state.iteration
    rmse = hartree2kcal(math.sqrt(trainer.state.output))
    writer.add_scalar('training_rmse_vs_iteration', rmse, iteration)
109
110


111
trainer.run(training, max_epochs=parser.max_epochs)