Unverified Commit 06cdce78 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by GitHub
Browse files

get AEV dimension from aev_computer (#442)

parent 62eb2236
......@@ -92,9 +92,10 @@ print('Self atomic energies: ', energy_shifter.self_energies)
#
###############################################################################
# Now let's define atomic neural networks.
aev_dim = aev_computer.aev_length
H_network = torch.nn.Sequential(
torch.nn.Linear(384, 160),
torch.nn.Linear(aev_dim, 160),
torch.nn.CELU(0.1),
torch.nn.Linear(160, 128),
torch.nn.CELU(0.1),
......@@ -104,7 +105,7 @@ H_network = torch.nn.Sequential(
)
C_network = torch.nn.Sequential(
torch.nn.Linear(384, 144),
torch.nn.Linear(aev_dim, 144),
torch.nn.CELU(0.1),
torch.nn.Linear(144, 112),
torch.nn.CELU(0.1),
......@@ -114,7 +115,7 @@ C_network = torch.nn.Sequential(
)
N_network = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.Linear(aev_dim, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 112),
torch.nn.CELU(0.1),
......@@ -124,7 +125,7 @@ N_network = torch.nn.Sequential(
)
O_network = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.Linear(aev_dim, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 112),
torch.nn.CELU(0.1),
......
......@@ -57,9 +57,10 @@ print('Self atomic energies: ', energy_shifter.self_energies)
###############################################################################
# The code to define networks, optimizers, are mostly the same
aev_dim = aev_computer.aev_length
H_network = torch.nn.Sequential(
torch.nn.Linear(384, 160),
torch.nn.Linear(aev_dim, 160),
torch.nn.CELU(0.1),
torch.nn.Linear(160, 128),
torch.nn.CELU(0.1),
......@@ -69,7 +70,7 @@ H_network = torch.nn.Sequential(
)
C_network = torch.nn.Sequential(
torch.nn.Linear(384, 144),
torch.nn.Linear(aev_dim, 144),
torch.nn.CELU(0.1),
torch.nn.Linear(144, 112),
torch.nn.CELU(0.1),
......@@ -79,7 +80,7 @@ C_network = torch.nn.Sequential(
)
N_network = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.Linear(aev_dim, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 112),
torch.nn.CELU(0.1),
......@@ -89,7 +90,7 @@ N_network = torch.nn.Sequential(
)
O_network = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.Linear(aev_dim, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 112),
torch.nn.CELU(0.1),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment