"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "05b0f1ea2f9dc012dbc19deabca7fa653db9a1ac"
Unverified Commit cbbbbde7 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] fix multiple bugs in JTNN example (#2220)

* [Bug] fix multiple bugs in JTNN example

* remove debug code
parent d628f5a2
...@@ -195,6 +195,7 @@ class DGLJTMPN(nn.Module): ...@@ -195,6 +195,7 @@ class DGLJTMPN(nn.Module):
cand_graphs.apply_edges( cand_graphs.apply_edges(
func=lambda edges: {'src_x': edges.src['x']}, func=lambda edges: {'src_x': edges.src['x']},
) )
cand_line_graph.ndata.update(cand_graphs.edata)
bond_features = cand_line_graph.ndata['x'] bond_features = cand_line_graph.ndata['x']
source_features = cand_line_graph.ndata['src_x'] source_features = cand_line_graph.ndata['src_x']
......
...@@ -143,7 +143,7 @@ class DGLJTNNDecoder(nn.Module): ...@@ -143,7 +143,7 @@ class DGLJTNNDecoder(nn.Module):
# Predict root # Predict root
mol_tree_batch.pull(root_ids, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h')) mol_tree_batch.pull(root_ids, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
mol_tree_batch.apply_nodes(dec_tree_node_update) mol_tree_batch.apply_nodes(dec_tree_node_update, v=root_ids)
# Extract hidden states and store them for stop/label prediction # Extract hidden states and store them for stop/label prediction
h = mol_tree_batch.nodes[root_ids].data['h'] h = mol_tree_batch.nodes[root_ids].data['h']
x = mol_tree_batch.nodes[root_ids].data['x'] x = mol_tree_batch.nodes[root_ids].data['x']
...@@ -170,12 +170,12 @@ class DGLJTNNDecoder(nn.Module): ...@@ -170,12 +170,12 @@ class DGLJTNNDecoder(nn.Module):
mol_tree_batch_lg.ndata.update(mol_tree_batch.edata) mol_tree_batch_lg.ndata.update(mol_tree_batch.edata)
mol_tree_batch_lg.pull(eid, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's')) mol_tree_batch_lg.pull(eid, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's'))
mol_tree_batch_lg.pull(eid, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm')) mol_tree_batch_lg.pull(eid, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm'))
mol_tree_batch_lg.apply_nodes(self.dec_tree_edge_update) mol_tree_batch_lg.apply_nodes(self.dec_tree_edge_update, v=eid)
mol_tree_batch.edata.update(mol_tree_batch_lg.ndata) mol_tree_batch.edata.update(mol_tree_batch_lg.ndata)
is_new = mol_tree_batch.nodes[v].data['new'] is_new = mol_tree_batch.nodes[v].data['new']
mol_tree_batch.pull(v, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h')) mol_tree_batch.pull(v, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
mol_tree_batch.apply_nodes(dec_tree_node_update) mol_tree_batch.apply_nodes(dec_tree_node_update, v=v)
# Extract # Extract
n_repr = mol_tree_batch.nodes[v].data n_repr = mol_tree_batch.nodes[v].data
...@@ -262,7 +262,6 @@ class DGLJTNNDecoder(nn.Module): ...@@ -262,7 +262,6 @@ class DGLJTNNDecoder(nn.Module):
# Predict stop # Predict stop
p_input = torch.cat([x, h, mol_vec], 1) p_input = torch.cat([x, h, mol_vec], 1)
p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input)))) p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input))))
p_score[:] = 0
backtrack = (p_score.item() < 0.5) backtrack = (p_score.item() < 0.5)
if not backtrack: if not backtrack:
...@@ -302,7 +301,7 @@ class DGLJTNNDecoder(nn.Module): ...@@ -302,7 +301,7 @@ class DGLJTNNDecoder(nn.Module):
uv, uv,
DGLF.copy_u('rm', 'rm'), DGLF.copy_u('rm', 'rm'),
DGLF.sum('rm', 'accum_rm')) DGLF.sum('rm', 'accum_rm'))
mol_tree_graph_lg.apply_nodes(self.dec_tree_edge_update.update_zm) mol_tree_graph_lg.apply_nodes(self.dec_tree_edge_update.update_zm, v=uv)
mol_tree_graph.edata.update(mol_tree_graph_lg.ndata) mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
mol_tree_graph.pull(v, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h')) mol_tree_graph.pull(v, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
...@@ -358,7 +357,7 @@ class DGLJTNNDecoder(nn.Module): ...@@ -358,7 +357,7 @@ class DGLJTNNDecoder(nn.Module):
mol_tree_graph_lg.pull(u_pu, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's')) mol_tree_graph_lg.pull(u_pu, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's'))
mol_tree_graph_lg.pull(u_pu, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm')) mol_tree_graph_lg.pull(u_pu, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm'))
mol_tree_graph_lg.apply_nodes(self.dec_tree_edge_update) mol_tree_graph_lg.apply_nodes(self.dec_tree_edge_update, v=u_pu)
mol_tree_graph.edata.update(mol_tree_graph_lg.ndata) mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
mol_tree_graph.pull(pu, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h')) mol_tree_graph.pull(pu, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
stack.pop() stack.pop()
......
...@@ -99,8 +99,8 @@ class DGLJTNNEncoder(nn.Module): ...@@ -99,8 +99,8 @@ class DGLJTNNEncoder(nn.Module):
for eid in level_order(mol_tree_batch, root_ids): for eid in level_order(mol_tree_batch, root_ids):
eid = eid.to(mol_tree_batch_lg.device) eid = eid.to(mol_tree_batch_lg.device)
mol_tree_batch_lg.pull(eid, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's')) mol_tree_batch_lg.pull(eid, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's'))
mol_tree_batch_lg.pull(eid, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'rm')) mol_tree_batch_lg.pull(eid, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm'))
mol_tree_batch_lg.apply_nodes(self.enc_tree_update) mol_tree_batch_lg.apply_nodes(self.enc_tree_update, v=eid)
# Readout # Readout
mol_tree_batch.edata.update(mol_tree_batch_lg.ndata) mol_tree_batch.edata.update(mol_tree_batch_lg.ndata)
......
...@@ -136,6 +136,7 @@ class DGLMPN(nn.Module): ...@@ -136,6 +136,7 @@ class DGLMPN(nn.Module):
mol_graph.apply_edges( mol_graph.apply_edges(
func=lambda edges: {'src_x': edges.src['x']}, func=lambda edges: {'src_x': edges.src['x']},
) )
mol_line_graph.ndata.update(mol_graph.edata)
e_repr = mol_line_graph.ndata e_repr = mol_line_graph.ndata
bond_features = e_repr['x'] bond_features = e_repr['x']
......
...@@ -70,7 +70,7 @@ def train(): ...@@ -70,7 +70,7 @@ def train():
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
num_workers=0, num_workers=4,
collate_fn=JTNNCollator(vocab, True), collate_fn=JTNNCollator(vocab, True),
drop_last=True, drop_last=True,
worker_init_fn=worker_init_fn) worker_init_fn=worker_init_fn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment