"vscode:/vscode.git/clone" did not exist on "2e6c8ec80333e822a3715fd314c7c760e7f832d2"
Unverified Commit 704bcaf6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 6bc82161
......@@ -7,12 +7,12 @@ import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from dgl.data import CiteseerGraphDataset, CoraGraphDataset
from model import JKNet
from sklearn.model_selection import train_test_split
from tqdm import trange
from dgl.data import CiteseerGraphDataset, CoraGraphDataset
def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
......
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.nn import GraphConv, JumpingKnowledge
......
import torch
from torch.utils.data import Dataset
import dgl
import torch
from dgl.data.utils import (
_get_dgl_url,
download,
extract_archive,
get_download_dir,
)
from torch.utils.data import Dataset
from .jtmpn import ATOM_FDIM as ATOM_FDIM_DEC
from .jtmpn import BOND_FDIM as BOND_FDIM_DEC
from .jtmpn import mol2dgl_single as mol2dgl_dec
from .jtmpn import (
ATOM_FDIM as ATOM_FDIM_DEC,
BOND_FDIM as BOND_FDIM_DEC,
mol2dgl_single as mol2dgl_dec,
)
from .mol_tree import Vocab
from .mol_tree_nx import DGLMolTree
from .mpn import mol2dgl_single as mol2dgl_enc
......
import os
import dgl
import dgl.function as DGLF
import rdkit.Chem as Chem
import torch
import torch.nn as nn
from dgl import line_graph, mean_nodes
from .nnutils import cuda
import rdkit.Chem as Chem
import dgl
from dgl import mean_nodes, line_graph
import dgl.function as DGLF
import os
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ELEM_LIST = [
"C",
"N",
"O",
"S",
"F",
"Si",
"P",
"Cl",
"Br",
"Mg",
"Na",
"Ca",
"Fe",
"Al",
"I",
"B",
"K",
"Se",
"Zn",
"H",
"Cu",
"Mn",
"unknown",
]
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 1
BOND_FDIM = 5
BOND_FDIM = 5
MAX_NB = 10
PAPER = os.getenv('PAPER', False)
PAPER = os.getenv("PAPER", False)
def onek_encoding_unk(x, allowable_set):
if x not in allowable_set:
x = allowable_set[-1]
return [x == s for s in allowable_set]
# Note that during graph decoding they don't predict stereochemistry-related
# characteristics (i.e. Chiral Atoms, E-Z, Cis-Trans). Instead, they decode
# the 2-D graph first, then enumerate all possible 3-D forms and find the
# one with highest score.
def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5])
+ onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0])
+ [atom.GetIsAromatic()]))
return torch.Tensor(
onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(), [0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0])
+ [atom.GetIsAromatic()]
)
def bond_features(bond):
bt = bond.GetBondType()
return (torch.Tensor([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]))
return torch.Tensor(
[
bt == Chem.rdchem.BondType.SINGLE,
bt == Chem.rdchem.BondType.DOUBLE,
bt == Chem.rdchem.BondType.TRIPLE,
bt == Chem.rdchem.BondType.AROMATIC,
bond.IsInRing(),
]
)
def mol2dgl_single(cand_batch):
cand_graphs = []
tree_mess_source_edges = [] # map these edges from trees to...
tree_mess_target_edges = [] # these edges on candidate graphs
tree_mess_source_edges = [] # map these edges from trees to...
tree_mess_target_edges = [] # these edges on candidate graphs
tree_mess_target_nodes = []
n_nodes = 0
n_edges = 0
......@@ -50,8 +89,8 @@ def mol2dgl_single(cand_batch):
n_bonds = mol.GetNumBonds()
ctr_node = mol_tree.nodes_dict[ctr_node_id]
ctr_bid = ctr_node['idx']
mol_tree_graph = getattr(mol_tree, 'graph', mol_tree)
ctr_bid = ctr_node["idx"]
mol_tree_graph = getattr(mol_tree, "graph", mol_tree)
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
......@@ -75,15 +114,19 @@ def mol2dgl_single(cand_batch):
x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum()
# Tree node ID in the batch
x_bid = mol_tree.nodes_dict[x_nid - 1]['idx'] if x_nid > 0 else -1
y_bid = mol_tree.nodes_dict[y_nid - 1]['idx'] if y_nid > 0 else -1
x_bid = mol_tree.nodes_dict[x_nid - 1]["idx"] if x_nid > 0 else -1
y_bid = mol_tree.nodes_dict[y_nid - 1]["idx"] if y_nid > 0 else -1
if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
if mol_tree_graph.has_edges_between(x_bid, y_bid):
tree_mess_target_edges.append((begin_idx + n_nodes, end_idx + n_nodes))
tree_mess_target_edges.append(
(begin_idx + n_nodes, end_idx + n_nodes)
)
tree_mess_source_edges.append((x_bid, y_bid))
tree_mess_target_nodes.append(end_idx + n_nodes)
if mol_tree_graph.has_edges_between(y_bid, x_bid):
tree_mess_target_edges.append((end_idx + n_nodes, begin_idx + n_nodes))
tree_mess_target_edges.append(
(end_idx + n_nodes, begin_idx + n_nodes)
)
tree_mess_source_edges.append((y_bid, x_bid))
tree_mess_target_nodes.append(begin_idx + n_nodes)
......@@ -91,11 +134,14 @@ def mol2dgl_single(cand_batch):
g = dgl.graph((bond_src, bond_dst), num_nodes=n_atoms)
cand_graphs.append(g)
return cand_graphs, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0), \
torch.LongTensor(tree_mess_source_edges), \
torch.LongTensor(tree_mess_target_edges), \
torch.LongTensor(tree_mess_target_nodes)
return (
cand_graphs,
torch.stack(atom_x),
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0),
torch.LongTensor(tree_mess_source_edges),
torch.LongTensor(tree_mess_target_edges),
torch.LongTensor(tree_mess_target_nodes),
)
class LoopyBPUpdate(nn.Module):
......@@ -106,28 +152,28 @@ class LoopyBPUpdate(nn.Module):
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, node):
msg_input = node.data['msg_input']
msg_delta = self.W_h(node.data['accum_msg'] + node.data['alpha'])
msg_input = node.data["msg_input"]
msg_delta = self.W_h(node.data["accum_msg"] + node.data["alpha"])
msg = torch.relu(msg_input + msg_delta)
return {'msg': msg}
return {"msg": msg}
if PAPER:
mpn_gather_msg = [
DGLF.copy_e(edge='msg', out='msg'),
DGLF.copy_e(edge='alpha', out='alpha')
DGLF.copy_e(edge="msg", out="msg"),
DGLF.copy_e(edge="alpha", out="alpha"),
]
else:
mpn_gather_msg = DGLF.copy_e(edge='msg', out='msg')
mpn_gather_msg = DGLF.copy_e(edge="msg", out="msg")
if PAPER:
mpn_gather_reduce = [
DGLF.sum(msg='msg', out='m'),
DGLF.sum(msg='alpha', out='accum_alpha'),
DGLF.sum(msg="msg", out="m"),
DGLF.sum(msg="alpha", out="accum_alpha"),
]
else:
mpn_gather_reduce = DGLF.sum(msg='msg', out='m')
mpn_gather_reduce = DGLF.sum(msg="msg", out="m")
class GatherUpdate(nn.Module):
......@@ -139,12 +185,12 @@ class GatherUpdate(nn.Module):
def forward(self, node):
if PAPER:
#m = node['m']
m = node.data['m'] + node.data['accum_alpha']
# m = node['m']
m = node.data["m"] + node.data["accum_alpha"]
else:
m = node.data['m'] + node.data['alpha']
m = node.data["m"] + node.data["alpha"]
return {
'h': torch.relu(self.W_o(torch.cat([node.data['x'], m], 1))),
"h": torch.relu(self.W_o(torch.cat([node.data["x"], m], 1))),
}
......@@ -166,20 +212,32 @@ class DGLJTMPN(nn.Module):
self.n_passes = 0
def forward(self, cand_batch, mol_tree_batch):
cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = cand_batch
(
cand_graphs,
tree_mess_src_edges,
tree_mess_tgt_edges,
tree_mess_tgt_nodes,
) = cand_batch
n_samples = len(cand_graphs)
cand_line_graph = line_graph(cand_graphs, backtracking=False, shared=True)
cand_line_graph = line_graph(
cand_graphs, backtracking=False, shared=True
)
n_nodes = cand_graphs.number_of_nodes()
n_edges = cand_graphs.number_of_edges()
cand_graphs = self.run(
cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges,
tree_mess_tgt_nodes, mol_tree_batch)
cand_graphs,
cand_line_graph,
tree_mess_src_edges,
tree_mess_tgt_edges,
tree_mess_tgt_nodes,
mol_tree_batch,
)
g_repr = mean_nodes(cand_graphs, 'h')
g_repr = mean_nodes(cand_graphs, "h")
self.n_samples_total += n_samples
self.n_nodes_total += n_nodes
......@@ -188,33 +246,45 @@ class DGLJTMPN(nn.Module):
return g_repr
def run(self, cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges,
tree_mess_tgt_nodes, mol_tree_batch):
def run(
self,
cand_graphs,
cand_line_graph,
tree_mess_src_edges,
tree_mess_tgt_edges,
tree_mess_tgt_nodes,
mol_tree_batch,
):
n_nodes = cand_graphs.number_of_nodes()
cand_graphs.apply_edges(
func=lambda edges: {'src_x': edges.src['x']},
func=lambda edges: {"src_x": edges.src["x"]},
)
cand_line_graph.ndata.update(cand_graphs.edata)
bond_features = cand_line_graph.ndata['x']
source_features = cand_line_graph.ndata['src_x']
bond_features = cand_line_graph.ndata["x"]
source_features = cand_line_graph.ndata["src_x"]
features = torch.cat([source_features, bond_features], 1)
msg_input = self.W_i(features)
cand_line_graph.ndata.update({
'msg_input': msg_input,
'msg': torch.relu(msg_input),
'accum_msg': torch.zeros_like(msg_input),
})
cand_line_graph.ndata.update(
{
"msg_input": msg_input,
"msg": torch.relu(msg_input),
"accum_msg": torch.zeros_like(msg_input),
}
)
zero_node_state = bond_features.new(n_nodes, self.hidden_size).zero_()
cand_graphs.ndata.update({
'm': zero_node_state.clone(),
'h': zero_node_state.clone(),
})
cand_graphs.edata['alpha'] = \
cuda(torch.zeros(cand_graphs.number_of_edges(), self.hidden_size))
cand_graphs.ndata['alpha'] = zero_node_state
cand_graphs.ndata.update(
{
"m": zero_node_state.clone(),
"h": zero_node_state.clone(),
}
)
cand_graphs.edata["alpha"] = cuda(
torch.zeros(cand_graphs.number_of_edges(), self.hidden_size)
)
cand_graphs.ndata["alpha"] = zero_node_state
if tree_mess_src_edges.shape[0] > 0:
if PAPER:
src_u, src_v = tree_mess_src_edges.unbind(1)
......@@ -222,33 +292,39 @@ class DGLJTMPN(nn.Module):
src_u = src_u.to(mol_tree_batch.device)
src_v = src_v.to(mol_tree_batch.device)
eid = mol_tree_batch.edge_ids(src_u, src_v)
alpha = mol_tree_batch.edata['m'][eid]
cand_graphs.edges[tgt_u, tgt_v].data['alpha'] = alpha
alpha = mol_tree_batch.edata["m"][eid]
cand_graphs.edges[tgt_u, tgt_v].data["alpha"] = alpha
else:
src_u, src_v = tree_mess_src_edges.unbind(1)
src_u = src_u.to(mol_tree_batch.device)
src_v = src_v.to(mol_tree_batch.device)
eid = mol_tree_batch.edge_ids(src_u, src_v)
alpha = mol_tree_batch.edata['m'][eid]
node_idx = (tree_mess_tgt_nodes
.to(device=zero_node_state.device)[:, None]
.expand_as(alpha))
node_alpha = zero_node_state.clone().scatter_add(0, node_idx, alpha)
cand_graphs.ndata['alpha'] = node_alpha
alpha = mol_tree_batch.edata["m"][eid]
node_idx = tree_mess_tgt_nodes.to(
device=zero_node_state.device
)[:, None].expand_as(alpha)
node_alpha = zero_node_state.clone().scatter_add(
0, node_idx, alpha
)
cand_graphs.ndata["alpha"] = node_alpha
cand_graphs.apply_edges(
func=lambda edges: {'alpha': edges.src['alpha']},
func=lambda edges: {"alpha": edges.src["alpha"]},
)
cand_line_graph.ndata.update(cand_graphs.edata)
for i in range(self.depth - 1):
cand_line_graph.update_all(DGLF.copy_u('msg', 'msg'), DGLF.sum('msg', 'accum_msg'))
cand_line_graph.update_all(
DGLF.copy_u("msg", "msg"), DGLF.sum("msg", "accum_msg")
)
cand_line_graph.apply_nodes(self.loopy_bp_updater)
cand_graphs.edata.update(cand_line_graph.ndata)
cand_graphs.update_all(DGLF.copy_e('msg', 'msg'), DGLF.sum('msg', 'm'))
cand_graphs.update_all(DGLF.copy_e("msg", "msg"), DGLF.sum("msg", "m"))
if PAPER:
cand_graphs.update_all(DGLF.copy_e('alpha', 'alpha'), DGLF.sum('alpha', 'accum_alpha'))
cand_graphs.update_all(
DGLF.copy_e("alpha", "alpha"), DGLF.sum("alpha", "accum_alpha")
)
cand_graphs.apply_nodes(self.gather_updater)
return cand_graphs
import dgl.function as DGLF
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .mol_tree_nx import DGLMolTree
from .chemutils import enum_assemble_nx, get_mol
from .nnutils import GRUUpdate, cuda, tocpu
from dgl import batch, dfs_labeled_edges_generator, line_graph
import dgl.function as DGLF
import numpy as np
from .chemutils import enum_assemble_nx, get_mol
from .mol_tree_nx import DGLMolTree
from .nnutils import cuda, GRUUpdate, tocpu
MAX_NB = 8
MAX_DECODE_LEN = 100
......@@ -21,60 +22,70 @@ def dfs_order(forest, roots):
# using find_edges().
yield e ^ l, l
dec_tree_node_msg = DGLF.copy_e(edge='m', out='m')
dec_tree_node_reduce = DGLF.sum(msg='m', out='h')
dec_tree_node_msg = DGLF.copy_e(edge="m", out="m")
dec_tree_node_reduce = DGLF.sum(msg="m", out="h")
def dec_tree_node_update(nodes):
return {'new': nodes.data['new'].clone().zero_()}
return {"new": nodes.data["new"].clone().zero_()}
def have_slots(fa_slots, ch_slots):
if len(fa_slots) > 2 and len(ch_slots) > 2:
return True
matches = []
for i,s1 in enumerate(fa_slots):
a1,c1,h1 = s1
for j,s2 in enumerate(ch_slots):
a2,c2,h2 = s2
for i, s1 in enumerate(fa_slots):
a1, c1, h1 = s1
for j, s2 in enumerate(ch_slots):
a2, c2, h2 = s2
if a1 == a2 and c1 == c2 and (a1 != "C" or h1 + h2 >= 4):
matches.append( (i,j) )
matches.append((i, j))
if len(matches) == 0: return False
if len(matches) == 0:
return False
fa_match,ch_match = list(zip(*matches))
if len(set(fa_match)) == 1 and 1 < len(fa_slots) <= 2: #never remove atom from ring
fa_match, ch_match = list(zip(*matches))
if (
len(set(fa_match)) == 1 and 1 < len(fa_slots) <= 2
): # never remove atom from ring
fa_slots.pop(fa_match[0])
if len(set(ch_match)) == 1 and 1 < len(ch_slots) <= 2: #never remove atom from ring
if (
len(set(ch_match)) == 1 and 1 < len(ch_slots) <= 2
): # never remove atom from ring
ch_slots.pop(ch_match[0])
return True
def can_assemble(mol_tree, u, v_node_dict):
u_node_dict = mol_tree.nodes_dict[u]
u_neighbors = mol_tree.graph.successors(u)
u_neighbors_node_dict = [
mol_tree.nodes_dict[_u]
for _u in u_neighbors
if _u in mol_tree.nodes_dict
]
mol_tree.nodes_dict[_u]
for _u in u_neighbors
if _u in mol_tree.nodes_dict
]
neis = u_neighbors_node_dict + [v_node_dict]
for i,nei in enumerate(neis):
nei['nid'] = i
neighbors = [nei for nei in neis if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x:x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in neis if nei['mol'].GetNumAtoms() == 1]
for i, nei in enumerate(neis):
nei["nid"] = i
neighbors = [nei for nei in neis if nei["mol"].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x["mol"].GetNumAtoms(), reverse=True
)
singletons = [nei for nei in neis if nei["mol"].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands = enum_assemble_nx(u_node_dict, neighbors)
return len(cands) > 0
def create_node_dict(smiles, clique=[]):
return dict(
smiles=smiles,
mol=get_mol(smiles),
clique=clique,
)
smiles=smiles,
mol=get_mol(smiles),
clique=clique,
)
class DGLJTNNDecoder(nn.Module):
......@@ -98,41 +109,54 @@ class DGLJTNNDecoder(nn.Module):
self.U_s = nn.Linear(hidden_size, 1)
def forward(self, mol_trees, tree_vec):
'''
"""
The training procedure which computes the prediction loss given the
ground truth tree
'''
"""
mol_tree_batch = batch(mol_trees)
mol_tree_batch_lg = line_graph(mol_tree_batch, backtracking=False, shared=True)
mol_tree_batch_lg = line_graph(
mol_tree_batch, backtracking=False, shared=True
)
n_trees = len(mol_trees)
return self.run(mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec)
def run(self, mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec):
node_offset = np.cumsum(np.insert(mol_tree_batch.batch_num_nodes().cpu().numpy(), 0, 0))
node_offset = np.cumsum(
np.insert(mol_tree_batch.batch_num_nodes().cpu().numpy(), 0, 0)
)
root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges()
mol_tree_batch.ndata.update({
'x': self.embedding(mol_tree_batch.ndata['wid']),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)),
'new': cuda(torch.ones(n_nodes).bool()), # whether it's newly generated node
})
mol_tree_batch.edata.update({
's': cuda(torch.zeros(n_edges, self.hidden_size)),
'm': cuda(torch.zeros(n_edges, self.hidden_size)),
'r': cuda(torch.zeros(n_edges, self.hidden_size)),
'z': cuda(torch.zeros(n_edges, self.hidden_size)),
'src_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'dst_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'rm': cuda(torch.zeros(n_edges, self.hidden_size)),
'accum_rm': cuda(torch.zeros(n_edges, self.hidden_size)),
})
mol_tree_batch.ndata.update(
{
"x": self.embedding(mol_tree_batch.ndata["wid"]),
"h": cuda(torch.zeros(n_nodes, self.hidden_size)),
"new": cuda(
torch.ones(n_nodes).bool()
), # whether it's newly generated node
}
)
mol_tree_batch.edata.update(
{
"s": cuda(torch.zeros(n_edges, self.hidden_size)),
"m": cuda(torch.zeros(n_edges, self.hidden_size)),
"r": cuda(torch.zeros(n_edges, self.hidden_size)),
"z": cuda(torch.zeros(n_edges, self.hidden_size)),
"src_x": cuda(torch.zeros(n_edges, self.hidden_size)),
"dst_x": cuda(torch.zeros(n_edges, self.hidden_size)),
"rm": cuda(torch.zeros(n_edges, self.hidden_size)),
"accum_rm": cuda(torch.zeros(n_edges, self.hidden_size)),
}
)
mol_tree_batch.apply_edges(
func=lambda edges: {'src_x': edges.src['x'], 'dst_x': edges.dst['x']},
func=lambda edges: {
"src_x": edges.src["x"],
"dst_x": edges.dst["x"],
},
)
# input tensors for stop prediction (p) and label prediction (q)
......@@ -142,16 +166,16 @@ class DGLJTNNDecoder(nn.Module):
q_targets = []
# Predict root
mol_tree_batch.pull(root_ids, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
mol_tree_batch.pull(root_ids, DGLF.copy_e("m", "m"), DGLF.sum("m", "h"))
mol_tree_batch.apply_nodes(dec_tree_node_update, v=root_ids)
# Extract hidden states and store them for stop/label prediction
h = mol_tree_batch.nodes[root_ids].data['h']
x = mol_tree_batch.nodes[root_ids].data['x']
h = mol_tree_batch.nodes[root_ids].data["h"]
x = mol_tree_batch.nodes[root_ids].data["x"]
p_inputs.append(torch.cat([x, h, tree_vec], 1))
# If the out degree is 0 we don't generate any edges at all
root_out_degrees = mol_tree_batch.out_degrees(root_ids)
q_inputs.append(torch.cat([h, tree_vec], 1))
q_targets.append(mol_tree_batch.nodes[root_ids].data['wid'])
q_targets.append(mol_tree_batch.nodes[root_ids].data["wid"])
# Traverse the tree and predict on children
for eid, p in dfs_order(mol_tree_batch, root_ids):
......@@ -160,29 +184,35 @@ class DGLJTNNDecoder(nn.Module):
u, v = mol_tree_batch.find_edges(eid)
p_target_list = torch.zeros_like(root_out_degrees)
p_target_list[root_out_degrees > 0] = (1 - p)
p_target_list[root_out_degrees > 0] = 1 - p
p_target_list = p_target_list[root_out_degrees >= 0]
p_targets.append(torch.tensor(p_target_list))
root_out_degrees -= (root_out_degrees == 0).long()
root_out_degrees -= torch.tensor(np.isin(root_ids, v.cpu().numpy())).to(root_out_degrees)
root_out_degrees -= torch.tensor(
np.isin(root_ids, v.cpu().numpy())
).to(root_out_degrees)
mol_tree_batch_lg.ndata.update(mol_tree_batch.edata)
mol_tree_batch_lg.pull(eid, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's'))
mol_tree_batch_lg.pull(eid, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm'))
mol_tree_batch_lg.pull(
eid, DGLF.copy_u("m", "m"), DGLF.sum("m", "s")
)
mol_tree_batch_lg.pull(
eid, DGLF.copy_u("rm", "rm"), DGLF.sum("rm", "accum_rm")
)
mol_tree_batch_lg.apply_nodes(self.dec_tree_edge_update, v=eid)
mol_tree_batch.edata.update(mol_tree_batch_lg.ndata)
is_new = mol_tree_batch.nodes[v].data['new']
mol_tree_batch.pull(v, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
is_new = mol_tree_batch.nodes[v].data["new"]
mol_tree_batch.pull(v, DGLF.copy_e("m", "m"), DGLF.sum("m", "h"))
mol_tree_batch.apply_nodes(dec_tree_node_update, v=v)
# Extract
n_repr = mol_tree_batch.nodes[v].data
h = n_repr['h']
x = n_repr['x']
h = n_repr["h"]
x = n_repr["x"]
tree_vec_set = tree_vec[root_out_degrees >= 0]
wid = n_repr['wid']
wid = n_repr["wid"]
p_inputs.append(torch.cat([x, h, tree_vec_set], 1))
# Only newly generated nodes are needed for label prediction
# NOTE: The following works since the uncomputed messages are zeros.
......@@ -192,10 +222,13 @@ class DGLJTNNDecoder(nn.Module):
if q_input.shape[0] > 0:
q_inputs.append(q_input)
q_targets.append(q_target)
p_targets.append(torch.zeros(
(root_out_degrees == 0).sum(),
device=root_out_degrees.device,
dtype=torch.int64))
p_targets.append(
torch.zeros(
(root_out_degrees == 0).sum(),
device=root_out_degrees.device,
dtype=torch.int64,
)
)
# Batch compute the stop/label prediction losses
p_inputs = torch.cat(p_inputs, 0)
......@@ -206,9 +239,12 @@ class DGLJTNNDecoder(nn.Module):
q = self.W_o(torch.relu(self.W(q_inputs)))
p = self.U_s(torch.relu(self.U(p_inputs)))[:, 0]
p_loss = F.binary_cross_entropy_with_logits(
p, p_targets.float(), size_average=False
) / n_trees
p_loss = (
F.binary_cross_entropy_with_logits(
p, p_targets.float(), size_average=False
)
/ n_trees
)
q_loss = F.cross_entropy(q, q_targets, size_average=False) / n_trees
p_acc = ((p > 0).long() == p_targets).sum().float() / p_targets.shape[0]
q_acc = (q.max(1)[1] == q_targets).float().sum() / q_targets.shape[0]
......@@ -237,13 +273,14 @@ class DGLJTNNDecoder(nn.Module):
_, root_wid = torch.max(root_score, 1)
root_wid = root_wid.view(1)
mol_tree_graph.add_nodes(1) # root
mol_tree_graph.ndata['wid'] = root_wid
mol_tree_graph.ndata['x'] = self.embedding(root_wid)
mol_tree_graph.ndata['h'] = init_hidden
mol_tree_graph.ndata['fail'] = cuda(torch.tensor([0]))
mol_tree_graph.add_nodes(1) # root
mol_tree_graph.ndata["wid"] = root_wid
mol_tree_graph.ndata["x"] = self.embedding(root_wid)
mol_tree_graph.ndata["h"] = init_hidden
mol_tree_graph.ndata["fail"] = cuda(torch.tensor([0]))
mol_tree.nodes_dict[0] = root_node_dict = create_node_dict(
self.vocab.get_smiles(root_wid))
self.vocab.get_smiles(root_wid)
)
stack, trace = [], []
stack.append((0, self.vocab.get_slots(root_wid)))
......@@ -256,13 +293,13 @@ class DGLJTNNDecoder(nn.Module):
for step in range(MAX_DECODE_LEN):
u, u_slots = stack[-1]
x = mol_tree_graph.ndata['x'][u:u+1]
h = mol_tree_graph.ndata['h'][u:u+1]
x = mol_tree_graph.ndata["x"][u : u + 1]
h = mol_tree_graph.ndata["h"][u : u + 1]
# Predict stop
p_input = torch.cat([x, h, mol_vec], 1)
p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input))))
backtrack = (p_score.item() < 0.5)
backtrack = p_score.item() < 0.5
if not backtrack:
# Predict next clique. Note that the prediction may fail due
......@@ -273,49 +310,61 @@ class DGLJTNNDecoder(nn.Module):
mol_tree_graph.add_edges(u, v)
uv = new_edge_id
new_edge_id += 1
if first:
mol_tree_graph.edata.update({
's': cuda(torch.zeros(1, self.hidden_size)),
'm': cuda(torch.zeros(1, self.hidden_size)),
'r': cuda(torch.zeros(1, self.hidden_size)),
'z': cuda(torch.zeros(1, self.hidden_size)),
'src_x': cuda(torch.zeros(1, self.hidden_size)),
'dst_x': cuda(torch.zeros(1, self.hidden_size)),
'rm': cuda(torch.zeros(1, self.hidden_size)),
'accum_rm': cuda(torch.zeros(1, self.hidden_size)),
})
mol_tree_graph.edata.update(
{
"s": cuda(torch.zeros(1, self.hidden_size)),
"m": cuda(torch.zeros(1, self.hidden_size)),
"r": cuda(torch.zeros(1, self.hidden_size)),
"z": cuda(torch.zeros(1, self.hidden_size)),
"src_x": cuda(torch.zeros(1, self.hidden_size)),
"dst_x": cuda(torch.zeros(1, self.hidden_size)),
"rm": cuda(torch.zeros(1, self.hidden_size)),
"accum_rm": cuda(torch.zeros(1, self.hidden_size)),
}
)
first = False
mol_tree_graph.edata['src_x'][uv] = mol_tree_graph.ndata['x'][u]
mol_tree_graph.edata["src_x"][uv] = mol_tree_graph.ndata["x"][u]
# keeping dst_x 0 is fine as h on new edge doesn't depend on that.
# DGL doesn't dynamically maintain a line graph.
mol_tree_graph_lg = line_graph(mol_tree_graph, backtracking=False, shared=True)
mol_tree_graph_lg = line_graph(
mol_tree_graph, backtracking=False, shared=True
)
mol_tree_graph_lg.pull(
uv,
DGLF.copy_u('m', 'm'),
DGLF.sum('m', 's'))
uv, DGLF.copy_u("m", "m"), DGLF.sum("m", "s")
)
mol_tree_graph_lg.pull(
uv,
DGLF.copy_u('rm', 'rm'),
DGLF.sum('rm', 'accum_rm'))
mol_tree_graph_lg.apply_nodes(self.dec_tree_edge_update.update_zm, v=uv)
uv, DGLF.copy_u("rm", "rm"), DGLF.sum("rm", "accum_rm")
)
mol_tree_graph_lg.apply_nodes(
self.dec_tree_edge_update.update_zm, v=uv
)
mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
mol_tree_graph.pull(v, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
mol_tree_graph.pull(
v, DGLF.copy_e("m", "m"), DGLF.sum("m", "h")
)
h_v = mol_tree_graph.ndata['h'][v:v+1]
h_v = mol_tree_graph.ndata["h"][v : v + 1]
q_input = torch.cat([h_v, mol_vec], 1)
q_score = torch.softmax(self.W_o(torch.relu(self.W(q_input))), -1)
q_score = torch.softmax(
self.W_o(torch.relu(self.W(q_input))), -1
)
_, sort_wid = torch.sort(q_score, 1, descending=True)
sort_wid = sort_wid.squeeze()
next_wid = None
for wid in sort_wid.tolist()[:5]:
slots = self.vocab.get_slots(wid)
cand_node_dict = create_node_dict(self.vocab.get_smiles(wid))
if (have_slots(u_slots, slots) and can_assemble(mol_tree, u, cand_node_dict)):
cand_node_dict = create_node_dict(
self.vocab.get_smiles(wid)
)
if have_slots(u_slots, slots) and can_assemble(
mol_tree, u, cand_node_dict
):
next_wid = wid
next_slots = slots
next_node_dict = cand_node_dict
......@@ -324,44 +373,59 @@ class DGLJTNNDecoder(nn.Module):
if next_wid is None:
# Failed adding an actual children; v is a spurious node
# and we mark it.
mol_tree_graph.ndata['fail'][v] = cuda(torch.tensor([1]))
mol_tree_graph.ndata["fail"][v] = cuda(torch.tensor([1]))
backtrack = True
else:
next_wid = cuda(torch.tensor([next_wid]))
mol_tree_graph.ndata['wid'][v] = next_wid
mol_tree_graph.ndata['x'][v] = self.embedding(next_wid)
mol_tree_graph.ndata["wid"][v] = next_wid
mol_tree_graph.ndata["x"][v] = self.embedding(next_wid)
mol_tree.nodes_dict[v] = next_node_dict
all_nodes[v] = next_node_dict
stack.append((v, next_slots))
mol_tree_graph.add_edges(v, u)
vu = new_edge_id
new_edge_id += 1
mol_tree_graph.edata['dst_x'][uv] = mol_tree_graph.ndata['x'][v]
mol_tree_graph.edata['src_x'][vu] = mol_tree_graph.ndata['x'][v]
mol_tree_graph.edata['dst_x'][vu] = mol_tree_graph.ndata['x'][u]
mol_tree_graph.edata["dst_x"][uv] = mol_tree_graph.ndata[
"x"
][v]
mol_tree_graph.edata["src_x"][vu] = mol_tree_graph.ndata[
"x"
][v]
mol_tree_graph.edata["dst_x"][vu] = mol_tree_graph.ndata[
"x"
][u]
# DGL doesn't dynamically maintain a line graph.
mol_tree_graph_lg = line_graph(mol_tree_graph, backtracking=False, shared=True)
mol_tree_graph_lg = line_graph(
mol_tree_graph, backtracking=False, shared=True
)
mol_tree_graph_lg.apply_nodes(
self.dec_tree_edge_update.update_r,
uv
)
self.dec_tree_edge_update.update_r, uv
)
mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
if backtrack:
if len(stack) == 1:
break # At root, terminate
break # At root, terminate
pu, _ = stack[-2]
u_pu = mol_tree_graph.edge_ids(u, pu)
mol_tree_graph_lg.pull(u_pu, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's'))
mol_tree_graph_lg.pull(u_pu, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm'))
mol_tree_graph_lg.pull(
u_pu, DGLF.copy_u("m", "m"), DGLF.sum("m", "s")
)
mol_tree_graph_lg.pull(
u_pu, DGLF.copy_u("rm", "rm"), DGLF.sum("rm", "accum_rm")
)
mol_tree_graph_lg.apply_nodes(self.dec_tree_edge_update, v=u_pu)
mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)
mol_tree_graph.pull(pu, DGLF.copy_e('m', 'm'), DGLF.sum('m', 'h'))
mol_tree_graph.pull(
pu, DGLF.copy_e("m", "m"), DGLF.sum("m", "h")
)
stack.pop()
effective_nodes = mol_tree_graph.filter_nodes(lambda nodes: nodes.data['fail'] != 1)
effective_nodes = mol_tree_graph.filter_nodes(
lambda nodes: nodes.data["fail"] != 1
)
effective_nodes, _ = torch.sort(effective_nodes)
return mol_tree, all_nodes, effective_nodes
import dgl.function as DGLF
import numpy as np
import torch
import torch.nn as nn
import dgl.function as DGLF
from dgl import batch, bfs_edges_generator, line_graph
from .nnutils import GRUUpdate, cuda, tocpu
from .nnutils import cuda, GRUUpdate, tocpu
MAX_NB = 8
......
......@@ -14,12 +14,10 @@ from .chemutils import (
enum_assemble_nx,
set_atommap,
)
from .jtmpn import DGLJTMPN
from .jtmpn import mol2dgl_single as mol2dgl_dec
from .jtmpn import DGLJTMPN, mol2dgl_single as mol2dgl_dec
from .jtnn_dec import DGLJTNNDecoder
from .jtnn_enc import DGLJTNNEncoder
from .mpn import DGLMPN
from .mpn import mol2dgl_single as mol2dgl_enc
from .mpn import DGLMPN, mol2dgl_single as mol2dgl_enc
from .nnutils import cuda
......
import dgl
import numpy as np
import rdkit.Chem as Chem
import dgl
from .chemutils import (
decode_stereo,
enum_assemble_nx,
......
import dgl
import dgl.function as DGLF
import rdkit.Chem as Chem
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as DGLF
from dgl import line_graph, mean_nodes
from .chemutils import get_mol
......
import os
import dgl
import torch
import torch.nn as nn
import dgl
def cuda(x):
if torch.cuda.is_available() and not os.getenv("NOCUDA", None):
......
import argparse
import torch
import dgl
import torch
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from dgl.nn import LabelPropagation
......
......@@ -20,18 +20,18 @@
import warnings
from time import time
import dgl
import matplotlib.pyplot as plt
import numpy as np
import scipy.sparse as ss
import torch
from dgl import function as fn
from lda_model import LatentDirichletAllocation as LDAModel
from sklearn.datasets import fetch_20newsgroups
from sklearn.decomposition import NMF, LatentDirichletAllocation
from sklearn.decomposition import LatentDirichletAllocation, NMF
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
import dgl
from dgl import function as fn
n_samples = 2000
n_features = 1000
n_components = 10
......
......@@ -23,12 +23,12 @@ import io
import os
import warnings
import dgl
import numpy as np
import scipy as sp
import torch
import dgl
try:
from functools import cached_property
except ImportError:
......
import copy
import itertools
import dgl
import dgl.function as fn
import networkx as nx
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class GNNModule(nn.Module):
def __init__(self, in_feats, out_feats, radius):
......@@ -15,14 +17,22 @@ class GNNModule(nn.Module):
self.radius = radius
new_linear = lambda: nn.Linear(in_feats, out_feats)
new_linear_list = lambda: nn.ModuleList([new_linear() for i in range(radius)])
new_linear_list = lambda: nn.ModuleList(
[new_linear() for i in range(radius)]
)
self.theta_x, self.theta_deg, self.theta_y = \
new_linear(), new_linear(), new_linear()
self.theta_x, self.theta_deg, self.theta_y = (
new_linear(),
new_linear(),
new_linear(),
)
self.theta_list = new_linear_list()
self.gamma_y, self.gamma_deg, self.gamma_x = \
new_linear(), new_linear(), new_linear()
self.gamma_y, self.gamma_deg, self.gamma_x = (
new_linear(),
new_linear(),
new_linear(),
)
self.gamma_list = new_linear_list()
self.bn_x = nn.BatchNorm1d(out_feats)
......@@ -30,43 +40,61 @@ class GNNModule(nn.Module):
def aggregate(self, g, z):
z_list = []
g.ndata['z'] = z
g.update_all(fn.copy_u(u='z', out='m'), fn.sum(msg='m', out='z'))
z_list.append(g.ndata['z'])
g.ndata["z"] = z
g.update_all(fn.copy_u(u="z", out="m"), fn.sum(msg="m", out="z"))
z_list.append(g.ndata["z"])
for i in range(self.radius - 1):
for j in range(2 ** i):
g.update_all(fn.copy_u(u='z', out='m'), fn.sum(msg='m', out='z'))
z_list.append(g.ndata['z'])
for j in range(2**i):
g.update_all(
fn.copy_u(u="z", out="m"), fn.sum(msg="m", out="z")
)
z_list.append(g.ndata["z"])
return z_list
def forward(self, g, lg, x, y, deg_g, deg_lg, pm_pd):
pmpd_x = F.embedding(pm_pd, x)
sum_x = sum(theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x)))
sum_x = sum(
theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x))
)
g.edata['y'] = y
g.update_all(fn.copy_e(e='y', out='m'), fn.sum('m', 'pmpd_y'))
pmpd_y = g.ndata.pop('pmpd_y')
g.edata["y"] = y
g.update_all(fn.copy_e(e="y", out="m"), fn.sum("m", "pmpd_y"))
pmpd_y = g.ndata.pop("pmpd_y")
x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum_x + self.theta_y(pmpd_y)
x = (
self.theta_x(x)
+ self.theta_deg(deg_g * x)
+ sum_x
+ self.theta_y(pmpd_y)
)
n = self.out_feats // 2
x = th.cat([x[:, :n], F.relu(x[:, n:])], 1)
x = self.bn_x(x)
sum_y = sum(gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y)))
sum_y = sum(
gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y))
)
y = self.gamma_y(y) + self.gamma_deg(deg_lg * y) + sum_y + self.gamma_x(pmpd_x)
y = (
self.gamma_y(y)
+ self.gamma_deg(deg_lg * y)
+ sum_y
+ self.gamma_x(pmpd_x)
)
y = th.cat([y[:, :n], F.relu(y[:, n:])], 1)
y = self.bn_y(y)
return x, y
class GNN(nn.Module):
def __init__(self, feats, radius, n_classes):
super(GNN, self).__init__()
self.linear = nn.Linear(feats[-1], n_classes)
self.module_list = nn.ModuleList([GNNModule(m, n, radius)
for m, n in zip(feats[:-1], feats[1:])])
self.module_list = nn.ModuleList(
[GNNModule(m, n, radius) for m, n in zip(feats[:-1], feats[1:])]
)
def forward(self, g, lg, deg_g, deg_lg, pm_pd):
x, y = deg_g, deg_lg
......
......@@ -16,9 +16,9 @@ import numpy as np
import torch as th
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from dgl.data import SBMMixtureDataset
from torch.utils.data import DataLoader
parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", type=int, help="Batch size", default=1)
......
import os
import torch as th
import torch.nn as nn
import tqdm
......@@ -20,35 +21,37 @@ class PBar(object):
class AminerDataset(object):
"""
Download Aminer Dataset from Amazon S3 bucket.
Download Aminer Dataset from Amazon S3 bucket.
"""
def __init__(self, path):
self.url = 'https://data.dgl.ai/dataset/aminer.zip'
def __init__(self, path):
self.url = "https://data.dgl.ai/dataset/aminer.zip"
if not os.path.exists(os.path.join(path, 'aminer.txt')):
print('File not found. Downloading from', self.url)
self._download_and_extract(path, 'aminer.zip')
self.fn = os.path.join(path, 'aminer.txt')
if not os.path.exists(os.path.join(path, "aminer.txt")):
print("File not found. Downloading from", self.url)
self._download_and_extract(path, "aminer.zip")
self.fn = os.path.join(path, "aminer.txt")
def _download_and_extract(self, path, filename):
import shutil, zipfile, zlib
from tqdm import tqdm
import urllib.request
from tqdm import tqdm
fn = os.path.join(path, filename)
with PBar() as pb:
urllib.request.urlretrieve(self.url, fn, pb)
print('Download finished. Unzipping the file...')
print("Download finished. Unzipping the file...")
with zipfile.ZipFile(fn) as zf:
zf.extractall(path)
print('Unzip finished.')
print("Unzip finished.")
class CustomDataset(object):
"""
Custom dataset generated by sampler.py (e.g. NetDBIS)
"""
def __init__(self, path):
self.fn = path
import torch
import argparse
import torch
import torch.optim as optim
from download import AminerDataset, CustomDataset
from model import SkipGramModel
from reading_data import DataReader, Metapath2vecDataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from reading_data import DataReader, Metapath2vecDataset
from model import SkipGramModel
from download import AminerDataset, CustomDataset
class Metapath2VecTrainer:
def __init__(self, args):
......@@ -18,8 +19,13 @@ class Metapath2VecTrainer:
dataset = CustomDataset(args.path)
self.data = DataReader(dataset, args.min_count, args.care_type)
dataset = Metapath2vecDataset(self.data, args.window_size)
self.dataloader = DataLoader(dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.num_workers, collate_fn=dataset.collate)
self.dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
collate_fn=dataset.collate,
)
self.output_file_name = args.output_file
self.emb_size = len(self.data.word2id)
......@@ -35,15 +41,17 @@ class Metapath2VecTrainer:
self.skip_gram_model.cuda()
def train(self):
optimizer = optim.SparseAdam(list(self.skip_gram_model.parameters()), lr=self.initial_lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(self.dataloader))
optimizer = optim.SparseAdam(
list(self.skip_gram_model.parameters()), lr=self.initial_lr
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, len(self.dataloader)
)
for iteration in range(self.iterations):
print("\n\n\nIteration: " + str(iteration + 1))
running_loss = 0.0
for i, sample_batched in enumerate(tqdm(self.dataloader)):
if len(sample_batched[0]) > 1:
pos_u = sample_batched[0].to(self.device)
pos_v = sample_batched[1].to(self.device)
......@@ -59,23 +67,40 @@ class Metapath2VecTrainer:
if i > 0 and i % 500 == 0:
print(" Loss: " + str(running_loss))
self.skip_gram_model.save_embedding(self.data.id2word, self.output_file_name)
self.skip_gram_model.save_embedding(
self.data.id2word, self.output_file_name
)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Metapath2vec")
#parser.add_argument('--input_file', type=str, help="input_file")
parser.add_argument('--aminer', action='store_true', help='Use AMiner dataset')
parser.add_argument('--path', type=str, help="input_path")
parser.add_argument('--output_file', type=str, help='output_file')
parser.add_argument('--dim', default=128, type=int, help="embedding dimensions")
parser.add_argument('--window_size', default=7, type=int, help="context window size")
parser.add_argument('--iterations', default=5, type=int, help="iterations")
parser.add_argument('--batch_size', default=50, type=int, help="batch size")
parser.add_argument('--care_type', default=0, type=int, help="if 1, heterogeneous negative sampling, else normal negative sampling")
parser.add_argument('--initial_lr', default=0.025, type=float, help="learning rate")
parser.add_argument('--min_count', default=5, type=int, help="min count")
parser.add_argument('--num_workers', default=16, type=int, help="number of workers")
# parser.add_argument('--input_file', type=str, help="input_file")
parser.add_argument(
"--aminer", action="store_true", help="Use AMiner dataset"
)
parser.add_argument("--path", type=str, help="input_path")
parser.add_argument("--output_file", type=str, help="output_file")
parser.add_argument(
"--dim", default=128, type=int, help="embedding dimensions"
)
parser.add_argument(
"--window_size", default=7, type=int, help="context window size"
)
parser.add_argument("--iterations", default=5, type=int, help="iterations")
parser.add_argument("--batch_size", default=50, type=int, help="batch size")
parser.add_argument(
"--care_type",
default=0,
type=int,
help="if 1, heterogeneous negative sampling, else normal negative sampling",
)
parser.add_argument(
"--initial_lr", default=0.025, type=float, help="learning rate"
)
parser.add_argument("--min_count", default=5, type=int, help="min count")
parser.add_argument(
"--num_workers", default=16, type=int, help="number of workers"
)
args = parser.parse_args()
m2v = Metapath2VecTrainer(args)
m2v.train()
......@@ -10,7 +10,6 @@ from torch.nn import init
class SkipGramModel(nn.Module):
def __init__(self, emb_size, emb_dimension):
super(SkipGramModel, self).__init__()
self.emb_size = emb_size
......@@ -39,8 +38,8 @@ class SkipGramModel(nn.Module):
def save_embedding(self, id2word, file_name):
embedding = self.u_embeddings.weight.cpu().data.numpy()
with open(file_name, 'w') as f:
f.write('%d %d\n' % (len(id2word), self.emb_dimension))
with open(file_name, "w") as f:
f.write("%d %d\n" % (len(id2word), self.emb_dimension))
for wid, w in id2word.items():
e = ' '.join(map(lambda x: str(x), embedding[wid]))
f.write('%s %s\n' % (w, e))
\ No newline at end of file
e = " ".join(map(lambda x: str(x), embedding[wid]))
f.write("%s %s\n" % (w, e))
import numpy as np
import torch
from torch.utils.data import Dataset
from download import AminerDataset
from torch.utils.data import Dataset
np.random.seed(12345)
class DataReader:
NEGATIVE_TABLE_SIZE = 1e8
def __init__(self, dataset, min_count, care_type):
self.negatives = []
self.discards = []
self.negpos = 0
......@@ -35,7 +36,11 @@ class DataReader:
word_frequency[word] = word_frequency.get(word, 0) + 1
if self.token_count % 1000000 == 0:
print("Read " + str(int(self.token_count / 1000000)) + "M words.")
print(
"Read "
+ str(int(self.token_count / 1000000))
+ "M words."
)
wid = 0
for w, c in word_frequency.items():
......@@ -71,15 +76,18 @@ class DataReader:
def getNegatives(self, target, size): # TODO check equality with target
if self.care_type == 0:
response = self.negatives[self.negpos:self.negpos + size]
response = self.negatives[self.negpos : self.negpos + size]
self.negpos = (self.negpos + size) % len(self.negatives)
if len(response) != size:
return np.concatenate((response, self.negatives[0:self.negpos]))
return np.concatenate(
(response, self.negatives[0 : self.negpos])
)
return response
# -----------------------------------------------------------------------------------------------------------------
class Metapath2vecDataset(Dataset):
def __init__(self, data, window_size):
# read in data, window_size and input filename
......@@ -103,25 +111,44 @@ class Metapath2vecDataset(Dataset):
words = line.split()
if len(words) > 1:
word_ids = [self.data.word2id[w] for w in words if
w in self.data.word2id and np.random.rand() < self.data.discards[self.data.word2id[w]]]
word_ids = [
self.data.word2id[w]
for w in words
if w in self.data.word2id
and np.random.rand()
< self.data.discards[self.data.word2id[w]]
]
pair_catch = []
for i, u in enumerate(word_ids):
for j, v in enumerate(
word_ids[max(i - self.window_size, 0):i + self.window_size]):
word_ids[
max(i - self.window_size, 0) : i
+ self.window_size
]
):
assert u < self.data.word_count
assert v < self.data.word_count
if i == j:
continue
pair_catch.append((u, v, self.data.getNegatives(v,5)))
pair_catch.append(
(u, v, self.data.getNegatives(v, 5))
)
return pair_catch
@staticmethod
def collate(batches):
all_u = [u for batch in batches for u, _, _ in batch if len(batch) > 0]
all_v = [v for batch in batches for _, v, _ in batch if len(batch) > 0]
all_neg_v = [neg_v for batch in batches for _, _, neg_v in batch if len(batch) > 0]
return torch.LongTensor(all_u), torch.LongTensor(all_v), torch.LongTensor(all_neg_v)
all_neg_v = [
neg_v
for batch in batches
for _, _, neg_v in batch
if len(batch) > 0
]
return (
torch.LongTensor(all_u),
torch.LongTensor(all_v),
torch.LongTensor(all_neg_v),
)
import numpy as np
import os
import random
import sys
import time
import tqdm
import dgl
import sys
import os
import numpy as np
import tqdm
num_walks_per_node = 1000
walk_length = 100
path = sys.argv[1]
def construct_graph():
paper_ids = []
paper_names = []
......@@ -31,7 +33,7 @@ def construct_graph():
while True:
w = f_4.readline()
if not w:
break;
break
w = w.strip().split()
identity = int(w[0])
conf_ids.append(identity)
......@@ -39,10 +41,10 @@ def construct_graph():
while True:
v = f_5.readline()
if not v:
break;
break
v = v.strip().split()
identity = int(v[0])
paper_name = 'p' + ''.join(v[1:])
paper_name = "p" + "".join(v[1:])
paper_ids.append(identity)
paper_names.append(paper_name)
f_3.close()
......@@ -60,41 +62,49 @@ def construct_graph():
f_1 = open(os.path.join(path, "paper_author.txt"), "r")
f_2 = open(os.path.join(path, "paper_conf.txt"), "r")
for x in f_1:
x = x.split('\t')
x = x.split("\t")
x[0] = int(x[0])
x[1] = int(x[1].strip('\n'))
x[1] = int(x[1].strip("\n"))
paper_author_src.append(paper_ids_invmap[x[0]])
paper_author_dst.append(author_ids_invmap[x[1]])
for y in f_2:
y = y.split('\t')
y = y.split("\t")
y[0] = int(y[0])
y[1] = int(y[1].strip('\n'))
y[1] = int(y[1].strip("\n"))
paper_conf_src.append(paper_ids_invmap[y[0]])
paper_conf_dst.append(conf_ids_invmap[y[1]])
f_1.close()
f_2.close()
hg = dgl.heterograph({
('paper', 'pa', 'author') : (paper_author_src, paper_author_dst),
('author', 'ap', 'paper') : (paper_author_dst, paper_author_src),
('paper', 'pc', 'conf') : (paper_conf_src, paper_conf_dst),
('conf', 'cp', 'paper') : (paper_conf_dst, paper_conf_src)})
hg = dgl.heterograph(
{
("paper", "pa", "author"): (paper_author_src, paper_author_dst),
("author", "ap", "paper"): (paper_author_dst, paper_author_src),
("paper", "pc", "conf"): (paper_conf_src, paper_conf_dst),
("conf", "cp", "paper"): (paper_conf_dst, paper_conf_src),
}
)
return hg, author_names, conf_names, paper_names
#"conference - paper - Author - paper - conference" metapath sampling
# "conference - paper - Author - paper - conference" metapath sampling
def generate_metapath():
output_path = open(os.path.join(path, "output_path.txt"), "w")
count = 0
hg, author_names, conf_names, paper_names = construct_graph()
for conf_idx in tqdm.trange(hg.number_of_nodes('conf')):
for conf_idx in tqdm.trange(hg.number_of_nodes("conf")):
traces, _ = dgl.sampling.random_walk(
hg, [conf_idx] * num_walks_per_node, metapath=['cp', 'pa', 'ap', 'pc'] * walk_length)
hg,
[conf_idx] * num_walks_per_node,
metapath=["cp", "pa", "ap", "pc"] * walk_length,
)
for tr in traces:
outline = ' '.join(
(conf_names if i % 4 == 0 else author_names)[tr[i]]
for i in range(0, len(tr), 2)) # skip paper
outline = " ".join(
(conf_names if i % 4 == 0 else author_names)[tr[i]]
for i in range(0, len(tr), 2)
) # skip paper
print(outline, file=output_path)
output_path.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment