Commit 12d70630 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by Minjie Wang
Browse files

[Hotfix] fixing zero shaped tensor problems for PyTorch 1.0.0 in JTNN example (#371)

parent dedfd908
...@@ -9,6 +9,8 @@ from .mol_tree import Vocab ...@@ -9,6 +9,8 @@ from .mol_tree import Vocab
from .mpn import mol2dgl_single as mol2dgl_enc from .mpn import mol2dgl_single as mol2dgl_enc
from .jtmpn import mol2dgl_single as mol2dgl_dec from .jtmpn import mol2dgl_single as mol2dgl_dec
from .jtmpn import ATOM_FDIM as ATOM_FDIM_DEC
from .jtmpn import BOND_FDIM as BOND_FDIM_DEC
_url = 'https://www.dropbox.com/s/4ypr0e0abcbsvoh/jtnn.zip?dl=1' _url = 'https://www.dropbox.com/s/4ypr0e0abcbsvoh/jtnn.zip?dl=1'
...@@ -82,11 +84,11 @@ class JTNNDataset(Dataset): ...@@ -82,11 +84,11 @@ class JTNNDataset(Dataset):
tree_mess_tgt_e, tree_mess_tgt_n = mol2dgl_dec(cands) tree_mess_tgt_e, tree_mess_tgt_n = mol2dgl_dec(cands)
else: else:
cand_graphs = [] cand_graphs = []
atom_x_dec = torch.zeros(0, atom_x_enc.shape[1]) atom_x_dec = torch.zeros(0, ATOM_FDIM_DEC)
bond_x_dec = torch.zeros(0, bond_x_enc.shape[1]) bond_x_dec = torch.zeros(0, BOND_FDIM_DEC)
tree_mess_src_e = torch.zeros(0, 2).long() tree_mess_src_e = torch.zeros(0, 2).long()
tree_mess_tgt_e = torch.zeros(0, 2).long() tree_mess_tgt_e = torch.zeros(0, 2).long()
tree_mess_tgt_n = torch.zeros(0, 2).long() tree_mess_tgt_n = torch.zeros(0).long()
# prebuild the stereoisomers # prebuild the stereoisomers
cands = mol_tree.stereo_cands cands = mol_tree.stereo_cands
......
...@@ -199,8 +199,12 @@ class DGLJTNNDecoder(nn.Module): ...@@ -199,8 +199,12 @@ class DGLJTNNDecoder(nn.Module):
p_inputs.append(torch.cat([x, h, tree_vec_set], 1)) p_inputs.append(torch.cat([x, h, tree_vec_set], 1))
# Only newly generated nodes are needed for label prediction # Only newly generated nodes are needed for label prediction
# NOTE: The following works since the uncomputed messages are zeros. # NOTE: The following works since the uncomputed messages are zeros.
q_inputs.append(torch.cat([h, tree_vec_set], 1)[is_new])
q_targets.append(wid[is_new]) q_input = torch.cat([h, tree_vec_set], 1)[is_new]
q_target = wid[is_new]
if q_input.shape[0] > 0:
q_inputs.append(q_input)
q_targets.append(q_target)
p_targets.append(torch.zeros((root_out_degrees == 0).sum()).long()) p_targets.append(torch.zeros((root_out_degrees == 0).sum()).long())
# Batch compute the stop/label prediction losses # Batch compute the stop/label prediction losses
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment