benchmark.py 1.12 KB
Newer Older
1
import torchani
Xiang Gao's avatar
Xiang Gao committed
2
3


4
class ANIBenchmark:
Xiang Gao's avatar
Xiang Gao committed
5
6

    def __init__(self, device):
7
8
9
10
        super(ANIBenchmark, self).__init__(device)
        self.aev_computer = torchani.SortedAEV(device=device)
        self.model = torchani.ModelOnAEV(
            self.aev_computer, benchmark=True, derivative=True, from_nc=None)
Xiang Gao's avatar
Xiang Gao committed
11
12

    def oneByOne(self, coordinates, species):
13
14
15
16
17
18
19
20
21
22
23
24
        conformations = coordinates.shape[0]
        coordinates = coordinates.to(self.device)
        for i in range(conformations):
            c = coordinates[i:i+1, :, :]
            self.model(c, species)
        ret = {
            'aev': self.model.timers['aev'],
            'energy': self.model.timers['nn'],
            'force': self.model.timers['derivative']
        }
        self.model.reset_timers()
        return ret
Xiang Gao's avatar
Xiang Gao committed
25
26

    def inBatch(self, coordinates, species):
27
28
29
30
31
32
33
34
35
        coordinates = coordinates.to(self.device)
        self.model(coordinates, species)
        ret = {
            'aev': self.model.timers['aev'],
            'energy': self.model.timers['nn'],
            'force': self.model.timers['derivative']
        }
        self.model.reset_timers()
        return ret