"vscode:/vscode.git/clone" did not exist on "89655cfda245911c1c834013d9befc5571565349"
Unverified Commit 0b9df9d7 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4652)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent f19f05ce
import dgl
import rdkit.Chem as Chem
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \
set_atommap, enum_assemble_nx, decode_stereo
import numpy as np
import rdkit.Chem as Chem
import dgl
from .chemutils import (
decode_stereo,
enum_assemble_nx,
get_clique_mol,
get_mol,
get_smiles,
set_atommap,
tree_decomp,
)
class DGLMolTree(object):
def __init__(self, smiles):
......@@ -38,11 +48,13 @@ class DGLMolTree(object):
# The clique with atom ID 0 becomes root
if root > 0:
for attr in self.nodes_dict[0]:
self.nodes_dict[0][attr], self.nodes_dict[root][attr] = \
self.nodes_dict[root][attr], self.nodes_dict[0][attr]
self.nodes_dict[0][attr], self.nodes_dict[root][attr] = (
self.nodes_dict[root][attr],
self.nodes_dict[0][attr],
)
src = np.zeros((len(edges) * 2,), dtype='int')
dst = np.zeros((len(edges) * 2,), dtype='int')
src = np.zeros((len(edges) * 2,), dtype="int")
dst = np.zeros((len(edges) * 2,), dtype="int")
for i, (_x, _y) in enumerate(edges):
x = 0 if _x == root else root if _x == 0 else _x
y = 0 if _y == root else root if _y == 0 else _y
......@@ -53,10 +65,12 @@ class DGLMolTree(object):
self.graph = dgl.graph((src, dst), num_nodes=len(cliques))
for i in self.nodes_dict:
self.nodes_dict[i]['nid'] = i + 1
self.nodes_dict[i]["nid"] = i + 1
if self.graph.out_degrees(i) > 1: # Leaf node mol is not marked
set_atommap(self.nodes_dict[i]['mol'], self.nodes_dict[i]['nid'])
self.nodes_dict[i]['is_leaf'] = (self.graph.out_degrees(i) == 1)
set_atommap(
self.nodes_dict[i]["mol"], self.nodes_dict[i]["nid"]
)
self.nodes_dict[i]["is_leaf"] = self.graph.out_degrees(i) == 1
def treesize(self):
return self.graph.number_of_nodes()
......@@ -65,49 +79,65 @@ class DGLMolTree(object):
node = self.nodes_dict[i]
clique = []
clique.extend(node['clique'])
if not node['is_leaf']:
for cidx in node['clique']:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid'])
clique.extend(node["clique"])
if not node["is_leaf"]:
for cidx in node["clique"]:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node["nid"])
for j in self.graph.successors(i).numpy():
nei_node = self.nodes_dict[j]
clique.extend(nei_node['clique'])
if nei_node['is_leaf']: # Leaf node, no need to mark
clique.extend(nei_node["clique"])
if nei_node["is_leaf"]: # Leaf node, no need to mark
continue
for cidx in nei_node['clique']:
for cidx in nei_node["clique"]:
# allow singleton node override the atom mapping
if cidx not in node['clique'] or len(nei_node['clique']) == 1:
if cidx not in node["clique"] or len(nei_node["clique"]) == 1:
atom = original_mol.GetAtomWithIdx(cidx)
atom.SetAtomMapNum(nei_node['nid'])
atom.SetAtomMapNum(nei_node["nid"])
clique = list(set(clique))
label_mol = get_clique_mol(original_mol, clique)
node['label'] = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
node['label_mol'] = get_mol(node['label'])
node["label"] = Chem.MolToSmiles(
Chem.MolFromSmiles(get_smiles(label_mol))
)
node["label_mol"] = get_mol(node["label"])
for cidx in clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
return node['label']
return node["label"]
def _assemble_node(self, i):
neighbors = [self.nodes_dict[j] for j in self.graph.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [self.nodes_dict[j] for j in self.graph.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() == 1]
neighbors = [
self.nodes_dict[j]
for j in self.graph.successors(i).numpy()
if self.nodes_dict[j]["mol"].GetNumAtoms() > 1
]
neighbors = sorted(
neighbors, key=lambda x: x["mol"].GetNumAtoms(), reverse=True
)
singletons = [
self.nodes_dict[j]
for j in self.graph.successors(i).numpy()
if self.nodes_dict[j]["mol"].GetNumAtoms() == 1
]
neighbors = singletons + neighbors
cands = enum_assemble_nx(self.nodes_dict[i], neighbors)
if len(cands) > 0:
self.nodes_dict[i]['cands'], self.nodes_dict[i]['cand_mols'], _ = list(zip(*cands))
self.nodes_dict[i]['cands'] = list(self.nodes_dict[i]['cands'])
self.nodes_dict[i]['cand_mols'] = list(self.nodes_dict[i]['cand_mols'])
(
self.nodes_dict[i]["cands"],
self.nodes_dict[i]["cand_mols"],
_,
) = list(zip(*cands))
self.nodes_dict[i]["cands"] = list(self.nodes_dict[i]["cands"])
self.nodes_dict[i]["cand_mols"] = list(
self.nodes_dict[i]["cand_mols"]
)
else:
self.nodes_dict[i]['cands'] = []
self.nodes_dict[i]['cand_mols'] = []
self.nodes_dict[i]["cands"] = []
self.nodes_dict[i]["cand_mols"] = []
def recover(self):
for i in self.nodes_dict:
......
import rdkit.Chem as Chem
import torch
import torch.nn as nn
import rdkit.Chem as Chem
import torch.nn.functional as F
from .chemutils import get_mol
import dgl
from dgl import mean_nodes, line_graph
import dgl.function as DGLF
from dgl import line_graph, mean_nodes
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
from .chemutils import get_mol
ELEM_LIST = [
"C",
"N",
"O",
"S",
"F",
"Si",
"P",
"Cl",
"Br",
"Mg",
"Na",
"Ca",
"Fe",
"Al",
"I",
"B",
"K",
"Se",
"Zn",
"H",
"Cu",
"Mn",
"unknown",
]
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1
BOND_FDIM = 5 + 6
MAX_NB = 6
def onek_encoding_unk(x, allowable_set):
if x not in allowable_set:
x = allowable_set[-1]
return [x == s for s in allowable_set]
def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5])
+ onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0])
+ onek_encoding_unk(int(atom.GetChiralTag()), [0,1,2,3])
+ [atom.GetIsAromatic()]))
return torch.Tensor(
onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(), [0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0])
+ onek_encoding_unk(int(atom.GetChiralTag()), [0, 1, 2, 3])
+ [atom.GetIsAromatic()]
)
def bond_features(bond):
bt = bond.GetBondType()
stereo = int(bond.GetStereo())
fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]
fstereo = onek_encoding_unk(stereo, [0,1,2,3,4,5])
return (torch.Tensor(fbond + fstereo))
fbond = [
bt == Chem.rdchem.BondType.SINGLE,
bt == Chem.rdchem.BondType.DOUBLE,
bt == Chem.rdchem.BondType.TRIPLE,
bt == Chem.rdchem.BondType.AROMATIC,
bond.IsInRing(),
]
fstereo = onek_encoding_unk(stereo, [0, 1, 2, 3, 4, 5])
return torch.Tensor(fbond + fstereo)
def mol2dgl_single(smiles):
n_edges = 0
......@@ -61,8 +98,11 @@ def mol2dgl_single(smiles):
bond_x.append(features)
graph = dgl.graph((bond_src, bond_dst), num_nodes=n_atoms)
n_edges += n_bonds
return graph, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0)
return (
graph,
torch.stack(atom_x),
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0),
)
class LoopyBPUpdate(nn.Module):
......@@ -73,10 +113,10 @@ class LoopyBPUpdate(nn.Module):
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, nodes):
msg_input = nodes.data['msg_input']
msg_delta = self.W_h(nodes.data['accum_msg'])
msg_input = nodes.data["msg_input"]
msg_delta = self.W_h(nodes.data["accum_msg"])
msg = F.relu(msg_input + msg_delta)
return {'msg': msg}
return {"msg": msg}
class GatherUpdate(nn.Module):
......@@ -87,9 +127,9 @@ class GatherUpdate(nn.Module):
self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
def forward(self, nodes):
m = nodes.data['m']
m = nodes.data["m"]
return {
'h': F.relu(self.W_o(torch.cat([nodes.data['x'], m], 1))),
"h": F.relu(self.W_o(torch.cat([nodes.data["x"], m], 1))),
}
......@@ -121,7 +161,7 @@ class DGLMPN(nn.Module):
mol_graph = self.run(mol_graph, mol_line_graph)
# TODO: replace with unbatch or readout
g_repr = mean_nodes(mol_graph, 'h')
g_repr = mean_nodes(mol_graph, "h")
self.n_samples_total += n_samples
self.n_nodes_total += n_nodes
......@@ -134,32 +174,38 @@ class DGLMPN(nn.Module):
n_nodes = mol_graph.number_of_nodes()
mol_graph.apply_edges(
func=lambda edges: {'src_x': edges.src['x']},
func=lambda edges: {"src_x": edges.src["x"]},
)
mol_line_graph.ndata.update(mol_graph.edata)
e_repr = mol_line_graph.ndata
bond_features = e_repr['x']
source_features = e_repr['src_x']
bond_features = e_repr["x"]
source_features = e_repr["src_x"]
features = torch.cat([source_features, bond_features], 1)
msg_input = self.W_i(features)
mol_line_graph.ndata.update({
'msg_input': msg_input,
'msg': F.relu(msg_input),
'accum_msg': torch.zeros_like(msg_input),
})
mol_graph.ndata.update({
'm': bond_features.new(n_nodes, self.hidden_size).zero_(),
'h': bond_features.new(n_nodes, self.hidden_size).zero_(),
})
mol_line_graph.ndata.update(
{
"msg_input": msg_input,
"msg": F.relu(msg_input),
"accum_msg": torch.zeros_like(msg_input),
}
)
mol_graph.ndata.update(
{
"m": bond_features.new(n_nodes, self.hidden_size).zero_(),
"h": bond_features.new(n_nodes, self.hidden_size).zero_(),
}
)
for i in range(self.depth - 1):
mol_line_graph.update_all(DGLF.copy_u('msg', 'msg'), DGLF.sum('msg', 'accum_msg'))
mol_line_graph.update_all(
DGLF.copy_u("msg", "msg"), DGLF.sum("msg", "accum_msg")
)
mol_line_graph.apply_nodes(self.loopy_bp_updater)
mol_graph.edata.update(mol_line_graph.ndata)
mol_graph.update_all(DGLF.copy_e('msg', 'msg'), DGLF.sum('msg', 'm'))
mol_graph.update_all(DGLF.copy_e("msg", "msg"), DGLF.sum("msg", "m"))
mol_graph.apply_nodes(self.gather_updater)
return mol_graph
import os
import torch
import torch.nn as nn
import os
import dgl
def cuda(x):
if torch.cuda.is_available() and not os.getenv('NOCUDA', None):
return x.to(torch.device('cuda')) # works for both DGLGraph and tensor
if torch.cuda.is_available() and not os.getenv("NOCUDA", None):
return x.to(torch.device("cuda")) # works for both DGLGraph and tensor
else:
return x
......@@ -22,27 +24,28 @@ class GRUUpdate(nn.Module):
self.W_h = nn.Linear(2 * hidden_size, hidden_size)
def update_zm(self, node):
src_x = node.data['src_x']
s = node.data['s']
rm = node.data['accum_rm']
src_x = node.data["src_x"]
s = node.data["s"]
rm = node.data["accum_rm"]
z = torch.sigmoid(self.W_z(torch.cat([src_x, s], 1)))
m = torch.tanh(self.W_h(torch.cat([src_x, rm], 1)))
m = (1 - z) * s + z * m
return {'m': m, 'z': z}
return {"m": m, "z": z}
def update_r(self, node, zm=None):
dst_x = node.data['dst_x']
m = node.data['m'] if zm is None else zm['m']
dst_x = node.data["dst_x"]
m = node.data["m"] if zm is None else zm["m"]
r_1 = self.W_r(dst_x)
r_2 = self.U_r(m)
r = torch.sigmoid(r_1 + r_2)
return {'r': r, 'rm': r * m}
return {"r": r, "rm": r * m}
def forward(self, node):
dic = self.update_zm(node)
dic.update(self.update_r(node, zm=dic))
return dic
def tocpu(g):
src, dst = g.edges()
src = src.cpu()
......
import math
import random
import sys
from collections import deque
from optparse import OptionParser
import rdkit
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
import math, random, sys
from optparse import OptionParser
from collections import deque
import rdkit
import tqdm
from jtnn import *
from torch.utils.data import DataLoader
torch.multiprocessing.set_sharing_strategy("file_system")
torch.multiprocessing.set_sharing_strategy('file_system')
def worker_init_fn(id_):
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
worker_init_fn(None)
parser = OptionParser()
parser.add_option("-t", "--train", dest="train", default='train', help='Training file name')
parser.add_option("-v", "--vocab", dest="vocab", default='vocab', help='Vocab file name')
parser.add_option(
"-t", "--train", dest="train", default="train", help="Training file name"
)
parser.add_option(
"-v", "--vocab", dest="vocab", default="vocab", help="Vocab file name"
)
parser.add_option("-s", "--save_dir", dest="save_path")
parser.add_option("-m", "--model", dest="model_path", default=None)
parser.add_option("-b", "--batch", dest="batch_size", default=40)
......@@ -31,7 +39,7 @@ parser.add_option("-d", "--depth", dest="depth", default=3)
parser.add_option("-z", "--beta", dest="beta", default=1.0)
parser.add_option("-q", "--lr", dest="lr", default=1e-3)
parser.add_option("-T", "--test", dest="test", action="store_true")
opts,args = parser.parse_args()
opts, args = parser.parse_args()
dataset = JTNNDataset(data=opts.train, vocab=opts.vocab, training=True)
vocab = dataset.vocab
......@@ -55,7 +63,10 @@ else:
nn.init.xavier_normal(param)
model = cuda(model)
print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,))
print(
"Model #Params: %dK"
% (sum([x.nelement() for x in model.parameters()]) / 1000,)
)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)
......@@ -64,6 +75,7 @@ scheduler.step()
MAX_EPOCH = 100
PRINT_ITER = 20
def train():
dataset.training = True
dataloader = DataLoader(
......@@ -73,17 +85,18 @@ def train():
num_workers=4,
collate_fn=JTNNCollator(vocab, True),
drop_last=True,
worker_init_fn=worker_init_fn)
worker_init_fn=worker_init_fn,
)
for epoch in range(MAX_EPOCH):
word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
for it, batch in enumerate(tqdm.tqdm(dataloader)):
model.zero_grad()
try:
loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta)
except:
print([t.smiles for t in batch['mol_trees']])
print([t.smiles for t in batch["mol_trees"]])
raise
loss.backward()
optimizer.step()
......@@ -99,20 +112,34 @@ def train():
assm_acc = assm_acc / PRINT_ITER * 100
steo_acc = steo_acc / PRINT_ITER * 100
print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f" % (
kl_div, word_acc, topo_acc, assm_acc, steo_acc, loss.item()))
word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
print(
"KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f"
% (
kl_div,
word_acc,
topo_acc,
assm_acc,
steo_acc,
loss.item(),
)
)
word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
sys.stdout.flush()
if (it + 1) % 1500 == 0: #Fast annealing
if (it + 1) % 1500 == 0: # Fast annealing
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(),
opts.save_path + "/model.iter-%d-%d" % (epoch, it + 1))
torch.save(
model.state_dict(),
opts.save_path + "/model.iter-%d-%d" % (epoch, it + 1),
)
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(), opts.save_path + "/model.iter-" + str(epoch))
torch.save(
model.state_dict(), opts.save_path + "/model.iter-" + str(epoch)
)
def test():
dataset.training = False
......@@ -123,12 +150,13 @@ def test():
num_workers=0,
collate_fn=JTNNCollator(vocab, False),
drop_last=True,
worker_init_fn=worker_init_fn)
worker_init_fn=worker_init_fn,
)
# Just an example of molecule decoding; in reality you may want to sample
# tree and molecule vectors.
for it, batch in enumerate(dataloader):
gt_smiles = batch['mol_trees'][0].smiles
gt_smiles = batch["mol_trees"][0].smiles
print(gt_smiles)
model.move_to_cuda(batch)
_, tree_vec, mol_vec = model.encode(batch)
......@@ -136,21 +164,28 @@ def test():
smiles = model.decode(tree_vec, mol_vec)
print(smiles)
if __name__ == '__main__':
if __name__ == "__main__":
if opts.test:
test()
else:
train()
print('# passes:', model.n_passes)
print('Total # nodes processed:', model.n_nodes_total)
print('Total # edges processed:', model.n_edges_total)
print('Total # tree nodes processed:', model.n_tree_nodes_total)
print('Graph decoder: # passes:', model.jtmpn.n_passes)
print('Graph decoder: Total # candidates processed:', model.jtmpn.n_samples_total)
print('Graph decoder: Total # nodes processed:', model.jtmpn.n_nodes_total)
print('Graph decoder: Total # edges processed:', model.jtmpn.n_edges_total)
print('Graph encoder: # passes:', model.mpn.n_passes)
print('Graph encoder: Total # candidates processed:', model.mpn.n_samples_total)
print('Graph encoder: Total # nodes processed:', model.mpn.n_nodes_total)
print('Graph encoder: Total # edges processed:', model.mpn.n_edges_total)
print("# passes:", model.n_passes)
print("Total # nodes processed:", model.n_nodes_total)
print("Total # edges processed:", model.n_edges_total)
print("Total # tree nodes processed:", model.n_tree_nodes_total)
print("Graph decoder: # passes:", model.jtmpn.n_passes)
print(
"Graph decoder: Total # candidates processed:",
model.jtmpn.n_samples_total,
)
print("Graph decoder: Total # nodes processed:", model.jtmpn.n_nodes_total)
print("Graph decoder: Total # edges processed:", model.jtmpn.n_edges_total)
print("Graph encoder: # passes:", model.mpn.n_passes)
print(
"Graph encoder: Total # candidates processed:",
model.mpn.n_samples_total,
)
print("Graph encoder: Total # nodes processed:", model.mpn.n_nodes_total)
print("Graph encoder: Total # edges processed:", model.mpn.n_edges_total)
import argparse
import torch
import dgl
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from dgl.nn import LabelPropagation
def main():
# check cuda
device = f'cuda:{args.gpu}' if torch.cuda.is_available() and args.gpu >= 0 else 'cpu'
device = (
f"cuda:{args.gpu}"
if torch.cuda.is_available() and args.gpu >= 0
else "cpu"
)
# load data
if args.dataset == 'Cora':
if args.dataset == "Cora":
dataset = CoraGraphDataset()
elif args.dataset == 'Citeseer':
elif args.dataset == "Citeseer":
dataset = CiteseerGraphDataset()
elif args.dataset == 'Pubmed':
elif args.dataset == "Pubmed":
dataset = PubmedGraphDataset()
else:
raise ValueError('Dataset {} is invalid.'.format(args.dataset))
raise ValueError("Dataset {} is invalid.".format(args.dataset))
g = dataset[0]
g = dgl.add_self_loop(g)
labels = g.ndata.pop('label').to(device).long()
labels = g.ndata.pop("label").to(device).long()
# load masks for train / test, valid is not used.
train_mask = g.ndata.pop('train_mask')
test_mask = g.ndata.pop('test_mask')
train_mask = g.ndata.pop("train_mask")
test_mask = g.ndata.pop("test_mask")
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze().to(device)
......@@ -37,19 +43,21 @@ def main():
lp = LabelPropagation(args.num_layers, args.alpha)
logits = lp(g, labels, mask=train_idx)
test_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
test_acc = torch.sum(
logits[test_idx].argmax(dim=1) == labels[test_idx]
).item() / len(test_idx)
print("Test Acc {:.4f}".format(test_acc))
if __name__ == '__main__':
if __name__ == "__main__":
"""
Label Propagation Hyperparameters
"""
parser = argparse.ArgumentParser(description='LP')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--dataset', type=str, default='Cora')
parser.add_argument('--num-layers', type=int, default=10)
parser.add_argument('--alpha', type=float, default=0.5)
parser = argparse.ArgumentParser(description="LP")
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--dataset", type=str, default="Cora")
parser.add_argument("--num-layers", type=int, default=10)
parser.add_argument("--alpha", type=float, default=0.5)
args = parser.parse_args()
print(args)
......
......@@ -17,49 +17,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from time import time
import matplotlib.pyplot as plt
import warnings
from time import time
import matplotlib.pyplot as plt
import numpy as np
import scipy.sparse as ss
import torch
import dgl
from dgl import function as fn
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.decomposition import NMF, LatentDirichletAllocation
from lda_model import LatentDirichletAllocation as LDAModel
from sklearn.datasets import fetch_20newsgroups
from sklearn.decomposition import NMF, LatentDirichletAllocation
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from lda_model import LatentDirichletAllocation as LDAModel
import dgl
from dgl import function as fn
n_samples = 2000
n_features = 1000
n_components = 10
n_top_words = 20
device = 'cuda'
device = "cuda"
def plot_top_words(model, feature_names, n_top_words, title):
fig, axes = plt.subplots(2, 5, figsize=(30, 15), sharex=True)
axes = axes.flatten()
for topic_idx, topic in enumerate(model.components_):
top_features_ind = topic.argsort()[:-n_top_words - 1:-1]
top_features_ind = topic.argsort()[: -n_top_words - 1 : -1]
top_features = [feature_names[i] for i in top_features_ind]
weights = topic[top_features_ind]
ax = axes[topic_idx]
ax.barh(top_features, weights, height=0.7)
ax.set_title(f'Topic {topic_idx +1}',
fontdict={'fontsize': 30})
ax.set_title(f"Topic {topic_idx +1}", fontdict={"fontsize": 30})
ax.invert_yaxis()
ax.tick_params(axis='both', which='major', labelsize=20)
for i in 'top right left'.split():
ax.tick_params(axis="both", which="major", labelsize=20)
for i in "top right left".split():
ax.spines[i].set_visible(False)
fig.suptitle(title, fontsize=40)
plt.subplots_adjust(top=0.90, bottom=0.05, wspace=0.90, hspace=0.3)
plt.show()
# Load the 20 newsgroups dataset and vectorize it. We use a few heuristics
# to filter out useless terms early on: the posts are stripped of headers,
# footers and quoted replies, and common English words, words occurring in
......@@ -67,43 +67,50 @@ def plot_top_words(model, feature_names, n_top_words, title):
print("Loading dataset...")
t0 = time()
data, _ = fetch_20newsgroups(shuffle=True, random_state=1,
remove=('headers', 'footers', 'quotes'),
return_X_y=True)
data, _ = fetch_20newsgroups(
shuffle=True,
random_state=1,
remove=("headers", "footers", "quotes"),
return_X_y=True,
)
data_samples = data[:n_samples]
data_test = data[n_samples:2*n_samples]
data_test = data[n_samples : 2 * n_samples]
print("done in %0.3fs." % (time() - t0))
# Use tf (raw term count) features for LDA.
print("Extracting tf features for LDA...")
tf_vectorizer = CountVectorizer(max_df=0.95, min_df=2,
max_features=n_features,
stop_words='english')
tf_vectorizer = CountVectorizer(
max_df=0.95, min_df=2, max_features=n_features, stop_words="english"
)
t0 = time()
tf_vectorizer.fit(data)
tf = tf_vectorizer.transform(data_samples)
tt = tf_vectorizer.transform(data_test)
tf_feature_names = tf_vectorizer.get_feature_names()
tf_uv = [(u,v)
for u,v,e in zip(tf.tocoo().row, tf.tocoo().col, tf.tocoo().data)
for _ in range(e)]
tt_uv = [(u,v)
for u,v,e in zip(tt.tocoo().row, tt.tocoo().col, tt.tocoo().data)
for _ in range(e)]
tf_uv = [
(u, v)
for u, v, e in zip(tf.tocoo().row, tf.tocoo().col, tf.tocoo().data)
for _ in range(e)
]
tt_uv = [
(u, v)
for u, v, e in zip(tt.tocoo().row, tt.tocoo().col, tt.tocoo().data)
for _ in range(e)
]
print("done in %0.3fs." % (time() - t0))
print()
print("Preparing dgl graphs...")
t0 = time()
G = dgl.heterograph({('doc','topic','word'): tf_uv}, device=device)
Gt = dgl.heterograph({('doc','topic','word'): tt_uv}, device=device)
G = dgl.heterograph({("doc", "topic", "word"): tf_uv}, device=device)
Gt = dgl.heterograph({("doc", "topic", "word"): tt_uv}, device=device)
print("done in %0.3fs." % (time() - t0))
print()
print("Training dgl-lda model...")
t0 = time()
model = LDAModel(G.num_nodes('word'), n_components)
model = LDAModel(G.num_nodes("word"), n_components)
model.fit(G)
print("done in %0.3fs." % (time() - t0))
print()
......@@ -113,20 +120,27 @@ print(f"dgl-lda testing perplexity {model.perplexity(Gt):.3f}")
word_nphi = np.vstack([nphi.tolist() for nphi in model.word_data.nphi])
plot_top_words(
type('dummy', (object,), {'components_': word_nphi}),
tf_feature_names, n_top_words, 'Topics in LDA model')
type("dummy", (object,), {"components_": word_nphi}),
tf_feature_names,
n_top_words,
"Topics in LDA model",
)
print("Training scikit-learn model...")
print('\n' * 2, "Fitting LDA models with tf features, "
"n_samples=%d and n_features=%d..."
% (n_samples, n_features))
lda = LatentDirichletAllocation(n_components=n_components, max_iter=5,
learning_method='online',
learning_offset=50.,
print(
"\n" * 2,
"Fitting LDA models with tf features, "
"n_samples=%d and n_features=%d..." % (n_samples, n_features),
)
lda = LatentDirichletAllocation(
n_components=n_components,
max_iter=5,
learning_method="online",
learning_offset=50.0,
random_state=0,
verbose=1,
)
)
t0 = time()
lda.fit(tf)
print("done in %0.3fs." % (time() - t0))
......
......@@ -17,8 +17,17 @@
# limitations under the License.
import os, functools, warnings, torch, collections, dgl, io
import numpy as np, scipy as sp
import collections
import functools
import io
import os
import warnings
import numpy as np
import scipy as sp
import torch
import dgl
try:
from functools import cached_property
......@@ -37,17 +46,21 @@ class EdgeData:
@property
def loglike(self):
return (self.src_data['Elog'] + self.dst_data['Elog']).logsumexp(1)
return (self.src_data["Elog"] + self.dst_data["Elog"]).logsumexp(1)
@property
def phi(self):
return (
self.src_data['Elog'] + self.dst_data['Elog'] - self.loglike.unsqueeze(1)
self.src_data["Elog"]
+ self.dst_data["Elog"]
- self.loglike.unsqueeze(1)
).exp()
@property
def expectation(self):
return (self.src_data['expectation'] * self.dst_data['expectation']).sum(1)
return (
self.src_data["expectation"] * self.dst_data["expectation"]
).sum(1)
class _Dirichlet:
......@@ -55,10 +68,13 @@ class _Dirichlet:
self.prior = prior
self.nphi = nphi
self.device = nphi.device
self._sum_by_parts = lambda map_fn: functools.reduce(torch.add, [
map_fn(slice(i, min(i+_chunksize, nphi.shape[1]))).sum(1)
self._sum_by_parts = lambda map_fn: functools.reduce(
torch.add,
[
map_fn(slice(i, min(i + _chunksize, nphi.shape[1]))).sum(1)
for i in list(range(0, nphi.shape[1], _chunksize))
])
],
)
def _posterior(self, _ID=slice(None)):
return self.prior + self.nphi[:, _ID]
......@@ -68,8 +84,9 @@ class _Dirichlet:
return self.nphi.sum(1) + self.prior * self.nphi.shape[1]
def _Elog(self, _ID=slice(None)):
return torch.digamma(self._posterior(_ID)) - \
torch.digamma(self.posterior_sum.unsqueeze(1))
return torch.digamma(self._posterior(_ID)) - torch.digamma(
self.posterior_sum.unsqueeze(1)
)
@cached_property
def loglike(self):
......@@ -105,9 +122,15 @@ class _Dirichlet:
@cached_property
def Bayesian_gap(self):
return 1. - self._sum_by_parts(lambda s: self._Elog(s).exp())
_cached_properties = ["posterior_sum", "loglike", "n", "cdf", "Bayesian_gap"]
return 1.0 - self._sum_by_parts(lambda s: self._Elog(s).exp())
_cached_properties = [
"posterior_sum",
"loglike",
"n",
"cdf",
"Bayesian_gap",
]
def clear_cache(self):
for name in self._cached_properties:
......@@ -117,27 +140,29 @@ class _Dirichlet:
pass
def update(self, new, _ID=slice(None), rho=1):
""" inplace: old * (1-rho) + new * rho """
"""inplace: old * (1-rho) + new * rho"""
self.clear_cache()
mean_change = (self.nphi[:, _ID] - new).abs().mean().tolist()
self.nphi *= (1 - rho)
self.nphi *= 1 - rho
self.nphi[:, _ID] += new * rho
return mean_change
class DocData(_Dirichlet):
""" nphi (n_docs by n_topics) """
"""nphi (n_docs by n_topics)"""
def prepare_graph(self, G, key="Elog"):
G.nodes['doc'].data[key] = getattr(self, '_'+key)().to(G.device)
G.nodes["doc"].data[key] = getattr(self, "_" + key)().to(G.device)
def update_from(self, G, mult):
new = G.nodes['doc'].data['nphi'] * mult
new = G.nodes["doc"].data["nphi"] * mult
return self.update(new.to(self.device))
class _Distributed(collections.UserList):
""" split on dim=0 and store on multiple devices """
"""split on dim=0 and store on multiple devices"""
def __init__(self, prior, nphi):
self.prior = prior
self.nphi = nphi
......@@ -146,36 +171,38 @@ class _Distributed(collections.UserList):
def split_device(self, other, dim=0):
split_sections = [x.shape[0] for x in self.nphi]
out = torch.split(other, split_sections, dim)
return [y.to(x.device) for x,y in zip(self.nphi, out)]
return [y.to(x.device) for x, y in zip(self.nphi, out)]
class WordData(_Distributed):
""" distributed nphi (n_topics by n_words), transpose to/from graph nodes data """
"""distributed nphi (n_topics by n_words), transpose to/from graph nodes data"""
def prepare_graph(self, G, key="Elog"):
if '_ID' in G.nodes['word'].data:
_ID = G.nodes['word'].data['_ID']
if "_ID" in G.nodes["word"].data:
_ID = G.nodes["word"].data["_ID"]
else:
_ID = slice(None)
out = [getattr(part, '_'+key)(_ID).to(G.device) for part in self]
G.nodes['word'].data[key] = torch.cat(out).T
out = [getattr(part, "_" + key)(_ID).to(G.device) for part in self]
G.nodes["word"].data[key] = torch.cat(out).T
def update_from(self, G, mult, rho):
nphi = G.nodes['word'].data['nphi'].T * mult
nphi = G.nodes["word"].data["nphi"].T * mult
if '_ID' in G.nodes['word'].data:
_ID = G.nodes['word'].data['_ID']
if "_ID" in G.nodes["word"].data:
_ID = G.nodes["word"].data["_ID"]
else:
_ID = slice(None)
mean_change = [x.update(y, _ID, rho)
for x, y in zip(self, self.split_device(nphi))]
mean_change = [
x.update(y, _ID, rho) for x, y in zip(self, self.split_device(nphi))
]
return np.mean(mean_change)
class Gamma(collections.namedtuple('Gamma', "concentration, rate")):
""" articulate the difference between torch gamma and numpy gamma """
class Gamma(collections.namedtuple("Gamma", "concentration, rate")):
"""articulate the difference between torch gamma and numpy gamma"""
@property
def shape(self):
return self.concentration
......@@ -218,20 +245,23 @@ class LatentDirichletAllocation:
(NIPS 2010).
[2] Reactive LDA Library blogpost by Yingjie Miao for a similar Gibbs model
"""
def __init__(
self, n_words, n_components,
self,
n_words,
n_components,
prior=None,
rho=1,
mult={'doc': 1, 'word': 1},
init={'doc': (100., 100.), 'word': (100., 100.)},
device_list=['cpu'],
mult={"doc": 1, "word": 1},
init={"doc": (100.0, 100.0), "word": (100.0, 100.0)},
device_list=["cpu"],
verbose=True,
):
self.n_words = n_words
self.n_components = n_components
if prior is None:
prior = {'doc': 1./n_components, 'word': 1./n_components}
prior = {"doc": 1.0 / n_components, "word": 1.0 / n_components}
self.prior = prior
self.rho = rho
......@@ -244,46 +274,46 @@ class LatentDirichletAllocation:
self._init_word_data()
def _init_word_data(self):
split_sections = np.diff(
np.linspace(0, self.n_components, len(self.device_list)+1).astype(int)
np.linspace(0, self.n_components, len(self.device_list) + 1).astype(
int
)
)
word_nphi = [
Gamma(*self.init['word']).sample((s, self.n_words), device)
Gamma(*self.init["word"]).sample((s, self.n_words), device)
for s, device in zip(split_sections, self.device_list)
]
self.word_data = WordData(self.prior['word'], word_nphi)
self.word_data = WordData(self.prior["word"], word_nphi)
def _init_doc_data(self, n_docs, device):
doc_nphi = Gamma(*self.init['doc']).sample(
(n_docs, self.n_components), device)
return DocData(self.prior['doc'], doc_nphi)
doc_nphi = Gamma(*self.init["doc"]).sample(
(n_docs, self.n_components), device
)
return DocData(self.prior["doc"], doc_nphi)
def save(self, f):
for w in self.word_data:
w.clear_cache()
torch.save({
'prior': self.prior,
'rho': self.rho,
'mult': self.mult,
'init': self.init,
'word_data': [part.nphi for part in self.word_data],
}, f)
torch.save(
{
"prior": self.prior,
"rho": self.rho,
"mult": self.mult,
"init": self.init,
"word_data": [part.nphi for part in self.word_data],
},
f,
)
def _prepare_graph(self, G, doc_data, key="Elog"):
doc_data.prepare_graph(G, key)
self.word_data.prepare_graph(G, key)
def _e_step(self, G, doc_data=None, mean_change_tol=1e-3, max_iters=100):
"""_e_step implements doc data sampling until convergence or max_iters
"""
"""_e_step implements doc data sampling until convergence or max_iters"""
if doc_data is None:
doc_data = self._init_doc_data(G.num_nodes('doc'), G.device)
doc_data = self._init_doc_data(G.num_nodes("doc"), G.device)
G_rev = G.reverse() # word -> doc
self.word_data.prepare_graph(G_rev)
......@@ -291,65 +321,76 @@ class LatentDirichletAllocation:
for i in range(max_iters):
doc_data.prepare_graph(G_rev)
G_rev.update_all(
lambda edges: {'phi': EdgeData(edges.src, edges.dst).phi},
dgl.function.sum('phi', 'nphi')
lambda edges: {"phi": EdgeData(edges.src, edges.dst).phi},
dgl.function.sum("phi", "nphi"),
)
mean_change = doc_data.update_from(G_rev, self.mult['doc'])
mean_change = doc_data.update_from(G_rev, self.mult["doc"])
if mean_change < mean_change_tol:
break
if self.verbose:
print(f"e-step num_iters={i+1} with mean_change={mean_change:.4f}, "
f"perplexity={self.perplexity(G, doc_data):.4f}")
print(
f"e-step num_iters={i+1} with mean_change={mean_change:.4f}, "
f"perplexity={self.perplexity(G, doc_data):.4f}"
)
return doc_data
transform = _e_step
def predict(self, doc_data):
pred_scores = [
# d_exp @ w._expectation()
(lambda x: x @ w.nphi + x.sum(1, keepdims=True) * w.prior)
(d_exp / w.posterior_sum.unsqueeze(0))
(lambda x: x @ w.nphi + x.sum(1, keepdims=True) * w.prior)(
d_exp / w.posterior_sum.unsqueeze(0)
)
for (d_exp, w) in zip(
self.word_data.split_device(doc_data._expectation(), dim=1),
self.word_data)
self.word_data,
)
]
x = torch.zeros_like(pred_scores[0], device=doc_data.device)
for p in pred_scores:
x += p.to(x.device)
return x
def sample(self, doc_data, num_samples):
""" draw independent words and return the marginal probabilities,
"""draw independent words and return the marginal probabilities,
i.e., the expectations in Dirichlet distributions.
"""
def fn(cdf):
u = torch.rand(cdf.shape[0], num_samples, device=cdf.device)
return torch.searchsorted(cdf, u).to(doc_data.device)
topic_ids = fn(doc_data.cdf)
word_ids = torch.cat([fn(part.cdf) for part in self.word_data])
ids = torch.gather(word_ids, 0, topic_ids) # pick components by topic_ids
ids = torch.gather(
word_ids, 0, topic_ids
) # pick components by topic_ids
# compute expectation scores on sampled ids
src_ids = torch.arange(
ids.shape[0], dtype=ids.dtype, device=ids.device
).reshape((-1, 1)).expand(ids.shape)
unique_ids, inverse_ids = torch.unique(ids, sorted=False, return_inverse=True)
src_ids = (
torch.arange(ids.shape[0], dtype=ids.dtype, device=ids.device)
.reshape((-1, 1))
.expand(ids.shape)
)
unique_ids, inverse_ids = torch.unique(
ids, sorted=False, return_inverse=True
)
G = dgl.heterograph({('doc','','word'): (src_ids.ravel(), inverse_ids.ravel())})
G.nodes['word'].data['_ID'] = unique_ids
G = dgl.heterograph(
{("doc", "", "word"): (src_ids.ravel(), inverse_ids.ravel())}
)
G.nodes["word"].data["_ID"] = unique_ids
self._prepare_graph(G, doc_data, "expectation")
G.apply_edges(lambda e: {'expectation': EdgeData(e.src, e.dst).expectation})
expectation = G.edata.pop('expectation').reshape(ids.shape)
G.apply_edges(
lambda e: {"expectation": EdgeData(e.src, e.dst).expectation}
)
expectation = G.edata.pop("expectation").reshape(ids.shape)
return ids, expectation
def _m_step(self, G, doc_data):
"""_m_step implements word data sampling and stores word_z stats.
mean_change is in the sense of full graph with rho=1.
......@@ -357,26 +398,25 @@ class LatentDirichletAllocation:
G = G.clone()
self._prepare_graph(G, doc_data)
G.update_all(
lambda edges: {'phi': EdgeData(edges.src, edges.dst).phi},
dgl.function.sum('phi', 'nphi')
lambda edges: {"phi": EdgeData(edges.src, edges.dst).phi},
dgl.function.sum("phi", "nphi"),
)
self._last_mean_change = self.word_data.update_from(
G, self.mult['word'], self.rho)
G, self.mult["word"], self.rho
)
if self.verbose:
print(f"m-step mean_change={self._last_mean_change:.4f}, ", end="")
Bayesian_gap = np.mean([
part.Bayesian_gap.mean().tolist() for part in self.word_data
])
Bayesian_gap = np.mean(
[part.Bayesian_gap.mean().tolist() for part in self.word_data]
)
print(f"Bayesian_gap={Bayesian_gap:.4f}")
def partial_fit(self, G):
doc_data = self._e_step(G)
self._m_step(G, doc_data)
return self
def fit(self, G, mean_change_tol=1e-3, max_epochs=10):
for i in range(max_epochs):
if self.verbose:
......@@ -387,7 +427,6 @@ class LatentDirichletAllocation:
break
return self
def perplexity(self, G, doc_data=None):
"""ppl = exp{-sum[log(p(w1,...,wn|d))] / n}
Follows Eq (15) in Hoffman et al., 2010.
......@@ -398,45 +437,50 @@ class LatentDirichletAllocation:
# compute E[log p(docs | theta, beta)]
G = G.clone()
self._prepare_graph(G, doc_data)
G.apply_edges(lambda edges: {'loglike': EdgeData(edges.src, edges.dst).loglike})
edge_elbo = (G.edata['loglike'].sum() / G.num_edges()).tolist()
G.apply_edges(
lambda edges: {"loglike": EdgeData(edges.src, edges.dst).loglike}
)
edge_elbo = (G.edata["loglike"].sum() / G.num_edges()).tolist()
if self.verbose:
print(f'neg_elbo phi: {-edge_elbo:.3f}', end=' ')
print(f"neg_elbo phi: {-edge_elbo:.3f}", end=" ")
# compute E[log p(theta | alpha) - log q(theta | gamma)]
doc_elbo = (doc_data.loglike.sum() / doc_data.n.sum()).tolist()
if self.verbose:
print(f'theta: {-doc_elbo:.3f}', end=' ')
print(f"theta: {-doc_elbo:.3f}", end=" ")
# compute E[log p(beta | eta) - log q(beta | lambda)]
# The denominator n for extrapolation perplexity is undefined.
# We use the train set, whereas sklearn uses the test set.
word_elbo = (
sum([part.loglike.sum().tolist() for part in self.word_data])
/ sum([part.n.sum().tolist() for part in self.word_data])
)
word_elbo = sum(
[part.loglike.sum().tolist() for part in self.word_data]
) / sum([part.n.sum().tolist() for part in self.word_data])
if self.verbose:
print(f'beta: {-word_elbo:.3f}')
print(f"beta: {-word_elbo:.3f}")
ppl = np.exp(-edge_elbo - doc_elbo - word_elbo)
if G.num_edges()>0 and np.isnan(ppl):
if G.num_edges() > 0 and np.isnan(ppl):
warnings.warn("numerical issue in perplexity")
return ppl
def doc_subgraph(G, doc_ids):
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
_, _, (block,) = sampler.sample(G.reverse(), {'doc': torch.as_tensor(doc_ids)})
_, _, (block,) = sampler.sample(
G.reverse(), {"doc": torch.as_tensor(doc_ids)}
)
B = dgl.DGLHeteroGraph(
block._graph, ['_', 'word', 'doc', '_'], block.etypes
block._graph, ["_", "word", "doc", "_"], block.etypes
).reverse()
B.nodes['word'].data['_ID'] = block.nodes['word'].data['_ID']
B.nodes["word"].data["_ID"] = block.nodes["word"].data["_ID"]
return B
if __name__ == '__main__':
print('Testing LatentDirichletAllocation ...')
G = dgl.heterograph({('doc', '', 'word'): [(0, 0), (1, 3)]}, {'doc': 2, 'word': 5})
if __name__ == "__main__":
print("Testing LatentDirichletAllocation ...")
G = dgl.heterograph(
{("doc", "", "word"): [(0, 0), (1, 3)]}, {"doc": 2, "word": 5}
)
model = LatentDirichletAllocation(n_words=5, n_components=10, verbose=False)
model.fit(G)
model.transform(G)
......@@ -454,4 +498,4 @@ if __name__ == '__main__':
f.seek(0)
print(torch.load(f))
print('Testing LatentDirichletAllocation passed!')
print("Testing LatentDirichletAllocation passed!")
......@@ -6,11 +6,12 @@ Author's implementation: https://github.com/joanbruna/GNN_community
"""
from __future__ import division
import time
import argparse
import time
from itertools import permutations
import gnn
import numpy as np
import torch as th
import torch.nn.functional as F
......@@ -18,37 +19,51 @@ import torch.optim as optim
from torch.utils.data import DataLoader
from dgl.data import SBMMixtureDataset
import gnn
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int, help='Batch size', default=1)
parser.add_argument('--gpu', type=int, help='GPU index', default=-1)
parser.add_argument('--lr', type=float, help='Learning rate', default=0.001)
parser.add_argument('--n-communities', type=int, help='Number of communities', default=2)
parser.add_argument('--n-epochs', type=int, help='Number of epochs', default=100)
parser.add_argument('--n-features', type=int, help='Number of features', default=16)
parser.add_argument('--n-graphs', type=int, help='Number of graphs', default=10)
parser.add_argument('--n-layers', type=int, help='Number of layers', default=30)
parser.add_argument('--n-nodes', type=int, help='Number of nodes', default=10000)
parser.add_argument('--optim', type=str, help='Optimizer', default='Adam')
parser.add_argument('--radius', type=int, help='Radius', default=3)
parser.add_argument('--verbose', action='store_true')
parser.add_argument("--batch-size", type=int, help="Batch size", default=1)
parser.add_argument("--gpu", type=int, help="GPU index", default=-1)
parser.add_argument("--lr", type=float, help="Learning rate", default=0.001)
parser.add_argument(
"--n-communities", type=int, help="Number of communities", default=2
)
parser.add_argument(
"--n-epochs", type=int, help="Number of epochs", default=100
)
parser.add_argument(
"--n-features", type=int, help="Number of features", default=16
)
parser.add_argument("--n-graphs", type=int, help="Number of graphs", default=10)
parser.add_argument("--n-layers", type=int, help="Number of layers", default=30)
parser.add_argument(
"--n-nodes", type=int, help="Number of nodes", default=10000
)
parser.add_argument("--optim", type=str, help="Optimizer", default="Adam")
parser.add_argument("--radius", type=int, help="Radius", default=3)
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()
dev = th.device('cpu') if args.gpu < 0 else th.device('cuda:%d' % args.gpu)
dev = th.device("cpu") if args.gpu < 0 else th.device("cuda:%d" % args.gpu)
K = args.n_communities
training_dataset = SBMMixtureDataset(args.n_graphs, args.n_nodes, K)
training_loader = DataLoader(training_dataset, args.batch_size,
collate_fn=training_dataset.collate_fn, drop_last=True)
training_loader = DataLoader(
training_dataset,
args.batch_size,
collate_fn=training_dataset.collate_fn,
drop_last=True,
)
ones = th.ones(args.n_nodes // K)
y_list = [th.cat([x * ones for x in p]).long().to(dev) for p in permutations(range(K))]
y_list = [
th.cat([x * ones for x in p]).long().to(dev) for p in permutations(range(K))
]
feats = [1] + [args.n_features] * args.n_layers + [K]
model = gnn.GNN(feats, args.radius, K).to(dev)
optimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr)
def compute_overlap(z_list):
ybar_list = [th.max(z, 1)[1] for z in z_list]
overlap_list = []
......@@ -58,15 +73,20 @@ def compute_overlap(z_list):
overlap_list.append(overlap)
return sum(overlap_list) / len(overlap_list)
def from_np(f, *args):
def wrap(*args):
new = [th.from_numpy(x) if isinstance(x, np.ndarray) else x for x in args]
new = [
th.from_numpy(x) if isinstance(x, np.ndarray) else x for x in args
]
return f(*new)
return wrap
@from_np
def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
""" One step of training. """
"""One step of training."""
g = g.to(dev)
lg = lg.to(dev)
deg_g = deg_g.to(dev).unsqueeze(1)
......@@ -77,7 +97,10 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
t_forward = time.time() - t0
z_list = th.chunk(z, args.batch_size, 0)
loss = sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list) / args.batch_size
loss = (
sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list)
/ args.batch_size
)
overlap = compute_overlap(z_list)
optimizer.zero_grad()
......@@ -88,6 +111,7 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
return loss, overlap, t_forward, t_backward
@from_np
def inference(g, lg, deg_g, deg_lg, pm_pd):
g = g.to(dev)
......@@ -99,9 +123,11 @@ def inference(g, lg, deg_g, deg_lg, pm_pd):
z = model(g, lg, deg_g, deg_lg, pm_pd)
return z
def test():
p_list =[6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0]
q_list =[0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6]
p_list = [6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0]
q_list = [0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6]
N = 1
overlap_list = []
for p, q in zip(p_list, q_list):
......@@ -112,31 +138,38 @@ def test():
overlap_list.append(compute_overlap(th.chunk(z, N, 0)))
return overlap_list
n_iterations = args.n_graphs // args.batch_size
for i in range(args.n_epochs):
total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0
for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader):
loss, overlap, t_forward, t_backward = step(i, j, g, lg, deg_g, deg_lg, pm_pd)
loss, overlap, t_forward, t_backward = step(
i, j, g, lg, deg_g, deg_lg, pm_pd
)
total_loss += loss
total_overlap += overlap
s_forward += t_forward
s_backward += t_backward
epoch = '0' * (len(str(args.n_epochs)) - len(str(i)))
iteration = '0' * (len(str(n_iterations)) - len(str(j)))
epoch = "0" * (len(str(args.n_epochs)) - len(str(i)))
iteration = "0" * (len(str(n_iterations)) - len(str(j)))
if args.verbose:
print('[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f'
% (epoch, i, iteration, j, loss, overlap))
print(
"[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f"
% (epoch, i, iteration, j, loss, overlap)
)
epoch = '0' * (len(str(args.n_epochs)) - len(str(i)))
epoch = "0" * (len(str(args.n_epochs)) - len(str(i)))
loss = total_loss / (j + 1)
overlap = total_overlap / (j + 1)
t_forward = s_forward / (j + 1)
t_backward = s_backward / (j + 1)
print('[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs'
% (epoch, i, loss, overlap, t_forward, t_backward))
print(
"[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs"
% (epoch, i, loss, overlap, t_forward, t_backward)
)
overlap_list = test()
overlap_str = ' - '.join(['%.3f' % overlap for overlap in overlap_list])
print('[epoch %s%d]overlap: %s' % (epoch, i, overlap_str))
overlap_str = " - ".join(["%.3f" % overlap for overlap in overlap_list])
print("[epoch %s%d]overlap: %s" % (epoch, i, overlap_str))
......@@ -2,16 +2,18 @@
import argparse
import copy
import random
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
import random
import torch.optim as optim
from tqdm import trange
import dgl
import dgl.function as fn
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from tqdm import trange
class MixHopConv(nn.Module):
r"""
......@@ -44,13 +46,16 @@ class MixHopConv(nn.Module):
batchnorm: bool, optional
If True, use batch normalization. Defaults: ``False``.
"""
def __init__(self,
def __init__(
self,
in_dim,
out_dim,
p=[0, 1, 2],
dropout=0,
activation=None,
batchnorm=False):
batchnorm=False,
):
super(MixHopConv, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
......@@ -66,9 +71,9 @@ class MixHopConv(nn.Module):
self.bn = nn.BatchNorm1d(out_dim * len(p))
# define weight dict for each power j
self.weights = nn.ModuleDict({
str(j): nn.Linear(in_dim, out_dim, bias=False) for j in p
})
self.weights = nn.ModuleDict(
{str(j): nn.Linear(in_dim, out_dim, bias=False) for j in p}
)
def forward(self, graph, feats):
with graph.local_scope():
......@@ -84,9 +89,9 @@ class MixHopConv(nn.Module):
outputs.append(output)
feats = feats * norm
graph.ndata['h'] = feats
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
feats = graph.ndata.pop('h')
graph.ndata["h"] = feats
graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
feats = graph.ndata.pop("h")
feats = feats * norm
final = torch.cat(outputs, dim=1)
......@@ -101,8 +106,10 @@ class MixHopConv(nn.Module):
return final
class MixHop(nn.Module):
def __init__(self,
def __init__(
self,
in_dim,
hid_dim,
out_dim,
......@@ -111,7 +118,8 @@ class MixHop(nn.Module):
input_dropout=0.0,
layer_dropout=0.0,
activation=None,
batchnorm=False):
batchnorm=False,
):
super(MixHop, self).__init__()
self.in_dim = in_dim
self.hid_dim = hid_dim
......@@ -127,23 +135,33 @@ class MixHop(nn.Module):
self.dropout = nn.Dropout(self.input_dropout)
# Input layer
self.layers.append(MixHopConv(self.in_dim,
self.layers.append(
MixHopConv(
self.in_dim,
self.hid_dim,
p=self.p,
dropout=self.input_dropout,
activation=self.activation,
batchnorm=self.batchnorm))
batchnorm=self.batchnorm,
)
)
# Hidden layers with n - 1 MixHopConv layers
for i in range(self.num_layers - 2):
self.layers.append(MixHopConv(self.hid_dim * len(args.p),
self.layers.append(
MixHopConv(
self.hid_dim * len(args.p),
self.hid_dim,
p=self.p,
dropout=self.layer_dropout,
activation=self.activation,
batchnorm=self.batchnorm))
batchnorm=self.batchnorm,
)
)
self.fc_layers = nn.Linear(self.hid_dim * len(args.p), self.out_dim, bias=False)
self.fc_layers = nn.Linear(
self.hid_dim * len(args.p), self.out_dim, bias=False
)
def forward(self, graph, feats):
feats = self.dropout(feats)
......@@ -154,41 +172,42 @@ class MixHop(nn.Module):
return feats
def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# Load from DGL dataset
if args.dataset == 'Cora':
if args.dataset == "Cora":
dataset = CoraGraphDataset()
elif args.dataset == 'Citeseer':
elif args.dataset == "Citeseer":
dataset = CiteseerGraphDataset()
elif args.dataset == 'Pubmed':
elif args.dataset == "Pubmed":
dataset = PubmedGraphDataset()
else:
raise ValueError('Dataset {} is invalid.'.format(args.dataset))
raise ValueError("Dataset {} is invalid.".format(args.dataset))
graph = dataset[0]
graph = dgl.add_self_loop(graph)
# check cuda
if args.gpu >= 0 and torch.cuda.is_available():
device = 'cuda:{}'.format(args.gpu)
device = "cuda:{}".format(args.gpu)
else:
device = 'cpu'
device = "cpu"
# retrieve the number of classes
n_classes = dataset.num_classes
# retrieve labels of ground truth
labels = graph.ndata.pop('label').to(device).long()
labels = graph.ndata.pop("label").to(device).long()
# Extract node features
feats = graph.ndata.pop('feat').to(device)
feats = graph.ndata.pop("feat").to(device)
n_features = feats.shape[-1]
# retrieve masks for train/validation/test
train_mask = graph.ndata.pop('train_mask')
val_mask = graph.ndata.pop('val_mask')
test_mask = graph.ndata.pop('test_mask')
train_mask = graph.ndata.pop("train_mask")
val_mask = graph.ndata.pop("val_mask")
test_mask = graph.ndata.pop("test_mask")
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)
val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze().to(device)
......@@ -197,7 +216,8 @@ def main(args):
graph = graph.to(device)
# Step 2: Create model =================================================================== #
model = MixHop(in_dim=n_features,
model = MixHop(
in_dim=n_features,
hid_dim=args.hid_dim,
out_dim=n_classes,
num_layers=args.num_layers,
......@@ -205,7 +225,8 @@ def main(args):
input_dropout=args.input_dropout,
layer_dropout=args.layer_dropout,
activation=torch.tanh,
batchnorm=True)
batchnorm=True,
)
model = model.to(device)
best_model = copy.deepcopy(model)
......@@ -218,7 +239,7 @@ def main(args):
# Step 4: training epoches =============================================================== #
acc = 0
no_improvement = 0
epochs = trange(args.epochs, desc='Accuracy & Loss')
epochs = trange(args.epochs, desc="Accuracy & Loss")
for _ in epochs:
# Training using a full graph
......@@ -228,7 +249,9 @@ def main(args):
# compute loss
train_loss = loss_fn(logits[train_idx], labels[train_idx])
train_acc = torch.sum(logits[train_idx].argmax(dim=1) == labels[train_idx]).item() / len(train_idx)
train_acc = torch.sum(
logits[train_idx].argmax(dim=1) == labels[train_idx]
).item() / len(train_idx)
# backward
opt.zero_grad()
......@@ -240,16 +263,21 @@ def main(args):
with torch.no_grad():
valid_loss = loss_fn(logits[val_idx], labels[val_idx])
valid_acc = torch.sum(logits[val_idx].argmax(dim=1) == labels[val_idx]).item() / len(val_idx)
valid_acc = torch.sum(
logits[val_idx].argmax(dim=1) == labels[val_idx]
).item() / len(val_idx)
# Print out performance
epochs.set_description('Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}'.format(
train_acc, train_loss.item(), valid_acc, valid_loss.item()))
epochs.set_description(
"Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}".format(
train_acc, train_loss.item(), valid_acc, valid_loss.item()
)
)
if valid_acc < acc:
no_improvement += 1
if no_improvement == args.early_stopping:
print('Early stop.')
print("Early stop.")
break
else:
no_improvement = 0
......@@ -260,34 +288,74 @@ def main(args):
best_model.eval()
logits = best_model(graph, feats)
test_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
test_acc = torch.sum(
logits[test_idx].argmax(dim=1) == labels[test_idx]
).item() / len(test_idx)
print("Test Acc {:.4f}".format(test_acc))
return test_acc
if __name__ == "__main__":
"""
MixHop Model Hyperparameters
"""
parser = argparse.ArgumentParser(description='MixHop GCN')
parser = argparse.ArgumentParser(description="MixHop GCN")
# data source params
parser.add_argument('--dataset', type=str, default='Cora', help='Name of dataset.')
parser.add_argument(
"--dataset", type=str, default="Cora", help="Name of dataset."
)
# cuda params
parser.add_argument('--gpu', type=int, default=-1, help='GPU index. Default: -1, using CPU.')
parser.add_argument(
"--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
)
# training params
parser.add_argument('--epochs', type=int, default=2000, help='Training epochs.')
parser.add_argument('--early-stopping', type=int, default=200, help='Patient epochs to wait before early stopping.')
parser.add_argument('--lr', type=float, default=0.5, help='Learning rate.')
parser.add_argument('--lamb', type=float, default=5e-4, help='L2 reg.')
parser.add_argument('--step-size', type=int, default=40, help='Period of learning rate decay.')
parser.add_argument('--gamma', type=float, default=0.01, help='Multiplicative factor of learning rate decay.')
parser.add_argument(
"--epochs", type=int, default=2000, help="Training epochs."
)
parser.add_argument(
"--early-stopping",
type=int,
default=200,
help="Patient epochs to wait before early stopping.",
)
parser.add_argument("--lr", type=float, default=0.5, help="Learning rate.")
parser.add_argument("--lamb", type=float, default=5e-4, help="L2 reg.")
parser.add_argument(
"--step-size",
type=int,
default=40,
help="Period of learning rate decay.",
)
parser.add_argument(
"--gamma",
type=float,
default=0.01,
help="Multiplicative factor of learning rate decay.",
)
# model params
parser.add_argument("--hid-dim", type=int, default=60, help='Hidden layer dimensionalities.')
parser.add_argument("--num-layers", type=int, default=4, help='Number of GNN layers.')
parser.add_argument("--input-dropout", type=float, default=0.7, help='Dropout applied at input layer.')
parser.add_argument("--layer-dropout", type=float, default=0.9, help='Dropout applied at hidden layers.')
parser.add_argument('--p', nargs='+', type=int, help='List of powers of adjacency matrix.')
parser.add_argument(
"--hid-dim", type=int, default=60, help="Hidden layer dimensionalities."
)
parser.add_argument(
"--num-layers", type=int, default=4, help="Number of GNN layers."
)
parser.add_argument(
"--input-dropout",
type=float,
default=0.7,
help="Dropout applied at input layer.",
)
parser.add_argument(
"--layer-dropout",
type=float,
default=0.9,
help="Dropout applied at hidden layers.",
)
parser.add_argument(
"--p", nargs="+", type=int, help="List of powers of adjacency matrix."
)
parser.set_defaults(p=[0, 1, 2])
......@@ -304,7 +372,7 @@ if __name__ == "__main__":
mean = np.around(np.mean(acc_lists_top, axis=0), decimals=3)
std = np.around(np.std(acc_lists_top, axis=0), decimals=3)
print('Total acc: ', acc_lists)
print('Top 50 acc:', acc_lists_top)
print('mean', mean)
print('std', std)
print("Total acc: ", acc_lists)
print("Top 50 acc:", acc_lists_top)
print("mean", mean)
print("std", std)
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment