Unverified Commit cbbbbde7 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] fix multiple bugs in JTNN example (#2220)

* [Bug] fix multiple bugs in JTNN example

* remove debug code
parent d628f5a2
......@@ -195,6 +195,7 @@ class DGLJTMPN(nn.Module):
cand_graphs.apply_edges(
func=lambda edges: {'src_x': edges.src['x']},
)
cand_line_graph.ndata.update(cand_graphs.edata)
bond_features = cand_line_graph.ndata['x']
source_features = cand_line_graph.ndata['src_x']
......
......@@ -143,7 +143,7 @@ class DGLJTNNDecoder(nn.Module):
# Predict root
mol_tree_batch.pull(root_ids, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
mol_tree_batch.apply_nodes(dec_tree_node_update)
mol_tree_batch.apply_nodes(dec_tree_node_update, v=root_ids)
# Extract hidden states and store them for stop/label prediction
h = mol_tree_batch.nodes[root_ids].data['h']
x = mol_tree_batch.nodes[root_ids].data['x']
......@@ -170,12 +170,12 @@ class DGLJTNNDecoder(nn.Module):
mol_tree_batch_lg.ndata.update(mol_tree_batch.edata)
mol_tree_batch_lg.pull(eid, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's'))
mol_tree_batch_lg.pull(eid, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm'))
mol_tree_batch_lg.apply_nodes(self.dec_tree_edge_update)
mol_tree_batch_lg.apply_nodes(self.dec_tree_edge_update, v=eid)
mol_tree_batch.edata.update(mol_tree_batch_lg.ndata)
is_new = mol_tree_batch.nodes[v].data['new']
mol_tree_batch.pull(v, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
mol_tree_batch.apply_nodes(dec_tree_node_update)
mol_tree_batch.apply_nodes(dec_tree_node_update, v=v)
# Extract
n_repr = mol_tree_batch.nodes[v].data
......@@ -262,7 +262,6 @@ class DGLJTNNDecoder(nn.Module):
# Predict stop
p_input = torch.cat([x, h, mol_vec], 1)
p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input))))
p_score[:] = 0
backtrack = (p_score.item() < 0.5)
if not backtrack:
......@@ -302,7 +301,7 @@ class DGLJTNNDecoder(nn.Module):
uv,
DGLF.copy_u('rm', 'rm'),
DGLF.sum('rm', 'accum_rm'))
mol_tree_graph_lg.apply_nodes(self.dec_tree_edge_update.update_zm)
mol_tree_graph_lg.apply_nodes(self.dec_tree_edge_update.update_zm, v=uv)
mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
mol_tree_graph.pull(v, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
......@@ -358,7 +357,7 @@ class DGLJTNNDecoder(nn.Module):
mol_tree_graph_lg.pull(u_pu, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's'))
mol_tree_graph_lg.pull(u_pu, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm'))
mol_tree_graph_lg.apply_nodes(self.dec_tree_edge_update)
mol_tree_graph_lg.apply_nodes(self.dec_tree_edge_update, v=u_pu)
mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
mol_tree_graph.pull(pu, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
stack.pop()
......
......@@ -99,8 +99,8 @@ class DGLJTNNEncoder(nn.Module):
for eid in level_order(mol_tree_batch, root_ids):
eid = eid.to(mol_tree_batch_lg.device)
mol_tree_batch_lg.pull(eid, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's'))
mol_tree_batch_lg.pull(eid, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'rm'))
mol_tree_batch_lg.apply_nodes(self.enc_tree_update)
mol_tree_batch_lg.pull(eid, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm'))
mol_tree_batch_lg.apply_nodes(self.enc_tree_update, v=eid)
# Readout
mol_tree_batch.edata.update(mol_tree_batch_lg.ndata)
......
......@@ -136,6 +136,7 @@ class DGLMPN(nn.Module):
mol_graph.apply_edges(
func=lambda edges: {'src_x': edges.src['x']},
)
mol_line_graph.ndata.update(mol_graph.edata)
e_repr = mol_line_graph.ndata
bond_features = e_repr['x']
......
......@@ -70,7 +70,7 @@ def train():
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
num_workers=4,
collate_fn=JTNNCollator(vocab, True),
drop_last=True,
worker_init_fn=worker_init_fn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment