Unverified Commit 68a978d4 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Model] Fixes JTNN for 0.5 (#1879)

* jtnn and fixes

* make metagraph a method

* fix test

* fix
parent faa1dc56
......@@ -6,7 +6,7 @@ This is a direct modification from https://github.com/wengong-jin/icml18-jtnn
Dependencies
--------------
* PyTorch 0.4.1+
* RDKit
* RDKit=2018.09.3.0
* requests
How to run
......
......@@ -84,9 +84,9 @@ class JTNNDataset(Dataset):
cand_graphs = []
atom_x_dec = torch.zeros(0, ATOM_FDIM_DEC)
bond_x_dec = torch.zeros(0, BOND_FDIM_DEC)
tree_mess_src_e = torch.zeros(0, 2).long()
tree_mess_tgt_e = torch.zeros(0, 2).long()
tree_mess_tgt_n = torch.zeros(0).long()
tree_mess_src_e = torch.zeros(0, 2).int()
tree_mess_tgt_e = torch.zeros(0, 2).int()
tree_mess_tgt_n = torch.zeros(0).int()
# prebuild the stereoisomers
cands = mol_tree.stereo_cands
......@@ -143,7 +143,7 @@ class JTNNCollator(object):
mol_trees = _unpack_field(examples, 'mol_tree')
wid = _unpack_field(examples, 'wid')
for _wid, mol_tree in zip(wid, mol_trees):
mol_tree.ndata['wid'] = torch.LongTensor(_wid)
mol_tree.graph.ndata['wid'] = torch.LongTensor(_wid)
# TODO: either support pickling or get around ctypes pointers using scipy
# batch molecule graphs
......@@ -176,7 +176,7 @@ class JTNNCollator(object):
tree_mess_src_e[i] += n_tree_nodes
tree_mess_tgt_n[i] += n_graph_nodes
n_graph_nodes += sum(g.number_of_nodes() for g in cand_graphs[i])
n_tree_nodes += mol_trees[i].number_of_nodes()
n_tree_nodes += mol_trees[i].graph.number_of_nodes()
cand_batch_idx.extend([i] * len(cand_graphs[i]))
tree_mess_tgt_e = torch.cat(tree_mess_tgt_e)
tree_mess_src_e = torch.cat(tree_mess_src_e)
......
import torch
import torch.nn as nn
from .nnutils import cuda
from .nnutils import cuda, line_graph
import rdkit.Chem as Chem
from dgl import DGLGraph, mean_nodes
import dgl
from dgl import mean_nodes
import dgl.function as DGLF
import os
......@@ -50,12 +51,11 @@ def mol2dgl_single(cand_batch):
ctr_node = mol_tree.nodes_dict[ctr_node_id]
ctr_bid = ctr_node['idx']
g = DGLGraph()
mol_tree_graph = getattr(mol_tree, 'graph', mol_tree)
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
atom_x.append(atom_features(atom))
g.add_nodes(n_atoms)
bond_src = []
bond_dst = []
......@@ -78,24 +78,24 @@ def mol2dgl_single(cand_batch):
x_bid = mol_tree.nodes_dict[x_nid - 1]['idx'] if x_nid > 0 else -1
y_bid = mol_tree.nodes_dict[y_nid - 1]['idx'] if y_nid > 0 else -1
if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
if mol_tree.has_edge_between(x_bid, y_bid):
if mol_tree_graph.has_edges_between(x_bid, y_bid):
tree_mess_target_edges.append((begin_idx + n_nodes, end_idx + n_nodes))
tree_mess_source_edges.append((x_bid, y_bid))
tree_mess_target_nodes.append(end_idx + n_nodes)
if mol_tree.has_edge_between(y_bid, x_bid):
if mol_tree_graph.has_edges_between(y_bid, x_bid):
tree_mess_target_edges.append((end_idx + n_nodes, begin_idx + n_nodes))
tree_mess_source_edges.append((y_bid, x_bid))
tree_mess_target_nodes.append(begin_idx + n_nodes)
n_nodes += n_atoms
g.add_edges(bond_src, bond_dst)
g = dgl.graph((bond_src, bond_dst), num_nodes=n_atoms)
cand_graphs.append(g)
return cand_graphs, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0), \
torch.LongTensor(tree_mess_source_edges), \
torch.LongTensor(tree_mess_target_edges), \
torch.LongTensor(tree_mess_target_nodes)
torch.IntTensor(tree_mess_source_edges), \
torch.IntTensor(tree_mess_target_edges), \
torch.IntTensor(tree_mess_target_nodes)
mpn_loopy_bp_msg = DGLF.copy_src(src='msg', out='msg')
......@@ -174,7 +174,7 @@ class DGLJTMPN(nn.Module):
n_samples = len(cand_graphs)
cand_line_graph = cand_graphs.line_graph(backtracking=False, shared=True)
cand_line_graph = line_graph(cand_graphs, backtracking=False, shared=True)
n_nodes = cand_graphs.number_of_nodes()
n_edges = cand_graphs.number_of_edges()
......@@ -222,20 +222,27 @@ class DGLJTMPN(nn.Module):
if PAPER:
src_u, src_v = tree_mess_src_edges.unbind(1)
tgt_u, tgt_v = tree_mess_tgt_edges.unbind(1)
alpha = mol_tree_batch.edges[src_u, src_v].data['m']
src_u = src_u.to(mol_tree_batch.device)
src_v = src_v.to(mol_tree_batch.device)
eid = mol_tree_batch.edge_ids(src_u.int(), src_v.int()).long()
alpha = mol_tree_batch.edata['m'][eid]
cand_graphs.edges[tgt_u, tgt_v].data['alpha'] = alpha
else:
src_u, src_v = tree_mess_src_edges.unbind(1)
alpha = mol_tree_batch.edges[src_u, src_v].data['m']
src_u = src_u.to(mol_tree_batch.device)
src_v = src_v.to(mol_tree_batch.device)
eid = mol_tree_batch.edge_ids(src_u.int(), src_v.int()).long()
alpha = mol_tree_batch.edata['m'][eid]
node_idx = (tree_mess_tgt_nodes
.to(device=zero_node_state.device)[:, None]
.expand_as(alpha))
node_alpha = zero_node_state.clone().scatter_add(0, node_idx, alpha)
node_alpha = zero_node_state.clone().scatter_add(0, node_idx.long(), alpha)
cand_graphs.ndata['alpha'] = node_alpha
cand_graphs.apply_edges(
func=lambda edges: {'alpha': edges.src['alpha']},
)
cand_line_graph.ndata.update(cand_graphs.edata)
for i in range(self.depth - 1):
cand_line_graph.update_all(
mpn_loopy_bp_msg,
......@@ -243,6 +250,7 @@ class DGLJTMPN(nn.Module):
self.loopy_bp_updater,
)
cand_graphs.edata.update(cand_line_graph.ndata)
cand_graphs.update_all(
mpn_gather_msg,
mpn_gather_reduce,
......
......@@ -3,7 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .mol_tree_nx import DGLMolTree
from .chemutils import enum_assemble_nx, get_mol
from .nnutils import GRUUpdate, cuda
from .nnutils import GRUUpdate, cuda, line_graph, tocpu
from dgl import batch, dfs_labeled_edges_generator
import dgl.function as DGLF
import numpy as np
......@@ -13,6 +13,7 @@ MAX_DECODE_LEN = 100
def dfs_order(forest, roots):
forest = tocpu(forest)
edges = dfs_labeled_edges_generator(forest, roots, has_reverse_edge=True)
for e, l in zip(*edges):
# I exploited the fact that the reverse edge ID equal to 1 xor forward
......@@ -55,7 +56,7 @@ def have_slots(fa_slots, ch_slots):
def can_assemble(mol_tree, u, v_node_dict):
u_node_dict = mol_tree.nodes_dict[u]
u_neighbors = mol_tree.successors(u)
u_neighbors = mol_tree.graph.successors(u)
u_neighbors_node_dict = [
mol_tree.nodes_dict[_u]
for _u in u_neighbors
......@@ -106,13 +107,13 @@ class DGLJTNNDecoder(nn.Module):
ground truth tree
'''
mol_tree_batch = batch(mol_trees)
mol_tree_batch_lg = mol_tree_batch.line_graph(backtracking=False, shared=True)
mol_tree_batch_lg = line_graph(mol_tree_batch, backtracking=False, shared=True)
n_trees = len(mol_trees)
return self.run(mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec)
def run(self, mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec):
node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes)
node_offset = np.cumsum(np.insert(mol_tree_batch.batch_num_nodes().cpu().numpy(), 0, 0))
root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges()
......@@ -120,7 +121,7 @@ class DGLJTNNDecoder(nn.Module):
mol_tree_batch.ndata.update({
'x': self.embedding(mol_tree_batch.ndata['wid']),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)),
'new': cuda(torch.ones(n_nodes).byte()), # whether it's newly generated node
'new': cuda(torch.ones(n_nodes).bool()), # whether it's newly generated node
})
mol_tree_batch.edata.update({
......@@ -162,22 +163,26 @@ class DGLJTNNDecoder(nn.Module):
# Traverse the tree and predict on children
for eid, p in dfs_order(mol_tree_batch, root_ids):
eid = eid.to(mol_tree_batch.device)
p = p.to(mol_tree_batch.device)
u, v = mol_tree_batch.find_edges(eid)
p_target_list = torch.zeros_like(root_out_degrees)
p_target_list[root_out_degrees > 0] = 1 - p
p_target_list[root_out_degrees > 0] = (1 - p).int()
p_target_list = p_target_list[root_out_degrees >= 0]
p_targets.append(torch.tensor(p_target_list))
root_out_degrees -= (root_out_degrees == 0).long()
root_out_degrees -= torch.tensor(np.isin(root_ids, v).astype('int64'))
root_out_degrees -= (root_out_degrees == 0).int()
root_out_degrees -= torch.tensor(np.isin(root_ids, v.cpu().numpy())).to(root_out_degrees)
mol_tree_batch_lg.ndata.update(mol_tree_batch.edata)
mol_tree_batch_lg.pull(
eid,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update,
)
mol_tree_batch.edata.update(mol_tree_batch_lg.ndata)
is_new = mol_tree_batch.nodes[v].data['new']
mol_tree_batch.pull(
v,
......@@ -185,6 +190,7 @@ class DGLJTNNDecoder(nn.Module):
dec_tree_node_reduce,
dec_tree_node_update,
)
# Extract
n_repr = mol_tree_batch.nodes[v].data
h = n_repr['h']
......@@ -200,7 +206,10 @@ class DGLJTNNDecoder(nn.Module):
if q_input.shape[0] > 0:
q_inputs.append(q_input)
q_targets.append(q_target)
p_targets.append(torch.zeros((root_out_degrees == 0).sum()).long())
p_targets.append(torch.zeros(
(root_out_degrees == 0).sum(),
device=root_out_degrees.device,
dtype=torch.int32))
# Batch compute the stop/label prediction losses
p_inputs = torch.cat(p_inputs, 0)
......@@ -231,6 +240,8 @@ class DGLJTNNDecoder(nn.Module):
assert mol_vec.shape[0] == 1
mol_tree = DGLMolTree(None)
mol_tree.graph = mol_tree.graph.to(mol_vec.device)
mol_tree_graph = mol_tree.graph
init_hidden = cuda(torch.zeros(1, self.hidden_size))
......@@ -240,11 +251,11 @@ class DGLJTNNDecoder(nn.Module):
_, root_wid = torch.max(root_score, 1)
root_wid = root_wid.view(1)
mol_tree.add_nodes(1) # root
mol_tree.nodes[0].data['wid'] = root_wid
mol_tree.nodes[0].data['x'] = self.embedding(root_wid)
mol_tree.nodes[0].data['h'] = init_hidden
mol_tree.nodes[0].data['fail'] = cuda(torch.tensor([0]))
mol_tree_graph.add_nodes(1) # root
mol_tree_graph.ndata['wid'] = root_wid
mol_tree_graph.ndata['x'] = self.embedding(root_wid)
mol_tree_graph.ndata['h'] = init_hidden
mol_tree_graph.ndata['fail'] = cuda(torch.tensor([0]))
mol_tree.nodes_dict[0] = root_node_dict = create_node_dict(
self.vocab.get_smiles(root_wid))
......@@ -259,27 +270,27 @@ class DGLJTNNDecoder(nn.Module):
for step in range(MAX_DECODE_LEN):
u, u_slots = stack[-1]
udata = mol_tree.nodes[u].data
x = udata['x']
h = udata['h']
x = mol_tree_graph.ndata['x'][u:u+1]
h = mol_tree_graph.ndata['h'][u:u+1]
# Predict stop
p_input = torch.cat([x, h, mol_vec], 1)
p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input))))
p_score[:] = 0
backtrack = (p_score.item() < 0.5)
if not backtrack:
# Predict next clique. Note that the prediction may fail due
# to lack of assemblable components
mol_tree.add_nodes(1)
mol_tree_graph.add_nodes(1)
new_node_id += 1
v = new_node_id
mol_tree.add_edges(u, v)
mol_tree_graph.add_edges(u, v)
uv = new_edge_id
new_edge_id += 1
if first:
mol_tree.edata.update({
mol_tree_graph.edata.update({
's': cuda(torch.zeros(1, self.hidden_size)),
'm': cuda(torch.zeros(1, self.hidden_size)),
'r': cuda(torch.zeros(1, self.hidden_size)),
......@@ -291,26 +302,26 @@ class DGLJTNNDecoder(nn.Module):
})
first = False
mol_tree.edges[uv].data['src_x'] = mol_tree.nodes[u].data['x']
mol_tree_graph.edata['src_x'][uv] = mol_tree_graph.ndata['x'][u]
# keeping dst_x 0 is fine as h on new edge doesn't depend on that.
# DGL doesn't dynamically maintain a line graph.
mol_tree_lg = mol_tree.line_graph(backtracking=False, shared=True)
mol_tree_graph_lg = line_graph(mol_tree_graph, backtracking=False, shared=True)
mol_tree_lg.pull(
mol_tree_graph_lg.pull(
uv,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update.update_zm,
)
mol_tree.pull(
mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
mol_tree_graph.pull(
v,
dec_tree_node_msg,
dec_tree_node_reduce,
)
vdata = mol_tree.nodes[v].data
h_v = vdata['h']
h_v = mol_tree_graph.ndata['h'][v:v+1]
q_input = torch.cat([h_v, mol_vec], 1)
q_score = torch.softmax(self.W_o(torch.relu(self.W(q_input))), -1)
_, sort_wid = torch.sort(q_score, 1, descending=True)
......@@ -329,49 +340,51 @@ class DGLJTNNDecoder(nn.Module):
if next_wid is None:
# Failed adding an actual children; v is a spurious node
# and we mark it.
vdata['fail'] = cuda(torch.tensor([1]))
mol_tree_graph.ndata['fail'][v] = cuda(torch.tensor([1]))
backtrack = True
else:
next_wid = cuda(torch.tensor([next_wid]))
vdata['wid'] = next_wid
vdata['x'] = self.embedding(next_wid)
mol_tree_graph.ndata['wid'][v] = next_wid
mol_tree_graph.ndata['x'][v] = self.embedding(next_wid)
mol_tree.nodes_dict[v] = next_node_dict
all_nodes[v] = next_node_dict
stack.append((v, next_slots))
mol_tree.add_edge(v, u)
mol_tree_graph.add_edges(v, u)
vu = new_edge_id
new_edge_id += 1
mol_tree.edges[uv].data['dst_x'] = mol_tree.nodes[v].data['x']
mol_tree.edges[vu].data['src_x'] = mol_tree.nodes[v].data['x']
mol_tree.edges[vu].data['dst_x'] = mol_tree.nodes[u].data['x']
mol_tree_graph.edata['dst_x'][uv] = mol_tree_graph.ndata['x'][v]
mol_tree_graph.edata['src_x'][vu] = mol_tree_graph.ndata['x'][v]
mol_tree_graph.edata['dst_x'][vu] = mol_tree_graph.ndata['x'][u]
# DGL doesn't dynamically maintain a line graph.
mol_tree_lg = mol_tree.line_graph(backtracking=False, shared=True)
mol_tree_lg.apply_nodes(
mol_tree_graph_lg = line_graph(mol_tree_graph, backtracking=False, shared=True)
mol_tree_graph_lg.apply_nodes(
self.dec_tree_edge_update.update_r,
uv
)
mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
if backtrack:
if len(stack) == 1:
break # At root, terminate
pu, _ = stack[-2]
u_pu = mol_tree.edge_id(u, pu)
u_pu = mol_tree_graph.edge_id(u, pu)
mol_tree_lg.pull(
mol_tree_graph_lg.pull(
u_pu,
dec_tree_edge_msg,
dec_tree_edge_reduce,
self.dec_tree_edge_update,
)
mol_tree.pull(
mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
mol_tree_graph.pull(
pu,
dec_tree_node_msg,
dec_tree_node_reduce,
)
stack.pop()
effective_nodes = mol_tree.filter_nodes(lambda nodes: nodes.data['fail'] != 1)
effective_nodes = mol_tree_graph.filter_nodes(lambda nodes: nodes.data['fail'] != 1)
effective_nodes, _ = torch.sort(effective_nodes)
return mol_tree, all_nodes, effective_nodes
import torch
import torch.nn as nn
from .nnutils import GRUUpdate, cuda
from .nnutils import GRUUpdate, cuda, line_graph, tocpu
from dgl import batch, bfs_edges_generator
import dgl.function as DGLF
import numpy as np
......@@ -8,7 +8,11 @@ import numpy as np
MAX_NB = 8
def level_order(forest, roots):
forest = tocpu(forest)
edges = bfs_edges_generator(forest, roots)
if len(edges) == 0:
# no edges in the tree; do not perform loopy BP
return
_, leaves = forest.find_edges(edges[-1])
edges_back = bfs_edges_generator(forest, roots, reverse=True)
yield from reversed(edges_back)
......@@ -53,14 +57,14 @@ class DGLJTNNEncoder(nn.Module):
mol_tree_batch = batch(mol_trees)
# Build line graph to prepare for belief propagation
mol_tree_batch_lg = mol_tree_batch.line_graph(backtracking=False, shared=True)
mol_tree_batch_lg = line_graph(mol_tree_batch, backtracking=False, shared=True)
return self.run(mol_tree_batch, mol_tree_batch_lg)
def run(self, mol_tree_batch, mol_tree_batch_lg):
# Since tree roots are designated to 0. In the batched graph we can
# simply find the corresponding node ID by looking at node_offset
node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes)
node_offset = np.cumsum(np.insert(mol_tree_batch.batch_num_nodes().cpu().numpy(), 0, 0))
root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges()
......@@ -68,6 +72,7 @@ class DGLJTNNEncoder(nn.Module):
# Assign structure embeddings to tree nodes
mol_tree_batch.ndata.update({
'x': self.embedding(mol_tree_batch.ndata['wid']),
'm': cuda(torch.zeros(n_nodes, self.hidden_size)),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)),
})
......@@ -95,16 +100,18 @@ class DGLJTNNEncoder(nn.Module):
# messages, and the uncomputed messages are zero vectors. Essentially,
# we can always compute s_ij as the sum of incoming m_ij, no matter
# if m_ij is actually computed or not.
mol_tree_batch_lg.ndata.update(mol_tree_batch.edata)
for eid in level_order(mol_tree_batch, root_ids):
#eid = mol_tree_batch.edge_ids(u, v)
mol_tree_batch_lg.pull(
eid,
eid.to(mol_tree_batch_lg.device),
enc_tree_msg,
enc_tree_reduce,
self.enc_tree_update,
)
# Readout
mol_tree_batch.edata.update(mol_tree_batch_lg.ndata)
mol_tree_batch.update_all(
enc_tree_gather_msg,
enc_tree_gather_reduce,
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from .nnutils import cuda, move_dgl_to_cuda
from .nnutils import cuda
from .chemutils import set_atommap, copy_edit_mol, enum_assemble_nx, \
attach_mols_nx, decode_stereo
from .jtnn_enc import DGLJTNNEncoder
......@@ -44,24 +44,24 @@ class DGLJTNNVAE(nn.Module):
@staticmethod
def move_to_cuda(mol_batch):
for t in mol_batch['mol_trees']:
move_dgl_to_cuda(t)
for i in range(len(mol_batch['mol_trees'])):
mol_batch['mol_trees'][i].graph = cuda(mol_batch['mol_trees'][i].graph)
move_dgl_to_cuda(mol_batch['mol_graph_batch'])
mol_batch['mol_graph_batch'] = cuda(mol_batch['mol_graph_batch'])
if 'cand_graph_batch' in mol_batch:
move_dgl_to_cuda(mol_batch['cand_graph_batch'])
mol_batch['cand_graph_batch'] = cuda(mol_batch['cand_graph_batch'])
if mol_batch.get('stereo_cand_graph_batch') is not None:
move_dgl_to_cuda(mol_batch['stereo_cand_graph_batch'])
mol_batch['stereo_cand_graph_batch'] = cuda(mol_batch['stereo_cand_graph_batch'])
def encode(self, mol_batch):
mol_graphs = mol_batch['mol_graph_batch']
mol_vec = self.mpn(mol_graphs)
mol_tree_batch, tree_vec = self.jtnn(mol_batch['mol_trees'])
mol_tree_batch, tree_vec = self.jtnn([t.graph for t in mol_batch['mol_trees']])
self.n_nodes_total += mol_graphs.number_of_nodes()
self.n_edges_total += mol_graphs.number_of_edges()
self.n_tree_nodes_total += sum(t.number_of_nodes() for t in mol_batch['mol_trees'])
self.n_tree_nodes_total += sum(t.graph.number_of_nodes() for t in mol_batch['mol_trees'])
self.n_passes += 1
return mol_tree_batch, tree_vec, mol_vec
......@@ -93,7 +93,7 @@ class DGLJTNNVAE(nn.Module):
tree_vec, mol_vec, z_mean, z_log_var = self.sample(tree_vec, mol_vec, e1, e2)
kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
word_loss, topo_loss, word_acc, topo_acc = self.decoder(mol_trees, tree_vec)
word_loss, topo_loss, word_acc, topo_acc = self.decoder([t.graph for t in mol_trees], tree_vec)
assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec)
stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)
......@@ -103,9 +103,9 @@ class DGLJTNNVAE(nn.Module):
def assm(self, mol_batch, mol_tree_batch, mol_vec):
cands = [mol_batch['cand_graph_batch'],
mol_batch['tree_mess_src_e'],
mol_batch['tree_mess_tgt_e'],
mol_batch['tree_mess_tgt_n']]
cuda(mol_batch['tree_mess_src_e']),
cuda(mol_batch['tree_mess_tgt_e']),
cuda(mol_batch['tree_mess_tgt_n'])]
cand_vec = self.jtmpn(cands, mol_tree_batch)
cand_vec = self.G_mean(cand_vec)
......@@ -179,12 +179,11 @@ class DGLJTNNVAE(nn.Module):
node['idx'] = i
node['nid'] = i + 1
node['is_leaf'] = True
if mol_tree.in_degree(node_id) > 1:
if mol_tree.graph.in_degrees(node_id) > 1:
node['is_leaf'] = False
set_atommap(node['mol'], node['nid'])
mol_tree_sg = mol_tree.subgraph(effective_nodes)
mol_tree_sg.copy_from_parent()
mol_tree_sg = mol_tree.graph.subgraph(effective_nodes.int().to(tree_vec.device))
mol_tree_msg, _ = self.jtnn([mol_tree_sg])
mol_tree_msg = unbatch(mol_tree_msg)[0]
mol_tree_msg.nodes_dict = nodes_dict
......@@ -210,7 +209,7 @@ class DGLJTNNVAE(nn.Module):
stereo_graphs = [mol2dgl_enc(c) for c in stereo_cands]
stereo_cand_graphs, atom_x, bond_x = \
zip(*stereo_graphs)
stereo_cand_graphs = batch(stereo_cand_graphs)
stereo_cand_graphs = cuda(batch(stereo_cand_graphs))
atom_x = cuda(torch.cat(atom_x))
bond_x = cuda(torch.cat(bond_x))
stereo_cand_graphs.ndata['x'] = atom_x
......@@ -248,9 +247,8 @@ class DGLJTNNVAE(nn.Module):
cands = [(candmol, mol_tree_msg, cur_node_id) for candmol in cand_mols]
cand_graphs, atom_x, bond_x, tree_mess_src_edges, \
tree_mess_tgt_edges, tree_mess_tgt_nodes = mol2dgl_dec(
cands)
cand_graphs = batch(cand_graphs)
tree_mess_tgt_edges, tree_mess_tgt_nodes = mol2dgl_dec(cands)
cand_graphs = batch([g.to(mol_vec.device) for g in cand_graphs])
atom_x = cuda(atom_x)
bond_x = cuda(bond_x)
cand_graphs.ndata['x'] = atom_x
......
from dgl import DGLGraph
import dgl
import rdkit.Chem as Chem
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \
set_atommap, enum_assemble_nx, decode_stereo
import numpy as np
class DGLMolTree(DGLGraph):
class DGLMolTree(object):
def __init__(self, smiles):
DGLGraph.__init__(self)
self.nodes_dict = {}
if smiles is None:
self.graph = dgl.graph(([], []))
return
self.smiles = smiles
......@@ -34,7 +34,6 @@ class DGLMolTree(DGLGraph):
)
if min(c) == 0:
root = i
self.add_nodes(len(cliques))
# The clique with atom ID 0 becomes root
if root > 0:
......@@ -51,16 +50,16 @@ class DGLMolTree(DGLGraph):
dst[2 * i] = y
src[2 * i + 1] = y
dst[2 * i + 1] = x
self.add_edges(src, dst)
self.graph = dgl.graph((src, dst), num_nodes=len(cliques))
for i in self.nodes_dict:
self.nodes_dict[i]['nid'] = i + 1
if self.out_degree(i) > 1: # Leaf node mol is not marked
if self.graph.out_degrees(i) > 1: # Leaf node mol is not marked
set_atommap(self.nodes_dict[i]['mol'], self.nodes_dict[i]['nid'])
self.nodes_dict[i]['is_leaf'] = (self.out_degree(i) == 1)
self.nodes_dict[i]['is_leaf'] = (self.graph.out_degrees(i) == 1)
def treesize(self):
return self.number_of_nodes()
return self.graph.number_of_nodes()
def _recover_node(self, i, original_mol):
node = self.nodes_dict[i]
......@@ -71,7 +70,7 @@ class DGLMolTree(DGLGraph):
for cidx in node['clique']:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid'])
for j in self.successors(i).numpy():
for j in self.graph.successors(i).numpy():
nei_node = self.nodes_dict[j]
clique.extend(nei_node['clique'])
if nei_node['is_leaf']: # Leaf node, no need to mark
......@@ -93,10 +92,10 @@ class DGLMolTree(DGLGraph):
return node['label']
def _assemble_node(self, i):
neighbors = [self.nodes_dict[j] for j in self.successors(i).numpy()
neighbors = [self.nodes_dict[j] for j in self.graph.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [self.nodes_dict[j] for j in self.successors(i).numpy()
singletons = [self.nodes_dict[j] for j in self.graph.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
......
......@@ -3,8 +3,10 @@ import torch.nn as nn
import rdkit.Chem as Chem
import torch.nn.functional as F
from .chemutils import get_mol
from dgl import DGLGraph, mean_nodes
import dgl
from dgl import mean_nodes
import dgl.function as DGLF
from .nnutils import line_graph
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
......@@ -41,11 +43,9 @@ def mol2dgl_single(smiles):
mol = get_mol(smiles)
n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds()
graph = DGLGraph()
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
atom_x.append(atom_features(atom))
graph.add_nodes(n_atoms)
bond_src = []
bond_dst = []
......@@ -60,8 +60,7 @@ def mol2dgl_single(smiles):
bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
graph.add_edges(bond_src, bond_dst)
graph = dgl.graph((bond_src, bond_dst), num_nodes=n_atoms)
n_edges += n_bonds
return graph, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0)
......@@ -123,7 +122,7 @@ class DGLMPN(nn.Module):
def forward(self, mol_graph):
n_samples = mol_graph.batch_size
mol_line_graph = mol_graph.line_graph(backtracking=False, shared=True)
mol_line_graph = line_graph(mol_graph, backtracking=False, shared=True)
n_nodes = mol_graph.number_of_nodes()
n_edges = mol_graph.number_of_edges()
......@@ -170,6 +169,7 @@ class DGLMPN(nn.Module):
self.loopy_bp_updater,
)
mol_graph.edata.update(mol_line_graph.ndata)
mol_graph.update_all(
mpn_gather_msg,
mpn_gather_reduce,
......
import torch
import torch.nn as nn
import os
import dgl
def cuda(tensor):
def cuda(x):
if torch.cuda.is_available() and not os.getenv('NOCUDA', None):
return tensor.cuda()
return x.to(torch.device('cuda')) # works for both DGLGraph and tensor
else:
return tensor
......@@ -42,7 +43,15 @@ class GRUUpdate(nn.Module):
dic.update(self.update_r(node, zm=dic))
return dic
def move_dgl_to_cuda(g):
g.ndata.update({k: cuda(g.ndata[k]) for k in g.ndata})
g.edata.update({k: cuda(g.edata[k]) for k in g.edata})
def tocpu(g):
src, dst = g.edges()
src = src.cpu()
dst = dst.cpu()
return dgl.graph((src, dst), num_nodes=g.number_of_nodes())
def line_graph(g, backtracking=True, shared=False):
#g2 = tocpu(g)
g2 = dgl.line_graph(g, backtracking, shared)
#g2 = g2.to(g.device)
g2.ndata.update(g.edata)
return g2
......@@ -8,6 +8,7 @@ import math, random, sys
from optparse import OptionParser
from collections import deque
import rdkit
import tqdm
from jtnn import *
......@@ -69,7 +70,7 @@ def train():
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
num_workers=0,
collate_fn=JTNNCollator(vocab, True),
drop_last=True,
worker_init_fn=worker_init_fn)
......@@ -77,7 +78,7 @@ def train():
for epoch in range(MAX_EPOCH):
word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
for it, batch in enumerate(dataloader):
for it, batch in tqdm.tqdm(enumerate(dataloader), total=2000):
model.zero_grad()
try:
loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta)
......
......@@ -28,7 +28,13 @@ def _new_object(cls):
class ObjectBase(_ObjectBase):
"""ObjectBase is the base class of all DGL CAPI object."""
"""ObjectBase is the base class of all DGL CAPI object.
The core attribute is ``handle``, which is a C raw pointer. It must be initialized
via ``__init_handle_by_constructor__``.
Note that the same handle **CANNOT** be shared across multiple ObjectBase instances.
"""
def __dir__(self):
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
......
......@@ -268,9 +268,6 @@ class DGLHeteroGraph(object):
self._etype2canonical[ety] = self._canonical_etypes[i]
self._etypes_invmap = {t : i for i, t in enumerate(self._canonical_etypes)}
# Cached metagraph in networkx
self._nx_metagraph = None
# node and edge frame
if node_frames is None:
node_frames = [None] * len(self._ntypes)
......@@ -286,27 +283,12 @@ class DGLHeteroGraph(object):
for i, frame in enumerate(edge_frames)]
self._edge_frames = edge_frames
def __getstate__(self):
metainfo = (self._ntypes, self._etypes, self._canonical_etypes,
self._srctypes_invmap, self._dsttypes_invmap,
self._is_unibipartite, self._etype2canonical, self._etypes_invmap)
return (self._graph, metainfo,
self._node_frames, self._edge_frames,
self._batch_num_nodes, self._batch_num_edges)
def __setstate__(self, state):
# Compatibility check
# TODO: version the storage
if isinstance(state, tuple) and len(state) == 6:
# DGL >= 0.5
#TODO(minjie): too many states in python; should clean up and lower to C
self._nx_metagraph = None
(self._graph, metainfo, self._node_frames, self._edge_frames,
self._batch_num_nodes, self._batch_num_edges) = state
(self._ntypes, self._etypes, self._canonical_etypes,
self._srctypes_invmap, self._dsttypes_invmap,
self._is_unibipartite, self._etype2canonical,
self._etypes_invmap) = metainfo
if isinstance(state, dict):
# Since 0.5 we use the default __dict__ method
self.__dict__.update(state)
elif isinstance(state, tuple) and len(state) == 5:
# DGL == 0.4.3
dgl_warning("The object is pickled with DGL == 0.4.3. "
......@@ -337,7 +319,7 @@ class DGLHeteroGraph(object):
for i in range(len(self.ntypes))}
nedge_dict = {self.canonical_etypes[i] : self._graph.number_of_edges(i)
for i in range(len(self.etypes))}
meta = str(self.metagraph.edges(keys=True))
meta = str(self.metagraph().edges(keys=True))
return ret.format(node=nnode_dict, edge=nedge_dict, meta=meta)
def __copy__(self):
......@@ -345,20 +327,7 @@ class DGLHeteroGraph(object):
#TODO(minjie): too many states in python; should clean up and lower to C
cls = type(self)
obj = cls.__new__(cls)
obj._graph = self._graph
obj._batch_num_nodes = self._batch_num_nodes
obj._batch_num_edges = self._batch_num_edges
obj._ntypes = self._ntypes
obj._etypes = self._etypes
obj._canonical_etypes = self._canonical_etypes
obj._srctypes_invmap = self._srctypes_invmap
obj._dsttypes_invmap = self._dsttypes_invmap
obj._is_unibipartite = self._is_unibipartite
obj._etype2canonical = self._etype2canonical
obj._etypes_invmap = self._etypes_invmap
obj._nx_metagraph = self._nx_metagraph
obj._node_frames = self._node_frames
obj._edge_frames = self._edge_frames
obj.__dict__.update(self.__dict__)
return obj
#################################################################
......@@ -975,7 +944,6 @@ class DGLHeteroGraph(object):
else:
return self.ntypes
@property
def metagraph(self):
"""Return the metagraph as networkx.MultiDiGraph.
......@@ -992,7 +960,7 @@ class DGLHeteroGraph(object):
>>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g])
>>> meta_g = g.metagraph
>>> meta_g = g.metagraph()
The metagraph then has two nodes and two edges.
......@@ -1005,13 +973,12 @@ class DGLHeteroGraph(object):
>>> meta_g.number_of_edges()
2
"""
if self._nx_metagraph is None:
nx_graph = self._graph.metagraph.to_networkx()
self._nx_metagraph = nx.MultiDiGraph()
for u_v in nx_graph.edges:
srctype, etype, dsttype = self.canonical_etypes[nx_graph.edges[u_v]['id']]
self._nx_metagraph.add_edge(srctype, dsttype, etype)
return self._nx_metagraph
nx_graph = self._graph.metagraph.to_networkx()
nx_metagraph = nx.MultiDiGraph()
for u_v in nx_graph.edges:
srctype, etype, dsttype = self.canonical_etypes[nx_graph.edges[u_v]['id']]
nx_metagraph.add_edge(srctype, dsttype, etype)
return nx_metagraph
def to_canonical_etype(self, etype):
"""Convert edge type to canonical etype: (srctype, etype, dsttype).
......@@ -5282,7 +5249,7 @@ class DGLBlock(DGLHeteroGraph):
for ntype in self.dsttypes}
nedge_dict = {etype : self.number_of_edges(etype)
for etype in self.canonical_etypes}
meta = str(self.metagraph.edges(keys=True))
meta = str(self.metagraph().edges(keys=True))
return ret.format(
srcnode=nsrcnode_dict, dstnode=ndstnode_dict, edge=nedge_dict, meta=meta)
......
......@@ -211,7 +211,7 @@ class PinSAGESampler(RandomWalkNeighborSampler):
"""
def __init__(self, G, ntype, other_type, random_walk_length, random_walk_restart_prob,
num_random_walks, num_neighbors, weight_column='weights'):
metagraph = G.metagraph
metagraph = G.metagraph()
fw_etype = list(metagraph[ntype][other_type])[0]
bw_etype = list(metagraph[other_type][ntype])[0]
super().__init__(G, random_walk_length,
......
......@@ -33,7 +33,9 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
const int32_t* indices_ptr = static_cast<int32_t*>(indices->data);
const void* data_ptr = data->data;
NDArray t_indptr = aten::NewIdArray(csr.num_cols + 1, ctx, bits);
// (BarclayII) csr2csc doesn't seem to clear the content of cscColPtr if nnz == 0.
// We need to do it ourselves.
NDArray t_indptr = aten::Full(0, csr.num_cols + 1, bits, ctx);
NDArray t_indices = aten::NewIdArray(nnz, ctx, bits);
NDArray t_data = aten::NewIdArray(nnz, ctx, bits);
int32_t* t_indptr_ptr = static_cast<int32_t*>(t_indptr->data);
......
......@@ -224,7 +224,7 @@ def test_query(idtype):
assert set(canonical_etypes) == set(g.canonical_etypes)
# metagraph
mg = g.metagraph
mg = g.metagraph()
assert set(g.ntypes) == set(mg.nodes)
etype_triplets = [(u, v, e) for u, v, e in mg.edges(keys=True)]
assert set([
......
......@@ -34,8 +34,8 @@ def _assert_is_identical_hetero(g, g2):
assert g.canonical_etypes == g2.canonical_etypes
# check if two metagraphs are identical
for edges, features in g.metagraph.edges(keys=True).items():
assert g2.metagraph.edges(keys=True)[edges] == features
for edges, features in g.metagraph().edges(keys=True).items():
assert g2.metagraph().edges(keys=True)[edges] == features
# check if node ID spaces and feature spaces are equal
for ntype in g.ntypes:
......
......@@ -35,8 +35,8 @@ def _assert_is_identical_hetero(g, g2):
assert g.canonical_etypes == g2.canonical_etypes
# check if two metagraphs are identical
for edges, features in g.metagraph.edges(keys=True).items():
assert g2.metagraph.edges(keys=True)[edges] == features
for edges, features in g.metagraph().edges(keys=True).items():
assert g2.metagraph().edges(keys=True)[edges] == features
# check if node ID spaces and feature spaces are equal
for ntype in g.ntypes:
......@@ -89,4 +89,4 @@ def test_copy_from_gpu():
if __name__ == "__main__":
test_single_process(F.int64)
test_multi_process(F.int32)
test_copy_from_gpu()
\ No newline at end of file
test_copy_from_gpu()
......@@ -17,8 +17,8 @@ def check_graph_equal(g1, g2, *,
assert g1.batch_size == g2.batch_size
# check if two metagraphs are identical
for edges, features in g1.metagraph.edges(keys=True).items():
assert g2.metagraph.edges(keys=True)[edges] == features
for edges, features in g1.metagraph().edges(keys=True).items():
assert g2.metagraph().edges(keys=True)[edges] == features
for nty in g1.ntypes:
assert g1.number_of_nodes(nty) == g2.number_of_nodes(nty)
......
......@@ -234,7 +234,7 @@ def plot_graph(nxg):
ag.layout('dot')
ag.draw('graph.png')
plot_graph(G.metagraph)
plot_graph(G.metagraph())
###############################################################################
# Learning tasks associated with heterographs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment