training-benchmark-nsys-profile.py 5.91 KB
Newer Older
1
2
3
4
import torch
import torchani
import argparse
import pkbar
Ignacio Pickering's avatar
Ignacio Pickering committed
5
from torchani.units import hartree2kcalmol
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57


WARM_UP_BATCHES = 50
PROFILE_BATCHES = 10


def atomic():
    model = torch.nn.Sequential(
        torch.nn.Linear(384, 128),
        torch.nn.CELU(0.1),
        torch.nn.Linear(128, 128),
        torch.nn.CELU(0.1),
        torch.nn.Linear(128, 64),
        torch.nn.CELU(0.1),
        torch.nn.Linear(64, 1)
    )
    return model


def time_func(key, func):

    def wrapper(*args, **kwargs):
        torch.cuda.nvtx.range_push(key)
        ret = func(*args, **kwargs)
        torch.cuda.nvtx.range_pop()
        return ret

    return wrapper


def enable_timers(model):
    torchani.aev.cutoff_cosine = time_func('cutoff_cosine', torchani.aev.cutoff_cosine)
    torchani.aev.radial_terms = time_func('radial_terms', torchani.aev.radial_terms)
    torchani.aev.angular_terms = time_func('angular_terms', torchani.aev.angular_terms)
    torchani.aev.compute_shifts = time_func('compute_shifts', torchani.aev.compute_shifts)
    torchani.aev.neighbor_pairs = time_func('neighbor_pairs', torchani.aev.neighbor_pairs)
    torchani.aev.triu_index = time_func('triu_index', torchani.aev.triu_index)
    torchani.aev.cumsum_from_zero = time_func('cumsum_from_zero', torchani.aev.cumsum_from_zero)
    torchani.aev.triple_by_molecule = time_func('triple_by_molecule', torchani.aev.triple_by_molecule)
    torchani.aev.compute_aev = time_func('compute_aev', torchani.aev.compute_aev)
    model[1].forward = time_func('nn forward', model[1].forward)


if __name__ == "__main__":
    # parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('dataset_path',
                        help='Path of the dataset, can a hdf5 file \
                            or a directory containing hdf5 files')
    parser.add_argument('-b', '--batch_size',
                        help='Number of conformations of each batch',
                        default=2560, type=int)
58
59
60
    parser.add_argument('-d', '--dry-run',
                        help='just run it in a CI without GPU',
                        action='store_true')
61
    parser = parser.parse_args()
62
    parser.device = torch.device('cpu' if parser.dry_run else 'cuda')
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

    Rcr = 5.2000e+00
    Rca = 3.5000e+00
    EtaR = torch.tensor([1.6000000e+01], device=parser.device)
    ShfR = torch.tensor([9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00], device=parser.device)
    Zeta = torch.tensor([3.2000000e+01], device=parser.device)
    ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=parser.device)
    EtaA = torch.tensor([8.0000000e+00], device=parser.device)
    ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=parser.device)
    num_species = 4
    aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)

    nn = torchani.ANIModel([atomic() for _ in range(4)])
    model = torch.nn.Sequential(aev_computer, nn).to(parser.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.000001)
    mse = torch.nn.MSELoss(reduction='none')

80
81
82
    print('=> loading dataset...')
    shifter = torchani.EnergyShifter(None)
    dataset = list(torchani.data.load(parser.dataset_path).subtract_self_energies(shifter).species_to_indices().shuffle().collate(parser.batch_size))
83
84
85
86
87
88
89
90

    print('=> start warming up')
    total_batch_counter = 0
    for epoch in range(0, WARM_UP_BATCHES + 1):

        print('Epoch: %d/inf' % (epoch + 1,))
        progbar = pkbar.Kbar(target=len(dataset) - 1, width=8)

91
        for i, properties in enumerate(dataset):
92

93
            if not parser.dry_run and total_batch_counter == WARM_UP_BATCHES:
94
95
96
97
                print('=> warm up finished, start profiling')
                enable_timers(model)
                torch.cuda.cudart().cudaProfilerStart()

98
            PROFILING_STARTED = not parser.dry_run and (total_batch_counter >= WARM_UP_BATCHES)
99
100

            if PROFILING_STARTED:
101
102
                torch.cuda.nvtx.range_push("batch{}".format(total_batch_counter))

103
104
105
106
107
108
            species = properties['species'].to(parser.device)
            coordinates = properties['coordinates'].to(parser.device).float()
            true_energies = properties['energies'].to(parser.device).float()
            num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
            with torch.autograd.profiler.emit_nvtx(enabled=PROFILING_STARTED, record_shapes=True):
                _, predicted_energies = model((species, coordinates))
109
            loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
Ignacio Pickering's avatar
Ignacio Pickering committed
110
            rmse = hartree2kcalmol((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy()
111

112
            if PROFILING_STARTED:
113
                torch.cuda.nvtx.range_push("backward")
114
115
116
            with torch.autograd.profiler.emit_nvtx(enabled=PROFILING_STARTED, record_shapes=True):
                loss.backward()
            if PROFILING_STARTED:
117
118
                torch.cuda.nvtx.range_pop()

119
            if PROFILING_STARTED:
120
                torch.cuda.nvtx.range_push("optimizer.step()")
121
122
123
            with torch.autograd.profiler.emit_nvtx(enabled=PROFILING_STARTED, record_shapes=True):
                optimizer.step()
            if PROFILING_STARTED:
124
125
126
127
                torch.cuda.nvtx.range_pop()

            progbar.update(i, values=[("rmse", rmse)])

128
            if PROFILING_STARTED:
129
130
131
132
133
134
135
136
137
                torch.cuda.nvtx.range_pop()

            total_batch_counter += 1
            if total_batch_counter > WARM_UP_BATCHES + PROFILE_BATCHES:
                break

        if total_batch_counter > WARM_UP_BATCHES + PROFILE_BATCHES:
            print('=> profiling terminate after {} batches'.format(PROFILE_BATCHES))
            break