".github/vscode:/vscode.git/clone" did not exist on "a0549fea4469251a8021d20d6550cf061a1cdb84"
Unverified Commit 165c67cc authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[DGL-LifeSci] Pre-trained GIN (#1558)

* Update

* Fix

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update
parent 3e696922
......@@ -44,6 +44,11 @@ Weave
.. automodule:: dgllife.model.gnn.weave
:members:
GIN
---
.. automodule:: dgllife.model.gnn.gin
:members:
WLN
---
.. automodule:: dgllife.model.gnn.wln
......
......@@ -54,6 +54,11 @@ Weave Predictor
.. automodule:: dgllife.model.model_zoo.weave_predictor
:members:
GIN Predictor
`````````````
.. automodule:: dgllife.model.model_zoo.gin_predictor
:members:
Generative Models
-----------------
......
......@@ -114,6 +114,7 @@ For using featurization methods like above in creating node features:
dgllife.utils.BaseAtomFeaturizer.feat_size
dgllife.utils.CanonicalAtomFeaturizer
dgllife.utils.CanonicalAtomFeaturizer.feat_size
dgllife.utils.PretrainAtomFeaturizer
Featurization for Edges
```````````````````````
......@@ -134,6 +135,7 @@ We consider the following bond descriptors:
dgllife.utils.bond_is_in_ring_one_hot
dgllife.utils.bond_is_in_ring
dgllife.utils.bond_stereo_one_hot
dgllife.utils.bond_direction_one_hot
For using featurization methods like above in creating edge features:
......@@ -144,3 +146,4 @@ For using featurization methods like above in creating edge features:
dgllife.utils.BaseBondFeaturizer.feat_size
dgllife.utils.CanonicalBondFeaturizer
dgllife.utils.CanonicalBondFeaturizer.feat_size
dgllife.utils.PretrainBondFeaturizer
......@@ -7,3 +7,4 @@ from .mpnn import *
from .schnet import *
from .wln import *
from .weave import *
from .gin import *
"""Graph Isomorphism Networks."""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['GIN']
# pylint: disable=W0221, C0103
class GINLayer(nn.Module):
r"""Single Layer GIN from `Strategies for
Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
Parameters
----------
num_edge_emb_list : list of int
num_edge_emb_list[i] gives the number of items to embed for the
i-th categorical edge feature variables. E.g. num_edge_emb_list[0] can be
the number of bond types and num_edge_emb_list[1] can be the number of
bond direction types.
emb_dim : int
The size of each embedding vector.
batch_norm : bool
Whether to apply batch normalization to the output of message passing.
Default to True.
activation : None or callable
Activation function to apply to the output node representations.
Default to None.
"""
def __init__(self, num_edge_emb_list, emb_dim, batch_norm=True, activation=None):
super(GINLayer, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(emb_dim, 2 * emb_dim),
nn.ReLU(),
nn.Linear(2 * emb_dim, emb_dim)
)
self.edge_embeddings = nn.ModuleList()
for num_emb in num_edge_emb_list:
emb_module = nn.Embedding(num_emb, emb_dim)
nn.init.xavier_uniform_(emb_module.weight.data)
self.edge_embeddings.append(emb_module)
if batch_norm:
self.bn = nn.BatchNorm1d(emb_dim)
else:
self.bn = None
self.activation = activation
def forward(self, g, node_feats, categorical_edge_feats):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
node_feats : FloatTensor of shape (N, emb_dim)
* Input node features
* N is the total number of nodes in the batch of graphs
* emb_dim is the input node feature size, which must match emb_dim in initialization
categorical_edge_feats : list of LongTensor of shape (E)
* Input categorical edge features
* len(categorical_edge_feats) should be the same as len(self.edge_embeddings)
* E is the total number of edges in the batch of graphs
Returns
-------
node_feats : float32 tensor of shape (N, emb_dim)
Output node representations
"""
edge_embeds = []
for i, feats in enumerate(categorical_edge_feats):
edge_embeds.append(self.edge_embeddings[i](feats))
edge_embeds = torch.stack(edge_embeds, dim=0).sum(0)
g = g.local_var()
g.ndata['feat'] = node_feats
g.edata['feat'] = edge_embeds
g.update_all(fn.u_add_e('feat', 'feat', 'm'), fn.sum('m', 'feat'))
node_feats = self.mlp(g.ndata.pop('feat'))
if self.bn is not None:
node_feats = self.bn(node_feats)
if self.activation is not None:
node_feats = self.activation(node_feats)
return node_feats
class GIN(nn.Module):
r"""Graph Isomorphism Network from `Strategies for
Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
This module is for updating node representations only.
Parameters
----------
num_node_emb_list : list of int
num_node_emb_list[i] gives the number of items to embed for the
i-th categorical node feature variables. E.g. num_node_emb_list[0] can be
the number of atom types and num_node_emb_list[1] can be the number of
atom chirality types.
num_edge_emb_list : list of int
num_edge_emb_list[i] gives the number of items to embed for the
i-th categorical edge feature variables. E.g. num_edge_emb_list[0] can be
the number of bond types and num_edge_emb_list[1] can be the number of
bond direction types.
num_layers : int
Number of GIN layers to use. Default to 5.
emb_dim : int
The size of each embedding vector. Default to 300.
JK : str
JK for jumping knowledge as in `Representation Learning on Graphs with
Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>`__. It decides
how we are going to combine the all-layer node representations for the final output.
There can be four options for this argument, ``concat``, ``last``, ``max`` and ``sum``.
Default to 'last'.
* ``'concat'``: concatenate the output node representations from all GIN layers
* ``'last'``: use the node representations from the last GIN layer
* ``'max'``: apply max pooling to the node representations across all GIN layers
* ``'sum'``: sum the output node representations from all GIN layers
dropout : float
Dropout to apply to the output of each GIN layer. Default to 0.5
"""
def __init__(self, num_node_emb_list, num_edge_emb_list,
num_layers=5, emb_dim=300, JK='last', dropout=0.5):
super(GIN, self).__init__()
self.num_layers = num_layers
self.JK = JK
self.dropout = nn.Dropout(dropout)
if num_layers < 2:
raise ValueError('Number of GNN layers must be greater '
'than 1, got {:d}'.format(num_layers))
self.node_embeddings = nn.ModuleList()
for num_emb in num_node_emb_list:
emb_module = nn.Embedding(num_emb, emb_dim)
nn.init.xavier_uniform_(emb_module.weight.data)
self.node_embeddings.append(emb_module)
self.gnn_layers = nn.ModuleList()
for layer in range(num_layers):
if layer == num_layers - 1:
self.gnn_layers.append(GINLayer(num_edge_emb_list, emb_dim))
else:
self.gnn_layers.append(GINLayer(num_edge_emb_list, emb_dim, activation=F.relu))
def forward(self, g, categorical_node_feats, categorical_edge_feats):
"""Update node representations
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
categorical_node_feats : list of LongTensor of shape (N)
* Input categorical node features
* len(categorical_node_feats) should be the same as len(self.node_embeddings)
* N is the total number of nodes in the batch of graphs
categorical_edge_feats : list of LongTensor of shape (E)
* Input categorical edge features
* len(categorical_edge_feats) should be the same as
len(num_edge_emb_list) in the arguments
* E is the total number of edges in the batch of graphs
Returns
-------
final_node_feats : float32 tensor of shape (N, M)
Output node representations, N for the number of nodes and
M for output size. In particular, M will be emb_dim * (num_layers + 1)
if self.JK == 'concat' and emb_dim otherwise.
"""
node_embeds = []
for i, feats in enumerate(categorical_node_feats):
node_embeds.append(self.node_embeddings[i](feats))
node_embeds = torch.stack(node_embeds, dim=0).sum(0)
all_layer_node_feats = [node_embeds]
for layer in range(self.num_layers):
node_feats = self.gnn_layers[layer](g, all_layer_node_feats[layer],
categorical_edge_feats)
node_feats = self.dropout(node_feats)
all_layer_node_feats.append(node_feats)
if self.JK == 'concat':
final_node_feats = torch.cat(all_layer_node_feats, dim=1)
elif self.JK == 'last':
final_node_feats = all_layer_node_feats[-1]
elif self.JK == 'max':
all_layer_node_feats = [h.unsqueeze_(0) for h in all_layer_node_feats]
final_node_feats = torch.max(torch.cat(all_layer_node_feats, dim=0), dim=0)[0]
elif self.JK == 'sum':
all_layer_node_feats = [h.unsqueeze_(0) for h in all_layer_node_feats]
final_node_feats = torch.sum(torch.cat(all_layer_node_feats, dim=0), dim=0)
else:
return ValueError("Expect self.JK to be 'concat', 'last', "
"'max' or 'sum', got {}".format(self.JK))
return final_node_feats
......@@ -12,3 +12,4 @@ from .acnn import *
from .wln_reaction_center import *
from .wln_reaction_ranking import *
from .weave_predictor import *
from .gin_predictor import *
"""GIN-based model for regression and classification on graphs."""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl
import torch.nn as nn
from dgl.nn.pytorch.glob import GlobalAttentionPooling, SumPooling, AvgPooling, MaxPooling
from ..gnn.gin import GIN
__all__ = ['GINPredictor']
# pylint: disable=W0221
class GINPredictor(nn.Module):
"""GIN-based model for regression and classification on graphs.
GIN was first introduced in `How Powerful Are Graph Neural Networks
<https://arxiv.org/abs/1810.00826>`__ for general graph property
prediction problems. It was further extended in `Strategies for
Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
for pre-training and semi-supervised learning on large-scale datasets.
For classification tasks, the output will be logits, i.e. values before
sigmoid or softmax.
Parameters
----------
num_node_emb_list : list of int
num_node_emb_list[i] gives the number of items to embed for the
i-th categorical node feature variables. E.g. num_node_emb_list[0] can be
the number of atom types and num_node_emb_list[1] can be the number of
atom chirality types.
num_edge_emb_list : list of int
num_edge_emb_list[i] gives the number of items to embed for the
i-th categorical edge feature variables. E.g. num_edge_emb_list[0] can be
the number of bond types and num_edge_emb_list[1] can be the number of
bond direction types.
num_layers : int
Number of GIN layers to use. Default to 5.
emb_dim : int
The size of each embedding vector. Default to 300.
JK : str
JK for jumping knowledge as in `Representation Learning on Graphs with
Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>`__. It decides
how we are going to combine the all-layer node representations for the final output.
There can be four options for this argument, ``'concat'``, ``'last'``, ``'max'`` and
``'sum'``. Default to 'last'.
* ``'concat'``: concatenate the output node representations from all GIN layers
* ``'last'``: use the node representations from the last GIN layer
* ``'max'``: apply max pooling to the node representations across all GIN layers
* ``'sum'``: sum the output node representations from all GIN layers
dropout : float
Dropout to apply to the output of each GIN layer. Default to 0.5.
readout : str
Readout for computing graph representations out of node representations, which
can be ``'sum'``, ``'mean'``, ``'max'``, or ``'attention'``. Default to 'mean'.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
"""
def __init__(self, num_node_emb_list, num_edge_emb_list, num_layers=5,
emb_dim=300, JK='last', dropout=0.5, readout='mean', n_tasks=1):
super(GINPredictor, self).__init__()
if num_layers < 2:
raise ValueError('Number of GNN layers must be greater '
'than 1, got {:d}'.format(num_layers))
self.gnn = GIN(num_node_emb_list=num_node_emb_list,
num_edge_emb_list=num_edge_emb_list,
num_layers=num_layers,
emb_dim=emb_dim,
JK=JK,
dropout=dropout)
if readout == 'sum':
self.readout = SumPooling()
elif readout == 'mean':
self.readout = AvgPooling()
elif readout == 'max':
self.readout = MaxPooling()
elif readout == 'attention':
if JK == 'concat':
self.readout = GlobalAttentionPooling(
gate_nn=nn.Linear((num_layers + 1) * emb_dim, 1))
else:
self.readout = GlobalAttentionPooling(
gate_nn=nn.Linear(emb_dim, 1))
else:
raise ValueError("Expect readout to be 'sum', 'mean', "
"'max' or 'attention', got {}".format(readout))
if JK == 'concat':
self.predict = nn.Linear((num_layers + 1) * emb_dim, n_tasks)
else:
self.predict = nn.Linear(emb_dim, n_tasks)
def forward(self, g, categorical_node_feats, categorical_edge_feats):
"""Graph-level regression/soft classification.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
categorical_node_feats : list of LongTensor of shape (N)
* Input categorical node features
* len(categorical_node_feats) should be the same as len(num_node_emb_list)
* N is the total number of nodes in the batch of graphs
categorical_edge_feats : list of LongTensor of shape (E)
* Input categorical edge features
* len(categorical_edge_feats) should be the same as
len(num_edge_emb_list) in the arguments
* E is the total number of edges in the batch of graphs
Returns
-------
FloatTensor of shape (B, n_tasks)
* Predictions on graphs
* B for the number of graphs in the batch
"""
node_feats = self.gnn(g, categorical_node_feats, categorical_edge_feats)
graph_feats = self.readout(g, node_feats)
return self.predict(graph_feats)
......@@ -8,7 +8,7 @@ from dgl.data.utils import _get_dgl_url, download, get_download_dir, extract_arc
from rdkit import Chem
from ..model import GCNPredictor, GATPredictor, AttentiveFPPredictor, DGMG, DGLJTNNVAE, \
WLNReactionCenter, WLNReactionRanking, WeavePredictor
WLNReactionCenter, WLNReactionRanking, WeavePredictor, GIN
__all__ = ['load_pretrained']
......@@ -24,6 +24,10 @@ URL = {
'JTNN_ZINC': 'pre_trained/JTNN_ZINC.pth',
'wln_center_uspto': 'dgllife/pre_trained/wln_center_uspto_v3.pth',
'wln_rank_uspto': 'dgllife/pre_trained/wln_rank_uspto.pth',
'gin_supervised_contextpred': 'dgllife/pre_trained/gin_supervised_contextpred.pth',
'gin_supervised_infomax': 'dgllife/pre_trained/gin_supervised_infomax.pth',
'gin_supervised_edgepred': 'dgllife/pre_trained/gin_supervised_edgepred.pth',
'gin_supervised_masking': 'dgllife/pre_trained/gin_supervised_masking.pth'
}
def download_and_load_checkpoint(model_name, model, model_postfix,
......@@ -86,6 +90,14 @@ def load_pretrained(model_name, log=True):
* ``'JTNN_ZINC'``: A JTNN model pre-trained on ZINC for molecule generation
* ``'wln_center_uspto'``: A WLN model pre-trained on USPTO for reaction prediction
* ``'wln_rank_uspto'``: A WLN model pre-trained on USPTO for candidate product ranking
* ``'gin_supervised_contextpred'``: A GIN model pre-trained with supervised learning
and context prediction
* ``'gin_supervised_infomax'``: A GIN model pre-trained with supervised learning
and deep graph infomax
* ``'gin_supervised_edgepred'``: A GIN model pre-trained with supervised learning
and edge prediction
* ``'gin_supervised_masking'``: A GIN model pre-trained with supervised learning
and attribute masking
log : bool
Whether to print progress for model loading
......@@ -170,4 +182,13 @@ def load_pretrained(model_name, log=True):
node_hidden_feats=500,
num_encode_gnn_layers=3)
elif model_name in ['gin_supervised_contextpred', 'gin_supervised_infomax',
'gin_supervised_edgepred', 'gin_supervised_masking']:
model = GIN(num_node_emb_list=[120, 3],
num_edge_emb_list=[6, 3],
num_layers=5,
emb_dim=300,
JK='last',
dropout=0.5)
return download_and_load_checkpoint(model_name, model, URL[model_name], log=log)
......@@ -41,15 +41,18 @@ __all__ = ['one_hot_encoding',
'BaseAtomFeaturizer',
'CanonicalAtomFeaturizer',
'WeaveAtomFeaturizer',
'PretrainAtomFeaturizer',
'bond_type_one_hot',
'bond_is_conjugated_one_hot',
'bond_is_conjugated',
'bond_is_in_ring_one_hot',
'bond_is_in_ring',
'bond_stereo_one_hot',
'bond_direction_one_hot',
'BaseBondFeaturizer',
'CanonicalBondFeaturizer',
'WeaveEdgeFeaturizer']
'WeaveEdgeFeaturizer',
'PretrainBondFeaturizer']
def one_hot_encoding(x, allowable_set, encode_unknown=False):
"""One-hot encoding.
......@@ -1069,6 +1072,70 @@ class WeaveAtomFeaturizer(object):
return {self._atom_data_field: F.zerocopy_from_numpy(atom_features.astype(np.float32))}
class PretrainAtomFeaturizer(object):
"""AtomFeaturizer in Strategies for Pre-training Graph Neural Networks.
The atom featurization performed in `Strategies for Pre-training Graph Neural Networks
<https://arxiv.org/abs/1905.12265>`__, which considers:
* atomic number
* chirality
Parameters
----------
atomic_number_types : list of int or None
Atomic number types to consider for one-hot encoding. If None, we will use a default
choice of 1-118.
chiral_types : list of Chem.rdchem.ChiralType or None
Atom chirality to consider for one-hot encoding. If None, we will use a default
choice of ``Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, Chem.rdchem.ChiralType.CHI_OTHER``.
"""
def __init__(self, atomic_number_types=None, chiral_types=None):
if atomic_number_types is None:
atomic_number_types = list(range(1, 119))
self._atomic_number_types = atomic_number_types
if chiral_types is None:
chiral_types = [
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
Chem.rdchem.ChiralType.CHI_OTHER
]
self._chiral_types = chiral_types
def __call__(self, mol):
"""Featurizes the input molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
Mapping 'atomic_number' and 'chirality_type' to separately an int64 tensor
of shape (N, 1), N is the number of atoms
"""
atom_features = []
num_atoms = mol.GetNumAtoms()
for i in range(num_atoms):
atom = mol.GetAtomWithIdx(i)
atom_features.append([
self._atomic_number_types.index(atom.GetAtomicNum()),
self._chiral_types.index(atom.GetChiralTag())
])
atom_features = np.stack(atom_features)
atom_features = F.zerocopy_from_numpy(atom_features.astype(np.int64))
return {
'atomic_number': atom_features[:, 0],
'chirality_type': atom_features[:, 1]
}
def bond_type_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for the type of a bond.
......@@ -1226,6 +1293,35 @@ def bond_stereo_one_hot(bond, allowable_set=None, encode_unknown=False):
Chem.rdchem.BondStereo.STEREOTRANS]
return one_hot_encoding(bond.GetStereo(), allowable_set, encode_unknown)
def bond_direction_one_hot(bond, allowable_set=None, encode_unknown=False):
"""One hot encoding for the direction of a bond.
Parameters
----------
bond : rdkit.Chem.rdchem.Bond
RDKit bond instance.
allowable_set : list of Chem.rdchem.BondDir
Bond directions to consider. Default: ``Chem.rdchem.BondDir.NONE``,
``Chem.rdchem.BondDir.ENDUPRIGHT``, ``Chem.rdchem.BondDir.ENDDOWNRIGHT``.
encode_unknown : bool
If True, map inputs not in the allowable set to the
additional last element. (Default: False)
Returns
-------
list
List of boolean values where at most one value is True.
See Also
--------
one_hot_encoding
"""
if allowable_set is None:
allowable_set = [Chem.rdchem.BondDir.NONE,
Chem.rdchem.BondDir.ENDUPRIGHT,
Chem.rdchem.BondDir.ENDDOWNRIGHT]
return one_hot_encoding(bond.GetBondDir(), allowable_set, encode_unknown)
class BaseBondFeaturizer(object):
"""An abstract class for bond featurizers.
Loop over all bonds in a molecule and featurize them with the ``featurizer_funcs``.
......@@ -1479,3 +1575,80 @@ class WeaveEdgeFeaturizer(object):
return {self._edge_data_field: torch.cat([distance_indicators,
bond_indicators,
ring_mate_indicators], dim=1)}
class PretrainBondFeaturizer(object):
"""BondFeaturizer in Strategies for Pre-training Graph Neural Networks.
The bond featurization performed in `Strategies for Pre-training Graph Neural Networks
<https://arxiv.org/abs/1905.12265>`__, which considers:
* bond type
* bond direction
Parameters
----------
bond_types : list of Chem.rdchem.BondType or None
Bond types to consider. Default to ``Chem.rdchem.BondType.SINGLE``,
``Chem.rdchem.BondType.DOUBLE``, ``Chem.rdchem.BondType.TRIPLE``,
``Chem.rdchem.BondType.AROMATIC``.
bond_direction_types : list of Chem.rdchem.BondDir or None
Bond directions to consider. Default to ``Chem.rdchem.BondDir.NONE``,
``Chem.rdchem.BondDir.ENDUPRIGHT``, ``Chem.rdchem.BondDir.ENDDOWNRIGHT``.
self_loop : bool
Whether self loops will be added. Default to True.
"""
def __init__(self, bond_types=None, bond_direction_types=None, self_loop=True):
if bond_types is None:
bond_types = [
Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC
]
self._bond_types = bond_types
if bond_direction_types is None:
bond_direction_types = [
Chem.rdchem.BondDir.NONE,
Chem.rdchem.BondDir.ENDUPRIGHT,
Chem.rdchem.BondDir.ENDDOWNRIGHT
]
self._bond_direction_types = bond_direction_types
self._self_loop = self_loop
def __call__(self, mol):
"""Featurizes the input molecule.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance.
Returns
-------
dict
Mapping 'bond_type' and 'bond_direction_type' separately to an int64
tensor of shape (N, 1), where N is the number of edges.
"""
edge_features = []
num_bonds = mol.GetNumBonds()
# Compute features for each bond
for i in range(num_bonds):
bond = mol.GetBondWithIdx(i)
bond_feats = [
self._bond_types.index(bond.GetBondType()),
self._bond_direction_types.index(bond.GetBondDir())
]
edge_features.extend([bond_feats, bond_feats.copy()])
if self._self_loop:
self_loop_features = torch.zeros((mol.GetNumAtoms(), 2), dtype=torch.int64)
self_loop_features[:, 0] = len(self._bond_types)
if num_bonds == 0:
edge_features = self_loop_features
else:
edge_features = np.stack(edge_features)
edge_features = F.zerocopy_from_numpy(edge_features.astype(np.int64))
edge_features = torch.cat([edge_features, self_loop_features], dim=0)
return {'bond_type': edge_features[:, 0], 'bond_direction_type': edge_features[:, 1]}
......@@ -43,6 +43,22 @@ def test_graph6():
bg = dgl.batch([g1, g2])
return bg, torch.LongTensor([0, 1, 0, 2, 0, 3, 4, 4]), torch.randn(7, 1)
def test_graph7():
"""Graph with categorical node and edge features."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
return g1, torch.LongTensor([0, 1, 0]), torch.LongTensor([2, 3, 4]), \
torch.LongTensor([0, 0, 1]), torch.LongTensor([2, 3, 2])
def test_graph8():
"""Batched graph with categorical node and edge features."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
bg = dgl.batch([g1, g2])
return bg, torch.LongTensor([0, 1, 0, 2, 1, 0, 2, 2]), \
torch.LongTensor([2, 3, 4, 1, 0, 1, 2, 2]), \
torch.LongTensor([0, 0, 1, 2, 1, 0, 0]), \
torch.LongTensor([2, 3, 2, 0, 1, 2, 1])
def test_gcn():
if torch.cuda.is_available():
device = torch.device('cuda:0')
......@@ -261,6 +277,37 @@ def test_weave():
assert gnn(g, node_feats, edge_feats).shape == torch.Size([3, 2])
assert gnn(bg, batch_node_feats, batch_edge_feats).shape == torch.Size([8, 2])
def test_gin():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats1, node_feats2, edge_feats1, edge_feats2 = test_graph7()
node_feats1, node_feats2 = node_feats1.to(device), node_feats2.to(device)
edge_feats1, edge_feats2 = edge_feats1.to(device), edge_feats2.to(device)
bg, batch_node_feats1, batch_node_feats2, \
batch_edge_feats1, batch_edge_feats2 = test_graph8()
batch_node_feats1, batch_node_feats2 = batch_node_feats1.to(device), \
batch_node_feats2.to(device)
batch_edge_feats1, batch_edge_feats2 = batch_edge_feats1.to(device), \
batch_edge_feats2.to(device)
# Test default setting
gnn = GIN(num_node_emb_list=[3, 5], num_edge_emb_list=[3, 4]).to(device)
assert gnn(g, [node_feats1, node_feats2], [edge_feats1, edge_feats2]).shape \
== torch.Size([3, 300])
assert gnn(bg, [batch_node_feats1, batch_node_feats2],
[batch_edge_feats1, batch_edge_feats2]).shape == torch.Size([8, 300])
# Test configured setting
gnn = GIN(num_node_emb_list=[3, 5], num_edge_emb_list=[3, 4],
num_layers=2, emb_dim=10, JK='concat', dropout=0.1).to(device)
assert gnn(g, [node_feats1, node_feats2], [edge_feats1, edge_feats2]).shape \
== torch.Size([3, 30])
assert gnn(bg, [batch_node_feats1, batch_node_feats2],
[batch_edge_feats1, batch_edge_feats2]).shape == torch.Size([8, 30])
if __name__ == '__main__':
test_gcn()
test_gat()
......@@ -270,3 +317,4 @@ if __name__ == '__main__':
test_mpnn_gnn()
test_wln()
test_weave()
test_gin()
......@@ -45,6 +45,22 @@ def test_graph6():
bg = dgl.batch([g1, g2])
return bg, torch.LongTensor([0, 1, 0, 2, 0, 3, 4, 4]), torch.randn(7, 1)
def test_graph7():
"""Graph with categorical node and edge features."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
return g1, torch.LongTensor([0, 1, 0]), torch.LongTensor([2, 3, 4]), \
torch.LongTensor([0, 0, 1]), torch.LongTensor([2, 3, 2])
def test_graph8():
"""Batched graph with categorical node and edge features."""
g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
bg = dgl.batch([g1, g2])
return bg, torch.LongTensor([0, 1, 0, 2, 1, 0, 2, 2]), \
torch.LongTensor([2, 3, 4, 1, 0, 1, 2, 2]), \
torch.LongTensor([0, 0, 1, 2, 1, 0, 0]), \
torch.LongTensor([2, 3, 2, 0, 1, 2, 1])
def test_mlp_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
......@@ -263,6 +279,38 @@ def test_weave_predictor():
assert weave_predictor(bg, batch_node_feats, batch_edge_feats).shape == \
torch.Size([2, 2])
def test_gin_predictor():
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
g, node_feats1, node_feats2, edge_feats1, edge_feats2 = test_graph7()
node_feats1, node_feats2 = node_feats1.to(device), node_feats2.to(device)
edge_feats1, edge_feats2 = edge_feats1.to(device), edge_feats2.to(device)
bg, batch_node_feats1, batch_node_feats2, \
batch_edge_feats1, batch_edge_feats2 = test_graph8()
batch_node_feats1, batch_node_feats2 = batch_node_feats1.to(device), \
batch_node_feats2.to(device)
batch_edge_feats1, batch_edge_feats2 = batch_edge_feats1.to(device), \
batch_edge_feats2.to(device)
num_node_emb_list = [3, 5]
num_edge_emb_list = [3, 4]
for JK in ['concat', 'last', 'max', 'sum']:
for readout in ['sum', 'mean', 'max', 'attention']:
model = GINPredictor(num_node_emb_list=num_node_emb_list,
num_edge_emb_list=num_edge_emb_list,
num_layers=2,
emb_dim=10,
JK=JK,
readout=readout,
n_tasks=2).to(device)
assert model(g, [node_feats1, node_feats2], [edge_feats1, edge_feats2]).shape \
== torch.Size([1, 2])
assert model(bg, [batch_node_feats1, batch_node_feats2],
[batch_edge_feats1, batch_edge_feats2]).shape == torch.Size([2, 2])
if __name__ == '__main__':
test_mlp_predictor()
test_gcn_predictor()
......@@ -272,3 +320,4 @@ if __name__ == '__main__':
test_mgcn_predictor()
test_mpnn_predictor()
test_weave_predictor()
test_gin_predictor()
......@@ -19,6 +19,9 @@ def test_mol1():
def test_mol2():
return Chem.MolFromSmiles('C1=CC2=CC=CC=CC2=C1')
def test_mol3():
return Chem.MolFromSmiles('O=C(O)/C=C/C(=O)O')
def test_atom_type_one_hot():
mol = test_mol1()
assert atom_type_one_hot(mol.GetAtomWithIdx(0), ['C', 'O']) == [1, 0]
......@@ -226,6 +229,14 @@ def test_weave_atom_featurizer():
1.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000]]), rtol=1e-3)
def test_pretrain_atom_featurizer():
featurizer = PretrainAtomFeaturizer()
mol = test_mol1()
feats = featurizer(mol)
assert list(feats.keys()) == ['atomic_number', 'chirality_type']
assert torch.allclose(feats['atomic_number'], torch.tensor([[5, 5, 7]]))
assert torch.allclose(feats['chirality_type'], torch.tensor([[0, 0, 0]]))
def test_bond_type_one_hot():
mol = test_mol1()
assert bond_type_one_hot(mol.GetBondWithIdx(0)) == [1, 0, 0, 0]
......@@ -260,6 +271,11 @@ def test_bond_stereo_one_hot():
mol = test_mol1()
assert bond_stereo_one_hot(mol.GetBondWithIdx(0)) == [1, 0, 0, 0, 0, 0]
def test_bond_direction_one_hot():
mol = test_mol3()
assert bond_direction_one_hot(mol.GetBondWithIdx(0)) == [1, 0, 0]
assert bond_direction_one_hot(mol.GetBondWithIdx(2)) == [0, 1, 0]
class TestBondFeaturizer(BaseBondFeaturizer):
def __init__(self):
super(TestBondFeaturizer, self).__init__(
......@@ -309,6 +325,23 @@ def test_weave_edge_featurizer():
[1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]))
def test_pretrain_bond_featurizer():
mol = test_mol3()
test_featurizer = PretrainBondFeaturizer()
feats = test_featurizer(mol)
assert torch.allclose(feats['bond_type'].nonzero(),
torch.tensor([[0], [1], [6], [7], [10], [11], [14], [15],
[16], [17], [18], [19], [20], [21]]))
assert torch.allclose(feats['bond_direction_type'].nonzero(),
torch.tensor([[4], [5], [8], [9]]))
test_featurizer = PretrainBondFeaturizer(self_loop=False)
feats = test_featurizer(mol)
assert torch.allclose(feats['bond_type'].nonzero(),
torch.tensor([[0], [1], [6], [7], [10], [11]]))
assert torch.allclose(feats['bond_direction_type'].nonzero(),
torch.tensor([[4], [5], [8], [9]]))
if __name__ == '__main__':
test_one_hot_encoding()
test_atom_type_one_hot()
......@@ -338,12 +371,15 @@ if __name__ == '__main__':
test_base_atom_featurizer()
test_canonical_atom_featurizer()
test_weave_atom_featurizer()
test_pretrain_atom_featurizer()
test_bond_type_one_hot()
test_bond_is_conjugated_one_hot()
test_bond_is_conjugated()
test_bond_is_in_ring_one_hot()
test_bond_is_in_ring()
test_bond_stereo_one_hot()
test_bond_direction_one_hot()
test_base_bond_featurizer()
test_canonical_bond_featurizer()
test_weave_edge_featurizer()
test_pretrain_bond_featurizer()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment