Unverified Commit 68a978d4 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Model] Fixes JTNN for 0.5 (#1879)

* jtnn and fixes

* make metagraph a method

* fix test

* fix
parent faa1dc56
...@@ -6,7 +6,7 @@ This is a direct modification from https://github.com/wengong-jin/icml18-jtnn ...@@ -6,7 +6,7 @@ This is a direct modification from https://github.com/wengong-jin/icml18-jtnn
Dependencies Dependencies
-------------- --------------
* PyTorch 0.4.1+ * PyTorch 0.4.1+
* RDKit * RDKit=2018.09.3.0
* requests * requests
How to run How to run
......
...@@ -84,9 +84,9 @@ class JTNNDataset(Dataset): ...@@ -84,9 +84,9 @@ class JTNNDataset(Dataset):
cand_graphs = [] cand_graphs = []
atom_x_dec = torch.zeros(0, ATOM_FDIM_DEC) atom_x_dec = torch.zeros(0, ATOM_FDIM_DEC)
bond_x_dec = torch.zeros(0, BOND_FDIM_DEC) bond_x_dec = torch.zeros(0, BOND_FDIM_DEC)
tree_mess_src_e = torch.zeros(0, 2).long() tree_mess_src_e = torch.zeros(0, 2).int()
tree_mess_tgt_e = torch.zeros(0, 2).long() tree_mess_tgt_e = torch.zeros(0, 2).int()
tree_mess_tgt_n = torch.zeros(0).long() tree_mess_tgt_n = torch.zeros(0).int()
# prebuild the stereoisomers # prebuild the stereoisomers
cands = mol_tree.stereo_cands cands = mol_tree.stereo_cands
...@@ -143,7 +143,7 @@ class JTNNCollator(object): ...@@ -143,7 +143,7 @@ class JTNNCollator(object):
mol_trees = _unpack_field(examples, 'mol_tree') mol_trees = _unpack_field(examples, 'mol_tree')
wid = _unpack_field(examples, 'wid') wid = _unpack_field(examples, 'wid')
for _wid, mol_tree in zip(wid, mol_trees): for _wid, mol_tree in zip(wid, mol_trees):
mol_tree.ndata['wid'] = torch.LongTensor(_wid) mol_tree.graph.ndata['wid'] = torch.LongTensor(_wid)
# TODO: either support pickling or get around ctypes pointers using scipy # TODO: either support pickling or get around ctypes pointers using scipy
# batch molecule graphs # batch molecule graphs
...@@ -176,7 +176,7 @@ class JTNNCollator(object): ...@@ -176,7 +176,7 @@ class JTNNCollator(object):
tree_mess_src_e[i] += n_tree_nodes tree_mess_src_e[i] += n_tree_nodes
tree_mess_tgt_n[i] += n_graph_nodes tree_mess_tgt_n[i] += n_graph_nodes
n_graph_nodes += sum(g.number_of_nodes() for g in cand_graphs[i]) n_graph_nodes += sum(g.number_of_nodes() for g in cand_graphs[i])
n_tree_nodes += mol_trees[i].number_of_nodes() n_tree_nodes += mol_trees[i].graph.number_of_nodes()
cand_batch_idx.extend([i] * len(cand_graphs[i])) cand_batch_idx.extend([i] * len(cand_graphs[i]))
tree_mess_tgt_e = torch.cat(tree_mess_tgt_e) tree_mess_tgt_e = torch.cat(tree_mess_tgt_e)
tree_mess_src_e = torch.cat(tree_mess_src_e) tree_mess_src_e = torch.cat(tree_mess_src_e)
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from .nnutils import cuda from .nnutils import cuda, line_graph
import rdkit.Chem as Chem import rdkit.Chem as Chem
from dgl import DGLGraph, mean_nodes import dgl
from dgl import mean_nodes
import dgl.function as DGLF import dgl.function as DGLF
import os import os
...@@ -50,12 +51,11 @@ def mol2dgl_single(cand_batch): ...@@ -50,12 +51,11 @@ def mol2dgl_single(cand_batch):
ctr_node = mol_tree.nodes_dict[ctr_node_id] ctr_node = mol_tree.nodes_dict[ctr_node_id]
ctr_bid = ctr_node['idx'] ctr_bid = ctr_node['idx']
g = DGLGraph() mol_tree_graph = getattr(mol_tree, 'graph', mol_tree)
for i, atom in enumerate(mol.GetAtoms()): for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx() assert i == atom.GetIdx()
atom_x.append(atom_features(atom)) atom_x.append(atom_features(atom))
g.add_nodes(n_atoms)
bond_src = [] bond_src = []
bond_dst = [] bond_dst = []
...@@ -78,24 +78,24 @@ def mol2dgl_single(cand_batch): ...@@ -78,24 +78,24 @@ def mol2dgl_single(cand_batch):
x_bid = mol_tree.nodes_dict[x_nid - 1]['idx'] if x_nid > 0 else -1 x_bid = mol_tree.nodes_dict[x_nid - 1]['idx'] if x_nid > 0 else -1
y_bid = mol_tree.nodes_dict[y_nid - 1]['idx'] if y_nid > 0 else -1 y_bid = mol_tree.nodes_dict[y_nid - 1]['idx'] if y_nid > 0 else -1
if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid: if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
if mol_tree.has_edge_between(x_bid, y_bid): if mol_tree_graph.has_edges_between(x_bid, y_bid):
tree_mess_target_edges.append((begin_idx + n_nodes, end_idx + n_nodes)) tree_mess_target_edges.append((begin_idx + n_nodes, end_idx + n_nodes))
tree_mess_source_edges.append((x_bid, y_bid)) tree_mess_source_edges.append((x_bid, y_bid))
tree_mess_target_nodes.append(end_idx + n_nodes) tree_mess_target_nodes.append(end_idx + n_nodes)
if mol_tree.has_edge_between(y_bid, x_bid): if mol_tree_graph.has_edges_between(y_bid, x_bid):
tree_mess_target_edges.append((end_idx + n_nodes, begin_idx + n_nodes)) tree_mess_target_edges.append((end_idx + n_nodes, begin_idx + n_nodes))
tree_mess_source_edges.append((y_bid, x_bid)) tree_mess_source_edges.append((y_bid, x_bid))
tree_mess_target_nodes.append(begin_idx + n_nodes) tree_mess_target_nodes.append(begin_idx + n_nodes)
n_nodes += n_atoms n_nodes += n_atoms
g.add_edges(bond_src, bond_dst) g = dgl.graph((bond_src, bond_dst), num_nodes=n_atoms)
cand_graphs.append(g) cand_graphs.append(g)
return cand_graphs, torch.stack(atom_x), \ return cand_graphs, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0), \ torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0), \
torch.LongTensor(tree_mess_source_edges), \ torch.IntTensor(tree_mess_source_edges), \
torch.LongTensor(tree_mess_target_edges), \ torch.IntTensor(tree_mess_target_edges), \
torch.LongTensor(tree_mess_target_nodes) torch.IntTensor(tree_mess_target_nodes)
mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg') mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg')
...@@ -174,7 +174,7 @@ class DGLJTMPN(nn.Module): ...@@ -174,7 +174,7 @@ class DGLJTMPN(nn.Module):
n_samples = len(cand_graphs) n_samples = len(cand_graphs)
cand_line_graph = cand_graphs.line_graph(backtracking=False, shared=True) cand_line_graph = line_graph(cand_graphs, backtracking=False, shared=True)
n_nodes = cand_graphs.number_of_nodes() n_nodes = cand_graphs.number_of_nodes()
n_edges = cand_graphs.number_of_edges() n_edges = cand_graphs.number_of_edges()
...@@ -222,20 +222,27 @@ class DGLJTMPN(nn.Module): ...@@ -222,20 +222,27 @@ class DGLJTMPN(nn.Module):
if PAPER: if PAPER:
src_u, src_v = tree_mess_src_edges.unbind(1) src_u, src_v = tree_mess_src_edges.unbind(1)
tgt_u, tgt_v = tree_mess_tgt_edges.unbind(1) tgt_u, tgt_v = tree_mess_tgt_edges.unbind(1)
alpha = mol_tree_batch.edges[src_u, src_v].data['m'] src_u = src_u.to(mol_tree_batch.device)
src_v = src_v.to(mol_tree_batch.device)
eid = mol_tree_batch.edge_ids(src_u.int(), src_v.int()).long()
alpha = mol_tree_batch.edata['m'][eid]
cand_graphs.edges[tgt_u, tgt_v].data['alpha'] = alpha cand_graphs.edges[tgt_u, tgt_v].data['alpha'] = alpha
else: else:
src_u, src_v = tree_mess_src_edges.unbind(1) src_u, src_v = tree_mess_src_edges.unbind(1)
alpha = mol_tree_batch.edges[src_u, src_v].data['m'] src_u = src_u.to(mol_tree_batch.device)
src_v = src_v.to(mol_tree_batch.device)
eid = mol_tree_batch.edge_ids(src_u.int(), src_v.int()).long()
alpha = mol_tree_batch.edata['m'][eid]
node_idx = (tree_mess_tgt_nodes node_idx = (tree_mess_tgt_nodes
.to(device=zero_node_state.device)[:, None] .to(device=zero_node_state.device)[:, None]
.expand_as(alpha)) .expand_as(alpha))
node_alpha = zero_node_state.clone().scatter_add(0, node_idx, alpha) node_alpha = zero_node_state.clone().scatter_add(0, node_idx.long(), alpha)
cand_graphs.ndata['alpha'] = node_alpha cand_graphs.ndata['alpha'] = node_alpha
cand_graphs.apply_edges( cand_graphs.apply_edges(
func=lambda edges: {'alpha': edges.src['alpha']}, func=lambda edges: {'alpha': edges.src['alpha']},
) )
cand_line_graph.ndata.update(cand_graphs.edata)
for i in range(self.depth - 1): for i in range(self.depth - 1):
cand_line_graph.update_all( cand_line_graph.update_all(
mpn_loopy_bp_msg, mpn_loopy_bp_msg,
...@@ -243,6 +250,7 @@ class DGLJTMPN(nn.Module): ...@@ -243,6 +250,7 @@ class DGLJTMPN(nn.Module):
self.loopy_bp_updater, self.loopy_bp_updater,
) )
cand_graphs.edata.update(cand_line_graph.ndata)
cand_graphs.update_all( cand_graphs.update_all(
mpn_gather_msg, mpn_gather_msg,
mpn_gather_reduce, mpn_gather_reduce,
......
...@@ -3,7 +3,7 @@ import torch.nn as nn ...@@ -3,7 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .mol_tree_nx import DGLMolTree from .mol_tree_nx import DGLMolTree
from .chemutils import enum_assemble_nx, get_mol from .chemutils import enum_assemble_nx, get_mol
from .nnutils import GRUUpdate, cuda from .nnutils import GRUUpdate, cuda, line_graph, tocpu
from dgl import batch, dfs_labeled_edges_generator from dgl import batch, dfs_labeled_edges_generator
import dgl.function as DGLF import dgl.function as DGLF
import numpy as np import numpy as np
...@@ -13,6 +13,7 @@ MAX_DECODE_LEN = 100 ...@@ -13,6 +13,7 @@ MAX_DECODE_LEN = 100
def dfs_order(forest, roots): def dfs_order(forest, roots):
forest = tocpu(forest)
edges = dfs_labeled_edges_generator(forest, roots, has_reverse_edge=True) edges = dfs_labeled_edges_generator(forest, roots, has_reverse_edge=True)
for e, l in zip(*edges): for e, l in zip(*edges):
# I exploited the fact that the reverse edge ID equal to 1 xor forward # I exploited the fact that the reverse edge ID equal to 1 xor forward
...@@ -55,7 +56,7 @@ def have_slots(fa_slots, ch_slots): ...@@ -55,7 +56,7 @@ def have_slots(fa_slots, ch_slots):
def can_assemble(mol_tree, u, v_node_dict): def can_assemble(mol_tree, u, v_node_dict):
u_node_dict = mol_tree.nodes_dict[u] u_node_dict = mol_tree.nodes_dict[u]
u_neighbors = mol_tree.successors(u) u_neighbors = mol_tree.graph.successors(u)
u_neighbors_node_dict = [ u_neighbors_node_dict = [
mol_tree.nodes_dict[_u] mol_tree.nodes_dict[_u]
for _u in u_neighbors for _u in u_neighbors
...@@ -106,13 +107,13 @@ class DGLJTNNDecoder(nn.Module): ...@@ -106,13 +107,13 @@ class DGLJTNNDecoder(nn.Module):
ground truth tree ground truth tree
''' '''
mol_tree_batch = batch(mol_trees) mol_tree_batch = batch(mol_trees)
mol_tree_batch_lg = mol_tree_batch.line_graph(backtracking=False, shared=True) mol_tree_batch_lg = line_graph(mol_tree_batch, backtracking=False, shared=True)
n_trees = len(mol_trees) n_trees = len(mol_trees)
return self.run(mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec) return self.run(mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec)
def run(self, mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec): def run(self, mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec):
node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes) node_offset = np.cumsum(np.insert(mol_tree_batch.batch_num_nodes().cpu().numpy(), 0, 0))
root_ids = node_offset[:-1] root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes() n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges() n_edges = mol_tree_batch.number_of_edges()
...@@ -120,7 +121,7 @@ class DGLJTNNDecoder(nn.Module): ...@@ -120,7 +121,7 @@ class DGLJTNNDecoder(nn.Module):
mol_tree_batch.ndata.update({ mol_tree_batch.ndata.update({
'x': self.embedding(mol_tree_batch.ndata['wid']), 'x': self.embedding(mol_tree_batch.ndata['wid']),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)), 'h': cuda(torch.zeros(n_nodes, self.hidden_size)),
'new': cuda(torch.ones(n_nodes).byte()), # whether it's newly generated node 'new': cuda(torch.ones(n_nodes).bool()), # whether it's newly generated node
}) })
mol_tree_batch.edata.update({ mol_tree_batch.edata.update({
...@@ -162,22 +163,26 @@ class DGLJTNNDecoder(nn.Module): ...@@ -162,22 +163,26 @@ class DGLJTNNDecoder(nn.Module):
# Traverse the tree and predict on children # Traverse the tree and predict on children
for eid, p in dfs_order(mol_tree_batch, root_ids): for eid, p in dfs_order(mol_tree_batch, root_ids):
eid = eid.to(mol_tree_batch.device)
p = p.to(mol_tree_batch.device)
u, v = mol_tree_batch.find_edges(eid) u, v = mol_tree_batch.find_edges(eid)
p_target_list = torch.zeros_like(root_out_degrees) p_target_list = torch.zeros_like(root_out_degrees)
p_target_list[root_out_degrees > 0] = 1 - p p_target_list[root_out_degrees > 0] = (1 - p).int()
p_target_list = p_target_list[root_out_degrees >= 0] p_target_list = p_target_list[root_out_degrees >= 0]
p_targets.append(torch.tensor(p_target_list)) p_targets.append(torch.tensor(p_target_list))
root_out_degrees -= (root_out_degrees == 0).long() root_out_degrees -= (root_out_degrees == 0).int()
root_out_degrees -= torch.tensor(np.isin(root_ids, v).astype('int64')) root_out_degrees -= torch.tensor(np.isin(root_ids, v.cpu().numpy())).to(root_out_degrees)
mol_tree_batch_lg.ndata.update(mol_tree_batch.edata)
mol_tree_batch_lg.pull( mol_tree_batch_lg.pull(
eid, eid,
dec_tree_edge_msg, dec_tree_edge_msg,
dec_tree_edge_reduce, dec_tree_edge_reduce,
self.dec_tree_edge_update, self.dec_tree_edge_update,
) )
mol_tree_batch.edata.update(mol_tree_batch_lg.ndata)
is_new = mol_tree_batch.nodes[v].data['new'] is_new = mol_tree_batch.nodes[v].data['new']
mol_tree_batch.pull( mol_tree_batch.pull(
v, v,
...@@ -185,6 +190,7 @@ class DGLJTNNDecoder(nn.Module): ...@@ -185,6 +190,7 @@ class DGLJTNNDecoder(nn.Module):
dec_tree_node_reduce, dec_tree_node_reduce,
dec_tree_node_update, dec_tree_node_update,
) )
# Extract # Extract
n_repr = mol_tree_batch.nodes[v].data n_repr = mol_tree_batch.nodes[v].data
h = n_repr['h'] h = n_repr['h']
...@@ -200,7 +206,10 @@ class DGLJTNNDecoder(nn.Module): ...@@ -200,7 +206,10 @@ class DGLJTNNDecoder(nn.Module):
if q_input.shape[0] > 0: if q_input.shape[0] > 0:
q_inputs.append(q_input) q_inputs.append(q_input)
q_targets.append(q_target) q_targets.append(q_target)
p_targets.append(torch.zeros((root_out_degrees == 0).sum()).long()) p_targets.append(torch.zeros(
(root_out_degrees == 0).sum(),
device=root_out_degrees.device,
dtype=torch.int32))
# Batch compute the stop/label prediction losses # Batch compute the stop/label prediction losses
p_inputs = torch.cat(p_inputs, 0) p_inputs = torch.cat(p_inputs, 0)
...@@ -231,6 +240,8 @@ class DGLJTNNDecoder(nn.Module): ...@@ -231,6 +240,8 @@ class DGLJTNNDecoder(nn.Module):
assert mol_vec.shape[0] == 1 assert mol_vec.shape[0] == 1
mol_tree = DGLMolTree(None) mol_tree = DGLMolTree(None)
mol_tree.graph = mol_tree.graph.to(mol_vec.device)
mol_tree_graph = mol_tree.graph
init_hidden = cuda(torch.zeros(1, self.hidden_size)) init_hidden = cuda(torch.zeros(1, self.hidden_size))
...@@ -240,11 +251,11 @@ class DGLJTNNDecoder(nn.Module): ...@@ -240,11 +251,11 @@ class DGLJTNNDecoder(nn.Module):
_, root_wid = torch.max(root_score, 1) _, root_wid = torch.max(root_score, 1)
root_wid = root_wid.view(1) root_wid = root_wid.view(1)
mol_tree.add_nodes(1) # root mol_tree_graph.add_nodes(1) # root
mol_tree.nodes[0].data['wid'] = root_wid mol_tree_graph.ndata['wid'] = root_wid
mol_tree.nodes[0].data['x'] = self.embedding(root_wid) mol_tree_graph.ndata['x'] = self.embedding(root_wid)
mol_tree.nodes[0].data['h'] = init_hidden mol_tree_graph.ndata['h'] = init_hidden
mol_tree.nodes[0].data['fail'] = cuda(torch.tensor([0])) mol_tree_graph.ndata['fail'] = cuda(torch.tensor([0]))
mol_tree.nodes_dict[0] = root_node_dict = create_node_dict( mol_tree.nodes_dict[0] = root_node_dict = create_node_dict(
self.vocab.get_smiles(root_wid)) self.vocab.get_smiles(root_wid))
...@@ -259,27 +270,27 @@ class DGLJTNNDecoder(nn.Module): ...@@ -259,27 +270,27 @@ class DGLJTNNDecoder(nn.Module):
for step in range(MAX_DECODE_LEN): for step in range(MAX_DECODE_LEN):
u, u_slots = stack[-1] u, u_slots = stack[-1]
udata = mol_tree.nodes[u].data x = mol_tree_graph.ndata['x'][u:u+1]
x = udata['x'] h = mol_tree_graph.ndata['h'][u:u+1]
h = udata['h']
# Predict stop # Predict stop
p_input = torch.cat([x, h, mol_vec], 1) p_input = torch.cat([x, h, mol_vec], 1)
p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input)))) p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input))))
p_score[:] = 0
backtrack = (p_score.item() < 0.5) backtrack = (p_score.item() < 0.5)
if not backtrack: if not backtrack:
# Predict next clique. Note that the prediction may fail due # Predict next clique. Note that the prediction may fail due
# to lack of assemblable components # to lack of assemblable components
mol_tree.add_nodes(1) mol_tree_graph.add_nodes(1)
new_node_id += 1 new_node_id += 1
v = new_node_id v = new_node_id
mol_tree.add_edges(u, v) mol_tree_graph.add_edges(u, v)
uv = new_edge_id uv = new_edge_id
new_edge_id += 1 new_edge_id += 1
if first: if first:
mol_tree.edata.update({ mol_tree_graph.edata.update({
's': cuda(torch.zeros(1, self.hidden_size)), 's': cuda(torch.zeros(1, self.hidden_size)),
'm': cuda(torch.zeros(1, self.hidden_size)), 'm': cuda(torch.zeros(1, self.hidden_size)),
'r': cuda(torch.zeros(1, self.hidden_size)), 'r': cuda(torch.zeros(1, self.hidden_size)),
...@@ -291,26 +302,26 @@ class DGLJTNNDecoder(nn.Module): ...@@ -291,26 +302,26 @@ class DGLJTNNDecoder(nn.Module):
}) })
first = False first = False
mol_tree.edges[uv].data['src_x'] = mol_tree.nodes[u].data['x'] mol_tree_graph.edata['src_x'][uv] = mol_tree_graph.ndata['x'][u]
# keeping dst_x 0 is fine as h on new edge doesn't depend on that. # keeping dst_x 0 is fine as h on new edge doesn't depend on that.
# DGL doesn't dynamically maintain a line graph. # DGL doesn't dynamically maintain a line graph.
mol_tree_lg = mol_tree.line_graph(backtracking=False, shared=True) mol_tree_graph_lg = line_graph(mol_tree_graph, backtracking=False, shared=True)
mol_tree_lg.pull( mol_tree_graph_lg.pull(
uv, uv,
dec_tree_edge_msg, dec_tree_edge_msg,
dec_tree_edge_reduce, dec_tree_edge_reduce,
self.dec_tree_edge_update.update_zm, self.dec_tree_edge_update.update_zm,
) )
mol_tree.pull( mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
mol_tree_graph.pull(
v, v,
dec_tree_node_msg, dec_tree_node_msg,
dec_tree_node_reduce, dec_tree_node_reduce,
) )
vdata = mol_tree.nodes[v].data h_v = mol_tree_graph.ndata['h'][v:v+1]
h_v = vdata['h']
q_input = torch.cat([h_v, mol_vec], 1) q_input = torch.cat([h_v, mol_vec], 1)
q_score = torch.softmax(self.W_o(torch.relu(self.W(q_input))), -1) q_score = torch.softmax(self.W_o(torch.relu(self.W(q_input))), -1)
_, sort_wid = torch.sort(q_score, 1, descending=True) _, sort_wid = torch.sort(q_score, 1, descending=True)
...@@ -329,49 +340,51 @@ class DGLJTNNDecoder(nn.Module): ...@@ -329,49 +340,51 @@ class DGLJTNNDecoder(nn.Module):
if next_wid is None: if next_wid is None:
# Failed adding an actual children; v is a spurious node # Failed adding an actual children; v is a spurious node
# and we mark it. # and we mark it.
vdata['fail'] = cuda(torch.tensor([1])) mol_tree_graph.ndata['fail'][v] = cuda(torch.tensor([1]))
backtrack = True backtrack = True
else: else:
next_wid = cuda(torch.tensor([next_wid])) next_wid = cuda(torch.tensor([next_wid]))
vdata['wid'] = next_wid mol_tree_graph.ndata['wid'][v] = next_wid
vdata['x'] = self.embedding(next_wid) mol_tree_graph.ndata['x'][v] = self.embedding(next_wid)
mol_tree.nodes_dict[v] = next_node_dict mol_tree.nodes_dict[v] = next_node_dict
all_nodes[v] = next_node_dict all_nodes[v] = next_node_dict
stack.append((v, next_slots)) stack.append((v, next_slots))
mol_tree.add_edge(v, u) mol_tree_graph.add_edges(v, u)
vu = new_edge_id vu = new_edge_id
new_edge_id += 1 new_edge_id += 1
mol_tree.edges[uv].data['dst_x'] = mol_tree.nodes[v].data['x'] mol_tree_graph.edata['dst_x'][uv] = mol_tree_graph.ndata['x'][v]
mol_tree.edges[vu].data['src_x'] = mol_tree.nodes[v].data['x'] mol_tree_graph.edata['src_x'][vu] = mol_tree_graph.ndata['x'][v]
mol_tree.edges[vu].data['dst_x'] = mol_tree.nodes[u].data['x'] mol_tree_graph.edata['dst_x'][vu] = mol_tree_graph.ndata['x'][u]
# DGL doesn't dynamically maintain a line graph. # DGL doesn't dynamically maintain a line graph.
mol_tree_lg = mol_tree.line_graph(backtracking=False, shared=True) mol_tree_graph_lg = line_graph(mol_tree_graph, backtracking=False, shared=True)
mol_tree_lg.apply_nodes( mol_tree_graph_lg.apply_nodes(
self.dec_tree_edge_update.update_r, self.dec_tree_edge_update.update_r,
uv uv
) )
mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
if backtrack: if backtrack:
if len(stack) == 1: if len(stack) == 1:
break # At root, terminate break # At root, terminate
pu, _ = stack[-2] pu, _ = stack[-2]
u_pu = mol_tree.edge_id(u, pu) u_pu = mol_tree_graph.edge_id(u, pu)
mol_tree_lg.pull( mol_tree_graph_lg.pull(
u_pu, u_pu,
dec_tree_edge_msg, dec_tree_edge_msg,
dec_tree_edge_reduce, dec_tree_edge_reduce,
self.dec_tree_edge_update, self.dec_tree_edge_update,
) )
mol_tree.pull( mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
mol_tree_graph.pull(
pu, pu,
dec_tree_node_msg, dec_tree_node_msg,
dec_tree_node_reduce, dec_tree_node_reduce,
) )
stack.pop() stack.pop()
effective_nodes = mol_tree.filter_nodes(lambda nodes: nodes.data['fail'] != 1) effective_nodes = mol_tree_graph.filter_nodes(lambda nodes: nodes.data['fail'] != 1)
effective_nodes, _ = torch.sort(effective_nodes) effective_nodes, _ = torch.sort(effective_nodes)
return mol_tree, all_nodes, effective_nodes return mol_tree, all_nodes, effective_nodes
import torch import torch
import torch.nn as nn import torch.nn as nn
from .nnutils import GRUUpdate, cuda from .nnutils import GRUUpdate, cuda, line_graph, tocpu
from dgl import batch, bfs_edges_generator from dgl import batch, bfs_edges_generator
import dgl.function as DGLF import dgl.function as DGLF
import numpy as np import numpy as np
...@@ -8,7 +8,11 @@ import numpy as np ...@@ -8,7 +8,11 @@ import numpy as np
MAX_NB = 8 MAX_NB = 8
def level_order(forest, roots): def level_order(forest, roots):
forest = tocpu(forest)
edges = bfs_edges_generator(forest, roots) edges = bfs_edges_generator(forest, roots)
if len(edges) == 0:
# no edges in the tree; do not perform loopy BP
return
_, leaves = forest.find_edges(edges[-1]) _, leaves = forest.find_edges(edges[-1])
edges_back = bfs_edges_generator(forest, roots, reverse=True) edges_back = bfs_edges_generator(forest, roots, reverse=True)
yield from reversed(edges_back) yield from reversed(edges_back)
...@@ -53,14 +57,14 @@ class DGLJTNNEncoder(nn.Module): ...@@ -53,14 +57,14 @@ class DGLJTNNEncoder(nn.Module):
mol_tree_batch = batch(mol_trees) mol_tree_batch = batch(mol_trees)
# Build line graph to prepare for belief propagation # Build line graph to prepare for belief propagation
mol_tree_batch_lg = mol_tree_batch.line_graph(backtracking=False, shared=True) mol_tree_batch_lg = line_graph(mol_tree_batch, backtracking=False, shared=True)
return self.run(mol_tree_batch, mol_tree_batch_lg) return self.run(mol_tree_batch, mol_tree_batch_lg)
def run(self, mol_tree_batch, mol_tree_batch_lg): def run(self, mol_tree_batch, mol_tree_batch_lg):
# Since tree roots are designated to 0. In the batched graph we can # Since tree roots are designated to 0. In the batched graph we can
# simply find the corresponding node ID by looking at node_offset # simply find the corresponding node ID by looking at node_offset
node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes) node_offset = np.cumsum(np.insert(mol_tree_batch.batch_num_nodes().cpu().numpy(), 0, 0))
root_ids = node_offset[:-1] root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes() n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges() n_edges = mol_tree_batch.number_of_edges()
...@@ -68,6 +72,7 @@ class DGLJTNNEncoder(nn.Module): ...@@ -68,6 +72,7 @@ class DGLJTNNEncoder(nn.Module):
# Assign structure embeddings to tree nodes # Assign structure embeddings to tree nodes
mol_tree_batch.ndata.update({ mol_tree_batch.ndata.update({
'x': self.embedding(mol_tree_batch.ndata['wid']), 'x': self.embedding(mol_tree_batch.ndata['wid']),
'm': cuda(torch.zeros(n_nodes, self.hidden_size)),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)), 'h': cuda(torch.zeros(n_nodes, self.hidden_size)),
}) })
...@@ -95,16 +100,18 @@ class DGLJTNNEncoder(nn.Module): ...@@ -95,16 +100,18 @@ class DGLJTNNEncoder(nn.Module):
# messages, and the uncomputed messages are zero vectors. Essentially, # messages, and the uncomputed messages are zero vectors. Essentially,
# we can always compute s_ij as the sum of incoming m_ij, no matter # we can always compute s_ij as the sum of incoming m_ij, no matter
# if m_ij is actually computed or not. # if m_ij is actually computed or not.
mol_tree_batch_lg.ndata.update(mol_tree_batch.edata)
for eid in level_order(mol_tree_batch, root_ids): for eid in level_order(mol_tree_batch, root_ids):
#eid = mol_tree_batch.edge_ids(u, v) #eid = mol_tree_batch.edge_ids(u, v)
mol_tree_batch_lg.pull( mol_tree_batch_lg.pull(
eid, eid.to(mol_tree_batch_lg.device),
enc_tree_msg, enc_tree_msg,
enc_tree_reduce, enc_tree_reduce,
self.enc_tree_update, self.enc_tree_update,
) )
# Readout # Readout
mol_tree_batch.edata.update(mol_tree_batch_lg.ndata)
mol_tree_batch.update_all( mol_tree_batch.update_all(
enc_tree_gather_msg, enc_tree_gather_msg,
enc_tree_gather_reduce, enc_tree_gather_reduce,
......
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .nnutils import cuda, move_dgl_to_cuda from .nnutils import cuda
from .chemutils import set_atommap, copy_edit_mol, enum_assemble_nx, \ from .chemutils import set_atommap, copy_edit_mol, enum_assemble_nx, \
attach_mols_nx, decode_stereo attach_mols_nx, decode_stereo
from .jtnn_enc import DGLJTNNEncoder from .jtnn_enc import DGLJTNNEncoder
...@@ -44,24 +44,24 @@ class DGLJTNNVAE(nn.Module): ...@@ -44,24 +44,24 @@ class DGLJTNNVAE(nn.Module):
@staticmethod @staticmethod
def move_to_cuda(mol_batch): def move_to_cuda(mol_batch):
for t in mol_batch['mol_trees']: for i in range(len(mol_batch['mol_trees'])):
move_dgl_to_cuda(t) mol_batch['mol_trees'][i].graph = cuda(mol_batch['mol_trees'][i].graph)
move_dgl_to_cuda(mol_batch['mol_graph_batch']) mol_batch['mol_graph_batch'] = cuda(mol_batch['mol_graph_batch'])
if 'cand_graph_batch' in mol_batch: if 'cand_graph_batch' in mol_batch:
move_dgl_to_cuda(mol_batch['cand_graph_batch']) mol_batch['cand_graph_batch'] = cuda(mol_batch['cand_graph_batch'])
if mol_batch.get('stereo_cand_graph_batch') is not None: if mol_batch.get('stereo_cand_graph_batch') is not None:
move_dgl_to_cuda(mol_batch['stereo_cand_graph_batch']) mol_batch['stereo_cand_graph_batch'] = cuda(mol_batch['stereo_cand_graph_batch'])
def encode(self, mol_batch): def encode(self, mol_batch):
mol_graphs = mol_batch['mol_graph_batch'] mol_graphs = mol_batch['mol_graph_batch']
mol_vec = self.mpn(mol_graphs) mol_vec = self.mpn(mol_graphs)
mol_tree_batch, tree_vec = self.jtnn(mol_batch['mol_trees']) mol_tree_batch, tree_vec = self.jtnn([t.graph for t in mol_batch['mol_trees']])
self.n_nodes_total += mol_graphs.number_of_nodes() self.n_nodes_total += mol_graphs.number_of_nodes()
self.n_edges_total += mol_graphs.number_of_edges() self.n_edges_total += mol_graphs.number_of_edges()
self.n_tree_nodes_total += sum(t.number_of_nodes() for t in mol_batch['mol_trees']) self.n_tree_nodes_total += sum(t.graph.number_of_nodes() for t in mol_batch['mol_trees'])
self.n_passes += 1 self.n_passes += 1
return mol_tree_batch, tree_vec, mol_vec return mol_tree_batch, tree_vec, mol_vec
...@@ -93,7 +93,7 @@ class DGLJTNNVAE(nn.Module): ...@@ -93,7 +93,7 @@ class DGLJTNNVAE(nn.Module):
tree_vec, mol_vec, z_mean, z_log_var = self.sample(tree_vec, mol_vec, e1, e2) tree_vec, mol_vec, z_mean, z_log_var = self.sample(tree_vec, mol_vec, e1, e2)
kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
word_loss, topo_loss, word_acc, topo_acc = self.decoder(mol_trees, tree_vec) word_loss, topo_loss, word_acc, topo_acc = self.decoder([t.graph for t in mol_trees], tree_vec)
assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec) assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec)
stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec) stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)
...@@ -103,9 +103,9 @@ class DGLJTNNVAE(nn.Module): ...@@ -103,9 +103,9 @@ class DGLJTNNVAE(nn.Module):
def assm(self, mol_batch, mol_tree_batch, mol_vec): def assm(self, mol_batch, mol_tree_batch, mol_vec):
cands = [mol_batch['cand_graph_batch'], cands = [mol_batch['cand_graph_batch'],
mol_batch['tree_mess_src_e'], cuda(mol_batch['tree_mess_src_e']),
mol_batch['tree_mess_tgt_e'], cuda(mol_batch['tree_mess_tgt_e']),
mol_batch['tree_mess_tgt_n']] cuda(mol_batch['tree_mess_tgt_n'])]
cand_vec = self.jtmpn(cands, mol_tree_batch) cand_vec = self.jtmpn(cands, mol_tree_batch)
cand_vec = self.G_mean(cand_vec) cand_vec = self.G_mean(cand_vec)
...@@ -179,12 +179,11 @@ class DGLJTNNVAE(nn.Module): ...@@ -179,12 +179,11 @@ class DGLJTNNVAE(nn.Module):
node['idx'] = i node['idx'] = i
node['nid'] = i + 1 node['nid'] = i + 1
node['is_leaf'] = True node['is_leaf'] = True
if mol_tree.in_degree(node_id) > 1: if mol_tree.graph.in_degrees(node_id) > 1:
node['is_leaf'] = False node['is_leaf'] = False
set_atommap(node['mol'], node['nid']) set_atommap(node['mol'], node['nid'])
mol_tree_sg = mol_tree.subgraph(effective_nodes) mol_tree_sg = mol_tree.graph.subgraph(effective_nodes.int().to(tree_vec.device))
mol_tree_sg.copy_from_parent()
mol_tree_msg, _ = self.jtnn([mol_tree_sg]) mol_tree_msg, _ = self.jtnn([mol_tree_sg])
mol_tree_msg = unbatch(mol_tree_msg)[0] mol_tree_msg = unbatch(mol_tree_msg)[0]
mol_tree_msg.nodes_dict = nodes_dict mol_tree_msg.nodes_dict = nodes_dict
...@@ -210,7 +209,7 @@ class DGLJTNNVAE(nn.Module): ...@@ -210,7 +209,7 @@ class DGLJTNNVAE(nn.Module):
stereo_graphs = [mol2dgl_enc(c) for c in stereo_cands] stereo_graphs = [mol2dgl_enc(c) for c in stereo_cands]
stereo_cand_graphs, atom_x, bond_x = \ stereo_cand_graphs, atom_x, bond_x = \
zip(*stereo_graphs) zip(*stereo_graphs)
stereo_cand_graphs = batch(stereo_cand_graphs) stereo_cand_graphs = cuda(batch(stereo_cand_graphs))
atom_x = cuda(torch.cat(atom_x)) atom_x = cuda(torch.cat(atom_x))
bond_x = cuda(torch.cat(bond_x)) bond_x = cuda(torch.cat(bond_x))
stereo_cand_graphs.ndata['x'] = atom_x stereo_cand_graphs.ndata['x'] = atom_x
...@@ -248,9 +247,8 @@ class DGLJTNNVAE(nn.Module): ...@@ -248,9 +247,8 @@ class DGLJTNNVAE(nn.Module):
cands = [(candmol, mol_tree_msg, cur_node_id) for candmol in cand_mols] cands = [(candmol, mol_tree_msg, cur_node_id) for candmol in cand_mols]
cand_graphs, atom_x, bond_x, tree_mess_src_edges, \ cand_graphs, atom_x, bond_x, tree_mess_src_edges, \
tree_mess_tgt_edges, tree_mess_tgt_nodes = mol2dgl_dec( tree_mess_tgt_edges, tree_mess_tgt_nodes = mol2dgl_dec(cands)
cands) cand_graphs = batch([g.to(mol_vec.device) for g in cand_graphs])
cand_graphs = batch(cand_graphs)
atom_x = cuda(atom_x) atom_x = cuda(atom_x)
bond_x = cuda(bond_x) bond_x = cuda(bond_x)
cand_graphs.ndata['x'] = atom_x cand_graphs.ndata['x'] = atom_x
......
from dgl import DGLGraph import dgl
import rdkit.Chem as Chem import rdkit.Chem as Chem
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \ from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \
set_atommap, enum_assemble_nx, decode_stereo set_atommap, enum_assemble_nx, decode_stereo
import numpy as np import numpy as np
class DGLMolTree(DGLGraph): class DGLMolTree(object):
def __init__(self, smiles): def __init__(self, smiles):
DGLGraph.__init__(self)
self.nodes_dict = {} self.nodes_dict = {}
if smiles is None: if smiles is None:
self.graph = dgl.graph(([], []))
return return
self.smiles = smiles self.smiles = smiles
...@@ -34,7 +34,6 @@ class DGLMolTree(DGLGraph): ...@@ -34,7 +34,6 @@ class DGLMolTree(DGLGraph):
) )
if min(c) == 0: if min(c) == 0:
root = i root = i
self.add_nodes(len(cliques))
# The clique with atom ID 0 becomes root # The clique with atom ID 0 becomes root
if root > 0: if root > 0:
...@@ -51,16 +50,16 @@ class DGLMolTree(DGLGraph): ...@@ -51,16 +50,16 @@ class DGLMolTree(DGLGraph):
dst[2 * i] = y dst[2 * i] = y
src[2 * i + 1] = y src[2 * i + 1] = y
dst[2 * i + 1] = x dst[2 * i + 1] = x
self.add_edges(src, dst) self.graph = dgl.graph((src, dst), num_nodes=len(cliques))
for i in self.nodes_dict: for i in self.nodes_dict:
self.nodes_dict[i]['nid'] = i + 1 self.nodes_dict[i]['nid'] = i + 1
if self.out_degree(i) > 1: # Leaf node mol is not marked if self.graph.out_degrees(i) > 1: # Leaf node mol is not marked
set_atommap(self.nodes_dict[i]['mol'], self.nodes_dict[i]['nid']) set_atommap(self.nodes_dict[i]['mol'], self.nodes_dict[i]['nid'])
self.nodes_dict[i]['is_leaf'] = (self.out_degree(i) == 1) self.nodes_dict[i]['is_leaf'] = (self.graph.out_degrees(i) == 1)
def treesize(self): def treesize(self):
return self.number_of_nodes() return self.graph.number_of_nodes()
def _recover_node(self, i, original_mol): def _recover_node(self, i, original_mol):
node = self.nodes_dict[i] node = self.nodes_dict[i]
...@@ -71,7 +70,7 @@ class DGLMolTree(DGLGraph): ...@@ -71,7 +70,7 @@ class DGLMolTree(DGLGraph):
for cidx in node['clique']: for cidx in node['clique']:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid']) original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid'])
for j in self.successors(i).numpy(): for j in self.graph.successors(i).numpy():
nei_node = self.nodes_dict[j] nei_node = self.nodes_dict[j]
clique.extend(nei_node['clique']) clique.extend(nei_node['clique'])
if nei_node['is_leaf']: # Leaf node, no need to mark if nei_node['is_leaf']: # Leaf node, no need to mark
...@@ -93,10 +92,10 @@ class DGLMolTree(DGLGraph): ...@@ -93,10 +92,10 @@ class DGLMolTree(DGLGraph):
return node['label'] return node['label']
def _assemble_node(self, i): def _assemble_node(self, i):
neighbors = [self.nodes_dict[j] for j in self.successors(i).numpy() neighbors = [self.nodes_dict[j] for j in self.graph.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() > 1] if self.nodes_dict[j]['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True) neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [self.nodes_dict[j] for j in self.successors(i).numpy() singletons = [self.nodes_dict[j] for j in self.graph.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() == 1] if self.nodes_dict[j]['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors neighbors = singletons + neighbors
......
...@@ -3,8 +3,10 @@ import torch.nn as nn ...@@ -3,8 +3,10 @@ import torch.nn as nn
import rdkit.Chem as Chem import rdkit.Chem as Chem
import torch.nn.functional as F import torch.nn.functional as F
from .chemutils import get_mol from .chemutils import get_mol
from dgl import DGLGraph, mean_nodes import dgl
from dgl import mean_nodes
import dgl.function as DGLF import dgl.function as DGLF
from .nnutils import line_graph
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown'] 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
...@@ -41,11 +43,9 @@ def mol2dgl_single(smiles): ...@@ -41,11 +43,9 @@ def mol2dgl_single(smiles):
mol = get_mol(smiles) mol = get_mol(smiles)
n_atoms = mol.GetNumAtoms() n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds() n_bonds = mol.GetNumBonds()
graph = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()): for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx() assert i == atom.GetIdx()
atom_x.append(atom_features(atom)) atom_x.append(atom_features(atom))
graph.add_nodes(n_atoms)
bond_src = [] bond_src = []
bond_dst = [] bond_dst = []
...@@ -60,8 +60,7 @@ def mol2dgl_single(smiles): ...@@ -60,8 +60,7 @@ def mol2dgl_single(smiles):
bond_src.append(end_idx) bond_src.append(end_idx)
bond_dst.append(begin_idx) bond_dst.append(begin_idx)
bond_x.append(features) bond_x.append(features)
graph.add_edges(bond_src, bond_dst) graph = dgl.graph((bond_src, bond_dst), num_nodes=n_atoms)
n_edges += n_bonds n_edges += n_bonds
return graph, torch.stack(atom_x), \ return graph, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0) torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0)
...@@ -123,7 +122,7 @@ class DGLMPN(nn.Module): ...@@ -123,7 +122,7 @@ class DGLMPN(nn.Module):
def forward(self, mol_graph): def forward(self, mol_graph):
n_samples = mol_graph.batch_size n_samples = mol_graph.batch_size
mol_line_graph = mol_graph.line_graph(backtracking=False, shared=True) mol_line_graph = line_graph(mol_graph, backtracking=False, shared=True)
n_nodes = mol_graph.number_of_nodes() n_nodes = mol_graph.number_of_nodes()
n_edges = mol_graph.number_of_edges() n_edges = mol_graph.number_of_edges()
...@@ -170,6 +169,7 @@ class DGLMPN(nn.Module): ...@@ -170,6 +169,7 @@ class DGLMPN(nn.Module):
self.loopy_bp_updater, self.loopy_bp_updater,
) )
mol_graph.edata.update(mol_line_graph.ndata)
mol_graph.update_all( mol_graph.update_all(
mpn_gather_msg, mpn_gather_msg,
mpn_gather_reduce, mpn_gather_reduce,
......
import torch import torch
import torch.nn as nn import torch.nn as nn
import os import os
import dgl
def cuda(tensor): def cuda(x):
if torch.cuda.is_available() and not os.getenv('NOCUDA', None): if torch.cuda.is_available() and not os.getenv('NOCUDA', None):
return tensor.cuda() return x.to(torch.device('cuda')) # works for both DGLGraph and tensor
else: else:
return tensor return tensor
...@@ -42,7 +43,15 @@ class GRUUpdate(nn.Module): ...@@ -42,7 +43,15 @@ class GRUUpdate(nn.Module):
dic.update(self.update_r(node, zm=dic)) dic.update(self.update_r(node, zm=dic))
return dic return dic
def tocpu(g):
def move_dgl_to_cuda(g): src, dst = g.edges()
g.ndata.update({k: cuda(g.ndata[k]) for k in g.ndata}) src = src.cpu()
g.edata.update({k: cuda(g.edata[k]) for k in g.edata}) dst = dst.cpu()
return dgl.graph((src, dst), num_nodes=g.number_of_nodes())
def line_graph(g, backtracking=True, shared=False):
#g2 = tocpu(g)
g2 = dgl.line_graph(g, backtracking, shared)
#g2 = g2.to(g.device)
g2.ndata.update(g.edata)
return g2
...@@ -8,6 +8,7 @@ import math, random, sys ...@@ -8,6 +8,7 @@ import math, random, sys
from optparse import OptionParser from optparse import OptionParser
from collections import deque from collections import deque
import rdkit import rdkit
import tqdm
from jtnn import * from jtnn import *
...@@ -69,7 +70,7 @@ def train(): ...@@ -69,7 +70,7 @@ def train():
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
num_workers=4, num_workers=0,
collate_fn=JTNNCollator(vocab, True), collate_fn=JTNNCollator(vocab, True),
drop_last=True, drop_last=True,
worker_init_fn=worker_init_fn) worker_init_fn=worker_init_fn)
...@@ -77,7 +78,7 @@ def train(): ...@@ -77,7 +78,7 @@ def train():
for epoch in range(MAX_EPOCH): for epoch in range(MAX_EPOCH):
word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0 word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
for it, batch in enumerate(dataloader): for it, batch in tqdm.tqdm(enumerate(dataloader), total=2000):
model.zero_grad() model.zero_grad()
try: try:
loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta) loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta)
......
...@@ -28,7 +28,13 @@ def _new_object(cls): ...@@ -28,7 +28,13 @@ def _new_object(cls):
class ObjectBase(_ObjectBase): class ObjectBase(_ObjectBase):
"""ObjectBase is the base class of all DGL CAPI object.""" """ObjectBase is the base class of all DGL CAPI object.
The core attribute is ``handle``, which is a C raw pointer. It must be initialized
via ``__init_handle_by_constructor__``.
Note that the same handle **CANNOT** be shared across multiple ObjectBase instances.
"""
def __dir__(self): def __dir__(self):
plist = ctypes.POINTER(ctypes.c_char_p)() plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint() size = ctypes.c_uint()
......
...@@ -268,9 +268,6 @@ class DGLHeteroGraph(object): ...@@ -268,9 +268,6 @@ class DGLHeteroGraph(object):
self._etype2canonical[ety] = self._canonical_etypes[i] self._etype2canonical[ety] = self._canonical_etypes[i]
self._etypes_invmap = {t : i for i, t in enumerate(self._canonical_etypes)} self._etypes_invmap = {t : i for i, t in enumerate(self._canonical_etypes)}
# Cached metagraph in networkx
self._nx_metagraph = None
# node and edge frame # node and edge frame
if node_frames is None: if node_frames is None:
node_frames = [None] * len(self._ntypes) node_frames = [None] * len(self._ntypes)
...@@ -286,27 +283,12 @@ class DGLHeteroGraph(object): ...@@ -286,27 +283,12 @@ class DGLHeteroGraph(object):
for i, frame in enumerate(edge_frames)] for i, frame in enumerate(edge_frames)]
self._edge_frames = edge_frames self._edge_frames = edge_frames
def __getstate__(self):
metainfo = (self._ntypes, self._etypes, self._canonical_etypes,
self._srctypes_invmap, self._dsttypes_invmap,
self._is_unibipartite, self._etype2canonical, self._etypes_invmap)
return (self._graph, metainfo,
self._node_frames, self._edge_frames,
self._batch_num_nodes, self._batch_num_edges)
def __setstate__(self, state): def __setstate__(self, state):
# Compatibility check # Compatibility check
# TODO: version the storage # TODO: version the storage
if isinstance(state, tuple) and len(state) == 6: if isinstance(state, dict):
# DGL >= 0.5 # Since 0.5 we use the default __dict__ method
#TODO(minjie): too many states in python; should clean up and lower to C self.__dict__.update(state)
self._nx_metagraph = None
(self._graph, metainfo, self._node_frames, self._edge_frames,
self._batch_num_nodes, self._batch_num_edges) = state
(self._ntypes, self._etypes, self._canonical_etypes,
self._srctypes_invmap, self._dsttypes_invmap,
self._is_unibipartite, self._etype2canonical,
self._etypes_invmap) = metainfo
elif isinstance(state, tuple) and len(state) == 5: elif isinstance(state, tuple) and len(state) == 5:
# DGL == 0.4.3 # DGL == 0.4.3
dgl_warning("The object is pickled with DGL == 0.4.3. " dgl_warning("The object is pickled with DGL == 0.4.3. "
...@@ -337,7 +319,7 @@ class DGLHeteroGraph(object): ...@@ -337,7 +319,7 @@ class DGLHeteroGraph(object):
for i in range(len(self.ntypes))} for i in range(len(self.ntypes))}
nedge_dict = {self.canonical_etypes[i] : self._graph.number_of_edges(i) nedge_dict = {self.canonical_etypes[i] : self._graph.number_of_edges(i)
for i in range(len(self.etypes))} for i in range(len(self.etypes))}
meta = str(self.metagraph.edges(keys=True)) meta = str(self.metagraph().edges(keys=True))
return ret.format(node=nnode_dict, edge=nedge_dict, meta=meta) return ret.format(node=nnode_dict, edge=nedge_dict, meta=meta)
def __copy__(self): def __copy__(self):
...@@ -345,20 +327,7 @@ class DGLHeteroGraph(object): ...@@ -345,20 +327,7 @@ class DGLHeteroGraph(object):
#TODO(minjie): too many states in python; should clean up and lower to C #TODO(minjie): too many states in python; should clean up and lower to C
cls = type(self) cls = type(self)
obj = cls.__new__(cls) obj = cls.__new__(cls)
obj._graph = self._graph obj.__dict__.update(self.__dict__)
obj._batch_num_nodes = self._batch_num_nodes
obj._batch_num_edges = self._batch_num_edges
obj._ntypes = self._ntypes
obj._etypes = self._etypes
obj._canonical_etypes = self._canonical_etypes
obj._srctypes_invmap = self._srctypes_invmap
obj._dsttypes_invmap = self._dsttypes_invmap
obj._is_unibipartite = self._is_unibipartite
obj._etype2canonical = self._etype2canonical
obj._etypes_invmap = self._etypes_invmap
obj._nx_metagraph = self._nx_metagraph
obj._node_frames = self._node_frames
obj._edge_frames = self._edge_frames
return obj return obj
################################################################# #################################################################
...@@ -975,7 +944,6 @@ class DGLHeteroGraph(object): ...@@ -975,7 +944,6 @@ class DGLHeteroGraph(object):
else: else:
return self.ntypes return self.ntypes
@property
def metagraph(self): def metagraph(self):
"""Return the metagraph as networkx.MultiDiGraph. """Return the metagraph as networkx.MultiDiGraph.
...@@ -992,7 +960,7 @@ class DGLHeteroGraph(object): ...@@ -992,7 +960,7 @@ class DGLHeteroGraph(object):
>>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows') >>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g]) >>> g = dgl.hetero_from_relations([follows_g, plays_g])
>>> meta_g = g.metagraph >>> meta_g = g.metagraph()
The metagraph then has two nodes and two edges. The metagraph then has two nodes and two edges.
...@@ -1005,13 +973,12 @@ class DGLHeteroGraph(object): ...@@ -1005,13 +973,12 @@ class DGLHeteroGraph(object):
>>> meta_g.number_of_edges() >>> meta_g.number_of_edges()
2 2
""" """
if self._nx_metagraph is None:
nx_graph = self._graph.metagraph.to_networkx() nx_graph = self._graph.metagraph.to_networkx()
self._nx_metagraph = nx.MultiDiGraph() nx_metagraph = nx.MultiDiGraph()
for u_v in nx_graph.edges: for u_v in nx_graph.edges:
srctype, etype, dsttype = self.canonical_etypes[nx_graph.edges[u_v]['id']] srctype, etype, dsttype = self.canonical_etypes[nx_graph.edges[u_v]['id']]
self._nx_metagraph.add_edge(srctype, dsttype, etype) nx_metagraph.add_edge(srctype, dsttype, etype)
return self._nx_metagraph return nx_metagraph
def to_canonical_etype(self, etype): def to_canonical_etype(self, etype):
"""Convert edge type to canonical etype: (srctype, etype, dsttype). """Convert edge type to canonical etype: (srctype, etype, dsttype).
...@@ -5282,7 +5249,7 @@ class DGLBlock(DGLHeteroGraph): ...@@ -5282,7 +5249,7 @@ class DGLBlock(DGLHeteroGraph):
for ntype in self.dsttypes} for ntype in self.dsttypes}
nedge_dict = {etype : self.number_of_edges(etype) nedge_dict = {etype : self.number_of_edges(etype)
for etype in self.canonical_etypes} for etype in self.canonical_etypes}
meta = str(self.metagraph.edges(keys=True)) meta = str(self.metagraph().edges(keys=True))
return ret.format( return ret.format(
srcnode=nsrcnode_dict, dstnode=ndstnode_dict, edge=nedge_dict, meta=meta) srcnode=nsrcnode_dict, dstnode=ndstnode_dict, edge=nedge_dict, meta=meta)
......
...@@ -211,7 +211,7 @@ class PinSAGESampler(RandomWalkNeighborSampler): ...@@ -211,7 +211,7 @@ class PinSAGESampler(RandomWalkNeighborSampler):
""" """
def __init__(self, G, ntype, other_type, random_walk_length, random_walk_restart_prob, def __init__(self, G, ntype, other_type, random_walk_length, random_walk_restart_prob,
num_random_walks, num_neighbors, weight_column='weights'): num_random_walks, num_neighbors, weight_column='weights'):
metagraph = G.metagraph metagraph = G.metagraph()
fw_etype = list(metagraph[ntype][other_type])[0] fw_etype = list(metagraph[ntype][other_type])[0]
bw_etype = list(metagraph[other_type][ntype])[0] bw_etype = list(metagraph[other_type][ntype])[0]
super().__init__(G, random_walk_length, super().__init__(G, random_walk_length,
......
...@@ -33,7 +33,9 @@ CSRMatrix CSRTranspose(CSRMatrix csr) { ...@@ -33,7 +33,9 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
const int32_t* indices_ptr = static_cast<int32_t*>(indices->data); const int32_t* indices_ptr = static_cast<int32_t*>(indices->data);
const void* data_ptr = data->data; const void* data_ptr = data->data;
NDArray t_indptr = aten::NewIdArray(csr.num_cols + 1, ctx, bits); // (BarclayII) csr2csc doesn't seem to clear the content of cscColPtr if nnz == 0.
// We need to do it ourselves.
NDArray t_indptr = aten::Full(0, csr.num_cols + 1, bits, ctx);
NDArray t_indices = aten::NewIdArray(nnz, ctx, bits); NDArray t_indices = aten::NewIdArray(nnz, ctx, bits);
NDArray t_data = aten::NewIdArray(nnz, ctx, bits); NDArray t_data = aten::NewIdArray(nnz, ctx, bits);
int32_t* t_indptr_ptr = static_cast<int32_t*>(t_indptr->data); int32_t* t_indptr_ptr = static_cast<int32_t*>(t_indptr->data);
......
...@@ -224,7 +224,7 @@ def test_query(idtype): ...@@ -224,7 +224,7 @@ def test_query(idtype):
assert set(canonical_etypes) == set(g.canonical_etypes) assert set(canonical_etypes) == set(g.canonical_etypes)
# metagraph # metagraph
mg = g.metagraph mg = g.metagraph()
assert set(g.ntypes) == set(mg.nodes) assert set(g.ntypes) == set(mg.nodes)
etype_triplets = [(u, v, e) for u, v, e in mg.edges(keys=True)] etype_triplets = [(u, v, e) for u, v, e in mg.edges(keys=True)]
assert set([ assert set([
......
...@@ -34,8 +34,8 @@ def _assert_is_identical_hetero(g, g2): ...@@ -34,8 +34,8 @@ def _assert_is_identical_hetero(g, g2):
assert g.canonical_etypes == g2.canonical_etypes assert g.canonical_etypes == g2.canonical_etypes
# check if two metagraphs are identical # check if two metagraphs are identical
for edges, features in g.metagraph.edges(keys=True).items(): for edges, features in g.metagraph().edges(keys=True).items():
assert g2.metagraph.edges(keys=True)[edges] == features assert g2.metagraph().edges(keys=True)[edges] == features
# check if node ID spaces and feature spaces are equal # check if node ID spaces and feature spaces are equal
for ntype in g.ntypes: for ntype in g.ntypes:
......
...@@ -35,8 +35,8 @@ def _assert_is_identical_hetero(g, g2): ...@@ -35,8 +35,8 @@ def _assert_is_identical_hetero(g, g2):
assert g.canonical_etypes == g2.canonical_etypes assert g.canonical_etypes == g2.canonical_etypes
# check if two metagraphs are identical # check if two metagraphs are identical
for edges, features in g.metagraph.edges(keys=True).items(): for edges, features in g.metagraph().edges(keys=True).items():
assert g2.metagraph.edges(keys=True)[edges] == features assert g2.metagraph().edges(keys=True)[edges] == features
# check if node ID spaces and feature spaces are equal # check if node ID spaces and feature spaces are equal
for ntype in g.ntypes: for ntype in g.ntypes:
......
...@@ -17,8 +17,8 @@ def check_graph_equal(g1, g2, *, ...@@ -17,8 +17,8 @@ def check_graph_equal(g1, g2, *,
assert g1.batch_size == g2.batch_size assert g1.batch_size == g2.batch_size
# check if two metagraphs are identical # check if two metagraphs are identical
for edges, features in g1.metagraph.edges(keys=True).items(): for edges, features in g1.metagraph().edges(keys=True).items():
assert g2.metagraph.edges(keys=True)[edges] == features assert g2.metagraph().edges(keys=True)[edges] == features
for nty in g1.ntypes: for nty in g1.ntypes:
assert g1.number_of_nodes(nty) == g2.number_of_nodes(nty) assert g1.number_of_nodes(nty) == g2.number_of_nodes(nty)
......
...@@ -234,7 +234,7 @@ def plot_graph(nxg): ...@@ -234,7 +234,7 @@ def plot_graph(nxg):
ag.layout('dot') ag.layout('dot')
ag.draw('graph.png') ag.draw('graph.png')
plot_graph(G.metagraph) plot_graph(G.metagraph())
############################################################################### ###############################################################################
# Learning tasks associated with heterographs # Learning tasks associated with heterographs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment