Unverified Commit 0b9df9d7 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4652)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent f19f05ce
import dgl
import rdkit.Chem as Chem
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \
set_atommap, enum_assemble_nx, decode_stereo
import numpy as np import numpy as np
import rdkit.Chem as Chem
import dgl
from .chemutils import (
decode_stereo,
enum_assemble_nx,
get_clique_mol,
get_mol,
get_smiles,
set_atommap,
tree_decomp,
)
class DGLMolTree(object): class DGLMolTree(object):
def __init__(self, smiles): def __init__(self, smiles):
...@@ -28,21 +38,23 @@ class DGLMolTree(object): ...@@ -28,21 +38,23 @@ class DGLMolTree(object):
cmol = get_clique_mol(self.mol, c) cmol = get_clique_mol(self.mol, c)
csmiles = get_smiles(cmol) csmiles = get_smiles(cmol)
self.nodes_dict[i] = dict( self.nodes_dict[i] = dict(
smiles=csmiles, smiles=csmiles,
mol=get_mol(csmiles), mol=get_mol(csmiles),
clique=c, clique=c,
) )
if min(c) == 0: if min(c) == 0:
root = i root = i
# The clique with atom ID 0 becomes root # The clique with atom ID 0 becomes root
if root > 0: if root > 0:
for attr in self.nodes_dict[0]: for attr in self.nodes_dict[0]:
self.nodes_dict[0][attr], self.nodes_dict[root][attr] = \ self.nodes_dict[0][attr], self.nodes_dict[root][attr] = (
self.nodes_dict[root][attr], self.nodes_dict[0][attr] self.nodes_dict[root][attr],
self.nodes_dict[0][attr],
)
src = np.zeros((len(edges) * 2,), dtype='int') src = np.zeros((len(edges) * 2,), dtype="int")
dst = np.zeros((len(edges) * 2,), dtype='int') dst = np.zeros((len(edges) * 2,), dtype="int")
for i, (_x, _y) in enumerate(edges): for i, (_x, _y) in enumerate(edges):
x = 0 if _x == root else root if _x == 0 else _x x = 0 if _x == root else root if _x == 0 else _x
y = 0 if _y == root else root if _y == 0 else _y y = 0 if _y == root else root if _y == 0 else _y
...@@ -53,10 +65,12 @@ class DGLMolTree(object): ...@@ -53,10 +65,12 @@ class DGLMolTree(object):
self.graph = dgl.graph((src, dst), num_nodes=len(cliques)) self.graph = dgl.graph((src, dst), num_nodes=len(cliques))
for i in self.nodes_dict: for i in self.nodes_dict:
self.nodes_dict[i]['nid'] = i + 1 self.nodes_dict[i]["nid"] = i + 1
if self.graph.out_degrees(i) > 1: # Leaf node mol is not marked if self.graph.out_degrees(i) > 1: # Leaf node mol is not marked
set_atommap(self.nodes_dict[i]['mol'], self.nodes_dict[i]['nid']) set_atommap(
self.nodes_dict[i]['is_leaf'] = (self.graph.out_degrees(i) == 1) self.nodes_dict[i]["mol"], self.nodes_dict[i]["nid"]
)
self.nodes_dict[i]["is_leaf"] = self.graph.out_degrees(i) == 1
def treesize(self): def treesize(self):
return self.graph.number_of_nodes() return self.graph.number_of_nodes()
...@@ -65,49 +79,65 @@ class DGLMolTree(object): ...@@ -65,49 +79,65 @@ class DGLMolTree(object):
node = self.nodes_dict[i] node = self.nodes_dict[i]
clique = [] clique = []
clique.extend(node['clique']) clique.extend(node["clique"])
if not node['is_leaf']: if not node["is_leaf"]:
for cidx in node['clique']: for cidx in node["clique"]:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid']) original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node["nid"])
for j in self.graph.successors(i).numpy(): for j in self.graph.successors(i).numpy():
nei_node = self.nodes_dict[j] nei_node = self.nodes_dict[j]
clique.extend(nei_node['clique']) clique.extend(nei_node["clique"])
if nei_node['is_leaf']: # Leaf node, no need to mark if nei_node["is_leaf"]: # Leaf node, no need to mark
continue continue
for cidx in nei_node['clique']: for cidx in nei_node["clique"]:
# allow singleton node override the atom mapping # allow singleton node override the atom mapping
if cidx not in node['clique'] or len(nei_node['clique']) == 1: if cidx not in node["clique"] or len(nei_node["clique"]) == 1:
atom = original_mol.GetAtomWithIdx(cidx) atom = original_mol.GetAtomWithIdx(cidx)
atom.SetAtomMapNum(nei_node['nid']) atom.SetAtomMapNum(nei_node["nid"])
clique = list(set(clique)) clique = list(set(clique))
label_mol = get_clique_mol(original_mol, clique) label_mol = get_clique_mol(original_mol, clique)
node['label'] = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol))) node["label"] = Chem.MolToSmiles(
node['label_mol'] = get_mol(node['label']) Chem.MolFromSmiles(get_smiles(label_mol))
)
node["label_mol"] = get_mol(node["label"])
for cidx in clique: for cidx in clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0) original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
return node['label'] return node["label"]
def _assemble_node(self, i): def _assemble_node(self, i):
neighbors = [self.nodes_dict[j] for j in self.graph.successors(i).numpy() neighbors = [
if self.nodes_dict[j]['mol'].GetNumAtoms() > 1] self.nodes_dict[j]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True) for j in self.graph.successors(i).numpy()
singletons = [self.nodes_dict[j] for j in self.graph.successors(i).numpy() if self.nodes_dict[j]["mol"].GetNumAtoms() > 1
if self.nodes_dict[j]['mol'].GetNumAtoms() == 1] ]
neighbors = sorted(
neighbors, key=lambda x: x["mol"].GetNumAtoms(), reverse=True
)
singletons = [
self.nodes_dict[j]
for j in self.graph.successors(i).numpy()
if self.nodes_dict[j]["mol"].GetNumAtoms() == 1
]
neighbors = singletons + neighbors neighbors = singletons + neighbors
cands = enum_assemble_nx(self.nodes_dict[i], neighbors) cands = enum_assemble_nx(self.nodes_dict[i], neighbors)
if len(cands) > 0: if len(cands) > 0:
self.nodes_dict[i]['cands'], self.nodes_dict[i]['cand_mols'], _ = list(zip(*cands)) (
self.nodes_dict[i]['cands'] = list(self.nodes_dict[i]['cands']) self.nodes_dict[i]["cands"],
self.nodes_dict[i]['cand_mols'] = list(self.nodes_dict[i]['cand_mols']) self.nodes_dict[i]["cand_mols"],
_,
) = list(zip(*cands))
self.nodes_dict[i]["cands"] = list(self.nodes_dict[i]["cands"])
self.nodes_dict[i]["cand_mols"] = list(
self.nodes_dict[i]["cand_mols"]
)
else: else:
self.nodes_dict[i]['cands'] = [] self.nodes_dict[i]["cands"] = []
self.nodes_dict[i]['cand_mols'] = [] self.nodes_dict[i]["cand_mols"] = []
def recover(self): def recover(self):
for i in self.nodes_dict: for i in self.nodes_dict:
......
import rdkit.Chem as Chem
import torch import torch
import torch.nn as nn import torch.nn as nn
import rdkit.Chem as Chem
import torch.nn.functional as F import torch.nn.functional as F
from .chemutils import get_mol
import dgl import dgl
from dgl import mean_nodes, line_graph
import dgl.function as DGLF import dgl.function as DGLF
from dgl import line_graph, mean_nodes
from .chemutils import get_mol
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', ELEM_LIST = [
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown'] "C",
"N",
"O",
"S",
"F",
"Si",
"P",
"Cl",
"Br",
"Mg",
"Na",
"Ca",
"Fe",
"Al",
"I",
"B",
"K",
"Se",
"Zn",
"H",
"Cu",
"Mn",
"unknown",
]
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1 ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1
BOND_FDIM = 5 + 6 BOND_FDIM = 5 + 6
MAX_NB = 6 MAX_NB = 6
def onek_encoding_unk(x, allowable_set): def onek_encoding_unk(x, allowable_set):
if x not in allowable_set: if x not in allowable_set:
x = allowable_set[-1] x = allowable_set[-1]
return [x == s for s in allowable_set] return [x == s for s in allowable_set]
def atom_features(atom): def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST) return torch.Tensor(
+ onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5]) onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0]) + onek_encoding_unk(atom.GetDegree(), [0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(int(atom.GetChiralTag()), [0,1,2,3]) + onek_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0])
+ [atom.GetIsAromatic()])) + onek_encoding_unk(int(atom.GetChiralTag()), [0, 1, 2, 3])
+ [atom.GetIsAromatic()]
)
def bond_features(bond): def bond_features(bond):
bt = bond.GetBondType() bt = bond.GetBondType()
stereo = int(bond.GetStereo()) stereo = int(bond.GetStereo())
fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()] fbond = [
fstereo = onek_encoding_unk(stereo, [0,1,2,3,4,5]) bt == Chem.rdchem.BondType.SINGLE,
return (torch.Tensor(fbond + fstereo)) bt == Chem.rdchem.BondType.DOUBLE,
bt == Chem.rdchem.BondType.TRIPLE,
bt == Chem.rdchem.BondType.AROMATIC,
bond.IsInRing(),
]
fstereo = onek_encoding_unk(stereo, [0, 1, 2, 3, 4, 5])
return torch.Tensor(fbond + fstereo)
def mol2dgl_single(smiles): def mol2dgl_single(smiles):
n_edges = 0 n_edges = 0
...@@ -61,8 +98,11 @@ def mol2dgl_single(smiles): ...@@ -61,8 +98,11 @@ def mol2dgl_single(smiles):
bond_x.append(features) bond_x.append(features)
graph = dgl.graph((bond_src, bond_dst), num_nodes=n_atoms) graph = dgl.graph((bond_src, bond_dst), num_nodes=n_atoms)
n_edges += n_bonds n_edges += n_bonds
return graph, torch.stack(atom_x), \ return (
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0) graph,
torch.stack(atom_x),
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0),
)
class LoopyBPUpdate(nn.Module): class LoopyBPUpdate(nn.Module):
...@@ -73,10 +113,10 @@ class LoopyBPUpdate(nn.Module): ...@@ -73,10 +113,10 @@ class LoopyBPUpdate(nn.Module):
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False) self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, nodes): def forward(self, nodes):
msg_input = nodes.data['msg_input'] msg_input = nodes.data["msg_input"]
msg_delta = self.W_h(nodes.data['accum_msg']) msg_delta = self.W_h(nodes.data["accum_msg"])
msg = F.relu(msg_input + msg_delta) msg = F.relu(msg_input + msg_delta)
return {'msg': msg} return {"msg": msg}
class GatherUpdate(nn.Module): class GatherUpdate(nn.Module):
...@@ -87,9 +127,9 @@ class GatherUpdate(nn.Module): ...@@ -87,9 +127,9 @@ class GatherUpdate(nn.Module):
self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size) self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
def forward(self, nodes): def forward(self, nodes):
m = nodes.data['m'] m = nodes.data["m"]
return { return {
'h': F.relu(self.W_o(torch.cat([nodes.data['x'], m], 1))), "h": F.relu(self.W_o(torch.cat([nodes.data["x"], m], 1))),
} }
...@@ -121,7 +161,7 @@ class DGLMPN(nn.Module): ...@@ -121,7 +161,7 @@ class DGLMPN(nn.Module):
mol_graph = self.run(mol_graph, mol_line_graph) mol_graph = self.run(mol_graph, mol_line_graph)
# TODO: replace with unbatch or readout # TODO: replace with unbatch or readout
g_repr = mean_nodes(mol_graph, 'h') g_repr = mean_nodes(mol_graph, "h")
self.n_samples_total += n_samples self.n_samples_total += n_samples
self.n_nodes_total += n_nodes self.n_nodes_total += n_nodes
...@@ -134,32 +174,38 @@ class DGLMPN(nn.Module): ...@@ -134,32 +174,38 @@ class DGLMPN(nn.Module):
n_nodes = mol_graph.number_of_nodes() n_nodes = mol_graph.number_of_nodes()
mol_graph.apply_edges( mol_graph.apply_edges(
func=lambda edges: {'src_x': edges.src['x']}, func=lambda edges: {"src_x": edges.src["x"]},
) )
mol_line_graph.ndata.update(mol_graph.edata) mol_line_graph.ndata.update(mol_graph.edata)
e_repr = mol_line_graph.ndata e_repr = mol_line_graph.ndata
bond_features = e_repr['x'] bond_features = e_repr["x"]
source_features = e_repr['src_x'] source_features = e_repr["src_x"]
features = torch.cat([source_features, bond_features], 1) features = torch.cat([source_features, bond_features], 1)
msg_input = self.W_i(features) msg_input = self.W_i(features)
mol_line_graph.ndata.update({ mol_line_graph.ndata.update(
'msg_input': msg_input, {
'msg': F.relu(msg_input), "msg_input": msg_input,
'accum_msg': torch.zeros_like(msg_input), "msg": F.relu(msg_input),
}) "accum_msg": torch.zeros_like(msg_input),
mol_graph.ndata.update({ }
'm': bond_features.new(n_nodes, self.hidden_size).zero_(), )
'h': bond_features.new(n_nodes, self.hidden_size).zero_(), mol_graph.ndata.update(
}) {
"m": bond_features.new(n_nodes, self.hidden_size).zero_(),
"h": bond_features.new(n_nodes, self.hidden_size).zero_(),
}
)
for i in range(self.depth - 1): for i in range(self.depth - 1):
mol_line_graph.update_all(DGLF.copy_u('msg', 'msg'), DGLF.sum('msg', 'accum_msg')) mol_line_graph.update_all(
DGLF.copy_u("msg", "msg"), DGLF.sum("msg", "accum_msg")
)
mol_line_graph.apply_nodes(self.loopy_bp_updater) mol_line_graph.apply_nodes(self.loopy_bp_updater)
mol_graph.edata.update(mol_line_graph.ndata) mol_graph.edata.update(mol_line_graph.ndata)
mol_graph.update_all(DGLF.copy_e('msg', 'msg'), DGLF.sum('msg', 'm')) mol_graph.update_all(DGLF.copy_e("msg", "msg"), DGLF.sum("msg", "m"))
mol_graph.apply_nodes(self.gather_updater) mol_graph.apply_nodes(self.gather_updater)
return mol_graph return mol_graph
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import os
import dgl import dgl
def cuda(x): def cuda(x):
if torch.cuda.is_available() and not os.getenv('NOCUDA', None): if torch.cuda.is_available() and not os.getenv("NOCUDA", None):
return x.to(torch.device('cuda')) # works for both DGLGraph and tensor return x.to(torch.device("cuda")) # works for both DGLGraph and tensor
else: else:
return x return x
...@@ -22,27 +24,28 @@ class GRUUpdate(nn.Module): ...@@ -22,27 +24,28 @@ class GRUUpdate(nn.Module):
self.W_h = nn.Linear(2 * hidden_size, hidden_size) self.W_h = nn.Linear(2 * hidden_size, hidden_size)
def update_zm(self, node): def update_zm(self, node):
src_x = node.data['src_x'] src_x = node.data["src_x"]
s = node.data['s'] s = node.data["s"]
rm = node.data['accum_rm'] rm = node.data["accum_rm"]
z = torch.sigmoid(self.W_z(torch.cat([src_x, s], 1))) z = torch.sigmoid(self.W_z(torch.cat([src_x, s], 1)))
m = torch.tanh(self.W_h(torch.cat([src_x, rm], 1))) m = torch.tanh(self.W_h(torch.cat([src_x, rm], 1)))
m = (1 - z) * s + z * m m = (1 - z) * s + z * m
return {'m': m, 'z': z} return {"m": m, "z": z}
def update_r(self, node, zm=None): def update_r(self, node, zm=None):
dst_x = node.data['dst_x'] dst_x = node.data["dst_x"]
m = node.data['m'] if zm is None else zm['m'] m = node.data["m"] if zm is None else zm["m"]
r_1 = self.W_r(dst_x) r_1 = self.W_r(dst_x)
r_2 = self.U_r(m) r_2 = self.U_r(m)
r = torch.sigmoid(r_1 + r_2) r = torch.sigmoid(r_1 + r_2)
return {'r': r, 'rm': r * m} return {"r": r, "rm": r * m}
def forward(self, node): def forward(self, node):
dic = self.update_zm(node) dic = self.update_zm(node)
dic.update(self.update_r(node, zm=dic)) dic.update(self.update_r(node, zm=dic))
return dic return dic
def tocpu(g): def tocpu(g):
src, dst = g.edges() src, dst = g.edges()
src = src.cpu() src = src.cpu()
......
import math
import random
import sys
from collections import deque
from optparse import OptionParser
import rdkit
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
import math, random, sys
from optparse import OptionParser
from collections import deque
import rdkit
import tqdm import tqdm
from jtnn import * from jtnn import *
from torch.utils.data import DataLoader
torch.multiprocessing.set_sharing_strategy("file_system")
torch.multiprocessing.set_sharing_strategy('file_system')
def worker_init_fn(id_): def worker_init_fn(id_):
lg = rdkit.RDLogger.logger() lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL) lg.setLevel(rdkit.RDLogger.CRITICAL)
worker_init_fn(None) worker_init_fn(None)
parser = OptionParser() parser = OptionParser()
parser.add_option("-t", "--train", dest="train", default='train', help='Training file name') parser.add_option(
parser.add_option("-v", "--vocab", dest="vocab", default='vocab', help='Vocab file name') "-t", "--train", dest="train", default="train", help="Training file name"
)
parser.add_option(
"-v", "--vocab", dest="vocab", default="vocab", help="Vocab file name"
)
parser.add_option("-s", "--save_dir", dest="save_path") parser.add_option("-s", "--save_dir", dest="save_path")
parser.add_option("-m", "--model", dest="model_path", default=None) parser.add_option("-m", "--model", dest="model_path", default=None)
parser.add_option("-b", "--batch", dest="batch_size", default=40) parser.add_option("-b", "--batch", dest="batch_size", default=40)
...@@ -31,7 +39,7 @@ parser.add_option("-d", "--depth", dest="depth", default=3) ...@@ -31,7 +39,7 @@ parser.add_option("-d", "--depth", dest="depth", default=3)
parser.add_option("-z", "--beta", dest="beta", default=1.0) parser.add_option("-z", "--beta", dest="beta", default=1.0)
parser.add_option("-q", "--lr", dest="lr", default=1e-3) parser.add_option("-q", "--lr", dest="lr", default=1e-3)
parser.add_option("-T", "--test", dest="test", action="store_true") parser.add_option("-T", "--test", dest="test", action="store_true")
opts,args = parser.parse_args() opts, args = parser.parse_args()
dataset = JTNNDataset(data=opts.train, vocab=opts.vocab, training=True) dataset = JTNNDataset(data=opts.train, vocab=opts.vocab, training=True)
vocab = dataset.vocab vocab = dataset.vocab
...@@ -55,7 +63,10 @@ else: ...@@ -55,7 +63,10 @@ else:
nn.init.xavier_normal(param) nn.init.xavier_normal(param)
model = cuda(model) model = cuda(model)
print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,)) print(
"Model #Params: %dK"
% (sum([x.nelement() for x in model.parameters()]) / 1000,)
)
optimizer = optim.Adam(model.parameters(), lr=lr) optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9) scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)
...@@ -64,26 +75,28 @@ scheduler.step() ...@@ -64,26 +75,28 @@ scheduler.step()
MAX_EPOCH = 100 MAX_EPOCH = 100
PRINT_ITER = 20 PRINT_ITER = 20
def train(): def train():
dataset.training = True dataset.training = True
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
num_workers=4, num_workers=4,
collate_fn=JTNNCollator(vocab, True), collate_fn=JTNNCollator(vocab, True),
drop_last=True, drop_last=True,
worker_init_fn=worker_init_fn) worker_init_fn=worker_init_fn,
)
for epoch in range(MAX_EPOCH): for epoch in range(MAX_EPOCH):
word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0 word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
for it, batch in enumerate(tqdm.tqdm(dataloader)): for it, batch in enumerate(tqdm.tqdm(dataloader)):
model.zero_grad() model.zero_grad()
try: try:
loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta) loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta)
except: except:
print([t.smiles for t in batch['mol_trees']]) print([t.smiles for t in batch["mol_trees"]])
raise raise
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -99,36 +112,51 @@ def train(): ...@@ -99,36 +112,51 @@ def train():
assm_acc = assm_acc / PRINT_ITER * 100 assm_acc = assm_acc / PRINT_ITER * 100
steo_acc = steo_acc / PRINT_ITER * 100 steo_acc = steo_acc / PRINT_ITER * 100
print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f" % ( print(
kl_div, word_acc, topo_acc, assm_acc, steo_acc, loss.item())) "KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f"
word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0 % (
kl_div,
word_acc,
topo_acc,
assm_acc,
steo_acc,
loss.item(),
)
)
word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
sys.stdout.flush() sys.stdout.flush()
if (it + 1) % 1500 == 0: #Fast annealing if (it + 1) % 1500 == 0: # Fast annealing
scheduler.step() scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0]) print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(), torch.save(
opts.save_path + "/model.iter-%d-%d" % (epoch, it + 1)) model.state_dict(),
opts.save_path + "/model.iter-%d-%d" % (epoch, it + 1),
)
scheduler.step() scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0]) print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(), opts.save_path + "/model.iter-" + str(epoch)) torch.save(
model.state_dict(), opts.save_path + "/model.iter-" + str(epoch)
)
def test(): def test():
dataset.training = False dataset.training = False
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
num_workers=0, num_workers=0,
collate_fn=JTNNCollator(vocab, False), collate_fn=JTNNCollator(vocab, False),
drop_last=True, drop_last=True,
worker_init_fn=worker_init_fn) worker_init_fn=worker_init_fn,
)
# Just an example of molecule decoding; in reality you may want to sample # Just an example of molecule decoding; in reality you may want to sample
# tree and molecule vectors. # tree and molecule vectors.
for it, batch in enumerate(dataloader): for it, batch in enumerate(dataloader):
gt_smiles = batch['mol_trees'][0].smiles gt_smiles = batch["mol_trees"][0].smiles
print(gt_smiles) print(gt_smiles)
model.move_to_cuda(batch) model.move_to_cuda(batch)
_, tree_vec, mol_vec = model.encode(batch) _, tree_vec, mol_vec = model.encode(batch)
...@@ -136,21 +164,28 @@ def test(): ...@@ -136,21 +164,28 @@ def test():
smiles = model.decode(tree_vec, mol_vec) smiles = model.decode(tree_vec, mol_vec)
print(smiles) print(smiles)
if __name__ == '__main__':
if __name__ == "__main__":
if opts.test: if opts.test:
test() test()
else: else:
train() train()
print('# passes:', model.n_passes) print("# passes:", model.n_passes)
print('Total # nodes processed:', model.n_nodes_total) print("Total # nodes processed:", model.n_nodes_total)
print('Total # edges processed:', model.n_edges_total) print("Total # edges processed:", model.n_edges_total)
print('Total # tree nodes processed:', model.n_tree_nodes_total) print("Total # tree nodes processed:", model.n_tree_nodes_total)
print('Graph decoder: # passes:', model.jtmpn.n_passes) print("Graph decoder: # passes:", model.jtmpn.n_passes)
print('Graph decoder: Total # candidates processed:', model.jtmpn.n_samples_total) print(
print('Graph decoder: Total # nodes processed:', model.jtmpn.n_nodes_total) "Graph decoder: Total # candidates processed:",
print('Graph decoder: Total # edges processed:', model.jtmpn.n_edges_total) model.jtmpn.n_samples_total,
print('Graph encoder: # passes:', model.mpn.n_passes) )
print('Graph encoder: Total # candidates processed:', model.mpn.n_samples_total) print("Graph decoder: Total # nodes processed:", model.jtmpn.n_nodes_total)
print('Graph encoder: Total # nodes processed:', model.mpn.n_nodes_total) print("Graph decoder: Total # edges processed:", model.jtmpn.n_edges_total)
print('Graph encoder: Total # edges processed:', model.mpn.n_edges_total) print("Graph encoder: # passes:", model.mpn.n_passes)
print(
"Graph encoder: Total # candidates processed:",
model.mpn.n_samples_total,
)
print("Graph encoder: Total # nodes processed:", model.mpn.n_nodes_total)
print("Graph encoder: Total # edges processed:", model.mpn.n_edges_total)
import argparse import argparse
import torch import torch
import dgl import dgl
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from dgl.nn import LabelPropagation from dgl.nn import LabelPropagation
def main(): def main():
# check cuda # check cuda
device = f'cuda:{args.gpu}' if torch.cuda.is_available() and args.gpu >= 0 else 'cpu' device = (
f"cuda:{args.gpu}"
if torch.cuda.is_available() and args.gpu >= 0
else "cpu"
)
# load data # load data
if args.dataset == 'Cora': if args.dataset == "Cora":
dataset = CoraGraphDataset() dataset = CoraGraphDataset()
elif args.dataset == 'Citeseer': elif args.dataset == "Citeseer":
dataset = CiteseerGraphDataset() dataset = CiteseerGraphDataset()
elif args.dataset == 'Pubmed': elif args.dataset == "Pubmed":
dataset = PubmedGraphDataset() dataset = PubmedGraphDataset()
else: else:
raise ValueError('Dataset {} is invalid.'.format(args.dataset)) raise ValueError("Dataset {} is invalid.".format(args.dataset))
g = dataset[0] g = dataset[0]
g = dgl.add_self_loop(g) g = dgl.add_self_loop(g)
labels = g.ndata.pop('label').to(device).long() labels = g.ndata.pop("label").to(device).long()
# load masks for train / test, valid is not used. # load masks for train / test, valid is not used.
train_mask = g.ndata.pop('train_mask') train_mask = g.ndata.pop("train_mask")
test_mask = g.ndata.pop('test_mask') test_mask = g.ndata.pop("test_mask")
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device) train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze().to(device) test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze().to(device)
g = g.to(device) g = g.to(device)
# label propagation # label propagation
lp = LabelPropagation(args.num_layers, args.alpha) lp = LabelPropagation(args.num_layers, args.alpha)
logits = lp(g, labels, mask=train_idx) logits = lp(g, labels, mask=train_idx)
test_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx) test_acc = torch.sum(
logits[test_idx].argmax(dim=1) == labels[test_idx]
).item() / len(test_idx)
print("Test Acc {:.4f}".format(test_acc)) print("Test Acc {:.4f}".format(test_acc))
if __name__ == '__main__': if __name__ == "__main__":
""" """
Label Propagation Hyperparameters Label Propagation Hyperparameters
""" """
parser = argparse.ArgumentParser(description='LP') parser = argparse.ArgumentParser(description="LP")
parser.add_argument('--gpu', type=int, default=0) parser.add_argument("--gpu", type=int, default=0)
parser.add_argument('--dataset', type=str, default='Cora') parser.add_argument("--dataset", type=str, default="Cora")
parser.add_argument('--num-layers', type=int, default=10) parser.add_argument("--num-layers", type=int, default=10)
parser.add_argument('--alpha', type=float, default=0.5) parser.add_argument("--alpha", type=float, default=0.5)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -17,49 +17,49 @@ ...@@ -17,49 +17,49 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from time import time
import matplotlib.pyplot as plt
import warnings import warnings
from time import time
import matplotlib.pyplot as plt
import numpy as np import numpy as np
import scipy.sparse as ss import scipy.sparse as ss
import torch import torch
import dgl from lda_model import LatentDirichletAllocation as LDAModel
from dgl import function as fn
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.decomposition import NMF, LatentDirichletAllocation
from sklearn.datasets import fetch_20newsgroups from sklearn.datasets import fetch_20newsgroups
from sklearn.decomposition import NMF, LatentDirichletAllocation
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from lda_model import LatentDirichletAllocation as LDAModel import dgl
from dgl import function as fn
n_samples = 2000 n_samples = 2000
n_features = 1000 n_features = 1000
n_components = 10 n_components = 10
n_top_words = 20 n_top_words = 20
device = 'cuda' device = "cuda"
def plot_top_words(model, feature_names, n_top_words, title): def plot_top_words(model, feature_names, n_top_words, title):
fig, axes = plt.subplots(2, 5, figsize=(30, 15), sharex=True) fig, axes = plt.subplots(2, 5, figsize=(30, 15), sharex=True)
axes = axes.flatten() axes = axes.flatten()
for topic_idx, topic in enumerate(model.components_): for topic_idx, topic in enumerate(model.components_):
top_features_ind = topic.argsort()[:-n_top_words - 1:-1] top_features_ind = topic.argsort()[: -n_top_words - 1 : -1]
top_features = [feature_names[i] for i in top_features_ind] top_features = [feature_names[i] for i in top_features_ind]
weights = topic[top_features_ind] weights = topic[top_features_ind]
ax = axes[topic_idx] ax = axes[topic_idx]
ax.barh(top_features, weights, height=0.7) ax.barh(top_features, weights, height=0.7)
ax.set_title(f'Topic {topic_idx +1}', ax.set_title(f"Topic {topic_idx +1}", fontdict={"fontsize": 30})
fontdict={'fontsize': 30})
ax.invert_yaxis() ax.invert_yaxis()
ax.tick_params(axis='both', which='major', labelsize=20) ax.tick_params(axis="both", which="major", labelsize=20)
for i in 'top right left'.split(): for i in "top right left".split():
ax.spines[i].set_visible(False) ax.spines[i].set_visible(False)
fig.suptitle(title, fontsize=40) fig.suptitle(title, fontsize=40)
plt.subplots_adjust(top=0.90, bottom=0.05, wspace=0.90, hspace=0.3) plt.subplots_adjust(top=0.90, bottom=0.05, wspace=0.90, hspace=0.3)
plt.show() plt.show()
# Load the 20 newsgroups dataset and vectorize it. We use a few heuristics # Load the 20 newsgroups dataset and vectorize it. We use a few heuristics
# to filter out useless terms early on: the posts are stripped of headers, # to filter out useless terms early on: the posts are stripped of headers,
# footers and quoted replies, and common English words, words occurring in # footers and quoted replies, and common English words, words occurring in
...@@ -67,43 +67,50 @@ def plot_top_words(model, feature_names, n_top_words, title): ...@@ -67,43 +67,50 @@ def plot_top_words(model, feature_names, n_top_words, title):
print("Loading dataset...") print("Loading dataset...")
t0 = time() t0 = time()
data, _ = fetch_20newsgroups(shuffle=True, random_state=1, data, _ = fetch_20newsgroups(
remove=('headers', 'footers', 'quotes'), shuffle=True,
return_X_y=True) random_state=1,
remove=("headers", "footers", "quotes"),
return_X_y=True,
)
data_samples = data[:n_samples] data_samples = data[:n_samples]
data_test = data[n_samples:2*n_samples] data_test = data[n_samples : 2 * n_samples]
print("done in %0.3fs." % (time() - t0)) print("done in %0.3fs." % (time() - t0))
# Use tf (raw term count) features for LDA. # Use tf (raw term count) features for LDA.
print("Extracting tf features for LDA...") print("Extracting tf features for LDA...")
tf_vectorizer = CountVectorizer(max_df=0.95, min_df=2, tf_vectorizer = CountVectorizer(
max_features=n_features, max_df=0.95, min_df=2, max_features=n_features, stop_words="english"
stop_words='english') )
t0 = time() t0 = time()
tf_vectorizer.fit(data) tf_vectorizer.fit(data)
tf = tf_vectorizer.transform(data_samples) tf = tf_vectorizer.transform(data_samples)
tt = tf_vectorizer.transform(data_test) tt = tf_vectorizer.transform(data_test)
tf_feature_names = tf_vectorizer.get_feature_names() tf_feature_names = tf_vectorizer.get_feature_names()
tf_uv = [(u,v) tf_uv = [
for u,v,e in zip(tf.tocoo().row, tf.tocoo().col, tf.tocoo().data) (u, v)
for _ in range(e)] for u, v, e in zip(tf.tocoo().row, tf.tocoo().col, tf.tocoo().data)
tt_uv = [(u,v) for _ in range(e)
for u,v,e in zip(tt.tocoo().row, tt.tocoo().col, tt.tocoo().data) ]
for _ in range(e)] tt_uv = [
(u, v)
for u, v, e in zip(tt.tocoo().row, tt.tocoo().col, tt.tocoo().data)
for _ in range(e)
]
print("done in %0.3fs." % (time() - t0)) print("done in %0.3fs." % (time() - t0))
print() print()
print("Preparing dgl graphs...") print("Preparing dgl graphs...")
t0 = time() t0 = time()
G = dgl.heterograph({('doc','topic','word'): tf_uv}, device=device) G = dgl.heterograph({("doc", "topic", "word"): tf_uv}, device=device)
Gt = dgl.heterograph({('doc','topic','word'): tt_uv}, device=device) Gt = dgl.heterograph({("doc", "topic", "word"): tt_uv}, device=device)
print("done in %0.3fs." % (time() - t0)) print("done in %0.3fs." % (time() - t0))
print() print()
print("Training dgl-lda model...") print("Training dgl-lda model...")
t0 = time() t0 = time()
model = LDAModel(G.num_nodes('word'), n_components) model = LDAModel(G.num_nodes("word"), n_components)
model.fit(G) model.fit(G)
print("done in %0.3fs." % (time() - t0)) print("done in %0.3fs." % (time() - t0))
print() print()
...@@ -113,20 +120,27 @@ print(f"dgl-lda testing perplexity {model.perplexity(Gt):.3f}") ...@@ -113,20 +120,27 @@ print(f"dgl-lda testing perplexity {model.perplexity(Gt):.3f}")
word_nphi = np.vstack([nphi.tolist() for nphi in model.word_data.nphi]) word_nphi = np.vstack([nphi.tolist() for nphi in model.word_data.nphi])
plot_top_words( plot_top_words(
type('dummy', (object,), {'components_': word_nphi}), type("dummy", (object,), {"components_": word_nphi}),
tf_feature_names, n_top_words, 'Topics in LDA model') tf_feature_names,
n_top_words,
"Topics in LDA model",
)
print("Training scikit-learn model...") print("Training scikit-learn model...")
print('\n' * 2, "Fitting LDA models with tf features, " print(
"n_samples=%d and n_features=%d..." "\n" * 2,
% (n_samples, n_features)) "Fitting LDA models with tf features, "
lda = LatentDirichletAllocation(n_components=n_components, max_iter=5, "n_samples=%d and n_features=%d..." % (n_samples, n_features),
learning_method='online', )
learning_offset=50., lda = LatentDirichletAllocation(
random_state=0, n_components=n_components,
verbose=1, max_iter=5,
) learning_method="online",
learning_offset=50.0,
random_state=0,
verbose=1,
)
t0 = time() t0 = time()
lda.fit(tf) lda.fit(tf)
print("done in %0.3fs." % (time() - t0)) print("done in %0.3fs." % (time() - t0))
......
...@@ -17,8 +17,17 @@ ...@@ -17,8 +17,17 @@
# limitations under the License. # limitations under the License.
import os, functools, warnings, torch, collections, dgl, io import collections
import numpy as np, scipy as sp import functools
import io
import os
import warnings
import numpy as np
import scipy as sp
import torch
import dgl
try: try:
from functools import cached_property from functools import cached_property
...@@ -37,17 +46,21 @@ class EdgeData: ...@@ -37,17 +46,21 @@ class EdgeData:
@property @property
def loglike(self): def loglike(self):
return (self.src_data['Elog'] + self.dst_data['Elog']).logsumexp(1) return (self.src_data["Elog"] + self.dst_data["Elog"]).logsumexp(1)
@property @property
def phi(self): def phi(self):
return ( return (
self.src_data['Elog'] + self.dst_data['Elog'] - self.loglike.unsqueeze(1) self.src_data["Elog"]
+ self.dst_data["Elog"]
- self.loglike.unsqueeze(1)
).exp() ).exp()
@property @property
def expectation(self): def expectation(self):
return (self.src_data['expectation'] * self.dst_data['expectation']).sum(1) return (
self.src_data["expectation"] * self.dst_data["expectation"]
).sum(1)
class _Dirichlet: class _Dirichlet:
...@@ -55,10 +68,13 @@ class _Dirichlet: ...@@ -55,10 +68,13 @@ class _Dirichlet:
self.prior = prior self.prior = prior
self.nphi = nphi self.nphi = nphi
self.device = nphi.device self.device = nphi.device
self._sum_by_parts = lambda map_fn: functools.reduce(torch.add, [ self._sum_by_parts = lambda map_fn: functools.reduce(
map_fn(slice(i, min(i+_chunksize, nphi.shape[1]))).sum(1) torch.add,
for i in list(range(0, nphi.shape[1], _chunksize)) [
]) map_fn(slice(i, min(i + _chunksize, nphi.shape[1]))).sum(1)
for i in list(range(0, nphi.shape[1], _chunksize))
],
)
def _posterior(self, _ID=slice(None)): def _posterior(self, _ID=slice(None)):
return self.prior + self.nphi[:, _ID] return self.prior + self.nphi[:, _ID]
...@@ -68,14 +84,15 @@ class _Dirichlet: ...@@ -68,14 +84,15 @@ class _Dirichlet:
return self.nphi.sum(1) + self.prior * self.nphi.shape[1] return self.nphi.sum(1) + self.prior * self.nphi.shape[1]
def _Elog(self, _ID=slice(None)): def _Elog(self, _ID=slice(None)):
return torch.digamma(self._posterior(_ID)) - \ return torch.digamma(self._posterior(_ID)) - torch.digamma(
torch.digamma(self.posterior_sum.unsqueeze(1)) self.posterior_sum.unsqueeze(1)
)
@cached_property @cached_property
def loglike(self): def loglike(self):
neg_evid = -self._sum_by_parts( neg_evid = -self._sum_by_parts(
lambda s: (self.nphi[:, s] * self._Elog(s)) lambda s: (self.nphi[:, s] * self._Elog(s))
) )
prior = torch.as_tensor(self.prior).to(self.nphi) prior = torch.as_tensor(self.prior).to(self.nphi)
K = self.nphi.shape[1] K = self.nphi.shape[1]
...@@ -83,7 +100,7 @@ class _Dirichlet: ...@@ -83,7 +100,7 @@ class _Dirichlet:
log_B_posterior = self._sum_by_parts( log_B_posterior = self._sum_by_parts(
lambda s: torch.lgamma(self._posterior(s)) lambda s: torch.lgamma(self._posterior(s))
) - torch.lgamma(self.posterior_sum) ) - torch.lgamma(self.posterior_sum)
return neg_evid - log_B_prior + log_B_posterior return neg_evid - log_B_prior + log_B_posterior
...@@ -105,9 +122,15 @@ class _Dirichlet: ...@@ -105,9 +122,15 @@ class _Dirichlet:
@cached_property @cached_property
def Bayesian_gap(self): def Bayesian_gap(self):
return 1. - self._sum_by_parts(lambda s: self._Elog(s).exp()) return 1.0 - self._sum_by_parts(lambda s: self._Elog(s).exp())
_cached_properties = ["posterior_sum", "loglike", "n", "cdf", "Bayesian_gap"] _cached_properties = [
"posterior_sum",
"loglike",
"n",
"cdf",
"Bayesian_gap",
]
def clear_cache(self): def clear_cache(self):
for name in self._cached_properties: for name in self._cached_properties:
...@@ -117,27 +140,29 @@ class _Dirichlet: ...@@ -117,27 +140,29 @@ class _Dirichlet:
pass pass
def update(self, new, _ID=slice(None), rho=1): def update(self, new, _ID=slice(None), rho=1):
""" inplace: old * (1-rho) + new * rho """ """inplace: old * (1-rho) + new * rho"""
self.clear_cache() self.clear_cache()
mean_change = (self.nphi[:, _ID] - new).abs().mean().tolist() mean_change = (self.nphi[:, _ID] - new).abs().mean().tolist()
self.nphi *= (1 - rho) self.nphi *= 1 - rho
self.nphi[:, _ID] += new * rho self.nphi[:, _ID] += new * rho
return mean_change return mean_change
class DocData(_Dirichlet): class DocData(_Dirichlet):
""" nphi (n_docs by n_topics) """ """nphi (n_docs by n_topics)"""
def prepare_graph(self, G, key="Elog"): def prepare_graph(self, G, key="Elog"):
G.nodes['doc'].data[key] = getattr(self, '_'+key)().to(G.device) G.nodes["doc"].data[key] = getattr(self, "_" + key)().to(G.device)
def update_from(self, G, mult): def update_from(self, G, mult):
new = G.nodes['doc'].data['nphi'] * mult new = G.nodes["doc"].data["nphi"] * mult
return self.update(new.to(self.device)) return self.update(new.to(self.device))
class _Distributed(collections.UserList): class _Distributed(collections.UserList):
""" split on dim=0 and store on multiple devices """ """split on dim=0 and store on multiple devices"""
def __init__(self, prior, nphi): def __init__(self, prior, nphi):
self.prior = prior self.prior = prior
self.nphi = nphi self.nphi = nphi
...@@ -146,36 +171,38 @@ class _Distributed(collections.UserList): ...@@ -146,36 +171,38 @@ class _Distributed(collections.UserList):
def split_device(self, other, dim=0): def split_device(self, other, dim=0):
split_sections = [x.shape[0] for x in self.nphi] split_sections = [x.shape[0] for x in self.nphi]
out = torch.split(other, split_sections, dim) out = torch.split(other, split_sections, dim)
return [y.to(x.device) for x,y in zip(self.nphi, out)] return [y.to(x.device) for x, y in zip(self.nphi, out)]
class WordData(_Distributed): class WordData(_Distributed):
""" distributed nphi (n_topics by n_words), transpose to/from graph nodes data """ """distributed nphi (n_topics by n_words), transpose to/from graph nodes data"""
def prepare_graph(self, G, key="Elog"): def prepare_graph(self, G, key="Elog"):
if '_ID' in G.nodes['word'].data: if "_ID" in G.nodes["word"].data:
_ID = G.nodes['word'].data['_ID'] _ID = G.nodes["word"].data["_ID"]
else: else:
_ID = slice(None) _ID = slice(None)
out = [getattr(part, '_'+key)(_ID).to(G.device) for part in self] out = [getattr(part, "_" + key)(_ID).to(G.device) for part in self]
G.nodes['word'].data[key] = torch.cat(out).T G.nodes["word"].data[key] = torch.cat(out).T
def update_from(self, G, mult, rho): def update_from(self, G, mult, rho):
nphi = G.nodes['word'].data['nphi'].T * mult nphi = G.nodes["word"].data["nphi"].T * mult
if '_ID' in G.nodes['word'].data: if "_ID" in G.nodes["word"].data:
_ID = G.nodes['word'].data['_ID'] _ID = G.nodes["word"].data["_ID"]
else: else:
_ID = slice(None) _ID = slice(None)
mean_change = [x.update(y, _ID, rho) mean_change = [
for x, y in zip(self, self.split_device(nphi))] x.update(y, _ID, rho) for x, y in zip(self, self.split_device(nphi))
]
return np.mean(mean_change) return np.mean(mean_change)
class Gamma(collections.namedtuple('Gamma', "concentration, rate")): class Gamma(collections.namedtuple("Gamma", "concentration, rate")):
""" articulate the difference between torch gamma and numpy gamma """ """articulate the difference between torch gamma and numpy gamma"""
@property @property
def shape(self): def shape(self):
return self.concentration return self.concentration
...@@ -218,20 +245,23 @@ class LatentDirichletAllocation: ...@@ -218,20 +245,23 @@ class LatentDirichletAllocation:
(NIPS 2010). (NIPS 2010).
[2] Reactive LDA Library blogpost by Yingjie Miao for a similar Gibbs model [2] Reactive LDA Library blogpost by Yingjie Miao for a similar Gibbs model
""" """
def __init__( def __init__(
self, n_words, n_components, self,
n_words,
n_components,
prior=None, prior=None,
rho=1, rho=1,
mult={'doc': 1, 'word': 1}, mult={"doc": 1, "word": 1},
init={'doc': (100., 100.), 'word': (100., 100.)}, init={"doc": (100.0, 100.0), "word": (100.0, 100.0)},
device_list=['cpu'], device_list=["cpu"],
verbose=True, verbose=True,
): ):
self.n_words = n_words self.n_words = n_words
self.n_components = n_components self.n_components = n_components
if prior is None: if prior is None:
prior = {'doc': 1./n_components, 'word': 1./n_components} prior = {"doc": 1.0 / n_components, "word": 1.0 / n_components}
self.prior = prior self.prior = prior
self.rho = rho self.rho = rho
...@@ -239,117 +269,128 @@ class LatentDirichletAllocation: ...@@ -239,117 +269,128 @@ class LatentDirichletAllocation:
self.init = init self.init = init
assert not isinstance(device_list, str), "plz wrap devices in a list" assert not isinstance(device_list, str), "plz wrap devices in a list"
self.device_list = device_list[:n_components] # avoid edge cases self.device_list = device_list[:n_components] # avoid edge cases
self.verbose = verbose self.verbose = verbose
self._init_word_data() self._init_word_data()
def _init_word_data(self): def _init_word_data(self):
split_sections = np.diff( split_sections = np.diff(
np.linspace(0, self.n_components, len(self.device_list)+1).astype(int) np.linspace(0, self.n_components, len(self.device_list) + 1).astype(
int
)
) )
word_nphi = [ word_nphi = [
Gamma(*self.init['word']).sample((s, self.n_words), device) Gamma(*self.init["word"]).sample((s, self.n_words), device)
for s, device in zip(split_sections, self.device_list) for s, device in zip(split_sections, self.device_list)
] ]
self.word_data = WordData(self.prior['word'], word_nphi) self.word_data = WordData(self.prior["word"], word_nphi)
def _init_doc_data(self, n_docs, device): def _init_doc_data(self, n_docs, device):
doc_nphi = Gamma(*self.init['doc']).sample( doc_nphi = Gamma(*self.init["doc"]).sample(
(n_docs, self.n_components), device) (n_docs, self.n_components), device
return DocData(self.prior['doc'], doc_nphi) )
return DocData(self.prior["doc"], doc_nphi)
def save(self, f): def save(self, f):
for w in self.word_data: for w in self.word_data:
w.clear_cache() w.clear_cache()
torch.save({ torch.save(
'prior': self.prior, {
'rho': self.rho, "prior": self.prior,
'mult': self.mult, "rho": self.rho,
'init': self.init, "mult": self.mult,
'word_data': [part.nphi for part in self.word_data], "init": self.init,
}, f) "word_data": [part.nphi for part in self.word_data],
},
f,
)
def _prepare_graph(self, G, doc_data, key="Elog"): def _prepare_graph(self, G, doc_data, key="Elog"):
doc_data.prepare_graph(G, key) doc_data.prepare_graph(G, key)
self.word_data.prepare_graph(G, key) self.word_data.prepare_graph(G, key)
def _e_step(self, G, doc_data=None, mean_change_tol=1e-3, max_iters=100): def _e_step(self, G, doc_data=None, mean_change_tol=1e-3, max_iters=100):
"""_e_step implements doc data sampling until convergence or max_iters """_e_step implements doc data sampling until convergence or max_iters"""
"""
if doc_data is None: if doc_data is None:
doc_data = self._init_doc_data(G.num_nodes('doc'), G.device) doc_data = self._init_doc_data(G.num_nodes("doc"), G.device)
G_rev = G.reverse() # word -> doc G_rev = G.reverse() # word -> doc
self.word_data.prepare_graph(G_rev) self.word_data.prepare_graph(G_rev)
for i in range(max_iters): for i in range(max_iters):
doc_data.prepare_graph(G_rev) doc_data.prepare_graph(G_rev)
G_rev.update_all( G_rev.update_all(
lambda edges: {'phi': EdgeData(edges.src, edges.dst).phi}, lambda edges: {"phi": EdgeData(edges.src, edges.dst).phi},
dgl.function.sum('phi', 'nphi') dgl.function.sum("phi", "nphi"),
) )
mean_change = doc_data.update_from(G_rev, self.mult['doc']) mean_change = doc_data.update_from(G_rev, self.mult["doc"])
if mean_change < mean_change_tol: if mean_change < mean_change_tol:
break break
if self.verbose: if self.verbose:
print(f"e-step num_iters={i+1} with mean_change={mean_change:.4f}, " print(
f"perplexity={self.perplexity(G, doc_data):.4f}") f"e-step num_iters={i+1} with mean_change={mean_change:.4f}, "
f"perplexity={self.perplexity(G, doc_data):.4f}"
)
return doc_data return doc_data
transform = _e_step transform = _e_step
def predict(self, doc_data): def predict(self, doc_data):
pred_scores = [ pred_scores = [
# d_exp @ w._expectation() # d_exp @ w._expectation()
(lambda x: x @ w.nphi + x.sum(1, keepdims=True) * w.prior) (lambda x: x @ w.nphi + x.sum(1, keepdims=True) * w.prior)(
(d_exp / w.posterior_sum.unsqueeze(0)) d_exp / w.posterior_sum.unsqueeze(0)
)
for (d_exp, w) in zip( for (d_exp, w) in zip(
self.word_data.split_device(doc_data._expectation(), dim=1), self.word_data.split_device(doc_data._expectation(), dim=1),
self.word_data) self.word_data,
)
] ]
x = torch.zeros_like(pred_scores[0], device=doc_data.device) x = torch.zeros_like(pred_scores[0], device=doc_data.device)
for p in pred_scores: for p in pred_scores:
x += p.to(x.device) x += p.to(x.device)
return x return x
def sample(self, doc_data, num_samples): def sample(self, doc_data, num_samples):
""" draw independent words and return the marginal probabilities, """draw independent words and return the marginal probabilities,
i.e., the expectations in Dirichlet distributions. i.e., the expectations in Dirichlet distributions.
""" """
def fn(cdf): def fn(cdf):
u = torch.rand(cdf.shape[0], num_samples, device=cdf.device) u = torch.rand(cdf.shape[0], num_samples, device=cdf.device)
return torch.searchsorted(cdf, u).to(doc_data.device) return torch.searchsorted(cdf, u).to(doc_data.device)
topic_ids = fn(doc_data.cdf) topic_ids = fn(doc_data.cdf)
word_ids = torch.cat([fn(part.cdf) for part in self.word_data]) word_ids = torch.cat([fn(part.cdf) for part in self.word_data])
ids = torch.gather(word_ids, 0, topic_ids) # pick components by topic_ids ids = torch.gather(
word_ids, 0, topic_ids
) # pick components by topic_ids
# compute expectation scores on sampled ids # compute expectation scores on sampled ids
src_ids = torch.arange( src_ids = (
ids.shape[0], dtype=ids.dtype, device=ids.device torch.arange(ids.shape[0], dtype=ids.dtype, device=ids.device)
).reshape((-1, 1)).expand(ids.shape) .reshape((-1, 1))
unique_ids, inverse_ids = torch.unique(ids, sorted=False, return_inverse=True) .expand(ids.shape)
)
unique_ids, inverse_ids = torch.unique(
ids, sorted=False, return_inverse=True
)
G = dgl.heterograph({('doc','','word'): (src_ids.ravel(), inverse_ids.ravel())}) G = dgl.heterograph(
G.nodes['word'].data['_ID'] = unique_ids {("doc", "", "word"): (src_ids.ravel(), inverse_ids.ravel())}
)
G.nodes["word"].data["_ID"] = unique_ids
self._prepare_graph(G, doc_data, "expectation") self._prepare_graph(G, doc_data, "expectation")
G.apply_edges(lambda e: {'expectation': EdgeData(e.src, e.dst).expectation}) G.apply_edges(
expectation = G.edata.pop('expectation').reshape(ids.shape) lambda e: {"expectation": EdgeData(e.src, e.dst).expectation}
)
expectation = G.edata.pop("expectation").reshape(ids.shape)
return ids, expectation return ids, expectation
def _m_step(self, G, doc_data): def _m_step(self, G, doc_data):
"""_m_step implements word data sampling and stores word_z stats. """_m_step implements word data sampling and stores word_z stats.
mean_change is in the sense of full graph with rho=1. mean_change is in the sense of full graph with rho=1.
...@@ -357,26 +398,25 @@ class LatentDirichletAllocation: ...@@ -357,26 +398,25 @@ class LatentDirichletAllocation:
G = G.clone() G = G.clone()
self._prepare_graph(G, doc_data) self._prepare_graph(G, doc_data)
G.update_all( G.update_all(
lambda edges: {'phi': EdgeData(edges.src, edges.dst).phi}, lambda edges: {"phi": EdgeData(edges.src, edges.dst).phi},
dgl.function.sum('phi', 'nphi') dgl.function.sum("phi", "nphi"),
) )
self._last_mean_change = self.word_data.update_from( self._last_mean_change = self.word_data.update_from(
G, self.mult['word'], self.rho) G, self.mult["word"], self.rho
)
if self.verbose: if self.verbose:
print(f"m-step mean_change={self._last_mean_change:.4f}, ", end="") print(f"m-step mean_change={self._last_mean_change:.4f}, ", end="")
Bayesian_gap = np.mean([ Bayesian_gap = np.mean(
part.Bayesian_gap.mean().tolist() for part in self.word_data [part.Bayesian_gap.mean().tolist() for part in self.word_data]
]) )
print(f"Bayesian_gap={Bayesian_gap:.4f}") print(f"Bayesian_gap={Bayesian_gap:.4f}")
def partial_fit(self, G): def partial_fit(self, G):
doc_data = self._e_step(G) doc_data = self._e_step(G)
self._m_step(G, doc_data) self._m_step(G, doc_data)
return self return self
def fit(self, G, mean_change_tol=1e-3, max_epochs=10): def fit(self, G, mean_change_tol=1e-3, max_epochs=10):
for i in range(max_epochs): for i in range(max_epochs):
if self.verbose: if self.verbose:
...@@ -387,7 +427,6 @@ class LatentDirichletAllocation: ...@@ -387,7 +427,6 @@ class LatentDirichletAllocation:
break break
return self return self
def perplexity(self, G, doc_data=None): def perplexity(self, G, doc_data=None):
"""ppl = exp{-sum[log(p(w1,...,wn|d))] / n} """ppl = exp{-sum[log(p(w1,...,wn|d))] / n}
Follows Eq (15) in Hoffman et al., 2010. Follows Eq (15) in Hoffman et al., 2010.
...@@ -398,45 +437,50 @@ class LatentDirichletAllocation: ...@@ -398,45 +437,50 @@ class LatentDirichletAllocation:
# compute E[log p(docs | theta, beta)] # compute E[log p(docs | theta, beta)]
G = G.clone() G = G.clone()
self._prepare_graph(G, doc_data) self._prepare_graph(G, doc_data)
G.apply_edges(lambda edges: {'loglike': EdgeData(edges.src, edges.dst).loglike}) G.apply_edges(
edge_elbo = (G.edata['loglike'].sum() / G.num_edges()).tolist() lambda edges: {"loglike": EdgeData(edges.src, edges.dst).loglike}
)
edge_elbo = (G.edata["loglike"].sum() / G.num_edges()).tolist()
if self.verbose: if self.verbose:
print(f'neg_elbo phi: {-edge_elbo:.3f}', end=' ') print(f"neg_elbo phi: {-edge_elbo:.3f}", end=" ")
# compute E[log p(theta | alpha) - log q(theta | gamma)] # compute E[log p(theta | alpha) - log q(theta | gamma)]
doc_elbo = (doc_data.loglike.sum() / doc_data.n.sum()).tolist() doc_elbo = (doc_data.loglike.sum() / doc_data.n.sum()).tolist()
if self.verbose: if self.verbose:
print(f'theta: {-doc_elbo:.3f}', end=' ') print(f"theta: {-doc_elbo:.3f}", end=" ")
# compute E[log p(beta | eta) - log q(beta | lambda)] # compute E[log p(beta | eta) - log q(beta | lambda)]
# The denominator n for extrapolation perplexity is undefined. # The denominator n for extrapolation perplexity is undefined.
# We use the train set, whereas sklearn uses the test set. # We use the train set, whereas sklearn uses the test set.
word_elbo = ( word_elbo = sum(
sum([part.loglike.sum().tolist() for part in self.word_data]) [part.loglike.sum().tolist() for part in self.word_data]
/ sum([part.n.sum().tolist() for part in self.word_data]) ) / sum([part.n.sum().tolist() for part in self.word_data])
)
if self.verbose: if self.verbose:
print(f'beta: {-word_elbo:.3f}') print(f"beta: {-word_elbo:.3f}")
ppl = np.exp(-edge_elbo - doc_elbo - word_elbo) ppl = np.exp(-edge_elbo - doc_elbo - word_elbo)
if G.num_edges()>0 and np.isnan(ppl): if G.num_edges() > 0 and np.isnan(ppl):
warnings.warn("numerical issue in perplexity") warnings.warn("numerical issue in perplexity")
return ppl return ppl
def doc_subgraph(G, doc_ids): def doc_subgraph(G, doc_ids):
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
_, _, (block,) = sampler.sample(G.reverse(), {'doc': torch.as_tensor(doc_ids)}) _, _, (block,) = sampler.sample(
G.reverse(), {"doc": torch.as_tensor(doc_ids)}
)
B = dgl.DGLHeteroGraph( B = dgl.DGLHeteroGraph(
block._graph, ['_', 'word', 'doc', '_'], block.etypes block._graph, ["_", "word", "doc", "_"], block.etypes
).reverse() ).reverse()
B.nodes['word'].data['_ID'] = block.nodes['word'].data['_ID'] B.nodes["word"].data["_ID"] = block.nodes["word"].data["_ID"]
return B return B
if __name__ == '__main__': if __name__ == "__main__":
print('Testing LatentDirichletAllocation ...') print("Testing LatentDirichletAllocation ...")
G = dgl.heterograph({('doc', '', 'word'): [(0, 0), (1, 3)]}, {'doc': 2, 'word': 5}) G = dgl.heterograph(
{("doc", "", "word"): [(0, 0), (1, 3)]}, {"doc": 2, "word": 5}
)
model = LatentDirichletAllocation(n_words=5, n_components=10, verbose=False) model = LatentDirichletAllocation(n_words=5, n_components=10, verbose=False)
model.fit(G) model.fit(G)
model.transform(G) model.transform(G)
...@@ -454,4 +498,4 @@ if __name__ == '__main__': ...@@ -454,4 +498,4 @@ if __name__ == '__main__':
f.seek(0) f.seek(0)
print(torch.load(f)) print(torch.load(f))
print('Testing LatentDirichletAllocation passed!') print("Testing LatentDirichletAllocation passed!")
...@@ -6,11 +6,12 @@ Author's implementation: https://github.com/joanbruna/GNN_community ...@@ -6,11 +6,12 @@ Author's implementation: https://github.com/joanbruna/GNN_community
""" """
from __future__ import division from __future__ import division
import time
import argparse import argparse
import time
from itertools import permutations from itertools import permutations
import gnn
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn.functional as F import torch.nn.functional as F
...@@ -18,37 +19,51 @@ import torch.optim as optim ...@@ -18,37 +19,51 @@ import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from dgl.data import SBMMixtureDataset from dgl.data import SBMMixtureDataset
import gnn
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int, help='Batch size', default=1) parser.add_argument("--batch-size", type=int, help="Batch size", default=1)
parser.add_argument('--gpu', type=int, help='GPU index', default=-1) parser.add_argument("--gpu", type=int, help="GPU index", default=-1)
parser.add_argument('--lr', type=float, help='Learning rate', default=0.001) parser.add_argument("--lr", type=float, help="Learning rate", default=0.001)
parser.add_argument('--n-communities', type=int, help='Number of communities', default=2) parser.add_argument(
parser.add_argument('--n-epochs', type=int, help='Number of epochs', default=100) "--n-communities", type=int, help="Number of communities", default=2
parser.add_argument('--n-features', type=int, help='Number of features', default=16) )
parser.add_argument('--n-graphs', type=int, help='Number of graphs', default=10) parser.add_argument(
parser.add_argument('--n-layers', type=int, help='Number of layers', default=30) "--n-epochs", type=int, help="Number of epochs", default=100
parser.add_argument('--n-nodes', type=int, help='Number of nodes', default=10000) )
parser.add_argument('--optim', type=str, help='Optimizer', default='Adam') parser.add_argument(
parser.add_argument('--radius', type=int, help='Radius', default=3) "--n-features", type=int, help="Number of features", default=16
parser.add_argument('--verbose', action='store_true') )
parser.add_argument("--n-graphs", type=int, help="Number of graphs", default=10)
parser.add_argument("--n-layers", type=int, help="Number of layers", default=30)
parser.add_argument(
"--n-nodes", type=int, help="Number of nodes", default=10000
)
parser.add_argument("--optim", type=str, help="Optimizer", default="Adam")
parser.add_argument("--radius", type=int, help="Radius", default=3)
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args() args = parser.parse_args()
dev = th.device('cpu') if args.gpu < 0 else th.device('cuda:%d' % args.gpu) dev = th.device("cpu") if args.gpu < 0 else th.device("cuda:%d" % args.gpu)
K = args.n_communities K = args.n_communities
training_dataset = SBMMixtureDataset(args.n_graphs, args.n_nodes, K) training_dataset = SBMMixtureDataset(args.n_graphs, args.n_nodes, K)
training_loader = DataLoader(training_dataset, args.batch_size, training_loader = DataLoader(
collate_fn=training_dataset.collate_fn, drop_last=True) training_dataset,
args.batch_size,
collate_fn=training_dataset.collate_fn,
drop_last=True,
)
ones = th.ones(args.n_nodes // K) ones = th.ones(args.n_nodes // K)
y_list = [th.cat([x * ones for x in p]).long().to(dev) for p in permutations(range(K))] y_list = [
th.cat([x * ones for x in p]).long().to(dev) for p in permutations(range(K))
]
feats = [1] + [args.n_features] * args.n_layers + [K] feats = [1] + [args.n_features] * args.n_layers + [K]
model = gnn.GNN(feats, args.radius, K).to(dev) model = gnn.GNN(feats, args.radius, K).to(dev)
optimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr) optimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr)
def compute_overlap(z_list): def compute_overlap(z_list):
ybar_list = [th.max(z, 1)[1] for z in z_list] ybar_list = [th.max(z, 1)[1] for z in z_list]
overlap_list = [] overlap_list = []
...@@ -58,15 +73,20 @@ def compute_overlap(z_list): ...@@ -58,15 +73,20 @@ def compute_overlap(z_list):
overlap_list.append(overlap) overlap_list.append(overlap)
return sum(overlap_list) / len(overlap_list) return sum(overlap_list) / len(overlap_list)
def from_np(f, *args): def from_np(f, *args):
def wrap(*args): def wrap(*args):
new = [th.from_numpy(x) if isinstance(x, np.ndarray) else x for x in args] new = [
th.from_numpy(x) if isinstance(x, np.ndarray) else x for x in args
]
return f(*new) return f(*new)
return wrap return wrap
@from_np @from_np
def step(i, j, g, lg, deg_g, deg_lg, pm_pd): def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
""" One step of training. """ """One step of training."""
g = g.to(dev) g = g.to(dev)
lg = lg.to(dev) lg = lg.to(dev)
deg_g = deg_g.to(dev).unsqueeze(1) deg_g = deg_g.to(dev).unsqueeze(1)
...@@ -77,7 +97,10 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd): ...@@ -77,7 +97,10 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
t_forward = time.time() - t0 t_forward = time.time() - t0
z_list = th.chunk(z, args.batch_size, 0) z_list = th.chunk(z, args.batch_size, 0)
loss = sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list) / args.batch_size loss = (
sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list)
/ args.batch_size
)
overlap = compute_overlap(z_list) overlap = compute_overlap(z_list)
optimizer.zero_grad() optimizer.zero_grad()
...@@ -88,6 +111,7 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd): ...@@ -88,6 +111,7 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
return loss, overlap, t_forward, t_backward return loss, overlap, t_forward, t_backward
@from_np @from_np
def inference(g, lg, deg_g, deg_lg, pm_pd): def inference(g, lg, deg_g, deg_lg, pm_pd):
g = g.to(dev) g = g.to(dev)
...@@ -99,9 +123,11 @@ def inference(g, lg, deg_g, deg_lg, pm_pd): ...@@ -99,9 +123,11 @@ def inference(g, lg, deg_g, deg_lg, pm_pd):
z = model(g, lg, deg_g, deg_lg, pm_pd) z = model(g, lg, deg_g, deg_lg, pm_pd)
return z return z
def test(): def test():
p_list =[6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0] p_list = [6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0]
q_list =[0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6] q_list = [0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6]
N = 1 N = 1
overlap_list = [] overlap_list = []
for p, q in zip(p_list, q_list): for p, q in zip(p_list, q_list):
...@@ -112,31 +138,38 @@ def test(): ...@@ -112,31 +138,38 @@ def test():
overlap_list.append(compute_overlap(th.chunk(z, N, 0))) overlap_list.append(compute_overlap(th.chunk(z, N, 0)))
return overlap_list return overlap_list
n_iterations = args.n_graphs // args.batch_size n_iterations = args.n_graphs // args.batch_size
for i in range(args.n_epochs): for i in range(args.n_epochs):
total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0 total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0
for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader): for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader):
loss, overlap, t_forward, t_backward = step(i, j, g, lg, deg_g, deg_lg, pm_pd) loss, overlap, t_forward, t_backward = step(
i, j, g, lg, deg_g, deg_lg, pm_pd
)
total_loss += loss total_loss += loss
total_overlap += overlap total_overlap += overlap
s_forward += t_forward s_forward += t_forward
s_backward += t_backward s_backward += t_backward
epoch = '0' * (len(str(args.n_epochs)) - len(str(i))) epoch = "0" * (len(str(args.n_epochs)) - len(str(i)))
iteration = '0' * (len(str(n_iterations)) - len(str(j))) iteration = "0" * (len(str(n_iterations)) - len(str(j)))
if args.verbose: if args.verbose:
print('[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f' print(
% (epoch, i, iteration, j, loss, overlap)) "[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f"
% (epoch, i, iteration, j, loss, overlap)
)
epoch = '0' * (len(str(args.n_epochs)) - len(str(i))) epoch = "0" * (len(str(args.n_epochs)) - len(str(i)))
loss = total_loss / (j + 1) loss = total_loss / (j + 1)
overlap = total_overlap / (j + 1) overlap = total_overlap / (j + 1)
t_forward = s_forward / (j + 1) t_forward = s_forward / (j + 1)
t_backward = s_backward / (j + 1) t_backward = s_backward / (j + 1)
print('[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs' print(
% (epoch, i, loss, overlap, t_forward, t_backward)) "[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs"
% (epoch, i, loss, overlap, t_forward, t_backward)
)
overlap_list = test() overlap_list = test()
overlap_str = ' - '.join(['%.3f' % overlap for overlap in overlap_list]) overlap_str = " - ".join(["%.3f" % overlap for overlap in overlap_list])
print('[epoch %s%d]overlap: %s' % (epoch, i, overlap_str)) print("[epoch %s%d]overlap: %s" % (epoch, i, overlap_str))
This diff is collapsed.
...@@ -2,55 +2,55 @@ import torch as th ...@@ -2,55 +2,55 @@ import torch as th
import torch.nn.functional as F import torch.nn.functional as F
GCN_CONFIG = { GCN_CONFIG = {
'extra_args': [16, 1, F.relu, 0.5], "extra_args": [16, 1, F.relu, 0.5],
'lr': 1e-2, "lr": 1e-2,
'weight_decay': 5e-4, "weight_decay": 5e-4,
} }
GAT_CONFIG = { GAT_CONFIG = {
'extra_args': [8, 1, [8] * 1 + [1], F.elu, 0.6, 0.6, 0.2, False], "extra_args": [8, 1, [8] * 1 + [1], F.elu, 0.6, 0.6, 0.2, False],
'lr': 0.005, "lr": 0.005,
'weight_decay': 5e-4, "weight_decay": 5e-4,
} }
GRAPHSAGE_CONFIG = { GRAPHSAGE_CONFIG = {
'extra_args': [16, 1, F.relu, 0.5, 'gcn'], "extra_args": [16, 1, F.relu, 0.5, "gcn"],
'lr': 1e-2, "lr": 1e-2,
'weight_decay': 5e-4, "weight_decay": 5e-4,
} }
APPNP_CONFIG = { APPNP_CONFIG = {
'extra_args': [64, 1, F.relu, 0.5, 0.5, 0.1, 10], "extra_args": [64, 1, F.relu, 0.5, 0.5, 0.1, 10],
'lr': 1e-2, "lr": 1e-2,
'weight_decay': 5e-4, "weight_decay": 5e-4,
} }
TAGCN_CONFIG = { TAGCN_CONFIG = {
'extra_args': [16, 1, F.relu, 0.5], "extra_args": [16, 1, F.relu, 0.5],
'lr': 1e-2, "lr": 1e-2,
'weight_decay': 5e-4, "weight_decay": 5e-4,
} }
AGNN_CONFIG = { AGNN_CONFIG = {
'extra_args': [32, 2, 1.0, True, 0.5], "extra_args": [32, 2, 1.0, True, 0.5],
'lr': 1e-2, "lr": 1e-2,
'weight_decay': 5e-4, "weight_decay": 5e-4,
} }
SGC_CONFIG = { SGC_CONFIG = {
'extra_args': [None, 2, False], "extra_args": [None, 2, False],
'lr': 0.2, "lr": 0.2,
'weight_decay': 5e-6, "weight_decay": 5e-6,
} }
GIN_CONFIG = { GIN_CONFIG = {
'extra_args': [16, 1, 0, True], "extra_args": [16, 1, 0, True],
'lr': 1e-2, "lr": 1e-2,
'weight_decay': 5e-6, "weight_decay": 5e-6,
} }
CHEBNET_CONFIG = { CHEBNET_CONFIG = {
'extra_args': [32, 1, 2, True], "extra_args": [32, 1, 2, True],
'lr': 1e-2, "lr": 1e-2,
'weight_decay': 5e-4, "weight_decay": 5e-4,
} }
...@@ -31,7 +31,7 @@ def laplacian(W, normalized=True): ...@@ -31,7 +31,7 @@ def laplacian(W, normalized=True):
def rescale_L(L, lmax=2): def rescale_L(L, lmax=2):
"""Rescale Laplacian eigenvalues to [-1,1]""" """Rescale Laplacian eigenvalues to [-1,1]"""
M, M = L.shape M, M = L.shape
I = scipy.sparse.identity(M, format='csr', dtype=L.dtype) I = scipy.sparse.identity(M, format="csr", dtype=L.dtype)
L /= lmax * 2 L /= lmax * 2
L -= I L -= I
return L return L
...@@ -39,7 +39,9 @@ def rescale_L(L, lmax=2): ...@@ -39,7 +39,9 @@ def rescale_L(L, lmax=2):
def lmax_L(L): def lmax_L(L):
"""Compute largest Laplacian eigenvalue""" """Compute largest Laplacian eigenvalue"""
return scipy.sparse.linalg.eigsh(L, k=1, which='LM', return_eigenvectors=False)[0] return scipy.sparse.linalg.eigsh(
L, k=1, which="LM", return_eigenvectors=False
)[0]
# graph coarsening with Heavy Edge Matching # graph coarsening with Heavy Edge Matching
...@@ -57,7 +59,11 @@ def coarsen(A, levels): ...@@ -57,7 +59,11 @@ def coarsen(A, levels):
A = A.tocsr() A = A.tocsr()
A.eliminate_zeros() A.eliminate_zeros()
Mnew, Mnew = A.shape Mnew, Mnew = A.shape
print('Layer {0}: M_{0} = |V| = {1} nodes ({2} added), |E| = {3} edges'.format(i, Mnew, Mnew - M, A.nnz // 2)) print(
"Layer {0}: M_{0} = |V| = {1} nodes ({2} added), |E| = {3} edges".format(
i, Mnew, Mnew - M, A.nnz // 2
)
)
L = laplacian(A, normalized=True) L = laplacian(A, normalized=True)
laplacians.append(L) laplacians.append(L)
...@@ -95,7 +101,7 @@ def HEM(W, levels, rid=None): ...@@ -95,7 +101,7 @@ def HEM(W, levels, rid=None):
graphs = [] graphs = []
graphs.append(W) graphs.append(W)
print('Heavy Edge Matching coarsening with Xavier version') print("Heavy Edge Matching coarsening with Xavier version")
for _ in range(levels): for _ in range(levels):
...@@ -183,7 +189,9 @@ def HEM_one_level(rr, cc, vv, rid, weights): ...@@ -183,7 +189,9 @@ def HEM_one_level(rr, cc, vv, rid, weights):
# First approach # First approach
if 2 == 1: if 2 == 1:
tval = vv[rs + jj] * (1.0 / weights[tid] + 1.0 / weights[nid]) tval = vv[rs + jj] * (
1.0 / weights[tid] + 1.0 / weights[nid]
)
# Second approach # Second approach
if 1 == 1: if 1 == 1:
...@@ -192,7 +200,7 @@ def HEM_one_level(rr, cc, vv, rid, weights): ...@@ -192,7 +200,7 @@ def HEM_one_level(rr, cc, vv, rid, weights):
Wjj = vv[rowstart[nid]] Wjj = vv[rowstart[nid]]
di = weights[tid] di = weights[tid]
dj = weights[nid] dj = weights[nid]
tval = (2. * Wij + Wii + Wjj) * 1. / (di + dj + 1e-9) tval = (2.0 * Wij + Wii + Wjj) * 1.0 / (di + dj + 1e-9)
if tval > wmax: if tval > wmax:
wmax = tval wmax = tval
...@@ -247,7 +255,7 @@ def compute_perm(parents): ...@@ -247,7 +255,7 @@ def compute_perm(parents):
# Sanity checks. # Sanity checks.
for i, indices_layer in enumerate(indices): for i, indices_layer in enumerate(indices):
M = M_last * 2 ** i M = M_last * 2**i
# Reduction by 2 at each layer (binary tree). # Reduction by 2 at each layer (binary tree).
assert len(indices[0] == M) assert len(indices[0] == M)
# The new ordering does not omit an indice. # The new ordering does not omit an indice.
...@@ -256,8 +264,9 @@ def compute_perm(parents): ...@@ -256,8 +264,9 @@ def compute_perm(parents):
return indices[::-1] return indices[::-1]
assert (compute_perm([np.array([4, 1, 1, 2, 2, 3, 0, 0, 3]), np.array([2, 1, 0, 1, 0])]) assert compute_perm(
== [[3, 4, 0, 9, 1, 2, 5, 8, 6, 7, 10, 11], [2, 4, 1, 3, 0, 5], [0, 1, 2]]) [np.array([4, 1, 1, 2, 2, 3, 0, 0, 3]), np.array([2, 1, 0, 1, 0])]
) == [[3, 4, 0, 9, 1, 2, 5, 8, 6, 7, 10, 11], [2, 4, 1, 3, 0, 5], [0, 1, 2]]
def perm_adjacency(A, indices): def perm_adjacency(A, indices):
......
...@@ -2,6 +2,8 @@ import torch as th ...@@ -2,6 +2,8 @@ import torch as th
"""Compute x,y coordinate for nodes in the graph""" """Compute x,y coordinate for nodes in the graph"""
eps = 1e-8 eps = 1e-8
def get_coordinates(graphs, grid_side, coarsening_levels, perm): def get_coordinates(graphs, grid_side, coarsening_levels, perm):
rst = [] rst = []
for l in range(coarsening_levels + 1): for l in range(coarsening_levels + 1):
...@@ -10,21 +12,25 @@ def get_coordinates(graphs, grid_side, coarsening_levels, perm): ...@@ -10,21 +12,25 @@ def get_coordinates(graphs, grid_side, coarsening_levels, perm):
cnt = eps cnt = eps
x_accum = 0 x_accum = 0
y_accum = 0 y_accum = 0
for j in range(i * 2 ** l, (i + 1) * 2 ** l): for j in range(i * 2**l, (i + 1) * 2**l):
if perm[j] < grid_side ** 2: if perm[j] < grid_side**2:
x_accum += (perm[j] // grid_side) x_accum += perm[j] // grid_side
y_accum += (perm[j] % grid_side) y_accum += perm[j] % grid_side
cnt += 1 cnt += 1
xs.append(x_accum / cnt) xs.append(x_accum / cnt)
ys.append(y_accum / cnt) ys.append(y_accum / cnt)
rst.append(th.cat([th.tensor(xs).view(-1, 1), th.tensor(ys).view(-1, 1)], -1)) rst.append(
th.cat([th.tensor(xs).view(-1, 1), th.tensor(ys).view(-1, 1)], -1)
)
return rst return rst
"""Cartesian coordinate to polar coordinate""" """Cartesian coordinate to polar coordinate"""
def z2polar(edges): def z2polar(edges):
z = edges.dst['xy'] - edges.src['xy'] z = edges.dst["xy"] - edges.src["xy"]
rho = th.norm(z, dim=-1, p=2) rho = th.norm(z, dim=-1, p=2)
x, y = z.unbind(dim=-1) x, y = z.unbind(dim=-1)
phi = th.atan2(y, x) phi = th.atan2(y, x)
return {'u': th.cat([rho.unsqueeze(-1), phi.unsqueeze(-1)], -1)} return {"u": th.cat([rho.unsqueeze(-1), phi.unsqueeze(-1)], -1)}
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment