Unverified Commit 0b9df9d7 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4652)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent f19f05ce
import dgl
import rdkit.Chem as Chem
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \
set_atommap, enum_assemble_nx, decode_stereo
import numpy as np import numpy as np
import rdkit.Chem as Chem
import dgl
from .chemutils import (
decode_stereo,
enum_assemble_nx,
get_clique_mol,
get_mol,
get_smiles,
set_atommap,
tree_decomp,
)
class DGLMolTree(object): class DGLMolTree(object):
def __init__(self, smiles): def __init__(self, smiles):
...@@ -38,11 +48,13 @@ class DGLMolTree(object): ...@@ -38,11 +48,13 @@ class DGLMolTree(object):
# The clique with atom ID 0 becomes root # The clique with atom ID 0 becomes root
if root > 0: if root > 0:
for attr in self.nodes_dict[0]: for attr in self.nodes_dict[0]:
self.nodes_dict[0][attr], self.nodes_dict[root][attr] = \ self.nodes_dict[0][attr], self.nodes_dict[root][attr] = (
self.nodes_dict[root][attr], self.nodes_dict[0][attr] self.nodes_dict[root][attr],
self.nodes_dict[0][attr],
)
src = np.zeros((len(edges) * 2,), dtype='int') src = np.zeros((len(edges) * 2,), dtype="int")
dst = np.zeros((len(edges) * 2,), dtype='int') dst = np.zeros((len(edges) * 2,), dtype="int")
for i, (_x, _y) in enumerate(edges): for i, (_x, _y) in enumerate(edges):
x = 0 if _x == root else root if _x == 0 else _x x = 0 if _x == root else root if _x == 0 else _x
y = 0 if _y == root else root if _y == 0 else _y y = 0 if _y == root else root if _y == 0 else _y
...@@ -53,10 +65,12 @@ class DGLMolTree(object): ...@@ -53,10 +65,12 @@ class DGLMolTree(object):
self.graph = dgl.graph((src, dst), num_nodes=len(cliques)) self.graph = dgl.graph((src, dst), num_nodes=len(cliques))
for i in self.nodes_dict: for i in self.nodes_dict:
self.nodes_dict[i]['nid'] = i + 1 self.nodes_dict[i]["nid"] = i + 1
if self.graph.out_degrees(i) > 1: # Leaf node mol is not marked if self.graph.out_degrees(i) > 1: # Leaf node mol is not marked
set_atommap(self.nodes_dict[i]['mol'], self.nodes_dict[i]['nid']) set_atommap(
self.nodes_dict[i]['is_leaf'] = (self.graph.out_degrees(i) == 1) self.nodes_dict[i]["mol"], self.nodes_dict[i]["nid"]
)
self.nodes_dict[i]["is_leaf"] = self.graph.out_degrees(i) == 1
def treesize(self): def treesize(self):
return self.graph.number_of_nodes() return self.graph.number_of_nodes()
...@@ -65,49 +79,65 @@ class DGLMolTree(object): ...@@ -65,49 +79,65 @@ class DGLMolTree(object):
node = self.nodes_dict[i] node = self.nodes_dict[i]
clique = [] clique = []
clique.extend(node['clique']) clique.extend(node["clique"])
if not node['is_leaf']: if not node["is_leaf"]:
for cidx in node['clique']: for cidx in node["clique"]:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid']) original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node["nid"])
for j in self.graph.successors(i).numpy(): for j in self.graph.successors(i).numpy():
nei_node = self.nodes_dict[j] nei_node = self.nodes_dict[j]
clique.extend(nei_node['clique']) clique.extend(nei_node["clique"])
if nei_node['is_leaf']: # Leaf node, no need to mark if nei_node["is_leaf"]: # Leaf node, no need to mark
continue continue
for cidx in nei_node['clique']: for cidx in nei_node["clique"]:
# allow singleton node override the atom mapping # allow singleton node override the atom mapping
if cidx not in node['clique'] or len(nei_node['clique']) == 1: if cidx not in node["clique"] or len(nei_node["clique"]) == 1:
atom = original_mol.GetAtomWithIdx(cidx) atom = original_mol.GetAtomWithIdx(cidx)
atom.SetAtomMapNum(nei_node['nid']) atom.SetAtomMapNum(nei_node["nid"])
clique = list(set(clique)) clique = list(set(clique))
label_mol = get_clique_mol(original_mol, clique) label_mol = get_clique_mol(original_mol, clique)
node['label'] = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol))) node["label"] = Chem.MolToSmiles(
node['label_mol'] = get_mol(node['label']) Chem.MolFromSmiles(get_smiles(label_mol))
)
node["label_mol"] = get_mol(node["label"])
for cidx in clique: for cidx in clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0) original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
return node['label'] return node["label"]
def _assemble_node(self, i): def _assemble_node(self, i):
neighbors = [self.nodes_dict[j] for j in self.graph.successors(i).numpy() neighbors = [
if self.nodes_dict[j]['mol'].GetNumAtoms() > 1] self.nodes_dict[j]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True) for j in self.graph.successors(i).numpy()
singletons = [self.nodes_dict[j] for j in self.graph.successors(i).numpy() if self.nodes_dict[j]["mol"].GetNumAtoms() > 1
if self.nodes_dict[j]['mol'].GetNumAtoms() == 1] ]
neighbors = sorted(
neighbors, key=lambda x: x["mol"].GetNumAtoms(), reverse=True
)
singletons = [
self.nodes_dict[j]
for j in self.graph.successors(i).numpy()
if self.nodes_dict[j]["mol"].GetNumAtoms() == 1
]
neighbors = singletons + neighbors neighbors = singletons + neighbors
cands = enum_assemble_nx(self.nodes_dict[i], neighbors) cands = enum_assemble_nx(self.nodes_dict[i], neighbors)
if len(cands) > 0: if len(cands) > 0:
self.nodes_dict[i]['cands'], self.nodes_dict[i]['cand_mols'], _ = list(zip(*cands)) (
self.nodes_dict[i]['cands'] = list(self.nodes_dict[i]['cands']) self.nodes_dict[i]["cands"],
self.nodes_dict[i]['cand_mols'] = list(self.nodes_dict[i]['cand_mols']) self.nodes_dict[i]["cand_mols"],
_,
) = list(zip(*cands))
self.nodes_dict[i]["cands"] = list(self.nodes_dict[i]["cands"])
self.nodes_dict[i]["cand_mols"] = list(
self.nodes_dict[i]["cand_mols"]
)
else: else:
self.nodes_dict[i]['cands'] = [] self.nodes_dict[i]["cands"] = []
self.nodes_dict[i]['cand_mols'] = [] self.nodes_dict[i]["cand_mols"] = []
def recover(self): def recover(self):
for i in self.nodes_dict: for i in self.nodes_dict:
......
import rdkit.Chem as Chem
import torch import torch
import torch.nn as nn import torch.nn as nn
import rdkit.Chem as Chem
import torch.nn.functional as F import torch.nn.functional as F
from .chemutils import get_mol
import dgl import dgl
from dgl import mean_nodes, line_graph
import dgl.function as DGLF import dgl.function as DGLF
from dgl import line_graph, mean_nodes
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', from .chemutils import get_mol
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ELEM_LIST = [
"C",
"N",
"O",
"S",
"F",
"Si",
"P",
"Cl",
"Br",
"Mg",
"Na",
"Ca",
"Fe",
"Al",
"I",
"B",
"K",
"Se",
"Zn",
"H",
"Cu",
"Mn",
"unknown",
]
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1 ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1
BOND_FDIM = 5 + 6 BOND_FDIM = 5 + 6
MAX_NB = 6 MAX_NB = 6
def onek_encoding_unk(x, allowable_set): def onek_encoding_unk(x, allowable_set):
if x not in allowable_set: if x not in allowable_set:
x = allowable_set[-1] x = allowable_set[-1]
return [x == s for s in allowable_set] return [x == s for s in allowable_set]
def atom_features(atom): def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST) return torch.Tensor(
+ onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5]) onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0]) + onek_encoding_unk(atom.GetDegree(), [0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(int(atom.GetChiralTag()), [0,1,2,3]) + onek_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0])
+ [atom.GetIsAromatic()])) + onek_encoding_unk(int(atom.GetChiralTag()), [0, 1, 2, 3])
+ [atom.GetIsAromatic()]
)
def bond_features(bond): def bond_features(bond):
bt = bond.GetBondType() bt = bond.GetBondType()
stereo = int(bond.GetStereo()) stereo = int(bond.GetStereo())
fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()] fbond = [
fstereo = onek_encoding_unk(stereo, [0,1,2,3,4,5]) bt == Chem.rdchem.BondType.SINGLE,
return (torch.Tensor(fbond + fstereo)) bt == Chem.rdchem.BondType.DOUBLE,
bt == Chem.rdchem.BondType.TRIPLE,
bt == Chem.rdchem.BondType.AROMATIC,
bond.IsInRing(),
]
fstereo = onek_encoding_unk(stereo, [0, 1, 2, 3, 4, 5])
return torch.Tensor(fbond + fstereo)
def mol2dgl_single(smiles): def mol2dgl_single(smiles):
n_edges = 0 n_edges = 0
...@@ -61,8 +98,11 @@ def mol2dgl_single(smiles): ...@@ -61,8 +98,11 @@ def mol2dgl_single(smiles):
bond_x.append(features) bond_x.append(features)
graph = dgl.graph((bond_src, bond_dst), num_nodes=n_atoms) graph = dgl.graph((bond_src, bond_dst), num_nodes=n_atoms)
n_edges += n_bonds n_edges += n_bonds
return graph, torch.stack(atom_x), \ return (
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0) graph,
torch.stack(atom_x),
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0),
)
class LoopyBPUpdate(nn.Module): class LoopyBPUpdate(nn.Module):
...@@ -73,10 +113,10 @@ class LoopyBPUpdate(nn.Module): ...@@ -73,10 +113,10 @@ class LoopyBPUpdate(nn.Module):
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False) self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, nodes): def forward(self, nodes):
msg_input = nodes.data['msg_input'] msg_input = nodes.data["msg_input"]
msg_delta = self.W_h(nodes.data['accum_msg']) msg_delta = self.W_h(nodes.data["accum_msg"])
msg = F.relu(msg_input + msg_delta) msg = F.relu(msg_input + msg_delta)
return {'msg': msg} return {"msg": msg}
class GatherUpdate(nn.Module): class GatherUpdate(nn.Module):
...@@ -87,9 +127,9 @@ class GatherUpdate(nn.Module): ...@@ -87,9 +127,9 @@ class GatherUpdate(nn.Module):
self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size) self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
def forward(self, nodes): def forward(self, nodes):
m = nodes.data['m'] m = nodes.data["m"]
return { return {
'h': F.relu(self.W_o(torch.cat([nodes.data['x'], m], 1))), "h": F.relu(self.W_o(torch.cat([nodes.data["x"], m], 1))),
} }
...@@ -121,7 +161,7 @@ class DGLMPN(nn.Module): ...@@ -121,7 +161,7 @@ class DGLMPN(nn.Module):
mol_graph = self.run(mol_graph, mol_line_graph) mol_graph = self.run(mol_graph, mol_line_graph)
# TODO: replace with unbatch or readout # TODO: replace with unbatch or readout
g_repr = mean_nodes(mol_graph, 'h') g_repr = mean_nodes(mol_graph, "h")
self.n_samples_total += n_samples self.n_samples_total += n_samples
self.n_nodes_total += n_nodes self.n_nodes_total += n_nodes
...@@ -134,32 +174,38 @@ class DGLMPN(nn.Module): ...@@ -134,32 +174,38 @@ class DGLMPN(nn.Module):
n_nodes = mol_graph.number_of_nodes() n_nodes = mol_graph.number_of_nodes()
mol_graph.apply_edges( mol_graph.apply_edges(
func=lambda edges: {'src_x': edges.src['x']}, func=lambda edges: {"src_x": edges.src["x"]},
) )
mol_line_graph.ndata.update(mol_graph.edata) mol_line_graph.ndata.update(mol_graph.edata)
e_repr = mol_line_graph.ndata e_repr = mol_line_graph.ndata
bond_features = e_repr['x'] bond_features = e_repr["x"]
source_features = e_repr['src_x'] source_features = e_repr["src_x"]
features = torch.cat([source_features, bond_features], 1) features = torch.cat([source_features, bond_features], 1)
msg_input = self.W_i(features) msg_input = self.W_i(features)
mol_line_graph.ndata.update({ mol_line_graph.ndata.update(
'msg_input': msg_input, {
'msg': F.relu(msg_input), "msg_input": msg_input,
'accum_msg': torch.zeros_like(msg_input), "msg": F.relu(msg_input),
}) "accum_msg": torch.zeros_like(msg_input),
mol_graph.ndata.update({ }
'm': bond_features.new(n_nodes, self.hidden_size).zero_(), )
'h': bond_features.new(n_nodes, self.hidden_size).zero_(), mol_graph.ndata.update(
}) {
"m": bond_features.new(n_nodes, self.hidden_size).zero_(),
"h": bond_features.new(n_nodes, self.hidden_size).zero_(),
}
)
for i in range(self.depth - 1): for i in range(self.depth - 1):
mol_line_graph.update_all(DGLF.copy_u('msg', 'msg'), DGLF.sum('msg', 'accum_msg')) mol_line_graph.update_all(
DGLF.copy_u("msg", "msg"), DGLF.sum("msg", "accum_msg")
)
mol_line_graph.apply_nodes(self.loopy_bp_updater) mol_line_graph.apply_nodes(self.loopy_bp_updater)
mol_graph.edata.update(mol_line_graph.ndata) mol_graph.edata.update(mol_line_graph.ndata)
mol_graph.update_all(DGLF.copy_e('msg', 'msg'), DGLF.sum('msg', 'm')) mol_graph.update_all(DGLF.copy_e("msg", "msg"), DGLF.sum("msg", "m"))
mol_graph.apply_nodes(self.gather_updater) mol_graph.apply_nodes(self.gather_updater)
return mol_graph return mol_graph
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import os
import dgl import dgl
def cuda(x): def cuda(x):
if torch.cuda.is_available() and not os.getenv('NOCUDA', None): if torch.cuda.is_available() and not os.getenv("NOCUDA", None):
return x.to(torch.device('cuda')) # works for both DGLGraph and tensor return x.to(torch.device("cuda")) # works for both DGLGraph and tensor
else: else:
return x return x
...@@ -22,27 +24,28 @@ class GRUUpdate(nn.Module): ...@@ -22,27 +24,28 @@ class GRUUpdate(nn.Module):
self.W_h = nn.Linear(2 * hidden_size, hidden_size) self.W_h = nn.Linear(2 * hidden_size, hidden_size)
def update_zm(self, node): def update_zm(self, node):
src_x = node.data['src_x'] src_x = node.data["src_x"]
s = node.data['s'] s = node.data["s"]
rm = node.data['accum_rm'] rm = node.data["accum_rm"]
z = torch.sigmoid(self.W_z(torch.cat([src_x, s], 1))) z = torch.sigmoid(self.W_z(torch.cat([src_x, s], 1)))
m = torch.tanh(self.W_h(torch.cat([src_x, rm], 1))) m = torch.tanh(self.W_h(torch.cat([src_x, rm], 1)))
m = (1 - z) * s + z * m m = (1 - z) * s + z * m
return {'m': m, 'z': z} return {"m": m, "z": z}
def update_r(self, node, zm=None): def update_r(self, node, zm=None):
dst_x = node.data['dst_x'] dst_x = node.data["dst_x"]
m = node.data['m'] if zm is None else zm['m'] m = node.data["m"] if zm is None else zm["m"]
r_1 = self.W_r(dst_x) r_1 = self.W_r(dst_x)
r_2 = self.U_r(m) r_2 = self.U_r(m)
r = torch.sigmoid(r_1 + r_2) r = torch.sigmoid(r_1 + r_2)
return {'r': r, 'rm': r * m} return {"r": r, "rm": r * m}
def forward(self, node): def forward(self, node):
dic = self.update_zm(node) dic = self.update_zm(node)
dic.update(self.update_r(node, zm=dic)) dic.update(self.update_r(node, zm=dic))
return dic return dic
def tocpu(g): def tocpu(g):
src, dst = g.edges() src, dst = g.edges()
src = src.cpu() src = src.cpu()
......
import math
import random
import sys
from collections import deque
from optparse import OptionParser
import rdkit
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
import math, random, sys
from optparse import OptionParser
from collections import deque
import rdkit
import tqdm import tqdm
from jtnn import * from jtnn import *
from torch.utils.data import DataLoader
torch.multiprocessing.set_sharing_strategy("file_system")
torch.multiprocessing.set_sharing_strategy('file_system')
def worker_init_fn(id_): def worker_init_fn(id_):
lg = rdkit.RDLogger.logger() lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL) lg.setLevel(rdkit.RDLogger.CRITICAL)
worker_init_fn(None) worker_init_fn(None)
parser = OptionParser() parser = OptionParser()
parser.add_option("-t", "--train", dest="train", default='train', help='Training file name') parser.add_option(
parser.add_option("-v", "--vocab", dest="vocab", default='vocab', help='Vocab file name') "-t", "--train", dest="train", default="train", help="Training file name"
)
parser.add_option(
"-v", "--vocab", dest="vocab", default="vocab", help="Vocab file name"
)
parser.add_option("-s", "--save_dir", dest="save_path") parser.add_option("-s", "--save_dir", dest="save_path")
parser.add_option("-m", "--model", dest="model_path", default=None) parser.add_option("-m", "--model", dest="model_path", default=None)
parser.add_option("-b", "--batch", dest="batch_size", default=40) parser.add_option("-b", "--batch", dest="batch_size", default=40)
...@@ -31,7 +39,7 @@ parser.add_option("-d", "--depth", dest="depth", default=3) ...@@ -31,7 +39,7 @@ parser.add_option("-d", "--depth", dest="depth", default=3)
parser.add_option("-z", "--beta", dest="beta", default=1.0) parser.add_option("-z", "--beta", dest="beta", default=1.0)
parser.add_option("-q", "--lr", dest="lr", default=1e-3) parser.add_option("-q", "--lr", dest="lr", default=1e-3)
parser.add_option("-T", "--test", dest="test", action="store_true") parser.add_option("-T", "--test", dest="test", action="store_true")
opts,args = parser.parse_args() opts, args = parser.parse_args()
dataset = JTNNDataset(data=opts.train, vocab=opts.vocab, training=True) dataset = JTNNDataset(data=opts.train, vocab=opts.vocab, training=True)
vocab = dataset.vocab vocab = dataset.vocab
...@@ -55,7 +63,10 @@ else: ...@@ -55,7 +63,10 @@ else:
nn.init.xavier_normal(param) nn.init.xavier_normal(param)
model = cuda(model) model = cuda(model)
print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,)) print(
"Model #Params: %dK"
% (sum([x.nelement() for x in model.parameters()]) / 1000,)
)
optimizer = optim.Adam(model.parameters(), lr=lr) optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9) scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)
...@@ -64,6 +75,7 @@ scheduler.step() ...@@ -64,6 +75,7 @@ scheduler.step()
MAX_EPOCH = 100 MAX_EPOCH = 100
PRINT_ITER = 20 PRINT_ITER = 20
def train(): def train():
dataset.training = True dataset.training = True
dataloader = DataLoader( dataloader = DataLoader(
...@@ -73,17 +85,18 @@ def train(): ...@@ -73,17 +85,18 @@ def train():
num_workers=4, num_workers=4,
collate_fn=JTNNCollator(vocab, True), collate_fn=JTNNCollator(vocab, True),
drop_last=True, drop_last=True,
worker_init_fn=worker_init_fn) worker_init_fn=worker_init_fn,
)
for epoch in range(MAX_EPOCH): for epoch in range(MAX_EPOCH):
word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0 word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
for it, batch in enumerate(tqdm.tqdm(dataloader)): for it, batch in enumerate(tqdm.tqdm(dataloader)):
model.zero_grad() model.zero_grad()
try: try:
loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta) loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta)
except: except:
print([t.smiles for t in batch['mol_trees']]) print([t.smiles for t in batch["mol_trees"]])
raise raise
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -99,20 +112,34 @@ def train(): ...@@ -99,20 +112,34 @@ def train():
assm_acc = assm_acc / PRINT_ITER * 100 assm_acc = assm_acc / PRINT_ITER * 100
steo_acc = steo_acc / PRINT_ITER * 100 steo_acc = steo_acc / PRINT_ITER * 100
print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f" % ( print(
kl_div, word_acc, topo_acc, assm_acc, steo_acc, loss.item())) "KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f"
word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0 % (
kl_div,
word_acc,
topo_acc,
assm_acc,
steo_acc,
loss.item(),
)
)
word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
sys.stdout.flush() sys.stdout.flush()
if (it + 1) % 1500 == 0: #Fast annealing if (it + 1) % 1500 == 0: # Fast annealing
scheduler.step() scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0]) print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(), torch.save(
opts.save_path + "/model.iter-%d-%d" % (epoch, it + 1)) model.state_dict(),
opts.save_path + "/model.iter-%d-%d" % (epoch, it + 1),
)
scheduler.step() scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0]) print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(), opts.save_path + "/model.iter-" + str(epoch)) torch.save(
model.state_dict(), opts.save_path + "/model.iter-" + str(epoch)
)
def test(): def test():
dataset.training = False dataset.training = False
...@@ -123,12 +150,13 @@ def test(): ...@@ -123,12 +150,13 @@ def test():
num_workers=0, num_workers=0,
collate_fn=JTNNCollator(vocab, False), collate_fn=JTNNCollator(vocab, False),
drop_last=True, drop_last=True,
worker_init_fn=worker_init_fn) worker_init_fn=worker_init_fn,
)
# Just an example of molecule decoding; in reality you may want to sample # Just an example of molecule decoding; in reality you may want to sample
# tree and molecule vectors. # tree and molecule vectors.
for it, batch in enumerate(dataloader): for it, batch in enumerate(dataloader):
gt_smiles = batch['mol_trees'][0].smiles gt_smiles = batch["mol_trees"][0].smiles
print(gt_smiles) print(gt_smiles)
model.move_to_cuda(batch) model.move_to_cuda(batch)
_, tree_vec, mol_vec = model.encode(batch) _, tree_vec, mol_vec = model.encode(batch)
...@@ -136,21 +164,28 @@ def test(): ...@@ -136,21 +164,28 @@ def test():
smiles = model.decode(tree_vec, mol_vec) smiles = model.decode(tree_vec, mol_vec)
print(smiles) print(smiles)
if __name__ == '__main__':
if __name__ == "__main__":
if opts.test: if opts.test:
test() test()
else: else:
train() train()
print('# passes:', model.n_passes) print("# passes:", model.n_passes)
print('Total # nodes processed:', model.n_nodes_total) print("Total # nodes processed:", model.n_nodes_total)
print('Total # edges processed:', model.n_edges_total) print("Total # edges processed:", model.n_edges_total)
print('Total # tree nodes processed:', model.n_tree_nodes_total) print("Total # tree nodes processed:", model.n_tree_nodes_total)
print('Graph decoder: # passes:', model.jtmpn.n_passes) print("Graph decoder: # passes:", model.jtmpn.n_passes)
print('Graph decoder: Total # candidates processed:', model.jtmpn.n_samples_total) print(
print('Graph decoder: Total # nodes processed:', model.jtmpn.n_nodes_total) "Graph decoder: Total # candidates processed:",
print('Graph decoder: Total # edges processed:', model.jtmpn.n_edges_total) model.jtmpn.n_samples_total,
print('Graph encoder: # passes:', model.mpn.n_passes) )
print('Graph encoder: Total # candidates processed:', model.mpn.n_samples_total) print("Graph decoder: Total # nodes processed:", model.jtmpn.n_nodes_total)
print('Graph encoder: Total # nodes processed:', model.mpn.n_nodes_total) print("Graph decoder: Total # edges processed:", model.jtmpn.n_edges_total)
print('Graph encoder: Total # edges processed:', model.mpn.n_edges_total) print("Graph encoder: # passes:", model.mpn.n_passes)
print(
"Graph encoder: Total # candidates processed:",
model.mpn.n_samples_total,
)
print("Graph encoder: Total # nodes processed:", model.mpn.n_nodes_total)
print("Graph encoder: Total # edges processed:", model.mpn.n_edges_total)
import argparse import argparse
import torch import torch
import dgl import dgl
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from dgl.nn import LabelPropagation from dgl.nn import LabelPropagation
def main(): def main():
# check cuda # check cuda
device = f'cuda:{args.gpu}' if torch.cuda.is_available() and args.gpu >= 0 else 'cpu' device = (
f"cuda:{args.gpu}"
if torch.cuda.is_available() and args.gpu >= 0
else "cpu"
)
# load data # load data
if args.dataset == 'Cora': if args.dataset == "Cora":
dataset = CoraGraphDataset() dataset = CoraGraphDataset()
elif args.dataset == 'Citeseer': elif args.dataset == "Citeseer":
dataset = CiteseerGraphDataset() dataset = CiteseerGraphDataset()
elif args.dataset == 'Pubmed': elif args.dataset == "Pubmed":
dataset = PubmedGraphDataset() dataset = PubmedGraphDataset()
else: else:
raise ValueError('Dataset {} is invalid.'.format(args.dataset)) raise ValueError("Dataset {} is invalid.".format(args.dataset))
g = dataset[0] g = dataset[0]
g = dgl.add_self_loop(g) g = dgl.add_self_loop(g)
labels = g.ndata.pop('label').to(device).long() labels = g.ndata.pop("label").to(device).long()
# load masks for train / test, valid is not used. # load masks for train / test, valid is not used.
train_mask = g.ndata.pop('train_mask') train_mask = g.ndata.pop("train_mask")
test_mask = g.ndata.pop('test_mask') test_mask = g.ndata.pop("test_mask")
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device) train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze().to(device) test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze().to(device)
...@@ -37,19 +43,21 @@ def main(): ...@@ -37,19 +43,21 @@ def main():
lp = LabelPropagation(args.num_layers, args.alpha) lp = LabelPropagation(args.num_layers, args.alpha)
logits = lp(g, labels, mask=train_idx) logits = lp(g, labels, mask=train_idx)
test_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx) test_acc = torch.sum(
logits[test_idx].argmax(dim=1) == labels[test_idx]
).item() / len(test_idx)
print("Test Acc {:.4f}".format(test_acc)) print("Test Acc {:.4f}".format(test_acc))
if __name__ == '__main__': if __name__ == "__main__":
""" """
Label Propagation Hyperparameters Label Propagation Hyperparameters
""" """
parser = argparse.ArgumentParser(description='LP') parser = argparse.ArgumentParser(description="LP")
parser.add_argument('--gpu', type=int, default=0) parser.add_argument("--gpu", type=int, default=0)
parser.add_argument('--dataset', type=str, default='Cora') parser.add_argument("--dataset", type=str, default="Cora")
parser.add_argument('--num-layers', type=int, default=10) parser.add_argument("--num-layers", type=int, default=10)
parser.add_argument('--alpha', type=float, default=0.5) parser.add_argument("--alpha", type=float, default=0.5)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -17,49 +17,49 @@ ...@@ -17,49 +17,49 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from time import time
import matplotlib.pyplot as plt
import warnings import warnings
from time import time
import matplotlib.pyplot as plt
import numpy as np import numpy as np
import scipy.sparse as ss import scipy.sparse as ss
import torch import torch
import dgl from lda_model import LatentDirichletAllocation as LDAModel
from dgl import function as fn
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.decomposition import NMF, LatentDirichletAllocation
from sklearn.datasets import fetch_20newsgroups from sklearn.datasets import fetch_20newsgroups
from sklearn.decomposition import NMF, LatentDirichletAllocation
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from lda_model import LatentDirichletAllocation as LDAModel import dgl
from dgl import function as fn
n_samples = 2000 n_samples = 2000
n_features = 1000 n_features = 1000
n_components = 10 n_components = 10
n_top_words = 20 n_top_words = 20
device = 'cuda' device = "cuda"
def plot_top_words(model, feature_names, n_top_words, title): def plot_top_words(model, feature_names, n_top_words, title):
fig, axes = plt.subplots(2, 5, figsize=(30, 15), sharex=True) fig, axes = plt.subplots(2, 5, figsize=(30, 15), sharex=True)
axes = axes.flatten() axes = axes.flatten()
for topic_idx, topic in enumerate(model.components_): for topic_idx, topic in enumerate(model.components_):
top_features_ind = topic.argsort()[:-n_top_words - 1:-1] top_features_ind = topic.argsort()[: -n_top_words - 1 : -1]
top_features = [feature_names[i] for i in top_features_ind] top_features = [feature_names[i] for i in top_features_ind]
weights = topic[top_features_ind] weights = topic[top_features_ind]
ax = axes[topic_idx] ax = axes[topic_idx]
ax.barh(top_features, weights, height=0.7) ax.barh(top_features, weights, height=0.7)
ax.set_title(f'Topic {topic_idx +1}', ax.set_title(f"Topic {topic_idx +1}", fontdict={"fontsize": 30})
fontdict={'fontsize': 30})
ax.invert_yaxis() ax.invert_yaxis()
ax.tick_params(axis='both', which='major', labelsize=20) ax.tick_params(axis="both", which="major", labelsize=20)
for i in 'top right left'.split(): for i in "top right left".split():
ax.spines[i].set_visible(False) ax.spines[i].set_visible(False)
fig.suptitle(title, fontsize=40) fig.suptitle(title, fontsize=40)
plt.subplots_adjust(top=0.90, bottom=0.05, wspace=0.90, hspace=0.3) plt.subplots_adjust(top=0.90, bottom=0.05, wspace=0.90, hspace=0.3)
plt.show() plt.show()
# Load the 20 newsgroups dataset and vectorize it. We use a few heuristics # Load the 20 newsgroups dataset and vectorize it. We use a few heuristics
# to filter out useless terms early on: the posts are stripped of headers, # to filter out useless terms early on: the posts are stripped of headers,
# footers and quoted replies, and common English words, words occurring in # footers and quoted replies, and common English words, words occurring in
...@@ -67,43 +67,50 @@ def plot_top_words(model, feature_names, n_top_words, title): ...@@ -67,43 +67,50 @@ def plot_top_words(model, feature_names, n_top_words, title):
print("Loading dataset...") print("Loading dataset...")
t0 = time() t0 = time()
data, _ = fetch_20newsgroups(shuffle=True, random_state=1, data, _ = fetch_20newsgroups(
remove=('headers', 'footers', 'quotes'), shuffle=True,
return_X_y=True) random_state=1,
remove=("headers", "footers", "quotes"),
return_X_y=True,
)
data_samples = data[:n_samples] data_samples = data[:n_samples]
data_test = data[n_samples:2*n_samples] data_test = data[n_samples : 2 * n_samples]
print("done in %0.3fs." % (time() - t0)) print("done in %0.3fs." % (time() - t0))
# Use tf (raw term count) features for LDA. # Use tf (raw term count) features for LDA.
print("Extracting tf features for LDA...") print("Extracting tf features for LDA...")
tf_vectorizer = CountVectorizer(max_df=0.95, min_df=2, tf_vectorizer = CountVectorizer(
max_features=n_features, max_df=0.95, min_df=2, max_features=n_features, stop_words="english"
stop_words='english') )
t0 = time() t0 = time()
tf_vectorizer.fit(data) tf_vectorizer.fit(data)
tf = tf_vectorizer.transform(data_samples) tf = tf_vectorizer.transform(data_samples)
tt = tf_vectorizer.transform(data_test) tt = tf_vectorizer.transform(data_test)
tf_feature_names = tf_vectorizer.get_feature_names() tf_feature_names = tf_vectorizer.get_feature_names()
tf_uv = [(u,v) tf_uv = [
for u,v,e in zip(tf.tocoo().row, tf.tocoo().col, tf.tocoo().data) (u, v)
for _ in range(e)] for u, v, e in zip(tf.tocoo().row, tf.tocoo().col, tf.tocoo().data)
tt_uv = [(u,v) for _ in range(e)
for u,v,e in zip(tt.tocoo().row, tt.tocoo().col, tt.tocoo().data) ]
for _ in range(e)] tt_uv = [
(u, v)
for u, v, e in zip(tt.tocoo().row, tt.tocoo().col, tt.tocoo().data)
for _ in range(e)
]
print("done in %0.3fs." % (time() - t0)) print("done in %0.3fs." % (time() - t0))
print() print()
print("Preparing dgl graphs...") print("Preparing dgl graphs...")
t0 = time() t0 = time()
G = dgl.heterograph({('doc','topic','word'): tf_uv}, device=device) G = dgl.heterograph({("doc", "topic", "word"): tf_uv}, device=device)
Gt = dgl.heterograph({('doc','topic','word'): tt_uv}, device=device) Gt = dgl.heterograph({("doc", "topic", "word"): tt_uv}, device=device)
print("done in %0.3fs." % (time() - t0)) print("done in %0.3fs." % (time() - t0))
print() print()
print("Training dgl-lda model...") print("Training dgl-lda model...")
t0 = time() t0 = time()
model = LDAModel(G.num_nodes('word'), n_components) model = LDAModel(G.num_nodes("word"), n_components)
model.fit(G) model.fit(G)
print("done in %0.3fs." % (time() - t0)) print("done in %0.3fs." % (time() - t0))
print() print()
...@@ -113,20 +120,27 @@ print(f"dgl-lda testing perplexity {model.perplexity(Gt):.3f}") ...@@ -113,20 +120,27 @@ print(f"dgl-lda testing perplexity {model.perplexity(Gt):.3f}")
word_nphi = np.vstack([nphi.tolist() for nphi in model.word_data.nphi]) word_nphi = np.vstack([nphi.tolist() for nphi in model.word_data.nphi])
plot_top_words( plot_top_words(
type('dummy', (object,), {'components_': word_nphi}), type("dummy", (object,), {"components_": word_nphi}),
tf_feature_names, n_top_words, 'Topics in LDA model') tf_feature_names,
n_top_words,
"Topics in LDA model",
)
print("Training scikit-learn model...") print("Training scikit-learn model...")
print('\n' * 2, "Fitting LDA models with tf features, " print(
"n_samples=%d and n_features=%d..." "\n" * 2,
% (n_samples, n_features)) "Fitting LDA models with tf features, "
lda = LatentDirichletAllocation(n_components=n_components, max_iter=5, "n_samples=%d and n_features=%d..." % (n_samples, n_features),
learning_method='online', )
learning_offset=50., lda = LatentDirichletAllocation(
n_components=n_components,
max_iter=5,
learning_method="online",
learning_offset=50.0,
random_state=0, random_state=0,
verbose=1, verbose=1,
) )
t0 = time() t0 = time()
lda.fit(tf) lda.fit(tf)
print("done in %0.3fs." % (time() - t0)) print("done in %0.3fs." % (time() - t0))
......
...@@ -17,8 +17,17 @@ ...@@ -17,8 +17,17 @@
# limitations under the License. # limitations under the License.
import os, functools, warnings, torch, collections, dgl, io import collections
import numpy as np, scipy as sp import functools
import io
import os
import warnings
import numpy as np
import scipy as sp
import torch
import dgl
try: try:
from functools import cached_property from functools import cached_property
...@@ -37,17 +46,21 @@ class EdgeData: ...@@ -37,17 +46,21 @@ class EdgeData:
@property @property
def loglike(self): def loglike(self):
return (self.src_data['Elog'] + self.dst_data['Elog']).logsumexp(1) return (self.src_data["Elog"] + self.dst_data["Elog"]).logsumexp(1)
@property @property
def phi(self): def phi(self):
return ( return (
self.src_data['Elog'] + self.dst_data['Elog'] - self.loglike.unsqueeze(1) self.src_data["Elog"]
+ self.dst_data["Elog"]
- self.loglike.unsqueeze(1)
).exp() ).exp()
@property @property
def expectation(self): def expectation(self):
return (self.src_data['expectation'] * self.dst_data['expectation']).sum(1) return (
self.src_data["expectation"] * self.dst_data["expectation"]
).sum(1)
class _Dirichlet: class _Dirichlet:
...@@ -55,10 +68,13 @@ class _Dirichlet: ...@@ -55,10 +68,13 @@ class _Dirichlet:
self.prior = prior self.prior = prior
self.nphi = nphi self.nphi = nphi
self.device = nphi.device self.device = nphi.device
self._sum_by_parts = lambda map_fn: functools.reduce(torch.add, [ self._sum_by_parts = lambda map_fn: functools.reduce(
map_fn(slice(i, min(i+_chunksize, nphi.shape[1]))).sum(1) torch.add,
[
map_fn(slice(i, min(i + _chunksize, nphi.shape[1]))).sum(1)
for i in list(range(0, nphi.shape[1], _chunksize)) for i in list(range(0, nphi.shape[1], _chunksize))
]) ],
)
def _posterior(self, _ID=slice(None)): def _posterior(self, _ID=slice(None)):
return self.prior + self.nphi[:, _ID] return self.prior + self.nphi[:, _ID]
...@@ -68,8 +84,9 @@ class _Dirichlet: ...@@ -68,8 +84,9 @@ class _Dirichlet:
return self.nphi.sum(1) + self.prior * self.nphi.shape[1] return self.nphi.sum(1) + self.prior * self.nphi.shape[1]
def _Elog(self, _ID=slice(None)): def _Elog(self, _ID=slice(None)):
return torch.digamma(self._posterior(_ID)) - \ return torch.digamma(self._posterior(_ID)) - torch.digamma(
torch.digamma(self.posterior_sum.unsqueeze(1)) self.posterior_sum.unsqueeze(1)
)
@cached_property @cached_property
def loglike(self): def loglike(self):
...@@ -105,9 +122,15 @@ class _Dirichlet: ...@@ -105,9 +122,15 @@ class _Dirichlet:
@cached_property @cached_property
def Bayesian_gap(self): def Bayesian_gap(self):
return 1. - self._sum_by_parts(lambda s: self._Elog(s).exp()) return 1.0 - self._sum_by_parts(lambda s: self._Elog(s).exp())
_cached_properties = ["posterior_sum", "loglike", "n", "cdf", "Bayesian_gap"] _cached_properties = [
"posterior_sum",
"loglike",
"n",
"cdf",
"Bayesian_gap",
]
def clear_cache(self): def clear_cache(self):
for name in self._cached_properties: for name in self._cached_properties:
...@@ -117,27 +140,29 @@ class _Dirichlet: ...@@ -117,27 +140,29 @@ class _Dirichlet:
pass pass
def update(self, new, _ID=slice(None), rho=1): def update(self, new, _ID=slice(None), rho=1):
""" inplace: old * (1-rho) + new * rho """ """inplace: old * (1-rho) + new * rho"""
self.clear_cache() self.clear_cache()
mean_change = (self.nphi[:, _ID] - new).abs().mean().tolist() mean_change = (self.nphi[:, _ID] - new).abs().mean().tolist()
self.nphi *= (1 - rho) self.nphi *= 1 - rho
self.nphi[:, _ID] += new * rho self.nphi[:, _ID] += new * rho
return mean_change return mean_change
class DocData(_Dirichlet): class DocData(_Dirichlet):
""" nphi (n_docs by n_topics) """ """nphi (n_docs by n_topics)"""
def prepare_graph(self, G, key="Elog"): def prepare_graph(self, G, key="Elog"):
G.nodes['doc'].data[key] = getattr(self, '_'+key)().to(G.device) G.nodes["doc"].data[key] = getattr(self, "_" + key)().to(G.device)
def update_from(self, G, mult): def update_from(self, G, mult):
new = G.nodes['doc'].data['nphi'] * mult new = G.nodes["doc"].data["nphi"] * mult
return self.update(new.to(self.device)) return self.update(new.to(self.device))
class _Distributed(collections.UserList): class _Distributed(collections.UserList):
""" split on dim=0 and store on multiple devices """ """split on dim=0 and store on multiple devices"""
def __init__(self, prior, nphi): def __init__(self, prior, nphi):
self.prior = prior self.prior = prior
self.nphi = nphi self.nphi = nphi
...@@ -146,36 +171,38 @@ class _Distributed(collections.UserList): ...@@ -146,36 +171,38 @@ class _Distributed(collections.UserList):
def split_device(self, other, dim=0): def split_device(self, other, dim=0):
split_sections = [x.shape[0] for x in self.nphi] split_sections = [x.shape[0] for x in self.nphi]
out = torch.split(other, split_sections, dim) out = torch.split(other, split_sections, dim)
return [y.to(x.device) for x,y in zip(self.nphi, out)] return [y.to(x.device) for x, y in zip(self.nphi, out)]
class WordData(_Distributed): class WordData(_Distributed):
""" distributed nphi (n_topics by n_words), transpose to/from graph nodes data """ """distributed nphi (n_topics by n_words), transpose to/from graph nodes data"""
def prepare_graph(self, G, key="Elog"): def prepare_graph(self, G, key="Elog"):
if '_ID' in G.nodes['word'].data: if "_ID" in G.nodes["word"].data:
_ID = G.nodes['word'].data['_ID'] _ID = G.nodes["word"].data["_ID"]
else: else:
_ID = slice(None) _ID = slice(None)
out = [getattr(part, '_'+key)(_ID).to(G.device) for part in self] out = [getattr(part, "_" + key)(_ID).to(G.device) for part in self]
G.nodes['word'].data[key] = torch.cat(out).T G.nodes["word"].data[key] = torch.cat(out).T
def update_from(self, G, mult, rho): def update_from(self, G, mult, rho):
nphi = G.nodes['word'].data['nphi'].T * mult nphi = G.nodes["word"].data["nphi"].T * mult
if '_ID' in G.nodes['word'].data: if "_ID" in G.nodes["word"].data:
_ID = G.nodes['word'].data['_ID'] _ID = G.nodes["word"].data["_ID"]
else: else:
_ID = slice(None) _ID = slice(None)
mean_change = [x.update(y, _ID, rho) mean_change = [
for x, y in zip(self, self.split_device(nphi))] x.update(y, _ID, rho) for x, y in zip(self, self.split_device(nphi))
]
return np.mean(mean_change) return np.mean(mean_change)
class Gamma(collections.namedtuple('Gamma', "concentration, rate")): class Gamma(collections.namedtuple("Gamma", "concentration, rate")):
""" articulate the difference between torch gamma and numpy gamma """ """articulate the difference between torch gamma and numpy gamma"""
@property @property
def shape(self): def shape(self):
return self.concentration return self.concentration
...@@ -218,20 +245,23 @@ class LatentDirichletAllocation: ...@@ -218,20 +245,23 @@ class LatentDirichletAllocation:
(NIPS 2010). (NIPS 2010).
[2] Reactive LDA Library blogpost by Yingjie Miao for a similar Gibbs model [2] Reactive LDA Library blogpost by Yingjie Miao for a similar Gibbs model
""" """
def __init__( def __init__(
self, n_words, n_components, self,
n_words,
n_components,
prior=None, prior=None,
rho=1, rho=1,
mult={'doc': 1, 'word': 1}, mult={"doc": 1, "word": 1},
init={'doc': (100., 100.), 'word': (100., 100.)}, init={"doc": (100.0, 100.0), "word": (100.0, 100.0)},
device_list=['cpu'], device_list=["cpu"],
verbose=True, verbose=True,
): ):
self.n_words = n_words self.n_words = n_words
self.n_components = n_components self.n_components = n_components
if prior is None: if prior is None:
prior = {'doc': 1./n_components, 'word': 1./n_components} prior = {"doc": 1.0 / n_components, "word": 1.0 / n_components}
self.prior = prior self.prior = prior
self.rho = rho self.rho = rho
...@@ -244,46 +274,46 @@ class LatentDirichletAllocation: ...@@ -244,46 +274,46 @@ class LatentDirichletAllocation:
self._init_word_data() self._init_word_data()
def _init_word_data(self): def _init_word_data(self):
split_sections = np.diff( split_sections = np.diff(
np.linspace(0, self.n_components, len(self.device_list)+1).astype(int) np.linspace(0, self.n_components, len(self.device_list) + 1).astype(
int
)
) )
word_nphi = [ word_nphi = [
Gamma(*self.init['word']).sample((s, self.n_words), device) Gamma(*self.init["word"]).sample((s, self.n_words), device)
for s, device in zip(split_sections, self.device_list) for s, device in zip(split_sections, self.device_list)
] ]
self.word_data = WordData(self.prior['word'], word_nphi) self.word_data = WordData(self.prior["word"], word_nphi)
def _init_doc_data(self, n_docs, device): def _init_doc_data(self, n_docs, device):
doc_nphi = Gamma(*self.init['doc']).sample( doc_nphi = Gamma(*self.init["doc"]).sample(
(n_docs, self.n_components), device) (n_docs, self.n_components), device
return DocData(self.prior['doc'], doc_nphi) )
return DocData(self.prior["doc"], doc_nphi)
def save(self, f): def save(self, f):
for w in self.word_data: for w in self.word_data:
w.clear_cache() w.clear_cache()
torch.save({ torch.save(
'prior': self.prior, {
'rho': self.rho, "prior": self.prior,
'mult': self.mult, "rho": self.rho,
'init': self.init, "mult": self.mult,
'word_data': [part.nphi for part in self.word_data], "init": self.init,
}, f) "word_data": [part.nphi for part in self.word_data],
},
f,
)
def _prepare_graph(self, G, doc_data, key="Elog"): def _prepare_graph(self, G, doc_data, key="Elog"):
doc_data.prepare_graph(G, key) doc_data.prepare_graph(G, key)
self.word_data.prepare_graph(G, key) self.word_data.prepare_graph(G, key)
def _e_step(self, G, doc_data=None, mean_change_tol=1e-3, max_iters=100): def _e_step(self, G, doc_data=None, mean_change_tol=1e-3, max_iters=100):
"""_e_step implements doc data sampling until convergence or max_iters """_e_step implements doc data sampling until convergence or max_iters"""
"""
if doc_data is None: if doc_data is None:
doc_data = self._init_doc_data(G.num_nodes('doc'), G.device) doc_data = self._init_doc_data(G.num_nodes("doc"), G.device)
G_rev = G.reverse() # word -> doc G_rev = G.reverse() # word -> doc
self.word_data.prepare_graph(G_rev) self.word_data.prepare_graph(G_rev)
...@@ -291,65 +321,76 @@ class LatentDirichletAllocation: ...@@ -291,65 +321,76 @@ class LatentDirichletAllocation:
for i in range(max_iters): for i in range(max_iters):
doc_data.prepare_graph(G_rev) doc_data.prepare_graph(G_rev)
G_rev.update_all( G_rev.update_all(
lambda edges: {'phi': EdgeData(edges.src, edges.dst).phi}, lambda edges: {"phi": EdgeData(edges.src, edges.dst).phi},
dgl.function.sum('phi', 'nphi') dgl.function.sum("phi", "nphi"),
) )
mean_change = doc_data.update_from(G_rev, self.mult['doc']) mean_change = doc_data.update_from(G_rev, self.mult["doc"])
if mean_change < mean_change_tol: if mean_change < mean_change_tol:
break break
if self.verbose: if self.verbose:
print(f"e-step num_iters={i+1} with mean_change={mean_change:.4f}, " print(
f"perplexity={self.perplexity(G, doc_data):.4f}") f"e-step num_iters={i+1} with mean_change={mean_change:.4f}, "
f"perplexity={self.perplexity(G, doc_data):.4f}"
)
return doc_data return doc_data
transform = _e_step transform = _e_step
def predict(self, doc_data): def predict(self, doc_data):
pred_scores = [ pred_scores = [
# d_exp @ w._expectation() # d_exp @ w._expectation()
(lambda x: x @ w.nphi + x.sum(1, keepdims=True) * w.prior) (lambda x: x @ w.nphi + x.sum(1, keepdims=True) * w.prior)(
(d_exp / w.posterior_sum.unsqueeze(0)) d_exp / w.posterior_sum.unsqueeze(0)
)
for (d_exp, w) in zip( for (d_exp, w) in zip(
self.word_data.split_device(doc_data._expectation(), dim=1), self.word_data.split_device(doc_data._expectation(), dim=1),
self.word_data) self.word_data,
)
] ]
x = torch.zeros_like(pred_scores[0], device=doc_data.device) x = torch.zeros_like(pred_scores[0], device=doc_data.device)
for p in pred_scores: for p in pred_scores:
x += p.to(x.device) x += p.to(x.device)
return x return x
def sample(self, doc_data, num_samples): def sample(self, doc_data, num_samples):
""" draw independent words and return the marginal probabilities, """draw independent words and return the marginal probabilities,
i.e., the expectations in Dirichlet distributions. i.e., the expectations in Dirichlet distributions.
""" """
def fn(cdf): def fn(cdf):
u = torch.rand(cdf.shape[0], num_samples, device=cdf.device) u = torch.rand(cdf.shape[0], num_samples, device=cdf.device)
return torch.searchsorted(cdf, u).to(doc_data.device) return torch.searchsorted(cdf, u).to(doc_data.device)
topic_ids = fn(doc_data.cdf) topic_ids = fn(doc_data.cdf)
word_ids = torch.cat([fn(part.cdf) for part in self.word_data]) word_ids = torch.cat([fn(part.cdf) for part in self.word_data])
ids = torch.gather(word_ids, 0, topic_ids) # pick components by topic_ids ids = torch.gather(
word_ids, 0, topic_ids
) # pick components by topic_ids
# compute expectation scores on sampled ids # compute expectation scores on sampled ids
src_ids = torch.arange( src_ids = (
ids.shape[0], dtype=ids.dtype, device=ids.device torch.arange(ids.shape[0], dtype=ids.dtype, device=ids.device)
).reshape((-1, 1)).expand(ids.shape) .reshape((-1, 1))
unique_ids, inverse_ids = torch.unique(ids, sorted=False, return_inverse=True) .expand(ids.shape)
)
unique_ids, inverse_ids = torch.unique(
ids, sorted=False, return_inverse=True
)
G = dgl.heterograph({('doc','','word'): (src_ids.ravel(), inverse_ids.ravel())}) G = dgl.heterograph(
G.nodes['word'].data['_ID'] = unique_ids {("doc", "", "word"): (src_ids.ravel(), inverse_ids.ravel())}
)
G.nodes["word"].data["_ID"] = unique_ids
self._prepare_graph(G, doc_data, "expectation") self._prepare_graph(G, doc_data, "expectation")
G.apply_edges(lambda e: {'expectation': EdgeData(e.src, e.dst).expectation}) G.apply_edges(
expectation = G.edata.pop('expectation').reshape(ids.shape) lambda e: {"expectation": EdgeData(e.src, e.dst).expectation}
)
expectation = G.edata.pop("expectation").reshape(ids.shape)
return ids, expectation return ids, expectation
def _m_step(self, G, doc_data): def _m_step(self, G, doc_data):
"""_m_step implements word data sampling and stores word_z stats. """_m_step implements word data sampling and stores word_z stats.
mean_change is in the sense of full graph with rho=1. mean_change is in the sense of full graph with rho=1.
...@@ -357,26 +398,25 @@ class LatentDirichletAllocation: ...@@ -357,26 +398,25 @@ class LatentDirichletAllocation:
G = G.clone() G = G.clone()
self._prepare_graph(G, doc_data) self._prepare_graph(G, doc_data)
G.update_all( G.update_all(
lambda edges: {'phi': EdgeData(edges.src, edges.dst).phi}, lambda edges: {"phi": EdgeData(edges.src, edges.dst).phi},
dgl.function.sum('phi', 'nphi') dgl.function.sum("phi", "nphi"),
) )
self._last_mean_change = self.word_data.update_from( self._last_mean_change = self.word_data.update_from(
G, self.mult['word'], self.rho) G, self.mult["word"], self.rho
)
if self.verbose: if self.verbose:
print(f"m-step mean_change={self._last_mean_change:.4f}, ", end="") print(f"m-step mean_change={self._last_mean_change:.4f}, ", end="")
Bayesian_gap = np.mean([ Bayesian_gap = np.mean(
part.Bayesian_gap.mean().tolist() for part in self.word_data [part.Bayesian_gap.mean().tolist() for part in self.word_data]
]) )
print(f"Bayesian_gap={Bayesian_gap:.4f}") print(f"Bayesian_gap={Bayesian_gap:.4f}")
def partial_fit(self, G): def partial_fit(self, G):
doc_data = self._e_step(G) doc_data = self._e_step(G)
self._m_step(G, doc_data) self._m_step(G, doc_data)
return self return self
def fit(self, G, mean_change_tol=1e-3, max_epochs=10): def fit(self, G, mean_change_tol=1e-3, max_epochs=10):
for i in range(max_epochs): for i in range(max_epochs):
if self.verbose: if self.verbose:
...@@ -387,7 +427,6 @@ class LatentDirichletAllocation: ...@@ -387,7 +427,6 @@ class LatentDirichletAllocation:
break break
return self return self
def perplexity(self, G, doc_data=None): def perplexity(self, G, doc_data=None):
"""ppl = exp{-sum[log(p(w1,...,wn|d))] / n} """ppl = exp{-sum[log(p(w1,...,wn|d))] / n}
Follows Eq (15) in Hoffman et al., 2010. Follows Eq (15) in Hoffman et al., 2010.
...@@ -398,45 +437,50 @@ class LatentDirichletAllocation: ...@@ -398,45 +437,50 @@ class LatentDirichletAllocation:
# compute E[log p(docs | theta, beta)] # compute E[log p(docs | theta, beta)]
G = G.clone() G = G.clone()
self._prepare_graph(G, doc_data) self._prepare_graph(G, doc_data)
G.apply_edges(lambda edges: {'loglike': EdgeData(edges.src, edges.dst).loglike}) G.apply_edges(
edge_elbo = (G.edata['loglike'].sum() / G.num_edges()).tolist() lambda edges: {"loglike": EdgeData(edges.src, edges.dst).loglike}
)
edge_elbo = (G.edata["loglike"].sum() / G.num_edges()).tolist()
if self.verbose: if self.verbose:
print(f'neg_elbo phi: {-edge_elbo:.3f}', end=' ') print(f"neg_elbo phi: {-edge_elbo:.3f}", end=" ")
# compute E[log p(theta | alpha) - log q(theta | gamma)] # compute E[log p(theta | alpha) - log q(theta | gamma)]
doc_elbo = (doc_data.loglike.sum() / doc_data.n.sum()).tolist() doc_elbo = (doc_data.loglike.sum() / doc_data.n.sum()).tolist()
if self.verbose: if self.verbose:
print(f'theta: {-doc_elbo:.3f}', end=' ') print(f"theta: {-doc_elbo:.3f}", end=" ")
# compute E[log p(beta | eta) - log q(beta | lambda)] # compute E[log p(beta | eta) - log q(beta | lambda)]
# The denominator n for extrapolation perplexity is undefined. # The denominator n for extrapolation perplexity is undefined.
# We use the train set, whereas sklearn uses the test set. # We use the train set, whereas sklearn uses the test set.
word_elbo = ( word_elbo = sum(
sum([part.loglike.sum().tolist() for part in self.word_data]) [part.loglike.sum().tolist() for part in self.word_data]
/ sum([part.n.sum().tolist() for part in self.word_data]) ) / sum([part.n.sum().tolist() for part in self.word_data])
)
if self.verbose: if self.verbose:
print(f'beta: {-word_elbo:.3f}') print(f"beta: {-word_elbo:.3f}")
ppl = np.exp(-edge_elbo - doc_elbo - word_elbo) ppl = np.exp(-edge_elbo - doc_elbo - word_elbo)
if G.num_edges()>0 and np.isnan(ppl): if G.num_edges() > 0 and np.isnan(ppl):
warnings.warn("numerical issue in perplexity") warnings.warn("numerical issue in perplexity")
return ppl return ppl
def doc_subgraph(G, doc_ids): def doc_subgraph(G, doc_ids):
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
_, _, (block,) = sampler.sample(G.reverse(), {'doc': torch.as_tensor(doc_ids)}) _, _, (block,) = sampler.sample(
G.reverse(), {"doc": torch.as_tensor(doc_ids)}
)
B = dgl.DGLHeteroGraph( B = dgl.DGLHeteroGraph(
block._graph, ['_', 'word', 'doc', '_'], block.etypes block._graph, ["_", "word", "doc", "_"], block.etypes
).reverse() ).reverse()
B.nodes['word'].data['_ID'] = block.nodes['word'].data['_ID'] B.nodes["word"].data["_ID"] = block.nodes["word"].data["_ID"]
return B return B
if __name__ == '__main__': if __name__ == "__main__":
print('Testing LatentDirichletAllocation ...') print("Testing LatentDirichletAllocation ...")
G = dgl.heterograph({('doc', '', 'word'): [(0, 0), (1, 3)]}, {'doc': 2, 'word': 5}) G = dgl.heterograph(
{("doc", "", "word"): [(0, 0), (1, 3)]}, {"doc": 2, "word": 5}
)
model = LatentDirichletAllocation(n_words=5, n_components=10, verbose=False) model = LatentDirichletAllocation(n_words=5, n_components=10, verbose=False)
model.fit(G) model.fit(G)
model.transform(G) model.transform(G)
...@@ -454,4 +498,4 @@ if __name__ == '__main__': ...@@ -454,4 +498,4 @@ if __name__ == '__main__':
f.seek(0) f.seek(0)
print(torch.load(f)) print(torch.load(f))
print('Testing LatentDirichletAllocation passed!') print("Testing LatentDirichletAllocation passed!")
...@@ -6,11 +6,12 @@ Author's implementation: https://github.com/joanbruna/GNN_community ...@@ -6,11 +6,12 @@ Author's implementation: https://github.com/joanbruna/GNN_community
""" """
from __future__ import division from __future__ import division
import time
import argparse import argparse
import time
from itertools import permutations from itertools import permutations
import gnn
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn.functional as F import torch.nn.functional as F
...@@ -18,37 +19,51 @@ import torch.optim as optim ...@@ -18,37 +19,51 @@ import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from dgl.data import SBMMixtureDataset from dgl.data import SBMMixtureDataset
import gnn
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int, help='Batch size', default=1) parser.add_argument("--batch-size", type=int, help="Batch size", default=1)
parser.add_argument('--gpu', type=int, help='GPU index', default=-1) parser.add_argument("--gpu", type=int, help="GPU index", default=-1)
parser.add_argument('--lr', type=float, help='Learning rate', default=0.001) parser.add_argument("--lr", type=float, help="Learning rate", default=0.001)
parser.add_argument('--n-communities', type=int, help='Number of communities', default=2) parser.add_argument(
parser.add_argument('--n-epochs', type=int, help='Number of epochs', default=100) "--n-communities", type=int, help="Number of communities", default=2
parser.add_argument('--n-features', type=int, help='Number of features', default=16) )
parser.add_argument('--n-graphs', type=int, help='Number of graphs', default=10) parser.add_argument(
parser.add_argument('--n-layers', type=int, help='Number of layers', default=30) "--n-epochs", type=int, help="Number of epochs", default=100
parser.add_argument('--n-nodes', type=int, help='Number of nodes', default=10000) )
parser.add_argument('--optim', type=str, help='Optimizer', default='Adam') parser.add_argument(
parser.add_argument('--radius', type=int, help='Radius', default=3) "--n-features", type=int, help="Number of features", default=16
parser.add_argument('--verbose', action='store_true') )
parser.add_argument("--n-graphs", type=int, help="Number of graphs", default=10)
parser.add_argument("--n-layers", type=int, help="Number of layers", default=30)
parser.add_argument(
"--n-nodes", type=int, help="Number of nodes", default=10000
)
parser.add_argument("--optim", type=str, help="Optimizer", default="Adam")
parser.add_argument("--radius", type=int, help="Radius", default=3)
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args() args = parser.parse_args()
dev = th.device('cpu') if args.gpu < 0 else th.device('cuda:%d' % args.gpu) dev = th.device("cpu") if args.gpu < 0 else th.device("cuda:%d" % args.gpu)
K = args.n_communities K = args.n_communities
training_dataset = SBMMixtureDataset(args.n_graphs, args.n_nodes, K) training_dataset = SBMMixtureDataset(args.n_graphs, args.n_nodes, K)
training_loader = DataLoader(training_dataset, args.batch_size, training_loader = DataLoader(
collate_fn=training_dataset.collate_fn, drop_last=True) training_dataset,
args.batch_size,
collate_fn=training_dataset.collate_fn,
drop_last=True,
)
ones = th.ones(args.n_nodes // K) ones = th.ones(args.n_nodes // K)
y_list = [th.cat([x * ones for x in p]).long().to(dev) for p in permutations(range(K))] y_list = [
th.cat([x * ones for x in p]).long().to(dev) for p in permutations(range(K))
]
feats = [1] + [args.n_features] * args.n_layers + [K] feats = [1] + [args.n_features] * args.n_layers + [K]
model = gnn.GNN(feats, args.radius, K).to(dev) model = gnn.GNN(feats, args.radius, K).to(dev)
optimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr) optimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr)
def compute_overlap(z_list): def compute_overlap(z_list):
ybar_list = [th.max(z, 1)[1] for z in z_list] ybar_list = [th.max(z, 1)[1] for z in z_list]
overlap_list = [] overlap_list = []
...@@ -58,15 +73,20 @@ def compute_overlap(z_list): ...@@ -58,15 +73,20 @@ def compute_overlap(z_list):
overlap_list.append(overlap) overlap_list.append(overlap)
return sum(overlap_list) / len(overlap_list) return sum(overlap_list) / len(overlap_list)
def from_np(f, *args): def from_np(f, *args):
def wrap(*args): def wrap(*args):
new = [th.from_numpy(x) if isinstance(x, np.ndarray) else x for x in args] new = [
th.from_numpy(x) if isinstance(x, np.ndarray) else x for x in args
]
return f(*new) return f(*new)
return wrap return wrap
@from_np @from_np
def step(i, j, g, lg, deg_g, deg_lg, pm_pd): def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
""" One step of training. """ """One step of training."""
g = g.to(dev) g = g.to(dev)
lg = lg.to(dev) lg = lg.to(dev)
deg_g = deg_g.to(dev).unsqueeze(1) deg_g = deg_g.to(dev).unsqueeze(1)
...@@ -77,7 +97,10 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd): ...@@ -77,7 +97,10 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
t_forward = time.time() - t0 t_forward = time.time() - t0
z_list = th.chunk(z, args.batch_size, 0) z_list = th.chunk(z, args.batch_size, 0)
loss = sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list) / args.batch_size loss = (
sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list)
/ args.batch_size
)
overlap = compute_overlap(z_list) overlap = compute_overlap(z_list)
optimizer.zero_grad() optimizer.zero_grad()
...@@ -88,6 +111,7 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd): ...@@ -88,6 +111,7 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
return loss, overlap, t_forward, t_backward return loss, overlap, t_forward, t_backward
@from_np @from_np
def inference(g, lg, deg_g, deg_lg, pm_pd): def inference(g, lg, deg_g, deg_lg, pm_pd):
g = g.to(dev) g = g.to(dev)
...@@ -99,9 +123,11 @@ def inference(g, lg, deg_g, deg_lg, pm_pd): ...@@ -99,9 +123,11 @@ def inference(g, lg, deg_g, deg_lg, pm_pd):
z = model(g, lg, deg_g, deg_lg, pm_pd) z = model(g, lg, deg_g, deg_lg, pm_pd)
return z return z
def test(): def test():
p_list =[6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0] p_list = [6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0]
q_list =[0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6] q_list = [0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6]
N = 1 N = 1
overlap_list = [] overlap_list = []
for p, q in zip(p_list, q_list): for p, q in zip(p_list, q_list):
...@@ -112,31 +138,38 @@ def test(): ...@@ -112,31 +138,38 @@ def test():
overlap_list.append(compute_overlap(th.chunk(z, N, 0))) overlap_list.append(compute_overlap(th.chunk(z, N, 0)))
return overlap_list return overlap_list
n_iterations = args.n_graphs // args.batch_size n_iterations = args.n_graphs // args.batch_size
for i in range(args.n_epochs): for i in range(args.n_epochs):
total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0 total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0
for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader): for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader):
loss, overlap, t_forward, t_backward = step(i, j, g, lg, deg_g, deg_lg, pm_pd) loss, overlap, t_forward, t_backward = step(
i, j, g, lg, deg_g, deg_lg, pm_pd
)
total_loss += loss total_loss += loss
total_overlap += overlap total_overlap += overlap
s_forward += t_forward s_forward += t_forward
s_backward += t_backward s_backward += t_backward
epoch = '0' * (len(str(args.n_epochs)) - len(str(i))) epoch = "0" * (len(str(args.n_epochs)) - len(str(i)))
iteration = '0' * (len(str(n_iterations)) - len(str(j))) iteration = "0" * (len(str(n_iterations)) - len(str(j)))
if args.verbose: if args.verbose:
print('[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f' print(
% (epoch, i, iteration, j, loss, overlap)) "[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f"
% (epoch, i, iteration, j, loss, overlap)
)
epoch = '0' * (len(str(args.n_epochs)) - len(str(i))) epoch = "0" * (len(str(args.n_epochs)) - len(str(i)))
loss = total_loss / (j + 1) loss = total_loss / (j + 1)
overlap = total_overlap / (j + 1) overlap = total_overlap / (j + 1)
t_forward = s_forward / (j + 1) t_forward = s_forward / (j + 1)
t_backward = s_backward / (j + 1) t_backward = s_backward / (j + 1)
print('[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs' print(
% (epoch, i, loss, overlap, t_forward, t_backward)) "[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs"
% (epoch, i, loss, overlap, t_forward, t_backward)
)
overlap_list = test() overlap_list = test()
overlap_str = ' - '.join(['%.3f' % overlap for overlap in overlap_list]) overlap_str = " - ".join(["%.3f" % overlap for overlap in overlap_list])
print('[epoch %s%d]overlap: %s' % (epoch, i, overlap_str)) print("[epoch %s%d]overlap: %s" % (epoch, i, overlap_str))
...@@ -2,16 +2,18 @@ ...@@ -2,16 +2,18 @@
import argparse import argparse
import copy import copy
import random
import numpy as np
import torch import torch
import torch.optim as optim
import torch.nn as nn import torch.nn as nn
import numpy as np import torch.optim as optim
import random from tqdm import trange
import dgl import dgl
import dgl.function as fn import dgl.function as fn
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from tqdm import trange
class MixHopConv(nn.Module): class MixHopConv(nn.Module):
r""" r"""
...@@ -44,13 +46,16 @@ class MixHopConv(nn.Module): ...@@ -44,13 +46,16 @@ class MixHopConv(nn.Module):
batchnorm: bool, optional batchnorm: bool, optional
If True, use batch normalization. Defaults: ``False``. If True, use batch normalization. Defaults: ``False``.
""" """
def __init__(self,
def __init__(
self,
in_dim, in_dim,
out_dim, out_dim,
p=[0, 1, 2], p=[0, 1, 2],
dropout=0, dropout=0,
activation=None, activation=None,
batchnorm=False): batchnorm=False,
):
super(MixHopConv, self).__init__() super(MixHopConv, self).__init__()
self.in_dim = in_dim self.in_dim = in_dim
self.out_dim = out_dim self.out_dim = out_dim
...@@ -66,9 +71,9 @@ class MixHopConv(nn.Module): ...@@ -66,9 +71,9 @@ class MixHopConv(nn.Module):
self.bn = nn.BatchNorm1d(out_dim * len(p)) self.bn = nn.BatchNorm1d(out_dim * len(p))
# define weight dict for each power j # define weight dict for each power j
self.weights = nn.ModuleDict({ self.weights = nn.ModuleDict(
str(j): nn.Linear(in_dim, out_dim, bias=False) for j in p {str(j): nn.Linear(in_dim, out_dim, bias=False) for j in p}
}) )
def forward(self, graph, feats): def forward(self, graph, feats):
with graph.local_scope(): with graph.local_scope():
...@@ -84,9 +89,9 @@ class MixHopConv(nn.Module): ...@@ -84,9 +89,9 @@ class MixHopConv(nn.Module):
outputs.append(output) outputs.append(output)
feats = feats * norm feats = feats * norm
graph.ndata['h'] = feats graph.ndata["h"] = feats
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
feats = graph.ndata.pop('h') feats = graph.ndata.pop("h")
feats = feats * norm feats = feats * norm
final = torch.cat(outputs, dim=1) final = torch.cat(outputs, dim=1)
...@@ -101,8 +106,10 @@ class MixHopConv(nn.Module): ...@@ -101,8 +106,10 @@ class MixHopConv(nn.Module):
return final return final
class MixHop(nn.Module): class MixHop(nn.Module):
def __init__(self, def __init__(
self,
in_dim, in_dim,
hid_dim, hid_dim,
out_dim, out_dim,
...@@ -111,7 +118,8 @@ class MixHop(nn.Module): ...@@ -111,7 +118,8 @@ class MixHop(nn.Module):
input_dropout=0.0, input_dropout=0.0,
layer_dropout=0.0, layer_dropout=0.0,
activation=None, activation=None,
batchnorm=False): batchnorm=False,
):
super(MixHop, self).__init__() super(MixHop, self).__init__()
self.in_dim = in_dim self.in_dim = in_dim
self.hid_dim = hid_dim self.hid_dim = hid_dim
...@@ -127,23 +135,33 @@ class MixHop(nn.Module): ...@@ -127,23 +135,33 @@ class MixHop(nn.Module):
self.dropout = nn.Dropout(self.input_dropout) self.dropout = nn.Dropout(self.input_dropout)
# Input layer # Input layer
self.layers.append(MixHopConv(self.in_dim, self.layers.append(
MixHopConv(
self.in_dim,
self.hid_dim, self.hid_dim,
p=self.p, p=self.p,
dropout=self.input_dropout, dropout=self.input_dropout,
activation=self.activation, activation=self.activation,
batchnorm=self.batchnorm)) batchnorm=self.batchnorm,
)
)
# Hidden layers with n - 1 MixHopConv layers # Hidden layers with n - 1 MixHopConv layers
for i in range(self.num_layers - 2): for i in range(self.num_layers - 2):
self.layers.append(MixHopConv(self.hid_dim * len(args.p), self.layers.append(
MixHopConv(
self.hid_dim * len(args.p),
self.hid_dim, self.hid_dim,
p=self.p, p=self.p,
dropout=self.layer_dropout, dropout=self.layer_dropout,
activation=self.activation, activation=self.activation,
batchnorm=self.batchnorm)) batchnorm=self.batchnorm,
)
)
self.fc_layers = nn.Linear(self.hid_dim * len(args.p), self.out_dim, bias=False) self.fc_layers = nn.Linear(
self.hid_dim * len(args.p), self.out_dim, bias=False
)
def forward(self, graph, feats): def forward(self, graph, feats):
feats = self.dropout(feats) feats = self.dropout(feats)
...@@ -154,41 +172,42 @@ class MixHop(nn.Module): ...@@ -154,41 +172,42 @@ class MixHop(nn.Module):
return feats return feats
def main(args): def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= # # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# Load from DGL dataset # Load from DGL dataset
if args.dataset == 'Cora': if args.dataset == "Cora":
dataset = CoraGraphDataset() dataset = CoraGraphDataset()
elif args.dataset == 'Citeseer': elif args.dataset == "Citeseer":
dataset = CiteseerGraphDataset() dataset = CiteseerGraphDataset()
elif args.dataset == 'Pubmed': elif args.dataset == "Pubmed":
dataset = PubmedGraphDataset() dataset = PubmedGraphDataset()
else: else:
raise ValueError('Dataset {} is invalid.'.format(args.dataset)) raise ValueError("Dataset {} is invalid.".format(args.dataset))
graph = dataset[0] graph = dataset[0]
graph = dgl.add_self_loop(graph) graph = dgl.add_self_loop(graph)
# check cuda # check cuda
if args.gpu >= 0 and torch.cuda.is_available(): if args.gpu >= 0 and torch.cuda.is_available():
device = 'cuda:{}'.format(args.gpu) device = "cuda:{}".format(args.gpu)
else: else:
device = 'cpu' device = "cpu"
# retrieve the number of classes # retrieve the number of classes
n_classes = dataset.num_classes n_classes = dataset.num_classes
# retrieve labels of ground truth # retrieve labels of ground truth
labels = graph.ndata.pop('label').to(device).long() labels = graph.ndata.pop("label").to(device).long()
# Extract node features # Extract node features
feats = graph.ndata.pop('feat').to(device) feats = graph.ndata.pop("feat").to(device)
n_features = feats.shape[-1] n_features = feats.shape[-1]
# retrieve masks for train/validation/test # retrieve masks for train/validation/test
train_mask = graph.ndata.pop('train_mask') train_mask = graph.ndata.pop("train_mask")
val_mask = graph.ndata.pop('val_mask') val_mask = graph.ndata.pop("val_mask")
test_mask = graph.ndata.pop('test_mask') test_mask = graph.ndata.pop("test_mask")
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device) train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)
val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze().to(device) val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze().to(device)
...@@ -197,7 +216,8 @@ def main(args): ...@@ -197,7 +216,8 @@ def main(args):
graph = graph.to(device) graph = graph.to(device)
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
model = MixHop(in_dim=n_features, model = MixHop(
in_dim=n_features,
hid_dim=args.hid_dim, hid_dim=args.hid_dim,
out_dim=n_classes, out_dim=n_classes,
num_layers=args.num_layers, num_layers=args.num_layers,
...@@ -205,7 +225,8 @@ def main(args): ...@@ -205,7 +225,8 @@ def main(args):
input_dropout=args.input_dropout, input_dropout=args.input_dropout,
layer_dropout=args.layer_dropout, layer_dropout=args.layer_dropout,
activation=torch.tanh, activation=torch.tanh,
batchnorm=True) batchnorm=True,
)
model = model.to(device) model = model.to(device)
best_model = copy.deepcopy(model) best_model = copy.deepcopy(model)
...@@ -218,7 +239,7 @@ def main(args): ...@@ -218,7 +239,7 @@ def main(args):
# Step 4: training epoches =============================================================== # # Step 4: training epoches =============================================================== #
acc = 0 acc = 0
no_improvement = 0 no_improvement = 0
epochs = trange(args.epochs, desc='Accuracy & Loss') epochs = trange(args.epochs, desc="Accuracy & Loss")
for _ in epochs: for _ in epochs:
# Training using a full graph # Training using a full graph
...@@ -228,7 +249,9 @@ def main(args): ...@@ -228,7 +249,9 @@ def main(args):
# compute loss # compute loss
train_loss = loss_fn(logits[train_idx], labels[train_idx]) train_loss = loss_fn(logits[train_idx], labels[train_idx])
train_acc = torch.sum(logits[train_idx].argmax(dim=1) == labels[train_idx]).item() / len(train_idx) train_acc = torch.sum(
logits[train_idx].argmax(dim=1) == labels[train_idx]
).item() / len(train_idx)
# backward # backward
opt.zero_grad() opt.zero_grad()
...@@ -240,16 +263,21 @@ def main(args): ...@@ -240,16 +263,21 @@ def main(args):
with torch.no_grad(): with torch.no_grad():
valid_loss = loss_fn(logits[val_idx], labels[val_idx]) valid_loss = loss_fn(logits[val_idx], labels[val_idx])
valid_acc = torch.sum(logits[val_idx].argmax(dim=1) == labels[val_idx]).item() / len(val_idx) valid_acc = torch.sum(
logits[val_idx].argmax(dim=1) == labels[val_idx]
).item() / len(val_idx)
# Print out performance # Print out performance
epochs.set_description('Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}'.format( epochs.set_description(
train_acc, train_loss.item(), valid_acc, valid_loss.item())) "Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}".format(
train_acc, train_loss.item(), valid_acc, valid_loss.item()
)
)
if valid_acc < acc: if valid_acc < acc:
no_improvement += 1 no_improvement += 1
if no_improvement == args.early_stopping: if no_improvement == args.early_stopping:
print('Early stop.') print("Early stop.")
break break
else: else:
no_improvement = 0 no_improvement = 0
...@@ -260,34 +288,74 @@ def main(args): ...@@ -260,34 +288,74 @@ def main(args):
best_model.eval() best_model.eval()
logits = best_model(graph, feats) logits = best_model(graph, feats)
test_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx) test_acc = torch.sum(
logits[test_idx].argmax(dim=1) == labels[test_idx]
).item() / len(test_idx)
print("Test Acc {:.4f}".format(test_acc)) print("Test Acc {:.4f}".format(test_acc))
return test_acc return test_acc
if __name__ == "__main__": if __name__ == "__main__":
""" """
MixHop Model Hyperparameters MixHop Model Hyperparameters
""" """
parser = argparse.ArgumentParser(description='MixHop GCN') parser = argparse.ArgumentParser(description="MixHop GCN")
# data source params # data source params
parser.add_argument('--dataset', type=str, default='Cora', help='Name of dataset.') parser.add_argument(
"--dataset", type=str, default="Cora", help="Name of dataset."
)
# cuda params # cuda params
parser.add_argument('--gpu', type=int, default=-1, help='GPU index. Default: -1, using CPU.') parser.add_argument(
"--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
)
# training params # training params
parser.add_argument('--epochs', type=int, default=2000, help='Training epochs.') parser.add_argument(
parser.add_argument('--early-stopping', type=int, default=200, help='Patient epochs to wait before early stopping.') "--epochs", type=int, default=2000, help="Training epochs."
parser.add_argument('--lr', type=float, default=0.5, help='Learning rate.') )
parser.add_argument('--lamb', type=float, default=5e-4, help='L2 reg.') parser.add_argument(
parser.add_argument('--step-size', type=int, default=40, help='Period of learning rate decay.') "--early-stopping",
parser.add_argument('--gamma', type=float, default=0.01, help='Multiplicative factor of learning rate decay.') type=int,
default=200,
help="Patient epochs to wait before early stopping.",
)
parser.add_argument("--lr", type=float, default=0.5, help="Learning rate.")
parser.add_argument("--lamb", type=float, default=5e-4, help="L2 reg.")
parser.add_argument(
"--step-size",
type=int,
default=40,
help="Period of learning rate decay.",
)
parser.add_argument(
"--gamma",
type=float,
default=0.01,
help="Multiplicative factor of learning rate decay.",
)
# model params # model params
parser.add_argument("--hid-dim", type=int, default=60, help='Hidden layer dimensionalities.') parser.add_argument(
parser.add_argument("--num-layers", type=int, default=4, help='Number of GNN layers.') "--hid-dim", type=int, default=60, help="Hidden layer dimensionalities."
parser.add_argument("--input-dropout", type=float, default=0.7, help='Dropout applied at input layer.') )
parser.add_argument("--layer-dropout", type=float, default=0.9, help='Dropout applied at hidden layers.') parser.add_argument(
parser.add_argument('--p', nargs='+', type=int, help='List of powers of adjacency matrix.') "--num-layers", type=int, default=4, help="Number of GNN layers."
)
parser.add_argument(
"--input-dropout",
type=float,
default=0.7,
help="Dropout applied at input layer.",
)
parser.add_argument(
"--layer-dropout",
type=float,
default=0.9,
help="Dropout applied at hidden layers.",
)
parser.add_argument(
"--p", nargs="+", type=int, help="List of powers of adjacency matrix."
)
parser.set_defaults(p=[0, 1, 2]) parser.set_defaults(p=[0, 1, 2])
...@@ -304,7 +372,7 @@ if __name__ == "__main__": ...@@ -304,7 +372,7 @@ if __name__ == "__main__":
mean = np.around(np.mean(acc_lists_top, axis=0), decimals=3) mean = np.around(np.mean(acc_lists_top, axis=0), decimals=3)
std = np.around(np.std(acc_lists_top, axis=0), decimals=3) std = np.around(np.std(acc_lists_top, axis=0), decimals=3)
print('Total acc: ', acc_lists) print("Total acc: ", acc_lists)
print('Top 50 acc:', acc_lists_top) print("Top 50 acc:", acc_lists_top)
print('mean', mean) print("mean", mean)
print('std', std) print("std", std)
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment