Commit e1f08644 authored by lunar's avatar lunar Committed by Zihao Ye
Browse files

[Model] Schnet & MGCN (#726)

* [Model] SchNet

* [Model] SchNet model

* [Model] Schnet Model fix device-related bug

* [Model] SchNet fix bugs

* [Model] SchNet fix some bugs

* [Model] Schnet 🎨 code indent format

* [Model] SchNet  fix some typos
parent 747a8bee
# -*- coding:utf-8 -*-
"""Example dataloader of Tencent Alchemy Dataset
https://alchemy.tencent.com/
"""
import os
import zipfile
import os.path as osp
from rdkit import Chem
from rdkit.Chem import ChemicalFeatures
from rdkit import RDConfig
import dgl
from dgl.data.utils import download
import torch
from collections import defaultdict
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import pathlib
import pandas as pd
import numpy as np
_urls = {'Alchemy': 'https://alchemy.tencent.com/data/'}
class AlchemyBatcher:
def __init__(self, graph=None, label=None):
self.graph = graph
self.label = label
def batcher():
def batcher_dev(batch):
graphs, labels = zip(*batch)
batch_graphs = dgl.batch(graphs)
labels = torch.stack(labels, 0)
return AlchemyBatcher(graph=batch_graphs, label=labels)
return batcher_dev
class TencentAlchemyDataset(Dataset):
fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
chem_feature_factory = ChemicalFeatures.BuildFeatureFactory(fdef_name)
def alchemy_nodes(self, mol):
"""Featurization for all atoms in a molecule. The atom indices
will be preserved.
Args:
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
Returns
atom_feats_dict : dict
Dictionary for atom features
"""
atom_feats_dict = defaultdict(list)
is_donor = defaultdict(int)
is_acceptor = defaultdict(int)
fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
mol_featurizer = ChemicalFeatures.BuildFeatureFactory(fdef_name)
mol_feats = mol_featurizer.GetFeaturesForMol(mol)
mol_conformers = mol.GetConformers()
assert len(mol_conformers) == 1
geom = mol_conformers[0].GetPositions()
for i in range(len(mol_feats)):
if mol_feats[i].GetFamily() == 'Donor':
node_list = mol_feats[i].GetAtomIds()
for u in node_list:
is_donor[u] = 1
elif mol_feats[i].GetFamily() == 'Acceptor':
node_list = mol_feats[i].GetAtomIds()
for u in node_list:
is_acceptor[u] = 1
num_atoms = mol.GetNumAtoms()
for u in range(num_atoms):
atom = mol.GetAtomWithIdx(u)
symbol = atom.GetSymbol()
atom_type = atom.GetAtomicNum()
aromatic = atom.GetIsAromatic()
hybridization = atom.GetHybridization()
num_h = atom.GetTotalNumHs()
atom_feats_dict['pos'].append(torch.FloatTensor(geom[u]))
atom_feats_dict['node_type'].append(atom_type)
h_u = []
h_u += [
int(symbol == x) for x in ['H', 'C', 'N', 'O', 'F', 'S', 'Cl']
]
h_u.append(atom_type)
h_u.append(is_acceptor[u])
h_u.append(is_donor[u])
h_u.append(int(aromatic))
h_u += [
int(hybridization == x)
for x in (Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3)
]
h_u.append(num_h)
atom_feats_dict['n_feat'].append(torch.FloatTensor(h_u))
atom_feats_dict['n_feat'] = torch.stack(atom_feats_dict['n_feat'],
dim=0)
atom_feats_dict['pos'] = torch.stack(atom_feats_dict['pos'], dim=0)
atom_feats_dict['node_type'] = torch.LongTensor(
atom_feats_dict['node_type'])
return atom_feats_dict
def alchemy_edges(self, mol, self_loop=True):
"""Featurization for all bonds in a molecule. The bond indices
will be preserved.
Args:
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
Returns
bond_feats_dict : dict
Dictionary for bond features
"""
bond_feats_dict = defaultdict(list)
mol_conformers = mol.GetConformers()
assert len(mol_conformers) == 1
geom = mol_conformers[0].GetPositions()
num_atoms = mol.GetNumAtoms()
for u in range(num_atoms):
for v in range(num_atoms):
if u == v and not self_loop:
continue
e_uv = mol.GetBondBetweenAtoms(u, v)
if e_uv is None:
bond_type = None
else:
bond_type = e_uv.GetBondType()
bond_feats_dict['e_feat'].append([
float(bond_type == x)
for x in (Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC, None)
])
bond_feats_dict['distance'].append(
np.linalg.norm(geom[u] - geom[v]))
bond_feats_dict['e_feat'] = torch.FloatTensor(
bond_feats_dict['e_feat'])
bond_feats_dict['distance'] = torch.FloatTensor(
bond_feats_dict['distance']).reshape(-1, 1)
return bond_feats_dict
def sdf_to_dgl(self, sdf_file, self_loop=False):
"""
Read sdf file and convert to dgl_graph
Args:
sdf_file: path of sdf file
self_loop: Whetaher to add self loop
Returns:
g: DGLGraph
l: related labels
"""
sdf = open(str(sdf_file)).read()
mol = Chem.MolFromMolBlock(sdf, removeHs=False)
g = dgl.DGLGraph()
# add nodes
num_atoms = mol.GetNumAtoms()
atom_feats = self.alchemy_nodes(mol)
g.add_nodes(num=num_atoms, data=atom_feats)
# add edges
# The model we were interested assumes a complete graph.
# If this is not the case, do the code below instead
#
# for bond in mol.GetBonds():
# u = bond.GetBeginAtomIdx()
# v = bond.GetEndAtomIdx()
if self_loop:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms)],
[j for i in range(num_atoms) for j in range(num_atoms)])
else:
g.add_edges(
[i for i in range(num_atoms) for j in range(num_atoms - 1)], [
j for i in range(num_atoms)
for j in range(num_atoms) if i != j
])
bond_feats = self.alchemy_edges(mol, self_loop)
g.edata.update(bond_feats)
# for val/test set, labels are molecule ID
l = torch.FloatTensor(self.target.loc[int(sdf_file.stem)].tolist()) \
if self.mode == 'dev' else torch.LongTensor([int(sdf_file.stem)])
return (g, l)
def __init__(self, mode='dev', transform=None):
assert mode in ['dev', 'valid',
'test'], "mode should be dev/valid/test"
self.mode = mode
self.transform = transform
self.file_dir = pathlib.Path('./Alchemy_data', mode)
self.zip_file_path = pathlib.Path('./Alchemy_data', '%s.zip' % mode)
download(_urls['Alchemy'] + "%s.zip" % mode,
path=str(self.zip_file_path))
if not os.path.exists(str(self.file_dir)):
archive = zipfile.ZipFile(str(self.zip_file_path))
archive.extractall('./Alchemy_data')
archive.close()
self._load()
def _load(self):
if self.mode == 'dev':
target_file = pathlib.Path(self.file_dir, "dev_target.csv")
self.target = pd.read_csv(target_file,
index_col=0,
usecols=[
'gdb_idx',
] +
['property_%d' % x for x in range(12)])
self.target = self.target[['property_%d' % x for x in range(12)]]
sdf_dir = pathlib.Path(self.file_dir, "sdf")
self.graphs, self.labels = [], []
for sdf_file in sdf_dir.glob("**/*.sdf"):
result = self.sdf_to_dgl(sdf_file)
if result is None:
continue
self.graphs.append(result[0])
self.labels.append(result[1])
self.normalize()
print(len(self.graphs), "loaded!")
def normalize(self, mean=None, std=None):
labels = np.array([i.numpy() for i in self.labels])
if mean is None:
mean = np.mean(labels, axis=0)
if std is None:
std = np.std(labels, axis=0)
self.mean = mean
self.std = std
def __len__(self):
return len(self.graphs)
def __getitem__(self, idx):
g, l = self.graphs[idx], self.labels[idx]
if self.transform:
g = self.transform(g)
return g, l
if __name__ == '__main__':
alchemy_dataset = TencentAlchemyDataset()
device = torch.device('cpu')
# To speed up the training with multi-process data loader,
# the num_workers could be set to > 1.
alchemy_loader = DataLoader(dataset=alchemy_dataset,
batch_size=20,
collate_fn=batcher(),
shuffle=False,
num_workers=0)
for step, batch in enumerate(alchemy_loader):
print("bs =", batch.graph.batch_size)
print('feature size =', batch.graph.ndata['n_feat'].size())
print('pos size =', batch.graph.ndata['pos'].size())
print('edge feature size =', batch.graph.edata['e_feat'].size())
print('edge distance size =', batch.graph.edata['distance'].size())
print('label size=', batch.label.size())
print(dgl.sum_nodes(batch.graph, 'n_feat').size())
break
SchNet & MGCN
============
- K.T. Schütt. P.-J. Kindermans, H. E. Sauceda, S. Chmiela, A. Tkatchenko, K.-R. Müller.
SchNet: A continuous-filter convolutional neural network for modeling quantum interactions. Advances in Neural Information Processing Systems 30, pp. 992-1002 (2017) [link](http://papers.nips.cc/paper/6700-schnet-a-continuous-filter-convolutional-neural-network-for-modeling-quantum-interactions)
- C. Lu, Q. Liu, C. Wang, Z. Huang, P. Lin, L. He, Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective. The 33rd AAAI Conference on Artificial Intelligence (2019) [link](https://arxiv.org/abs/1906.11081)
Dependencies
------------
- PyTorch 1.0+
- dgl 0.3+
- RDKit (If use [Alchemy dataset](https://arxiv.org/abs/1906.09427).)
Usage
-----
Example usage on Alchemy dataset:
+ SchNet: expected MAE 0.065
```py
python train.py --model sch --epochs 250
```
+ MGCN: expected MAE 0.050
```py
python train.py --model mgcn --epochs 250
```
*With Tesla V100, SchNet takes 80s/epoch and MGCN takes 110s/epoch.*
Codes
-----
The folder contains five python files:
- `sch.py` the implementation of SchNet model.
- `mgcn.py` the implementation of Multilevel Graph Convolutional Network(MGCN).
- `layers.py` layers contained by two models above.
- `Alchemy_dataset.py` example dataloader of [Tencent Alchemy](https://alchemy.tencent.com) dataset.
- `train.py` example training code.
Modify `train.py` to switch between different implementations.
# -*- coding: utf-8 -*-
import torch as th
import numpy as np
import torch.nn as nn
import dgl.function as fn
from torch.nn import Softplus
class AtomEmbedding(nn.Module):
"""
Convert the atom(node) list to atom embeddings.
The atom with the same element share the same initial embeddding.
"""
def __init__(self, dim=128, type_num=100, pre_train=None):
"""
Randomly init the element embeddings.
Args:
dim: the dim of embeddings
type_num: the largest atomic number of atoms in the dataset
pre_train: the pre_trained embeddings
"""
super().__init__()
self._dim = dim
self._type_num = type_num
if pre_train is not None:
self.embedding = nn.Embedding.from_pretrained(pre_train,
padding_idx=0)
else:
self.embedding = nn.Embedding(type_num, dim, padding_idx=0)
def forward(self, g, p_name="node"):
"""Input type is dgl graph"""
atom_list = g.ndata["node_type"]
g.ndata[p_name] = self.embedding(atom_list)
return g.ndata[p_name]
class EdgeEmbedding(nn.Module):
"""
Convert the edge to embedding.
The edge links same pair of atoms share the same initial embedding.
"""
def __init__(self, dim=128, edge_num=3000, pre_train=None):
"""
Randomly init the edge embeddings.
Args:
dim: the dim of embeddings
edge_num: the maximum type of edges
pre_train: the pre_trained embeddings
"""
super().__init__()
self._dim = dim
self._edge_num = edge_num
if pre_train is not None:
self.embedding = nn.Embedding.from_pretrained(pre_train,
padding_idx=0)
else:
self.embedding = nn.Embedding(edge_num, dim, padding_idx=0)
def generate_edge_type(self, edges):
"""
Generate the edge type based on the src&dst atom type of the edge.
Note that C-O and O-C are the same edge type.
To map a pair of nodes to one number, we use an unordered pairing function here
See more detail in this disscussion:
https://math.stackexchange.com/questions/23503/create-unique-number-from-2-numbers
Note that, the edge_num should larger than the square of maximum atomic number
in the dataset.
"""
atom_type_x = edges.src["node_type"]
atom_type_y = edges.dst["node_type"]
return {
"type":
atom_type_x * atom_type_y +
(th.abs(atom_type_x - atom_type_y) - 1)**2 / 4
}
def forward(self, g, p_name="edge_f"):
g.apply_edges(self.generate_edge_type)
g.edata[p_name] = self.embedding(g.edata["type"])
return g.edata[p_name]
class ShiftSoftplus(Softplus):
"""
Shiftsoft plus activation function:
1/beta * (log(1 + exp**(beta * x)) - log(shift))
"""
def __init__(self, beta=1, shift=2, threshold=20):
super().__init__(beta, threshold)
self.shift = shift
self.softplus = Softplus(beta, threshold)
def forward(self, input):
return self.softplus(input) - np.log(float(self.shift))
class RBFLayer(nn.Module):
"""
Radial basis functions Layer.
e(d) = exp(- gamma * ||d - mu_k||^2)
default settings:
gamma = 10
0 <= mu_k <= 30 for k=1~300
"""
def __init__(self, low=0, high=30, gap=0.1, dim=1):
super().__init__()
self._low = low
self._high = high
self._gap = gap
self._dim = dim
self._n_centers = int(np.ceil((high - low) / gap))
centers = np.linspace(low, high, self._n_centers)
self.centers = th.tensor(centers, dtype=th.float, requires_grad=False)
self.centers = nn.Parameter(self.centers, requires_grad=False)
self._fan_out = self._dim * self._n_centers
self._gap = centers[1] - centers[0]
def dis2rbf(self, edges):
dist = edges.data["distance"]
radial = dist - self.centers
coef = -1 / self._gap
rbf = th.exp(coef * (radial**2))
return {"rbf": rbf}
def forward(self, g):
"""Convert distance scalar to rbf vector"""
g.apply_edges(self.dis2rbf)
return g.edata["rbf"]
class CFConv(nn.Module):
"""
The continuous-filter convolution layer in SchNet.
One CFConv contains one rbf layer and three linear layer
(two of them have activation funct).
"""
def __init__(self, rbf_dim, dim=64, act="sp"):
"""
Args:
rbf_dim: the dimsion of the RBF layer
dim: the dimension of linear layers
act: activation function (default shifted softplus)
"""
super().__init__()
self._rbf_dim = rbf_dim
self._dim = dim
self.linear_layer1 = nn.Linear(self._rbf_dim, self._dim)
self.linear_layer2 = nn.Linear(self._dim, self._dim)
if act == "sp":
self.activation = nn.Softplus(beta=0.5, threshold=14)
else:
self.activation = act
def update_edge(self, edges):
rbf = edges.data["rbf"]
h = self.linear_layer1(rbf)
h = self.activation(h)
h = self.linear_layer2(h)
return {"h": h}
def forward(self, g):
g.apply_edges(self.update_edge)
g.update_all(message_func=fn.u_mul_e('new_node', 'h', 'neighbor_info'),
reduce_func=fn.sum('neighbor_info', 'new_node'))
return g.ndata["new_node"]
class Interaction(nn.Module):
"""
The interaction layer in the SchNet model.
"""
def __init__(self, rbf_dim, dim):
super().__init__()
self._node_dim = dim
self.activation = nn.Softplus(beta=0.5, threshold=14)
self.node_layer1 = nn.Linear(dim, dim, bias=False)
self.cfconv = CFConv(rbf_dim, dim, act=self.activation)
self.node_layer2 = nn.Linear(dim, dim)
self.node_layer3 = nn.Linear(dim, dim)
def forward(self, g):
g.ndata["new_node"] = self.node_layer1(g.ndata["node"])
cf_node = self.cfconv(g)
cf_node_1 = self.node_layer2(cf_node)
cf_node_1a = self.activation(cf_node_1)
new_node = self.node_layer3(cf_node_1a)
g.ndata["node"] = g.ndata["node"] + new_node
return g.ndata["node"]
class VEConv(nn.Module):
"""
The Vertex-Edge convolution layer in MGCN which take edge & vertex features
in consideratoin at the same time.
"""
def __init__(self, rbf_dim, dim=64, update_edge=True):
"""
Args:
rbf_dim: the dimension of the RBF layer
dim: the dimension of linear layers
update_edge: whether update the edge emebedding in each conv-layer
"""
super().__init__()
self._rbf_dim = rbf_dim
self._dim = dim
self._update_edge = update_edge
self.linear_layer1 = nn.Linear(self._rbf_dim, self._dim)
self.linear_layer2 = nn.Linear(self._dim, self._dim)
self.linear_layer3 = nn.Linear(self._dim, self._dim)
self.activation = nn.Softplus(beta=0.5, threshold=14)
def update_rbf(self, edges):
rbf = edges.data["rbf"]
h = self.linear_layer1(rbf)
h = self.activation(h)
h = self.linear_layer2(h)
return {"h": h}
def update_edge(self, edges):
edge_f = edges.data["edge_f"]
h = self.linear_layer3(edge_f)
return {"edge_f": h}
def forward(self, g):
g.apply_edges(self.update_rbf)
if self._update_edge:
g.apply_edges(self.update_edge)
g.update_all(
message_func=[
fn.u_mul_e("new_node", "h", "m_0"),
fn.copy_e("edge_f", "m_1")],
reduce_func=[
fn.sum("m_0", "new_node_0"),
fn.sum("m_1", "new_node_1")])
g.ndata["new_node"] = g.ndata.pop("new_node_0") + g.ndata.pop(
"new_node_1")
return g.ndata["new_node"]
class MultiLevelInteraction(nn.Module):
"""
The multilevel interaction in the MGCN model.
"""
def __init__(self, rbf_dim, dim):
super().__init__()
self._atom_dim = dim
self.activation = nn.Softplus(beta=0.5, threshold=14)
self.node_layer1 = nn.Linear(dim, dim, bias=True)
self.edge_layer1 = nn.Linear(dim, dim, bias=True)
self.conv_layer = VEConv(rbf_dim, dim)
self.node_layer2 = nn.Linear(dim, dim)
self.node_layer3 = nn.Linear(dim, dim)
def forward(self, g, level=1):
g.ndata["new_node"] = self.node_layer1(g.ndata["node_%s" %
(level - 1)])
node = self.conv_layer(g)
g.edata["edge_f"] = self.activation(self.edge_layer1(
g.edata["edge_f"]))
node_1 = self.node_layer2(node)
node_1a = self.activation(node_1)
new_node = self.node_layer3(node_1a)
g.ndata["node_%s" % (level)] = g.ndata["node_%s" %
(level - 1)] + new_node
return g.ndata["node_%s" % (level)]
# -*- coding:utf-8 -*-
import dgl
import torch as th
import torch.nn as nn
from layers import AtomEmbedding, RBFLayer, EdgeEmbedding, \
MultiLevelInteraction
class MGCNModel(nn.Module):
"""
MGCN Model from:
Chengqiang Lu, et al.
Molecular Property Prediction: A Multilevel
Quantum Interactions Modeling Perspective. (AAAI'2019)
"""
def __init__(self,
dim=128,
output_dim=1,
edge_dim=128,
cutoff=5.0,
width=1,
n_conv=3,
norm=False,
atom_ref=None,
pre_train=None):
"""
Args:
dim: dimension of feature maps
out_put_dim: the num of target propperties to predict
edge_dim: dimension of edge feature
cutoff: the maximum distance between nodes
width: width in the RBF layer
n_conv: number of convolutional layers
norm: normalization
atom_ref: atom reference
used as the initial value of atom embeddings,
or set to None with random initialization
pre_train: pre_trained node embeddings
"""
super().__init__()
self.name = "MGCN"
self._dim = dim
self.output_dim = output_dim
self.edge_dim = edge_dim
self.cutoff = cutoff
self.width = width
self.n_conv = n_conv
self.atom_ref = atom_ref
self.norm = norm
self.activation = nn.Softplus(beta=1, threshold=20)
if atom_ref is not None:
self.e0 = AtomEmbedding(1, pre_train=atom_ref)
if pre_train is None:
self.embedding_layer = AtomEmbedding(dim)
else:
self.embedding_layer = AtomEmbedding(pre_train=pre_train)
self.edge_embedding_layer = EdgeEmbedding(dim=edge_dim)
self.rbf_layer = RBFLayer(0, cutoff, width)
self.conv_layers = nn.ModuleList([
MultiLevelInteraction(self.rbf_layer._fan_out, dim)
for i in range(n_conv)
])
self.node_dense_layer1 = nn.Linear(dim * (self.n_conv + 1), 64)
self.node_dense_layer2 = nn.Linear(64, output_dim)
def set_mean_std(self, mean, std, device):
self.mean_per_node = th.tensor(mean, device=device)
self.std_per_node = th.tensor(std, device=device)
def forward(self, g):
self.embedding_layer(g, "node_0")
if self.atom_ref is not None:
self.e0(g, "e0")
self.rbf_layer(g)
self.edge_embedding_layer(g)
for idx in range(self.n_conv):
self.conv_layers[idx](g, idx + 1)
node_embeddings = tuple(g.ndata["node_%d" % (i)]
for i in range(self.n_conv + 1))
g.ndata["node"] = th.cat(node_embeddings, 1)
# concat multilevel representations
node = self.node_dense_layer1(g.ndata["node"])
node = self.activation(node)
res = self.node_dense_layer2(node)
g.ndata["res"] = res
if self.atom_ref is not None:
g.ndata["res"] = g.ndata["res"] + g.ndata["e0"]
if self.norm:
g.ndata["res"] = g.ndata[
"res"] * self.std_per_node + self.mean_per_node
res = dgl.sum_nodes(g, "res")
return res
if __name__ == "__main__":
g = dgl.DGLGraph()
g.add_nodes(2)
g.add_edges([0, 0, 1, 1], [1, 0, 1, 0])
g.edata["distance"] = th.tensor([1.0, 3.0, 2.0, 4.0]).reshape(-1, 1)
g.ndata["node_type"] = th.LongTensor([1, 2])
model = MGCNModel(dim=2, edge_dim=2)
node = model(g)
print(node)
# -*- coding:utf-8 -*-
import dgl
import torch as th
import torch.nn as nn
from layers import AtomEmbedding, Interaction, ShiftSoftplus, RBFLayer
class SchNetModel(nn.Module):
"""
SchNet Model from:
Schütt, Kristof, et al.
SchNet: A continuous-filter convolutional neural network
for modeling quantum interactions. (NIPS'2017)
"""
def __init__(self,
dim=64,
cutoff=5.0,
output_dim=1,
width=1,
n_conv=3,
norm=False,
atom_ref=None,
pre_train=None):
"""
Args:
dim: dimension of features
output_dim: dimension of prediction
cutoff: radius cutoff
width: width in the RBF function
n_conv: number of interaction layers
atom_ref: used as the initial value of atom embeddings,
or set to None with random initialization
norm: normalization
"""
super().__init__()
self.name = "SchNet"
self._dim = dim
self.cutoff = cutoff
self.width = width
self.n_conv = n_conv
self.atom_ref = atom_ref
self.norm = norm
self.activation = ShiftSoftplus()
if atom_ref is not None:
self.e0 = AtomEmbedding(1, pre_train=atom_ref)
if pre_train is None:
self.embedding_layer = AtomEmbedding(dim)
else:
self.embedding_layer = AtomEmbedding(pre_train=pre_train)
self.rbf_layer = RBFLayer(0, cutoff, width)
self.conv_layers = nn.ModuleList(
[Interaction(self.rbf_layer._fan_out, dim) for i in range(n_conv)])
self.atom_dense_layer1 = nn.Linear(dim, 64)
self.atom_dense_layer2 = nn.Linear(64, output_dim)
def set_mean_std(self, mean, std, device="cpu"):
self.mean_per_atom = th.tensor(mean, device=device)
self.std_per_atom = th.tensor(std, device=device)
def forward(self, g):
"""g is the DGL.graph"""
self.embedding_layer(g)
if self.atom_ref is not None:
self.e0(g, "e0")
self.rbf_layer(g)
for idx in range(self.n_conv):
self.conv_layers[idx](g)
atom = self.atom_dense_layer1(g.ndata["node"])
atom = self.activation(atom)
res = self.atom_dense_layer2(atom)
g.ndata["res"] = res
if self.atom_ref is not None:
g.ndata["res"] = g.ndata["res"] + g.ndata["e0"]
if self.norm:
g.ndata["res"] = g.ndata[
"res"] * self.std_per_atom + self.mean_per_atom
res = dgl.sum_nodes(g, "res")
return res
if __name__ == "__main__":
g = dgl.DGLGraph()
g.add_nodes(2)
g.add_edges([0, 0, 1, 1], [1, 0, 1, 0])
g.edata["distance"] = th.tensor([1.0, 3.0, 2.0, 4.0]).reshape(-1, 1)
g.ndata["node_type"] = th.LongTensor([1, 2])
model = SchNetModel(dim=2)
atom = model(g)
# -*- coding:utf-8 -*-
"""Sample training code
"""
import argparse
import torch as th
import torch.nn as nn
from sch import SchNetModel
from mgcn import MGCNModel
from torch.utils.data import DataLoader
from Alchemy_dataset import TencentAlchemyDataset, batcher
def train(model="sch", epochs=80, device=th.device("cpu")):
print("start")
alchemy_dataset = TencentAlchemyDataset()
alchemy_loader = DataLoader(dataset=alchemy_dataset,
batch_size=20,
collate_fn=batcher(),
shuffle=False,
num_workers=0)
if model == "sch":
model = SchNetModel(norm=True, output_dim=12)
elif model == "mgcn":
model = MGCNModel(norm=True, output_dim=12)
model.set_mean_std(alchemy_dataset.mean, alchemy_dataset.std, device)
model.to(device)
loss_fn = nn.MSELoss()
MAE_fn = nn.L1Loss()
optimizer = th.optim.Adam(model.parameters(), lr=0.0001)
for epoch in range(epochs):
w_loss, w_mae = 0, 0
model.train()
for idx, batch in enumerate(alchemy_loader):
batch.graph.to(device)
batch.label = batch.label.to(device)
res = model(batch.graph)
loss = loss_fn(res, batch.label)
mae = MAE_fn(res, batch.label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
w_mae += mae.detach().item()
w_loss += loss.detach().item()
w_mae /= idx + 1
print("Epoch {:2d}, loss: {:.7f}, mae: {:.7f}".format(
epoch, w_loss, w_mae))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-M",
"--model",
help="model name (sch or mgcn)",
default="sch")
parser.add_argument("--epochs", help="number of epochs", default=250)
device = th.device('cuda:0' if th.cuda.is_available() else 'cpu')
args = parser.parse_args()
train(args.model, int(args.epochs), device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment