"tests/git@developer.sourcefind.cn:OpenDAS/torchani.git" did not exist on "ba3036d196a1c700bf1ee91537b90d0b66c3e724"
Unverified Commit 704bcaf6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 6bc82161
...@@ -7,12 +7,12 @@ import numpy as np ...@@ -7,12 +7,12 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from dgl.data import CiteseerGraphDataset, CoraGraphDataset
from model import JKNet from model import JKNet
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from tqdm import trange from tqdm import trange
from dgl.data import CiteseerGraphDataset, CoraGraphDataset
def main(args): def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= # # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
......
import dgl.function as fn
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as fn
from dgl.nn import GraphConv, JumpingKnowledge from dgl.nn import GraphConv, JumpingKnowledge
......
import torch
from torch.utils.data import Dataset
import dgl import dgl
import torch
from dgl.data.utils import ( from dgl.data.utils import (
_get_dgl_url, _get_dgl_url,
download, download,
extract_archive, extract_archive,
get_download_dir, get_download_dir,
) )
from torch.utils.data import Dataset
from .jtmpn import ATOM_FDIM as ATOM_FDIM_DEC from .jtmpn import (
from .jtmpn import BOND_FDIM as BOND_FDIM_DEC ATOM_FDIM as ATOM_FDIM_DEC,
from .jtmpn import mol2dgl_single as mol2dgl_dec BOND_FDIM as BOND_FDIM_DEC,
mol2dgl_single as mol2dgl_dec,
)
from .mol_tree import Vocab from .mol_tree import Vocab
from .mol_tree_nx import DGLMolTree from .mol_tree_nx import DGLMolTree
from .mpn import mol2dgl_single as mol2dgl_enc from .mpn import mol2dgl_single as mol2dgl_enc
......
import os
import dgl
import dgl.function as DGLF
import rdkit.Chem as Chem
import torch import torch
import torch.nn as nn import torch.nn as nn
from dgl import line_graph, mean_nodes
from .nnutils import cuda from .nnutils import cuda
import rdkit.Chem as Chem
import dgl
from dgl import mean_nodes, line_graph
import dgl.function as DGLF
import os
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', ELEM_LIST = [
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown'] "C",
"N",
"O",
"S",
"F",
"Si",
"P",
"Cl",
"Br",
"Mg",
"Na",
"Ca",
"Fe",
"Al",
"I",
"B",
"K",
"Se",
"Zn",
"H",
"Cu",
"Mn",
"unknown",
]
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 1 ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 1
BOND_FDIM = 5 BOND_FDIM = 5
MAX_NB = 10 MAX_NB = 10
PAPER = os.getenv('PAPER', False) PAPER = os.getenv("PAPER", False)
def onek_encoding_unk(x, allowable_set): def onek_encoding_unk(x, allowable_set):
if x not in allowable_set: if x not in allowable_set:
x = allowable_set[-1] x = allowable_set[-1]
return [x == s for s in allowable_set] return [x == s for s in allowable_set]
# Note that during graph decoding they don't predict stereochemistry-related # Note that during graph decoding they don't predict stereochemistry-related
# characteristics (i.e. Chiral Atoms, E-Z, Cis-Trans). Instead, they decode # characteristics (i.e. Chiral Atoms, E-Z, Cis-Trans). Instead, they decode
# the 2-D graph first, then enumerate all possible 3-D forms and find the # the 2-D graph first, then enumerate all possible 3-D forms and find the
# one with highest score. # one with highest score.
def atom_features(atom): def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST) return torch.Tensor(
+ onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5]) onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0]) + onek_encoding_unk(atom.GetDegree(), [0, 1, 2, 3, 4, 5])
+ [atom.GetIsAromatic()])) + onek_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0])
+ [atom.GetIsAromatic()]
)
def bond_features(bond): def bond_features(bond):
bt = bond.GetBondType() bt = bond.GetBondType()
return (torch.Tensor([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()])) return torch.Tensor(
[
bt == Chem.rdchem.BondType.SINGLE,
bt == Chem.rdchem.BondType.DOUBLE,
bt == Chem.rdchem.BondType.TRIPLE,
bt == Chem.rdchem.BondType.AROMATIC,
bond.IsInRing(),
]
)
def mol2dgl_single(cand_batch): def mol2dgl_single(cand_batch):
cand_graphs = [] cand_graphs = []
tree_mess_source_edges = [] # map these edges from trees to... tree_mess_source_edges = [] # map these edges from trees to...
tree_mess_target_edges = [] # these edges on candidate graphs tree_mess_target_edges = [] # these edges on candidate graphs
tree_mess_target_nodes = [] tree_mess_target_nodes = []
n_nodes = 0 n_nodes = 0
n_edges = 0 n_edges = 0
...@@ -50,8 +89,8 @@ def mol2dgl_single(cand_batch): ...@@ -50,8 +89,8 @@ def mol2dgl_single(cand_batch):
n_bonds = mol.GetNumBonds() n_bonds = mol.GetNumBonds()
ctr_node = mol_tree.nodes_dict[ctr_node_id] ctr_node = mol_tree.nodes_dict[ctr_node_id]
ctr_bid = ctr_node['idx'] ctr_bid = ctr_node["idx"]
mol_tree_graph = getattr(mol_tree, 'graph', mol_tree) mol_tree_graph = getattr(mol_tree, "graph", mol_tree)
for i, atom in enumerate(mol.GetAtoms()): for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx() assert i == atom.GetIdx()
...@@ -75,15 +114,19 @@ def mol2dgl_single(cand_batch): ...@@ -75,15 +114,19 @@ def mol2dgl_single(cand_batch):
x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum() x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum()
# Tree node ID in the batch # Tree node ID in the batch
x_bid = mol_tree.nodes_dict[x_nid - 1]['idx'] if x_nid > 0 else -1 x_bid = mol_tree.nodes_dict[x_nid - 1]["idx"] if x_nid > 0 else -1
y_bid = mol_tree.nodes_dict[y_nid - 1]['idx'] if y_nid > 0 else -1 y_bid = mol_tree.nodes_dict[y_nid - 1]["idx"] if y_nid > 0 else -1
if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid: if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
if mol_tree_graph.has_edges_between(x_bid, y_bid): if mol_tree_graph.has_edges_between(x_bid, y_bid):
tree_mess_target_edges.append((begin_idx + n_nodes, end_idx + n_nodes)) tree_mess_target_edges.append(
(begin_idx + n_nodes, end_idx + n_nodes)
)
tree_mess_source_edges.append((x_bid, y_bid)) tree_mess_source_edges.append((x_bid, y_bid))
tree_mess_target_nodes.append(end_idx + n_nodes) tree_mess_target_nodes.append(end_idx + n_nodes)
if mol_tree_graph.has_edges_between(y_bid, x_bid): if mol_tree_graph.has_edges_between(y_bid, x_bid):
tree_mess_target_edges.append((end_idx + n_nodes, begin_idx + n_nodes)) tree_mess_target_edges.append(
(end_idx + n_nodes, begin_idx + n_nodes)
)
tree_mess_source_edges.append((y_bid, x_bid)) tree_mess_source_edges.append((y_bid, x_bid))
tree_mess_target_nodes.append(begin_idx + n_nodes) tree_mess_target_nodes.append(begin_idx + n_nodes)
...@@ -91,11 +134,14 @@ def mol2dgl_single(cand_batch): ...@@ -91,11 +134,14 @@ def mol2dgl_single(cand_batch):
g = dgl.graph((bond_src, bond_dst), num_nodes=n_atoms) g = dgl.graph((bond_src, bond_dst), num_nodes=n_atoms)
cand_graphs.append(g) cand_graphs.append(g)
return cand_graphs, torch.stack(atom_x), \ return (
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0), \ cand_graphs,
torch.LongTensor(tree_mess_source_edges), \ torch.stack(atom_x),
torch.LongTensor(tree_mess_target_edges), \ torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0),
torch.LongTensor(tree_mess_target_nodes) torch.LongTensor(tree_mess_source_edges),
torch.LongTensor(tree_mess_target_edges),
torch.LongTensor(tree_mess_target_nodes),
)
class LoopyBPUpdate(nn.Module): class LoopyBPUpdate(nn.Module):
...@@ -106,28 +152,28 @@ class LoopyBPUpdate(nn.Module): ...@@ -106,28 +152,28 @@ class LoopyBPUpdate(nn.Module):
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False) self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, node): def forward(self, node):
msg_input = node.data['msg_input'] msg_input = node.data["msg_input"]
msg_delta = self.W_h(node.data['accum_msg'] + node.data['alpha']) msg_delta = self.W_h(node.data["accum_msg"] + node.data["alpha"])
msg = torch.relu(msg_input + msg_delta) msg = torch.relu(msg_input + msg_delta)
return {'msg': msg} return {"msg": msg}
if PAPER: if PAPER:
mpn_gather_msg = [ mpn_gather_msg = [
DGLF.copy_e(edge='msg', out='msg'), DGLF.copy_e(edge="msg", out="msg"),
DGLF.copy_e(edge='alpha', out='alpha') DGLF.copy_e(edge="alpha", out="alpha"),
] ]
else: else:
mpn_gather_msg = DGLF.copy_e(edge='msg', out='msg') mpn_gather_msg = DGLF.copy_e(edge="msg", out="msg")
if PAPER: if PAPER:
mpn_gather_reduce = [ mpn_gather_reduce = [
DGLF.sum(msg='msg', out='m'), DGLF.sum(msg="msg", out="m"),
DGLF.sum(msg='alpha', out='accum_alpha'), DGLF.sum(msg="alpha", out="accum_alpha"),
] ]
else: else:
mpn_gather_reduce = DGLF.sum(msg='msg', out='m') mpn_gather_reduce = DGLF.sum(msg="msg", out="m")
class GatherUpdate(nn.Module): class GatherUpdate(nn.Module):
...@@ -139,12 +185,12 @@ class GatherUpdate(nn.Module): ...@@ -139,12 +185,12 @@ class GatherUpdate(nn.Module):
def forward(self, node): def forward(self, node):
if PAPER: if PAPER:
#m = node['m'] # m = node['m']
m = node.data['m'] + node.data['accum_alpha'] m = node.data["m"] + node.data["accum_alpha"]
else: else:
m = node.data['m'] + node.data['alpha'] m = node.data["m"] + node.data["alpha"]
return { return {
'h': torch.relu(self.W_o(torch.cat([node.data['x'], m], 1))), "h": torch.relu(self.W_o(torch.cat([node.data["x"], m], 1))),
} }
...@@ -166,20 +212,32 @@ class DGLJTMPN(nn.Module): ...@@ -166,20 +212,32 @@ class DGLJTMPN(nn.Module):
self.n_passes = 0 self.n_passes = 0
def forward(self, cand_batch, mol_tree_batch): def forward(self, cand_batch, mol_tree_batch):
cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = cand_batch (
cand_graphs,
tree_mess_src_edges,
tree_mess_tgt_edges,
tree_mess_tgt_nodes,
) = cand_batch
n_samples = len(cand_graphs) n_samples = len(cand_graphs)
cand_line_graph = line_graph(cand_graphs, backtracking=False, shared=True) cand_line_graph = line_graph(
cand_graphs, backtracking=False, shared=True
)
n_nodes = cand_graphs.number_of_nodes() n_nodes = cand_graphs.number_of_nodes()
n_edges = cand_graphs.number_of_edges() n_edges = cand_graphs.number_of_edges()
cand_graphs = self.run( cand_graphs = self.run(
cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges, cand_graphs,
tree_mess_tgt_nodes, mol_tree_batch) cand_line_graph,
tree_mess_src_edges,
tree_mess_tgt_edges,
tree_mess_tgt_nodes,
mol_tree_batch,
)
g_repr = mean_nodes(cand_graphs, 'h') g_repr = mean_nodes(cand_graphs, "h")
self.n_samples_total += n_samples self.n_samples_total += n_samples
self.n_nodes_total += n_nodes self.n_nodes_total += n_nodes
...@@ -188,33 +246,45 @@ class DGLJTMPN(nn.Module): ...@@ -188,33 +246,45 @@ class DGLJTMPN(nn.Module):
return g_repr return g_repr
def run(self, cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges, def run(
tree_mess_tgt_nodes, mol_tree_batch): self,
cand_graphs,
cand_line_graph,
tree_mess_src_edges,
tree_mess_tgt_edges,
tree_mess_tgt_nodes,
mol_tree_batch,
):
n_nodes = cand_graphs.number_of_nodes() n_nodes = cand_graphs.number_of_nodes()
cand_graphs.apply_edges( cand_graphs.apply_edges(
func=lambda edges: {'src_x': edges.src['x']}, func=lambda edges: {"src_x": edges.src["x"]},
) )
cand_line_graph.ndata.update(cand_graphs.edata) cand_line_graph.ndata.update(cand_graphs.edata)
bond_features = cand_line_graph.ndata['x'] bond_features = cand_line_graph.ndata["x"]
source_features = cand_line_graph.ndata['src_x'] source_features = cand_line_graph.ndata["src_x"]
features = torch.cat([source_features, bond_features], 1) features = torch.cat([source_features, bond_features], 1)
msg_input = self.W_i(features) msg_input = self.W_i(features)
cand_line_graph.ndata.update({ cand_line_graph.ndata.update(
'msg_input': msg_input, {
'msg': torch.relu(msg_input), "msg_input": msg_input,
'accum_msg': torch.zeros_like(msg_input), "msg": torch.relu(msg_input),
}) "accum_msg": torch.zeros_like(msg_input),
}
)
zero_node_state = bond_features.new(n_nodes, self.hidden_size).zero_() zero_node_state = bond_features.new(n_nodes, self.hidden_size).zero_()
cand_graphs.ndata.update({ cand_graphs.ndata.update(
'm': zero_node_state.clone(), {
'h': zero_node_state.clone(), "m": zero_node_state.clone(),
}) "h": zero_node_state.clone(),
}
cand_graphs.edata['alpha'] = \ )
cuda(torch.zeros(cand_graphs.number_of_edges(), self.hidden_size))
cand_graphs.ndata['alpha'] = zero_node_state cand_graphs.edata["alpha"] = cuda(
torch.zeros(cand_graphs.number_of_edges(), self.hidden_size)
)
cand_graphs.ndata["alpha"] = zero_node_state
if tree_mess_src_edges.shape[0] > 0: if tree_mess_src_edges.shape[0] > 0:
if PAPER: if PAPER:
src_u, src_v = tree_mess_src_edges.unbind(1) src_u, src_v = tree_mess_src_edges.unbind(1)
...@@ -222,33 +292,39 @@ class DGLJTMPN(nn.Module): ...@@ -222,33 +292,39 @@ class DGLJTMPN(nn.Module):
src_u = src_u.to(mol_tree_batch.device) src_u = src_u.to(mol_tree_batch.device)
src_v = src_v.to(mol_tree_batch.device) src_v = src_v.to(mol_tree_batch.device)
eid = mol_tree_batch.edge_ids(src_u, src_v) eid = mol_tree_batch.edge_ids(src_u, src_v)
alpha = mol_tree_batch.edata['m'][eid] alpha = mol_tree_batch.edata["m"][eid]
cand_graphs.edges[tgt_u, tgt_v].data['alpha'] = alpha cand_graphs.edges[tgt_u, tgt_v].data["alpha"] = alpha
else: else:
src_u, src_v = tree_mess_src_edges.unbind(1) src_u, src_v = tree_mess_src_edges.unbind(1)
src_u = src_u.to(mol_tree_batch.device) src_u = src_u.to(mol_tree_batch.device)
src_v = src_v.to(mol_tree_batch.device) src_v = src_v.to(mol_tree_batch.device)
eid = mol_tree_batch.edge_ids(src_u, src_v) eid = mol_tree_batch.edge_ids(src_u, src_v)
alpha = mol_tree_batch.edata['m'][eid] alpha = mol_tree_batch.edata["m"][eid]
node_idx = (tree_mess_tgt_nodes node_idx = tree_mess_tgt_nodes.to(
.to(device=zero_node_state.device)[:, None] device=zero_node_state.device
.expand_as(alpha)) )[:, None].expand_as(alpha)
node_alpha = zero_node_state.clone().scatter_add(0, node_idx, alpha) node_alpha = zero_node_state.clone().scatter_add(
cand_graphs.ndata['alpha'] = node_alpha 0, node_idx, alpha
)
cand_graphs.ndata["alpha"] = node_alpha
cand_graphs.apply_edges( cand_graphs.apply_edges(
func=lambda edges: {'alpha': edges.src['alpha']}, func=lambda edges: {"alpha": edges.src["alpha"]},
) )
cand_line_graph.ndata.update(cand_graphs.edata) cand_line_graph.ndata.update(cand_graphs.edata)
for i in range(self.depth - 1): for i in range(self.depth - 1):
cand_line_graph.update_all(DGLF.copy_u('msg', 'msg'), DGLF.sum('msg', 'accum_msg')) cand_line_graph.update_all(
DGLF.copy_u("msg", "msg"), DGLF.sum("msg", "accum_msg")
)
cand_line_graph.apply_nodes(self.loopy_bp_updater) cand_line_graph.apply_nodes(self.loopy_bp_updater)
cand_graphs.edata.update(cand_line_graph.ndata) cand_graphs.edata.update(cand_line_graph.ndata)
cand_graphs.update_all(DGLF.copy_e('msg', 'msg'), DGLF.sum('msg', 'm')) cand_graphs.update_all(DGLF.copy_e("msg", "msg"), DGLF.sum("msg", "m"))
if PAPER: if PAPER:
cand_graphs.update_all(DGLF.copy_e('alpha', 'alpha'), DGLF.sum('alpha', 'accum_alpha')) cand_graphs.update_all(
DGLF.copy_e("alpha", "alpha"), DGLF.sum("alpha", "accum_alpha")
)
cand_graphs.apply_nodes(self.gather_updater) cand_graphs.apply_nodes(self.gather_updater)
return cand_graphs return cand_graphs
import dgl.function as DGLF
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .mol_tree_nx import DGLMolTree
from .chemutils import enum_assemble_nx, get_mol
from .nnutils import GRUUpdate, cuda, tocpu
from dgl import batch, dfs_labeled_edges_generator, line_graph from dgl import batch, dfs_labeled_edges_generator, line_graph
import dgl.function as DGLF
import numpy as np from .chemutils import enum_assemble_nx, get_mol
from .mol_tree_nx import DGLMolTree
from .nnutils import cuda, GRUUpdate, tocpu
MAX_NB = 8 MAX_NB = 8
MAX_DECODE_LEN = 100 MAX_DECODE_LEN = 100
...@@ -21,60 +22,70 @@ def dfs_order(forest, roots): ...@@ -21,60 +22,70 @@ def dfs_order(forest, roots):
# using find_edges(). # using find_edges().
yield e ^ l, l yield e ^ l, l
dec_tree_node_msg = DGLF.copy_e(edge='m', out='m')
dec_tree_node_reduce = DGLF.sum(msg='m', out='h') dec_tree_node_msg = DGLF.copy_e(edge="m", out="m")
dec_tree_node_reduce = DGLF.sum(msg="m", out="h")
def dec_tree_node_update(nodes): def dec_tree_node_update(nodes):
return {'new': nodes.data['new'].clone().zero_()} return {"new": nodes.data["new"].clone().zero_()}
def have_slots(fa_slots, ch_slots): def have_slots(fa_slots, ch_slots):
if len(fa_slots) > 2 and len(ch_slots) > 2: if len(fa_slots) > 2 and len(ch_slots) > 2:
return True return True
matches = [] matches = []
for i,s1 in enumerate(fa_slots): for i, s1 in enumerate(fa_slots):
a1,c1,h1 = s1 a1, c1, h1 = s1
for j,s2 in enumerate(ch_slots): for j, s2 in enumerate(ch_slots):
a2,c2,h2 = s2 a2, c2, h2 = s2
if a1 == a2 and c1 == c2 and (a1 != "C" or h1 + h2 >= 4): if a1 == a2 and c1 == c2 and (a1 != "C" or h1 + h2 >= 4):
matches.append( (i,j) ) matches.append((i, j))
if len(matches) == 0: return False if len(matches) == 0:
return False
fa_match,ch_match = list(zip(*matches)) fa_match, ch_match = list(zip(*matches))
if len(set(fa_match)) == 1 and 1 < len(fa_slots) <= 2: #never remove atom from ring if (
len(set(fa_match)) == 1 and 1 < len(fa_slots) <= 2
): # never remove atom from ring
fa_slots.pop(fa_match[0]) fa_slots.pop(fa_match[0])
if len(set(ch_match)) == 1 and 1 < len(ch_slots) <= 2: #never remove atom from ring if (
len(set(ch_match)) == 1 and 1 < len(ch_slots) <= 2
): # never remove atom from ring
ch_slots.pop(ch_match[0]) ch_slots.pop(ch_match[0])
return True return True
def can_assemble(mol_tree, u, v_node_dict): def can_assemble(mol_tree, u, v_node_dict):
u_node_dict = mol_tree.nodes_dict[u] u_node_dict = mol_tree.nodes_dict[u]
u_neighbors = mol_tree.graph.successors(u) u_neighbors = mol_tree.graph.successors(u)
u_neighbors_node_dict = [ u_neighbors_node_dict = [
mol_tree.nodes_dict[_u] mol_tree.nodes_dict[_u]
for _u in u_neighbors for _u in u_neighbors
if _u in mol_tree.nodes_dict if _u in mol_tree.nodes_dict
] ]
neis = u_neighbors_node_dict + [v_node_dict] neis = u_neighbors_node_dict + [v_node_dict]
for i,nei in enumerate(neis): for i, nei in enumerate(neis):
nei['nid'] = i nei["nid"] = i
neighbors = [nei for nei in neis if nei['mol'].GetNumAtoms() > 1] neighbors = [nei for nei in neis if nei["mol"].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x:x['mol'].GetNumAtoms(), reverse=True) neighbors = sorted(
singletons = [nei for nei in neis if nei['mol'].GetNumAtoms() == 1] neighbors, key=lambda x: x["mol"].GetNumAtoms(), reverse=True
)
singletons = [nei for nei in neis if nei["mol"].GetNumAtoms() == 1]
neighbors = singletons + neighbors neighbors = singletons + neighbors
cands = enum_assemble_nx(u_node_dict, neighbors) cands = enum_assemble_nx(u_node_dict, neighbors)
return len(cands) > 0 return len(cands) > 0
def create_node_dict(smiles, clique=[]): def create_node_dict(smiles, clique=[]):
return dict( return dict(
smiles=smiles, smiles=smiles,
mol=get_mol(smiles), mol=get_mol(smiles),
clique=clique, clique=clique,
) )
class DGLJTNNDecoder(nn.Module): class DGLJTNNDecoder(nn.Module):
...@@ -98,41 +109,54 @@ class DGLJTNNDecoder(nn.Module): ...@@ -98,41 +109,54 @@ class DGLJTNNDecoder(nn.Module):
self.U_s = nn.Linear(hidden_size, 1) self.U_s = nn.Linear(hidden_size, 1)
def forward(self, mol_trees, tree_vec): def forward(self, mol_trees, tree_vec):
''' """
The training procedure which computes the prediction loss given the The training procedure which computes the prediction loss given the
ground truth tree ground truth tree
''' """
mol_tree_batch = batch(mol_trees) mol_tree_batch = batch(mol_trees)
mol_tree_batch_lg = line_graph(mol_tree_batch, backtracking=False, shared=True) mol_tree_batch_lg = line_graph(
mol_tree_batch, backtracking=False, shared=True
)
n_trees = len(mol_trees) n_trees = len(mol_trees)
return self.run(mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec) return self.run(mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec)
def run(self, mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec): def run(self, mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec):
node_offset = np.cumsum(np.insert(mol_tree_batch.batch_num_nodes().cpu().numpy(), 0, 0)) node_offset = np.cumsum(
np.insert(mol_tree_batch.batch_num_nodes().cpu().numpy(), 0, 0)
)
root_ids = node_offset[:-1] root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes() n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges() n_edges = mol_tree_batch.number_of_edges()
mol_tree_batch.ndata.update({ mol_tree_batch.ndata.update(
'x': self.embedding(mol_tree_batch.ndata['wid']), {
'h': cuda(torch.zeros(n_nodes, self.hidden_size)), "x": self.embedding(mol_tree_batch.ndata["wid"]),
'new': cuda(torch.ones(n_nodes).bool()), # whether it's newly generated node "h": cuda(torch.zeros(n_nodes, self.hidden_size)),
}) "new": cuda(
torch.ones(n_nodes).bool()
mol_tree_batch.edata.update({ ), # whether it's newly generated node
's': cuda(torch.zeros(n_edges, self.hidden_size)), }
'm': cuda(torch.zeros(n_edges, self.hidden_size)), )
'r': cuda(torch.zeros(n_edges, self.hidden_size)),
'z': cuda(torch.zeros(n_edges, self.hidden_size)), mol_tree_batch.edata.update(
'src_x': cuda(torch.zeros(n_edges, self.hidden_size)), {
'dst_x': cuda(torch.zeros(n_edges, self.hidden_size)), "s": cuda(torch.zeros(n_edges, self.hidden_size)),
'rm': cuda(torch.zeros(n_edges, self.hidden_size)), "m": cuda(torch.zeros(n_edges, self.hidden_size)),
'accum_rm': cuda(torch.zeros(n_edges, self.hidden_size)), "r": cuda(torch.zeros(n_edges, self.hidden_size)),
}) "z": cuda(torch.zeros(n_edges, self.hidden_size)),
"src_x": cuda(torch.zeros(n_edges, self.hidden_size)),
"dst_x": cuda(torch.zeros(n_edges, self.hidden_size)),
"rm": cuda(torch.zeros(n_edges, self.hidden_size)),
"accum_rm": cuda(torch.zeros(n_edges, self.hidden_size)),
}
)
mol_tree_batch.apply_edges( mol_tree_batch.apply_edges(
func=lambda edges: {'src_x': edges.src['x'], 'dst_x': edges.dst['x']}, func=lambda edges: {
"src_x": edges.src["x"],
"dst_x": edges.dst["x"],
},
) )
# input tensors for stop prediction (p) and label prediction (q) # input tensors for stop prediction (p) and label prediction (q)
...@@ -142,16 +166,16 @@ class DGLJTNNDecoder(nn.Module): ...@@ -142,16 +166,16 @@ class DGLJTNNDecoder(nn.Module):
q_targets = [] q_targets = []
# Predict root # Predict root
mol_tree_batch.pull(root_ids, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h')) mol_tree_batch.pull(root_ids, DGLF.copy_e("m", "m"), DGLF.sum("m", "h"))
mol_tree_batch.apply_nodes(dec_tree_node_update, v=root_ids) mol_tree_batch.apply_nodes(dec_tree_node_update, v=root_ids)
# Extract hidden states and store them for stop/label prediction # Extract hidden states and store them for stop/label prediction
h = mol_tree_batch.nodes[root_ids].data['h'] h = mol_tree_batch.nodes[root_ids].data["h"]
x = mol_tree_batch.nodes[root_ids].data['x'] x = mol_tree_batch.nodes[root_ids].data["x"]
p_inputs.append(torch.cat([x, h, tree_vec], 1)) p_inputs.append(torch.cat([x, h, tree_vec], 1))
# If the out degree is 0 we don't generate any edges at all # If the out degree is 0 we don't generate any edges at all
root_out_degrees = mol_tree_batch.out_degrees(root_ids) root_out_degrees = mol_tree_batch.out_degrees(root_ids)
q_inputs.append(torch.cat([h, tree_vec], 1)) q_inputs.append(torch.cat([h, tree_vec], 1))
q_targets.append(mol_tree_batch.nodes[root_ids].data['wid']) q_targets.append(mol_tree_batch.nodes[root_ids].data["wid"])
# Traverse the tree and predict on children # Traverse the tree and predict on children
for eid, p in dfs_order(mol_tree_batch, root_ids): for eid, p in dfs_order(mol_tree_batch, root_ids):
...@@ -160,29 +184,35 @@ class DGLJTNNDecoder(nn.Module): ...@@ -160,29 +184,35 @@ class DGLJTNNDecoder(nn.Module):
u, v = mol_tree_batch.find_edges(eid) u, v = mol_tree_batch.find_edges(eid)
p_target_list = torch.zeros_like(root_out_degrees) p_target_list = torch.zeros_like(root_out_degrees)
p_target_list[root_out_degrees > 0] = (1 - p) p_target_list[root_out_degrees > 0] = 1 - p
p_target_list = p_target_list[root_out_degrees >= 0] p_target_list = p_target_list[root_out_degrees >= 0]
p_targets.append(torch.tensor(p_target_list)) p_targets.append(torch.tensor(p_target_list))
root_out_degrees -= (root_out_degrees == 0).long() root_out_degrees -= (root_out_degrees == 0).long()
root_out_degrees -= torch.tensor(np.isin(root_ids, v.cpu().numpy())).to(root_out_degrees) root_out_degrees -= torch.tensor(
np.isin(root_ids, v.cpu().numpy())
).to(root_out_degrees)
mol_tree_batch_lg.ndata.update(mol_tree_batch.edata) mol_tree_batch_lg.ndata.update(mol_tree_batch.edata)
mol_tree_batch_lg.pull(eid, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's')) mol_tree_batch_lg.pull(
mol_tree_batch_lg.pull(eid, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm')) eid, DGLF.copy_u("m", "m"), DGLF.sum("m", "s")
)
mol_tree_batch_lg.pull(
eid, DGLF.copy_u("rm", "rm"), DGLF.sum("rm", "accum_rm")
)
mol_tree_batch_lg.apply_nodes(self.dec_tree_edge_update, v=eid) mol_tree_batch_lg.apply_nodes(self.dec_tree_edge_update, v=eid)
mol_tree_batch.edata.update(mol_tree_batch_lg.ndata) mol_tree_batch.edata.update(mol_tree_batch_lg.ndata)
is_new = mol_tree_batch.nodes[v].data['new'] is_new = mol_tree_batch.nodes[v].data["new"]
mol_tree_batch.pull(v, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h')) mol_tree_batch.pull(v, DGLF.copy_e("m", "m"), DGLF.sum("m", "h"))
mol_tree_batch.apply_nodes(dec_tree_node_update, v=v) mol_tree_batch.apply_nodes(dec_tree_node_update, v=v)
# Extract # Extract
n_repr = mol_tree_batch.nodes[v].data n_repr = mol_tree_batch.nodes[v].data
h = n_repr['h'] h = n_repr["h"]
x = n_repr['x'] x = n_repr["x"]
tree_vec_set = tree_vec[root_out_degrees >= 0] tree_vec_set = tree_vec[root_out_degrees >= 0]
wid = n_repr['wid'] wid = n_repr["wid"]
p_inputs.append(torch.cat([x, h, tree_vec_set], 1)) p_inputs.append(torch.cat([x, h, tree_vec_set], 1))
# Only newly generated nodes are needed for label prediction # Only newly generated nodes are needed for label prediction
# NOTE: The following works since the uncomputed messages are zeros. # NOTE: The following works since the uncomputed messages are zeros.
...@@ -192,10 +222,13 @@ class DGLJTNNDecoder(nn.Module): ...@@ -192,10 +222,13 @@ class DGLJTNNDecoder(nn.Module):
if q_input.shape[0] > 0: if q_input.shape[0] > 0:
q_inputs.append(q_input) q_inputs.append(q_input)
q_targets.append(q_target) q_targets.append(q_target)
p_targets.append(torch.zeros( p_targets.append(
(root_out_degrees == 0).sum(), torch.zeros(
device=root_out_degrees.device, (root_out_degrees == 0).sum(),
dtype=torch.int64)) device=root_out_degrees.device,
dtype=torch.int64,
)
)
# Batch compute the stop/label prediction losses # Batch compute the stop/label prediction losses
p_inputs = torch.cat(p_inputs, 0) p_inputs = torch.cat(p_inputs, 0)
...@@ -206,9 +239,12 @@ class DGLJTNNDecoder(nn.Module): ...@@ -206,9 +239,12 @@ class DGLJTNNDecoder(nn.Module):
q = self.W_o(torch.relu(self.W(q_inputs))) q = self.W_o(torch.relu(self.W(q_inputs)))
p = self.U_s(torch.relu(self.U(p_inputs)))[:, 0] p = self.U_s(torch.relu(self.U(p_inputs)))[:, 0]
p_loss = F.binary_cross_entropy_with_logits( p_loss = (
p, p_targets.float(), size_average=False F.binary_cross_entropy_with_logits(
) / n_trees p, p_targets.float(), size_average=False
)
/ n_trees
)
q_loss = F.cross_entropy(q, q_targets, size_average=False) / n_trees q_loss = F.cross_entropy(q, q_targets, size_average=False) / n_trees
p_acc = ((p > 0).long() == p_targets).sum().float() / p_targets.shape[0] p_acc = ((p > 0).long() == p_targets).sum().float() / p_targets.shape[0]
q_acc = (q.max(1)[1] == q_targets).float().sum() / q_targets.shape[0] q_acc = (q.max(1)[1] == q_targets).float().sum() / q_targets.shape[0]
...@@ -237,13 +273,14 @@ class DGLJTNNDecoder(nn.Module): ...@@ -237,13 +273,14 @@ class DGLJTNNDecoder(nn.Module):
_, root_wid = torch.max(root_score, 1) _, root_wid = torch.max(root_score, 1)
root_wid = root_wid.view(1) root_wid = root_wid.view(1)
mol_tree_graph.add_nodes(1) # root mol_tree_graph.add_nodes(1) # root
mol_tree_graph.ndata['wid'] = root_wid mol_tree_graph.ndata["wid"] = root_wid
mol_tree_graph.ndata['x'] = self.embedding(root_wid) mol_tree_graph.ndata["x"] = self.embedding(root_wid)
mol_tree_graph.ndata['h'] = init_hidden mol_tree_graph.ndata["h"] = init_hidden
mol_tree_graph.ndata['fail'] = cuda(torch.tensor([0])) mol_tree_graph.ndata["fail"] = cuda(torch.tensor([0]))
mol_tree.nodes_dict[0] = root_node_dict = create_node_dict( mol_tree.nodes_dict[0] = root_node_dict = create_node_dict(
self.vocab.get_smiles(root_wid)) self.vocab.get_smiles(root_wid)
)
stack, trace = [], [] stack, trace = [], []
stack.append((0, self.vocab.get_slots(root_wid))) stack.append((0, self.vocab.get_slots(root_wid)))
...@@ -256,13 +293,13 @@ class DGLJTNNDecoder(nn.Module): ...@@ -256,13 +293,13 @@ class DGLJTNNDecoder(nn.Module):
for step in range(MAX_DECODE_LEN): for step in range(MAX_DECODE_LEN):
u, u_slots = stack[-1] u, u_slots = stack[-1]
x = mol_tree_graph.ndata['x'][u:u+1] x = mol_tree_graph.ndata["x"][u : u + 1]
h = mol_tree_graph.ndata['h'][u:u+1] h = mol_tree_graph.ndata["h"][u : u + 1]
# Predict stop # Predict stop
p_input = torch.cat([x, h, mol_vec], 1) p_input = torch.cat([x, h, mol_vec], 1)
p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input)))) p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input))))
backtrack = (p_score.item() < 0.5) backtrack = p_score.item() < 0.5
if not backtrack: if not backtrack:
# Predict next clique. Note that the prediction may fail due # Predict next clique. Note that the prediction may fail due
...@@ -273,49 +310,61 @@ class DGLJTNNDecoder(nn.Module): ...@@ -273,49 +310,61 @@ class DGLJTNNDecoder(nn.Module):
mol_tree_graph.add_edges(u, v) mol_tree_graph.add_edges(u, v)
uv = new_edge_id uv = new_edge_id
new_edge_id += 1 new_edge_id += 1
if first: if first:
mol_tree_graph.edata.update({ mol_tree_graph.edata.update(
's': cuda(torch.zeros(1, self.hidden_size)), {
'm': cuda(torch.zeros(1, self.hidden_size)), "s": cuda(torch.zeros(1, self.hidden_size)),
'r': cuda(torch.zeros(1, self.hidden_size)), "m": cuda(torch.zeros(1, self.hidden_size)),
'z': cuda(torch.zeros(1, self.hidden_size)), "r": cuda(torch.zeros(1, self.hidden_size)),
'src_x': cuda(torch.zeros(1, self.hidden_size)), "z": cuda(torch.zeros(1, self.hidden_size)),
'dst_x': cuda(torch.zeros(1, self.hidden_size)), "src_x": cuda(torch.zeros(1, self.hidden_size)),
'rm': cuda(torch.zeros(1, self.hidden_size)), "dst_x": cuda(torch.zeros(1, self.hidden_size)),
'accum_rm': cuda(torch.zeros(1, self.hidden_size)), "rm": cuda(torch.zeros(1, self.hidden_size)),
}) "accum_rm": cuda(torch.zeros(1, self.hidden_size)),
}
)
first = False first = False
mol_tree_graph.edata['src_x'][uv] = mol_tree_graph.ndata['x'][u] mol_tree_graph.edata["src_x"][uv] = mol_tree_graph.ndata["x"][u]
# keeping dst_x 0 is fine as h on new edge doesn't depend on that. # keeping dst_x 0 is fine as h on new edge doesn't depend on that.
# DGL doesn't dynamically maintain a line graph. # DGL doesn't dynamically maintain a line graph.
mol_tree_graph_lg = line_graph(mol_tree_graph, backtracking=False, shared=True) mol_tree_graph_lg = line_graph(
mol_tree_graph, backtracking=False, shared=True
)
mol_tree_graph_lg.pull( mol_tree_graph_lg.pull(
uv, uv, DGLF.copy_u("m", "m"), DGLF.sum("m", "s")
DGLF.copy_u('m', 'm'), )
DGLF.sum('m', 's'))
mol_tree_graph_lg.pull( mol_tree_graph_lg.pull(
uv, uv, DGLF.copy_u("rm", "rm"), DGLF.sum("rm", "accum_rm")
DGLF.copy_u('rm', 'rm'), )
DGLF.sum('rm', 'accum_rm')) mol_tree_graph_lg.apply_nodes(
mol_tree_graph_lg.apply_nodes(self.dec_tree_edge_update.update_zm, v=uv) self.dec_tree_edge_update.update_zm, v=uv
)
mol_tree_graph.edata.update(mol_tree_graph_lg.ndata) mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
mol_tree_graph.pull(v, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h')) mol_tree_graph.pull(
v, DGLF.copy_e("m", "m"), DGLF.sum("m", "h")
)
h_v = mol_tree_graph.ndata['h'][v:v+1] h_v = mol_tree_graph.ndata["h"][v : v + 1]
q_input = torch.cat([h_v, mol_vec], 1) q_input = torch.cat([h_v, mol_vec], 1)
q_score = torch.softmax(self.W_o(torch.relu(self.W(q_input))), -1) q_score = torch.softmax(
self.W_o(torch.relu(self.W(q_input))), -1
)
_, sort_wid = torch.sort(q_score, 1, descending=True) _, sort_wid = torch.sort(q_score, 1, descending=True)
sort_wid = sort_wid.squeeze() sort_wid = sort_wid.squeeze()
next_wid = None next_wid = None
for wid in sort_wid.tolist()[:5]: for wid in sort_wid.tolist()[:5]:
slots = self.vocab.get_slots(wid) slots = self.vocab.get_slots(wid)
cand_node_dict = create_node_dict(self.vocab.get_smiles(wid)) cand_node_dict = create_node_dict(
if (have_slots(u_slots, slots) and can_assemble(mol_tree, u, cand_node_dict)): self.vocab.get_smiles(wid)
)
if have_slots(u_slots, slots) and can_assemble(
mol_tree, u, cand_node_dict
):
next_wid = wid next_wid = wid
next_slots = slots next_slots = slots
next_node_dict = cand_node_dict next_node_dict = cand_node_dict
...@@ -324,44 +373,59 @@ class DGLJTNNDecoder(nn.Module): ...@@ -324,44 +373,59 @@ class DGLJTNNDecoder(nn.Module):
if next_wid is None: if next_wid is None:
# Failed adding an actual children; v is a spurious node # Failed adding an actual children; v is a spurious node
# and we mark it. # and we mark it.
mol_tree_graph.ndata['fail'][v] = cuda(torch.tensor([1])) mol_tree_graph.ndata["fail"][v] = cuda(torch.tensor([1]))
backtrack = True backtrack = True
else: else:
next_wid = cuda(torch.tensor([next_wid])) next_wid = cuda(torch.tensor([next_wid]))
mol_tree_graph.ndata['wid'][v] = next_wid mol_tree_graph.ndata["wid"][v] = next_wid
mol_tree_graph.ndata['x'][v] = self.embedding(next_wid) mol_tree_graph.ndata["x"][v] = self.embedding(next_wid)
mol_tree.nodes_dict[v] = next_node_dict mol_tree.nodes_dict[v] = next_node_dict
all_nodes[v] = next_node_dict all_nodes[v] = next_node_dict
stack.append((v, next_slots)) stack.append((v, next_slots))
mol_tree_graph.add_edges(v, u) mol_tree_graph.add_edges(v, u)
vu = new_edge_id vu = new_edge_id
new_edge_id += 1 new_edge_id += 1
mol_tree_graph.edata['dst_x'][uv] = mol_tree_graph.ndata['x'][v] mol_tree_graph.edata["dst_x"][uv] = mol_tree_graph.ndata[
mol_tree_graph.edata['src_x'][vu] = mol_tree_graph.ndata['x'][v] "x"
mol_tree_graph.edata['dst_x'][vu] = mol_tree_graph.ndata['x'][u] ][v]
mol_tree_graph.edata["src_x"][vu] = mol_tree_graph.ndata[
"x"
][v]
mol_tree_graph.edata["dst_x"][vu] = mol_tree_graph.ndata[
"x"
][u]
# DGL doesn't dynamically maintain a line graph. # DGL doesn't dynamically maintain a line graph.
mol_tree_graph_lg = line_graph(mol_tree_graph, backtracking=False, shared=True) mol_tree_graph_lg = line_graph(
mol_tree_graph, backtracking=False, shared=True
)
mol_tree_graph_lg.apply_nodes( mol_tree_graph_lg.apply_nodes(
self.dec_tree_edge_update.update_r, self.dec_tree_edge_update.update_r, uv
uv )
)
mol_tree_graph.edata.update(mol_tree_graph_lg.ndata) mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
if backtrack: if backtrack:
if len(stack) == 1: if len(stack) == 1:
break # At root, terminate break # At root, terminate
pu, _ = stack[-2] pu, _ = stack[-2]
u_pu = mol_tree_graph.edge_ids(u, pu) u_pu = mol_tree_graph.edge_ids(u, pu)
mol_tree_graph_lg.pull(u_pu, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's')) mol_tree_graph_lg.pull(
mol_tree_graph_lg.pull(u_pu, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm')) u_pu, DGLF.copy_u("m", "m"), DGLF.sum("m", "s")
)
mol_tree_graph_lg.pull(
u_pu, DGLF.copy_u("rm", "rm"), DGLF.sum("rm", "accum_rm")
)
mol_tree_graph_lg.apply_nodes(self.dec_tree_edge_update, v=u_pu) mol_tree_graph_lg.apply_nodes(self.dec_tree_edge_update, v=u_pu)
mol_tree_graph.edata.update(mol_tree_graph_lg.ndata) mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
mol_tree_graph.pull(pu, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h')) mol_tree_graph.pull(
pu, DGLF.copy_e("m", "m"), DGLF.sum("m", "h")
)
stack.pop() stack.pop()
effective_nodes = mol_tree_graph.filter_nodes(lambda nodes: nodes.data['fail'] != 1) effective_nodes = mol_tree_graph.filter_nodes(
lambda nodes: nodes.data["fail"] != 1
)
effective_nodes, _ = torch.sort(effective_nodes) effective_nodes, _ = torch.sort(effective_nodes)
return mol_tree, all_nodes, effective_nodes return mol_tree, all_nodes, effective_nodes
import dgl.function as DGLF
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import dgl.function as DGLF
from dgl import batch, bfs_edges_generator, line_graph from dgl import batch, bfs_edges_generator, line_graph
from .nnutils import GRUUpdate, cuda, tocpu from .nnutils import cuda, GRUUpdate, tocpu
MAX_NB = 8 MAX_NB = 8
......
...@@ -14,12 +14,10 @@ from .chemutils import ( ...@@ -14,12 +14,10 @@ from .chemutils import (
enum_assemble_nx, enum_assemble_nx,
set_atommap, set_atommap,
) )
from .jtmpn import DGLJTMPN from .jtmpn import DGLJTMPN, mol2dgl_single as mol2dgl_dec
from .jtmpn import mol2dgl_single as mol2dgl_dec
from .jtnn_dec import DGLJTNNDecoder from .jtnn_dec import DGLJTNNDecoder
from .jtnn_enc import DGLJTNNEncoder from .jtnn_enc import DGLJTNNEncoder
from .mpn import DGLMPN from .mpn import DGLMPN, mol2dgl_single as mol2dgl_enc
from .mpn import mol2dgl_single as mol2dgl_enc
from .nnutils import cuda from .nnutils import cuda
......
import dgl
import numpy as np import numpy as np
import rdkit.Chem as Chem import rdkit.Chem as Chem
import dgl
from .chemutils import ( from .chemutils import (
decode_stereo, decode_stereo,
enum_assemble_nx, enum_assemble_nx,
......
import dgl
import dgl.function as DGLF
import rdkit.Chem as Chem import rdkit.Chem as Chem
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
import dgl.function as DGLF
from dgl import line_graph, mean_nodes from dgl import line_graph, mean_nodes
from .chemutils import get_mol from .chemutils import get_mol
......
import os import os
import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
import dgl
def cuda(x): def cuda(x):
if torch.cuda.is_available() and not os.getenv("NOCUDA", None): if torch.cuda.is_available() and not os.getenv("NOCUDA", None):
......
import argparse import argparse
import torch
import dgl import dgl
import torch
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from dgl.nn import LabelPropagation from dgl.nn import LabelPropagation
......
...@@ -20,18 +20,18 @@ ...@@ -20,18 +20,18 @@
import warnings import warnings
from time import time from time import time
import dgl
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import scipy.sparse as ss import scipy.sparse as ss
import torch import torch
from dgl import function as fn
from lda_model import LatentDirichletAllocation as LDAModel from lda_model import LatentDirichletAllocation as LDAModel
from sklearn.datasets import fetch_20newsgroups from sklearn.datasets import fetch_20newsgroups
from sklearn.decomposition import NMF, LatentDirichletAllocation from sklearn.decomposition import LatentDirichletAllocation, NMF
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
import dgl
from dgl import function as fn
n_samples = 2000 n_samples = 2000
n_features = 1000 n_features = 1000
n_components = 10 n_components = 10
......
...@@ -23,12 +23,12 @@ import io ...@@ -23,12 +23,12 @@ import io
import os import os
import warnings import warnings
import dgl
import numpy as np import numpy as np
import scipy as sp import scipy as sp
import torch import torch
import dgl
try: try:
from functools import cached_property from functools import cached_property
except ImportError: except ImportError:
......
import copy import copy
import itertools import itertools
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import networkx as nx import networkx as nx
import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np
class GNNModule(nn.Module): class GNNModule(nn.Module):
def __init__(self, in_feats, out_feats, radius): def __init__(self, in_feats, out_feats, radius):
...@@ -15,14 +17,22 @@ class GNNModule(nn.Module): ...@@ -15,14 +17,22 @@ class GNNModule(nn.Module):
self.radius = radius self.radius = radius
new_linear = lambda: nn.Linear(in_feats, out_feats) new_linear = lambda: nn.Linear(in_feats, out_feats)
new_linear_list = lambda: nn.ModuleList([new_linear() for i in range(radius)]) new_linear_list = lambda: nn.ModuleList(
[new_linear() for i in range(radius)]
)
self.theta_x, self.theta_deg, self.theta_y = \ self.theta_x, self.theta_deg, self.theta_y = (
new_linear(), new_linear(), new_linear() new_linear(),
new_linear(),
new_linear(),
)
self.theta_list = new_linear_list() self.theta_list = new_linear_list()
self.gamma_y, self.gamma_deg, self.gamma_x = \ self.gamma_y, self.gamma_deg, self.gamma_x = (
new_linear(), new_linear(), new_linear() new_linear(),
new_linear(),
new_linear(),
)
self.gamma_list = new_linear_list() self.gamma_list = new_linear_list()
self.bn_x = nn.BatchNorm1d(out_feats) self.bn_x = nn.BatchNorm1d(out_feats)
...@@ -30,43 +40,61 @@ class GNNModule(nn.Module): ...@@ -30,43 +40,61 @@ class GNNModule(nn.Module):
def aggregate(self, g, z): def aggregate(self, g, z):
z_list = [] z_list = []
g.ndata['z'] = z g.ndata["z"] = z
g.update_all(fn.copy_u(u='z', out='m'), fn.sum(msg='m', out='z')) g.update_all(fn.copy_u(u="z", out="m"), fn.sum(msg="m", out="z"))
z_list.append(g.ndata['z']) z_list.append(g.ndata["z"])
for i in range(self.radius - 1): for i in range(self.radius - 1):
for j in range(2 ** i): for j in range(2**i):
g.update_all(fn.copy_u(u='z', out='m'), fn.sum(msg='m', out='z')) g.update_all(
z_list.append(g.ndata['z']) fn.copy_u(u="z", out="m"), fn.sum(msg="m", out="z")
)
z_list.append(g.ndata["z"])
return z_list return z_list
def forward(self, g, lg, x, y, deg_g, deg_lg, pm_pd): def forward(self, g, lg, x, y, deg_g, deg_lg, pm_pd):
pmpd_x = F.embedding(pm_pd, x) pmpd_x = F.embedding(pm_pd, x)
sum_x = sum(theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x))) sum_x = sum(
theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x))
)
g.edata['y'] = y g.edata["y"] = y
g.update_all(fn.copy_e(e='y', out='m'), fn.sum('m', 'pmpd_y')) g.update_all(fn.copy_e(e="y", out="m"), fn.sum("m", "pmpd_y"))
pmpd_y = g.ndata.pop('pmpd_y') pmpd_y = g.ndata.pop("pmpd_y")
x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum_x + self.theta_y(pmpd_y) x = (
self.theta_x(x)
+ self.theta_deg(deg_g * x)
+ sum_x
+ self.theta_y(pmpd_y)
)
n = self.out_feats // 2 n = self.out_feats // 2
x = th.cat([x[:, :n], F.relu(x[:, n:])], 1) x = th.cat([x[:, :n], F.relu(x[:, n:])], 1)
x = self.bn_x(x) x = self.bn_x(x)
sum_y = sum(gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y))) sum_y = sum(
gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y))
)
y = self.gamma_y(y) + self.gamma_deg(deg_lg * y) + sum_y + self.gamma_x(pmpd_x) y = (
self.gamma_y(y)
+ self.gamma_deg(deg_lg * y)
+ sum_y
+ self.gamma_x(pmpd_x)
)
y = th.cat([y[:, :n], F.relu(y[:, n:])], 1) y = th.cat([y[:, :n], F.relu(y[:, n:])], 1)
y = self.bn_y(y) y = self.bn_y(y)
return x, y return x, y
class GNN(nn.Module): class GNN(nn.Module):
def __init__(self, feats, radius, n_classes): def __init__(self, feats, radius, n_classes):
super(GNN, self).__init__() super(GNN, self).__init__()
self.linear = nn.Linear(feats[-1], n_classes) self.linear = nn.Linear(feats[-1], n_classes)
self.module_list = nn.ModuleList([GNNModule(m, n, radius) self.module_list = nn.ModuleList(
for m, n in zip(feats[:-1], feats[1:])]) [GNNModule(m, n, radius) for m, n in zip(feats[:-1], feats[1:])]
)
def forward(self, g, lg, deg_g, deg_lg, pm_pd): def forward(self, g, lg, deg_g, deg_lg, pm_pd):
x, y = deg_g, deg_lg x, y = deg_g, deg_lg
......
...@@ -16,9 +16,9 @@ import numpy as np ...@@ -16,9 +16,9 @@ import numpy as np
import torch as th import torch as th
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader
from dgl.data import SBMMixtureDataset from dgl.data import SBMMixtureDataset
from torch.utils.data import DataLoader
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", type=int, help="Batch size", default=1) parser.add_argument("--batch-size", type=int, help="Batch size", default=1)
......
import os import os
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import tqdm import tqdm
...@@ -20,35 +21,37 @@ class PBar(object): ...@@ -20,35 +21,37 @@ class PBar(object):
class AminerDataset(object): class AminerDataset(object):
""" """
Download Aminer Dataset from Amazon S3 bucket. Download Aminer Dataset from Amazon S3 bucket.
""" """
def __init__(self, path):
self.url = 'https://data.dgl.ai/dataset/aminer.zip' def __init__(self, path):
self.url = "https://data.dgl.ai/dataset/aminer.zip"
if not os.path.exists(os.path.join(path, 'aminer.txt')): if not os.path.exists(os.path.join(path, "aminer.txt")):
print('File not found. Downloading from', self.url) print("File not found. Downloading from", self.url)
self._download_and_extract(path, 'aminer.zip') self._download_and_extract(path, "aminer.zip")
self.fn = os.path.join(path, 'aminer.txt') self.fn = os.path.join(path, "aminer.txt")
def _download_and_extract(self, path, filename): def _download_and_extract(self, path, filename):
import shutil, zipfile, zlib import shutil, zipfile, zlib
from tqdm import tqdm
import urllib.request import urllib.request
from tqdm import tqdm
fn = os.path.join(path, filename) fn = os.path.join(path, filename)
with PBar() as pb: with PBar() as pb:
urllib.request.urlretrieve(self.url, fn, pb) urllib.request.urlretrieve(self.url, fn, pb)
print('Download finished. Unzipping the file...') print("Download finished. Unzipping the file...")
with zipfile.ZipFile(fn) as zf: with zipfile.ZipFile(fn) as zf:
zf.extractall(path) zf.extractall(path)
print('Unzip finished.') print("Unzip finished.")
class CustomDataset(object): class CustomDataset(object):
""" """
Custom dataset generated by sampler.py (e.g. NetDBIS) Custom dataset generated by sampler.py (e.g. NetDBIS)
""" """
def __init__(self, path): def __init__(self, path):
self.fn = path self.fn = path
import torch
import argparse import argparse
import torch
import torch.optim as optim import torch.optim as optim
from download import AminerDataset, CustomDataset
from model import SkipGramModel
from reading_data import DataReader, Metapath2vecDataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from reading_data import DataReader, Metapath2vecDataset
from model import SkipGramModel
from download import AminerDataset, CustomDataset
class Metapath2VecTrainer: class Metapath2VecTrainer:
def __init__(self, args): def __init__(self, args):
...@@ -18,8 +19,13 @@ class Metapath2VecTrainer: ...@@ -18,8 +19,13 @@ class Metapath2VecTrainer:
dataset = CustomDataset(args.path) dataset = CustomDataset(args.path)
self.data = DataReader(dataset, args.min_count, args.care_type) self.data = DataReader(dataset, args.min_count, args.care_type)
dataset = Metapath2vecDataset(self.data, args.window_size) dataset = Metapath2vecDataset(self.data, args.window_size)
self.dataloader = DataLoader(dataset, batch_size=args.batch_size, self.dataloader = DataLoader(
shuffle=True, num_workers=args.num_workers, collate_fn=dataset.collate) dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
collate_fn=dataset.collate,
)
self.output_file_name = args.output_file self.output_file_name = args.output_file
self.emb_size = len(self.data.word2id) self.emb_size = len(self.data.word2id)
...@@ -35,15 +41,17 @@ class Metapath2VecTrainer: ...@@ -35,15 +41,17 @@ class Metapath2VecTrainer:
self.skip_gram_model.cuda() self.skip_gram_model.cuda()
def train(self): def train(self):
optimizer = optim.SparseAdam(
optimizer = optim.SparseAdam(list(self.skip_gram_model.parameters()), lr=self.initial_lr) list(self.skip_gram_model.parameters()), lr=self.initial_lr
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(self.dataloader)) )
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, len(self.dataloader)
)
for iteration in range(self.iterations): for iteration in range(self.iterations):
print("\n\n\nIteration: " + str(iteration + 1)) print("\n\n\nIteration: " + str(iteration + 1))
running_loss = 0.0 running_loss = 0.0
for i, sample_batched in enumerate(tqdm(self.dataloader)): for i, sample_batched in enumerate(tqdm(self.dataloader)):
if len(sample_batched[0]) > 1: if len(sample_batched[0]) > 1:
pos_u = sample_batched[0].to(self.device) pos_u = sample_batched[0].to(self.device)
pos_v = sample_batched[1].to(self.device) pos_v = sample_batched[1].to(self.device)
...@@ -59,23 +67,40 @@ class Metapath2VecTrainer: ...@@ -59,23 +67,40 @@ class Metapath2VecTrainer:
if i > 0 and i % 500 == 0: if i > 0 and i % 500 == 0:
print(" Loss: " + str(running_loss)) print(" Loss: " + str(running_loss))
self.skip_gram_model.save_embedding(self.data.id2word, self.output_file_name) self.skip_gram_model.save_embedding(
self.data.id2word, self.output_file_name
)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Metapath2vec") parser = argparse.ArgumentParser(description="Metapath2vec")
#parser.add_argument('--input_file', type=str, help="input_file") # parser.add_argument('--input_file', type=str, help="input_file")
parser.add_argument('--aminer', action='store_true', help='Use AMiner dataset') parser.add_argument(
parser.add_argument('--path', type=str, help="input_path") "--aminer", action="store_true", help="Use AMiner dataset"
parser.add_argument('--output_file', type=str, help='output_file') )
parser.add_argument('--dim', default=128, type=int, help="embedding dimensions") parser.add_argument("--path", type=str, help="input_path")
parser.add_argument('--window_size', default=7, type=int, help="context window size") parser.add_argument("--output_file", type=str, help="output_file")
parser.add_argument('--iterations', default=5, type=int, help="iterations") parser.add_argument(
parser.add_argument('--batch_size', default=50, type=int, help="batch size") "--dim", default=128, type=int, help="embedding dimensions"
parser.add_argument('--care_type', default=0, type=int, help="if 1, heterogeneous negative sampling, else normal negative sampling") )
parser.add_argument('--initial_lr', default=0.025, type=float, help="learning rate") parser.add_argument(
parser.add_argument('--min_count', default=5, type=int, help="min count") "--window_size", default=7, type=int, help="context window size"
parser.add_argument('--num_workers', default=16, type=int, help="number of workers") )
parser.add_argument("--iterations", default=5, type=int, help="iterations")
parser.add_argument("--batch_size", default=50, type=int, help="batch size")
parser.add_argument(
"--care_type",
default=0,
type=int,
help="if 1, heterogeneous negative sampling, else normal negative sampling",
)
parser.add_argument(
"--initial_lr", default=0.025, type=float, help="learning rate"
)
parser.add_argument("--min_count", default=5, type=int, help="min count")
parser.add_argument(
"--num_workers", default=16, type=int, help="number of workers"
)
args = parser.parse_args() args = parser.parse_args()
m2v = Metapath2VecTrainer(args) m2v = Metapath2VecTrainer(args)
m2v.train() m2v.train()
...@@ -10,7 +10,6 @@ from torch.nn import init ...@@ -10,7 +10,6 @@ from torch.nn import init
class SkipGramModel(nn.Module): class SkipGramModel(nn.Module):
def __init__(self, emb_size, emb_dimension): def __init__(self, emb_size, emb_dimension):
super(SkipGramModel, self).__init__() super(SkipGramModel, self).__init__()
self.emb_size = emb_size self.emb_size = emb_size
...@@ -39,8 +38,8 @@ class SkipGramModel(nn.Module): ...@@ -39,8 +38,8 @@ class SkipGramModel(nn.Module):
def save_embedding(self, id2word, file_name): def save_embedding(self, id2word, file_name):
embedding = self.u_embeddings.weight.cpu().data.numpy() embedding = self.u_embeddings.weight.cpu().data.numpy()
with open(file_name, 'w') as f: with open(file_name, "w") as f:
f.write('%d %d\n' % (len(id2word), self.emb_dimension)) f.write("%d %d\n" % (len(id2word), self.emb_dimension))
for wid, w in id2word.items(): for wid, w in id2word.items():
e = ' '.join(map(lambda x: str(x), embedding[wid])) e = " ".join(map(lambda x: str(x), embedding[wid]))
f.write('%s %s\n' % (w, e)) f.write("%s %s\n" % (w, e))
\ No newline at end of file
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import Dataset
from download import AminerDataset from download import AminerDataset
from torch.utils.data import Dataset
np.random.seed(12345) np.random.seed(12345)
class DataReader: class DataReader:
NEGATIVE_TABLE_SIZE = 1e8 NEGATIVE_TABLE_SIZE = 1e8
def __init__(self, dataset, min_count, care_type): def __init__(self, dataset, min_count, care_type):
self.negatives = [] self.negatives = []
self.discards = [] self.discards = []
self.negpos = 0 self.negpos = 0
...@@ -35,7 +36,11 @@ class DataReader: ...@@ -35,7 +36,11 @@ class DataReader:
word_frequency[word] = word_frequency.get(word, 0) + 1 word_frequency[word] = word_frequency.get(word, 0) + 1
if self.token_count % 1000000 == 0: if self.token_count % 1000000 == 0:
print("Read " + str(int(self.token_count / 1000000)) + "M words.") print(
"Read "
+ str(int(self.token_count / 1000000))
+ "M words."
)
wid = 0 wid = 0
for w, c in word_frequency.items(): for w, c in word_frequency.items():
...@@ -71,15 +76,18 @@ class DataReader: ...@@ -71,15 +76,18 @@ class DataReader:
def getNegatives(self, target, size): # TODO check equality with target def getNegatives(self, target, size): # TODO check equality with target
if self.care_type == 0: if self.care_type == 0:
response = self.negatives[self.negpos:self.negpos + size] response = self.negatives[self.negpos : self.negpos + size]
self.negpos = (self.negpos + size) % len(self.negatives) self.negpos = (self.negpos + size) % len(self.negatives)
if len(response) != size: if len(response) != size:
return np.concatenate((response, self.negatives[0:self.negpos])) return np.concatenate(
(response, self.negatives[0 : self.negpos])
)
return response return response
# ----------------------------------------------------------------------------------------------------------------- # -----------------------------------------------------------------------------------------------------------------
class Metapath2vecDataset(Dataset): class Metapath2vecDataset(Dataset):
def __init__(self, data, window_size): def __init__(self, data, window_size):
# read in data, window_size and input filename # read in data, window_size and input filename
...@@ -103,25 +111,44 @@ class Metapath2vecDataset(Dataset): ...@@ -103,25 +111,44 @@ class Metapath2vecDataset(Dataset):
words = line.split() words = line.split()
if len(words) > 1: if len(words) > 1:
word_ids = [self.data.word2id[w] for w in words if word_ids = [
w in self.data.word2id and np.random.rand() < self.data.discards[self.data.word2id[w]]] self.data.word2id[w]
for w in words
if w in self.data.word2id
and np.random.rand()
< self.data.discards[self.data.word2id[w]]
]
pair_catch = [] pair_catch = []
for i, u in enumerate(word_ids): for i, u in enumerate(word_ids):
for j, v in enumerate( for j, v in enumerate(
word_ids[max(i - self.window_size, 0):i + self.window_size]): word_ids[
max(i - self.window_size, 0) : i
+ self.window_size
]
):
assert u < self.data.word_count assert u < self.data.word_count
assert v < self.data.word_count assert v < self.data.word_count
if i == j: if i == j:
continue continue
pair_catch.append((u, v, self.data.getNegatives(v,5))) pair_catch.append(
(u, v, self.data.getNegatives(v, 5))
)
return pair_catch return pair_catch
@staticmethod @staticmethod
def collate(batches): def collate(batches):
all_u = [u for batch in batches for u, _, _ in batch if len(batch) > 0] all_u = [u for batch in batches for u, _, _ in batch if len(batch) > 0]
all_v = [v for batch in batches for _, v, _ in batch if len(batch) > 0] all_v = [v for batch in batches for _, v, _ in batch if len(batch) > 0]
all_neg_v = [neg_v for batch in batches for _, _, neg_v in batch if len(batch) > 0] all_neg_v = [
neg_v
return torch.LongTensor(all_u), torch.LongTensor(all_v), torch.LongTensor(all_neg_v) for batch in batches
for _, _, neg_v in batch
if len(batch) > 0
]
return (
torch.LongTensor(all_u),
torch.LongTensor(all_v),
torch.LongTensor(all_neg_v),
)
import numpy as np import os
import random import random
import sys
import time import time
import tqdm
import dgl import dgl
import sys import numpy as np
import os import tqdm
num_walks_per_node = 1000 num_walks_per_node = 1000
walk_length = 100 walk_length = 100
path = sys.argv[1] path = sys.argv[1]
def construct_graph(): def construct_graph():
paper_ids = [] paper_ids = []
paper_names = [] paper_names = []
...@@ -31,7 +33,7 @@ def construct_graph(): ...@@ -31,7 +33,7 @@ def construct_graph():
while True: while True:
w = f_4.readline() w = f_4.readline()
if not w: if not w:
break; break
w = w.strip().split() w = w.strip().split()
identity = int(w[0]) identity = int(w[0])
conf_ids.append(identity) conf_ids.append(identity)
...@@ -39,10 +41,10 @@ def construct_graph(): ...@@ -39,10 +41,10 @@ def construct_graph():
while True: while True:
v = f_5.readline() v = f_5.readline()
if not v: if not v:
break; break
v = v.strip().split() v = v.strip().split()
identity = int(v[0]) identity = int(v[0])
paper_name = 'p' + ''.join(v[1:]) paper_name = "p" + "".join(v[1:])
paper_ids.append(identity) paper_ids.append(identity)
paper_names.append(paper_name) paper_names.append(paper_name)
f_3.close() f_3.close()
...@@ -60,41 +62,49 @@ def construct_graph(): ...@@ -60,41 +62,49 @@ def construct_graph():
f_1 = open(os.path.join(path, "paper_author.txt"), "r") f_1 = open(os.path.join(path, "paper_author.txt"), "r")
f_2 = open(os.path.join(path, "paper_conf.txt"), "r") f_2 = open(os.path.join(path, "paper_conf.txt"), "r")
for x in f_1: for x in f_1:
x = x.split('\t') x = x.split("\t")
x[0] = int(x[0]) x[0] = int(x[0])
x[1] = int(x[1].strip('\n')) x[1] = int(x[1].strip("\n"))
paper_author_src.append(paper_ids_invmap[x[0]]) paper_author_src.append(paper_ids_invmap[x[0]])
paper_author_dst.append(author_ids_invmap[x[1]]) paper_author_dst.append(author_ids_invmap[x[1]])
for y in f_2: for y in f_2:
y = y.split('\t') y = y.split("\t")
y[0] = int(y[0]) y[0] = int(y[0])
y[1] = int(y[1].strip('\n')) y[1] = int(y[1].strip("\n"))
paper_conf_src.append(paper_ids_invmap[y[0]]) paper_conf_src.append(paper_ids_invmap[y[0]])
paper_conf_dst.append(conf_ids_invmap[y[1]]) paper_conf_dst.append(conf_ids_invmap[y[1]])
f_1.close() f_1.close()
f_2.close() f_2.close()
hg = dgl.heterograph({ hg = dgl.heterograph(
('paper', 'pa', 'author') : (paper_author_src, paper_author_dst), {
('author', 'ap', 'paper') : (paper_author_dst, paper_author_src), ("paper", "pa", "author"): (paper_author_src, paper_author_dst),
('paper', 'pc', 'conf') : (paper_conf_src, paper_conf_dst), ("author", "ap", "paper"): (paper_author_dst, paper_author_src),
('conf', 'cp', 'paper') : (paper_conf_dst, paper_conf_src)}) ("paper", "pc", "conf"): (paper_conf_src, paper_conf_dst),
("conf", "cp", "paper"): (paper_conf_dst, paper_conf_src),
}
)
return hg, author_names, conf_names, paper_names return hg, author_names, conf_names, paper_names
#"conference - paper - Author - paper - conference" metapath sampling
# "conference - paper - Author - paper - conference" metapath sampling
def generate_metapath(): def generate_metapath():
output_path = open(os.path.join(path, "output_path.txt"), "w") output_path = open(os.path.join(path, "output_path.txt"), "w")
count = 0 count = 0
hg, author_names, conf_names, paper_names = construct_graph() hg, author_names, conf_names, paper_names = construct_graph()
for conf_idx in tqdm.trange(hg.number_of_nodes('conf')): for conf_idx in tqdm.trange(hg.number_of_nodes("conf")):
traces, _ = dgl.sampling.random_walk( traces, _ = dgl.sampling.random_walk(
hg, [conf_idx] * num_walks_per_node, metapath=['cp', 'pa', 'ap', 'pc'] * walk_length) hg,
[conf_idx] * num_walks_per_node,
metapath=["cp", "pa", "ap", "pc"] * walk_length,
)
for tr in traces: for tr in traces:
outline = ' '.join( outline = " ".join(
(conf_names if i % 4 == 0 else author_names)[tr[i]] (conf_names if i % 4 == 0 else author_names)[tr[i]]
for i in range(0, len(tr), 2)) # skip paper for i in range(0, len(tr), 2)
) # skip paper
print(outline, file=output_path) print(outline, file=output_path)
output_path.close() output_path.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment