Unverified Commit f967877b authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by GitHub
Browse files

Update NN in training benchmark (#453)

ANI1x NNs are updated to be consistent with what we used in the paper
parent fec59ac3
...@@ -8,18 +8,45 @@ from torchani.units import hartree2kcalmol ...@@ -8,18 +8,45 @@ from torchani.units import hartree2kcalmol
synchronize = False synchronize = False
H_network = torch.nn.Sequential(
def atomic(): torch.nn.Linear(384, 160),
model = torch.nn.Sequential( torch.nn.CELU(0.1),
torch.nn.Linear(384, 128), torch.nn.Linear(160, 128),
torch.nn.CELU(0.1), torch.nn.CELU(0.1),
torch.nn.Linear(128, 128), torch.nn.Linear(128, 96),
torch.nn.CELU(0.1), torch.nn.CELU(0.1),
torch.nn.Linear(128, 64), torch.nn.Linear(96, 1)
torch.nn.CELU(0.1), )
torch.nn.Linear(64, 1)
) C_network = torch.nn.Sequential(
return model torch.nn.Linear(384, 144),
torch.nn.CELU(0.1),
torch.nn.Linear(144, 112),
torch.nn.CELU(0.1),
torch.nn.Linear(112, 96),
torch.nn.CELU(0.1),
torch.nn.Linear(96, 1)
)
N_network = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 112),
torch.nn.CELU(0.1),
torch.nn.Linear(112, 96),
torch.nn.CELU(0.1),
torch.nn.Linear(96, 1)
)
O_network = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 112),
torch.nn.CELU(0.1),
torch.nn.Linear(112, 96),
torch.nn.CELU(0.1),
torch.nn.Linear(96, 1)
)
def time_func(key, func): def time_func(key, func):
...@@ -71,7 +98,7 @@ if __name__ == "__main__": ...@@ -71,7 +98,7 @@ if __name__ == "__main__":
num_species = 4 num_species = 4
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species) aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
nn = torchani.ANIModel([atomic() for _ in range(4)]) nn = torchani.ANIModel([H_network, C_network, N_network, O_network])
model = torch.nn.Sequential(aev_computer, nn).to(parser.device) model = torch.nn.Sequential(aev_computer, nn).to(parser.device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.000001) optimizer = torch.optim.Adam(model.parameters(), lr=0.000001)
mse = torch.nn.MSELoss(reduction='none') mse = torch.nn.MSELoss(reduction='none')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment