Unverified Commit 0b9df9d7 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4652)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent f19f05ce
import dgl
import rdkit.Chem as Chem
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, \
set_atommap, enum_assemble_nx, decode_stereo
import numpy as np
import rdkit.Chem as Chem
import dgl
from .chemutils import (
decode_stereo,
enum_assemble_nx,
get_clique_mol,
get_mol,
get_smiles,
set_atommap,
tree_decomp,
)
class DGLMolTree(object):
def __init__(self, smiles):
......@@ -28,21 +38,23 @@ class DGLMolTree(object):
cmol = get_clique_mol(self.mol, c)
csmiles = get_smiles(cmol)
self.nodes_dict[i] = dict(
smiles=csmiles,
mol=get_mol(csmiles),
clique=c,
)
smiles=csmiles,
mol=get_mol(csmiles),
clique=c,
)
if min(c) == 0:
root = i
# The clique with atom ID 0 becomes root
if root > 0:
for attr in self.nodes_dict[0]:
self.nodes_dict[0][attr], self.nodes_dict[root][attr] = \
self.nodes_dict[root][attr], self.nodes_dict[0][attr]
self.nodes_dict[0][attr], self.nodes_dict[root][attr] = (
self.nodes_dict[root][attr],
self.nodes_dict[0][attr],
)
src = np.zeros((len(edges) * 2,), dtype='int')
dst = np.zeros((len(edges) * 2,), dtype='int')
src = np.zeros((len(edges) * 2,), dtype="int")
dst = np.zeros((len(edges) * 2,), dtype="int")
for i, (_x, _y) in enumerate(edges):
x = 0 if _x == root else root if _x == 0 else _x
y = 0 if _y == root else root if _y == 0 else _y
......@@ -53,10 +65,12 @@ class DGLMolTree(object):
self.graph = dgl.graph((src, dst), num_nodes=len(cliques))
for i in self.nodes_dict:
self.nodes_dict[i]['nid'] = i + 1
if self.graph.out_degrees(i) > 1: # Leaf node mol is not marked
set_atommap(self.nodes_dict[i]['mol'], self.nodes_dict[i]['nid'])
self.nodes_dict[i]['is_leaf'] = (self.graph.out_degrees(i) == 1)
self.nodes_dict[i]["nid"] = i + 1
if self.graph.out_degrees(i) > 1: # Leaf node mol is not marked
set_atommap(
self.nodes_dict[i]["mol"], self.nodes_dict[i]["nid"]
)
self.nodes_dict[i]["is_leaf"] = self.graph.out_degrees(i) == 1
def treesize(self):
return self.graph.number_of_nodes()
......@@ -65,49 +79,65 @@ class DGLMolTree(object):
node = self.nodes_dict[i]
clique = []
clique.extend(node['clique'])
if not node['is_leaf']:
for cidx in node['clique']:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid'])
clique.extend(node["clique"])
if not node["is_leaf"]:
for cidx in node["clique"]:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node["nid"])
for j in self.graph.successors(i).numpy():
nei_node = self.nodes_dict[j]
clique.extend(nei_node['clique'])
if nei_node['is_leaf']: # Leaf node, no need to mark
clique.extend(nei_node["clique"])
if nei_node["is_leaf"]: # Leaf node, no need to mark
continue
for cidx in nei_node['clique']:
for cidx in nei_node["clique"]:
# allow singleton node override the atom mapping
if cidx not in node['clique'] or len(nei_node['clique']) == 1:
if cidx not in node["clique"] or len(nei_node["clique"]) == 1:
atom = original_mol.GetAtomWithIdx(cidx)
atom.SetAtomMapNum(nei_node['nid'])
atom.SetAtomMapNum(nei_node["nid"])
clique = list(set(clique))
label_mol = get_clique_mol(original_mol, clique)
node['label'] = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
node['label_mol'] = get_mol(node['label'])
node["label"] = Chem.MolToSmiles(
Chem.MolFromSmiles(get_smiles(label_mol))
)
node["label_mol"] = get_mol(node["label"])
for cidx in clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
return node['label']
return node["label"]
def _assemble_node(self, i):
neighbors = [self.nodes_dict[j] for j in self.graph.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [self.nodes_dict[j] for j in self.graph.successors(i).numpy()
if self.nodes_dict[j]['mol'].GetNumAtoms() == 1]
neighbors = [
self.nodes_dict[j]
for j in self.graph.successors(i).numpy()
if self.nodes_dict[j]["mol"].GetNumAtoms() > 1
]
neighbors = sorted(
neighbors, key=lambda x: x["mol"].GetNumAtoms(), reverse=True
)
singletons = [
self.nodes_dict[j]
for j in self.graph.successors(i).numpy()
if self.nodes_dict[j]["mol"].GetNumAtoms() == 1
]
neighbors = singletons + neighbors
cands = enum_assemble_nx(self.nodes_dict[i], neighbors)
if len(cands) > 0:
self.nodes_dict[i]['cands'], self.nodes_dict[i]['cand_mols'], _ = list(zip(*cands))
self.nodes_dict[i]['cands'] = list(self.nodes_dict[i]['cands'])
self.nodes_dict[i]['cand_mols'] = list(self.nodes_dict[i]['cand_mols'])
(
self.nodes_dict[i]["cands"],
self.nodes_dict[i]["cand_mols"],
_,
) = list(zip(*cands))
self.nodes_dict[i]["cands"] = list(self.nodes_dict[i]["cands"])
self.nodes_dict[i]["cand_mols"] = list(
self.nodes_dict[i]["cand_mols"]
)
else:
self.nodes_dict[i]['cands'] = []
self.nodes_dict[i]['cand_mols'] = []
self.nodes_dict[i]["cands"] = []
self.nodes_dict[i]["cand_mols"] = []
def recover(self):
for i in self.nodes_dict:
......
import rdkit.Chem as Chem
import torch
import torch.nn as nn
import rdkit.Chem as Chem
import torch.nn.functional as F
from .chemutils import get_mol
import dgl
from dgl import mean_nodes, line_graph
import dgl.function as DGLF
from dgl import line_graph, mean_nodes
from .chemutils import get_mol
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ELEM_LIST = [
"C",
"N",
"O",
"S",
"F",
"Si",
"P",
"Cl",
"Br",
"Mg",
"Na",
"Ca",
"Fe",
"Al",
"I",
"B",
"K",
"Se",
"Zn",
"H",
"Cu",
"Mn",
"unknown",
]
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1
BOND_FDIM = 5 + 6
MAX_NB = 6
def onek_encoding_unk(x, allowable_set):
if x not in allowable_set:
x = allowable_set[-1]
return [x == s for s in allowable_set]
def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5])
+ onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0])
+ onek_encoding_unk(int(atom.GetChiralTag()), [0,1,2,3])
+ [atom.GetIsAromatic()]))
return torch.Tensor(
onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(), [0, 1, 2, 3, 4, 5])
+ onek_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0])
+ onek_encoding_unk(int(atom.GetChiralTag()), [0, 1, 2, 3])
+ [atom.GetIsAromatic()]
)
def bond_features(bond):
bt = bond.GetBondType()
stereo = int(bond.GetStereo())
fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]
fstereo = onek_encoding_unk(stereo, [0,1,2,3,4,5])
return (torch.Tensor(fbond + fstereo))
fbond = [
bt == Chem.rdchem.BondType.SINGLE,
bt == Chem.rdchem.BondType.DOUBLE,
bt == Chem.rdchem.BondType.TRIPLE,
bt == Chem.rdchem.BondType.AROMATIC,
bond.IsInRing(),
]
fstereo = onek_encoding_unk(stereo, [0, 1, 2, 3, 4, 5])
return torch.Tensor(fbond + fstereo)
def mol2dgl_single(smiles):
n_edges = 0
......@@ -61,8 +98,11 @@ def mol2dgl_single(smiles):
bond_x.append(features)
graph = dgl.graph((bond_src, bond_dst), num_nodes=n_atoms)
n_edges += n_bonds
return graph, torch.stack(atom_x), \
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0)
return (
graph,
torch.stack(atom_x),
torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0),
)
class LoopyBPUpdate(nn.Module):
......@@ -73,10 +113,10 @@ class LoopyBPUpdate(nn.Module):
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, nodes):
msg_input = nodes.data['msg_input']
msg_delta = self.W_h(nodes.data['accum_msg'])
msg_input = nodes.data["msg_input"]
msg_delta = self.W_h(nodes.data["accum_msg"])
msg = F.relu(msg_input + msg_delta)
return {'msg': msg}
return {"msg": msg}
class GatherUpdate(nn.Module):
......@@ -87,9 +127,9 @@ class GatherUpdate(nn.Module):
self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
def forward(self, nodes):
m = nodes.data['m']
m = nodes.data["m"]
return {
'h': F.relu(self.W_o(torch.cat([nodes.data['x'], m], 1))),
"h": F.relu(self.W_o(torch.cat([nodes.data["x"], m], 1))),
}
......@@ -121,7 +161,7 @@ class DGLMPN(nn.Module):
mol_graph = self.run(mol_graph, mol_line_graph)
# TODO: replace with unbatch or readout
g_repr = mean_nodes(mol_graph, 'h')
g_repr = mean_nodes(mol_graph, "h")
self.n_samples_total += n_samples
self.n_nodes_total += n_nodes
......@@ -134,32 +174,38 @@ class DGLMPN(nn.Module):
n_nodes = mol_graph.number_of_nodes()
mol_graph.apply_edges(
func=lambda edges: {'src_x': edges.src['x']},
func=lambda edges: {"src_x": edges.src["x"]},
)
mol_line_graph.ndata.update(mol_graph.edata)
e_repr = mol_line_graph.ndata
bond_features = e_repr['x']
source_features = e_repr['src_x']
bond_features = e_repr["x"]
source_features = e_repr["src_x"]
features = torch.cat([source_features, bond_features], 1)
msg_input = self.W_i(features)
mol_line_graph.ndata.update({
'msg_input': msg_input,
'msg': F.relu(msg_input),
'accum_msg': torch.zeros_like(msg_input),
})
mol_graph.ndata.update({
'm': bond_features.new(n_nodes, self.hidden_size).zero_(),
'h': bond_features.new(n_nodes, self.hidden_size).zero_(),
})
mol_line_graph.ndata.update(
{
"msg_input": msg_input,
"msg": F.relu(msg_input),
"accum_msg": torch.zeros_like(msg_input),
}
)
mol_graph.ndata.update(
{
"m": bond_features.new(n_nodes, self.hidden_size).zero_(),
"h": bond_features.new(n_nodes, self.hidden_size).zero_(),
}
)
for i in range(self.depth - 1):
mol_line_graph.update_all(DGLF.copy_u('msg', 'msg'), DGLF.sum('msg', 'accum_msg'))
mol_line_graph.update_all(
DGLF.copy_u("msg", "msg"), DGLF.sum("msg", "accum_msg")
)
mol_line_graph.apply_nodes(self.loopy_bp_updater)
mol_graph.edata.update(mol_line_graph.ndata)
mol_graph.update_all(DGLF.copy_e('msg', 'msg'), DGLF.sum('msg', 'm'))
mol_graph.update_all(DGLF.copy_e("msg", "msg"), DGLF.sum("msg", "m"))
mol_graph.apply_nodes(self.gather_updater)
return mol_graph
import os
import torch
import torch.nn as nn
import os
import dgl
def cuda(x):
if torch.cuda.is_available() and not os.getenv('NOCUDA', None):
return x.to(torch.device('cuda')) # works for both DGLGraph and tensor
if torch.cuda.is_available() and not os.getenv("NOCUDA", None):
return x.to(torch.device("cuda")) # works for both DGLGraph and tensor
else:
return x
......@@ -22,27 +24,28 @@ class GRUUpdate(nn.Module):
self.W_h = nn.Linear(2 * hidden_size, hidden_size)
def update_zm(self, node):
src_x = node.data['src_x']
s = node.data['s']
rm = node.data['accum_rm']
src_x = node.data["src_x"]
s = node.data["s"]
rm = node.data["accum_rm"]
z = torch.sigmoid(self.W_z(torch.cat([src_x, s], 1)))
m = torch.tanh(self.W_h(torch.cat([src_x, rm], 1)))
m = (1 - z) * s + z * m
return {'m': m, 'z': z}
return {"m": m, "z": z}
def update_r(self, node, zm=None):
dst_x = node.data['dst_x']
m = node.data['m'] if zm is None else zm['m']
dst_x = node.data["dst_x"]
m = node.data["m"] if zm is None else zm["m"]
r_1 = self.W_r(dst_x)
r_2 = self.U_r(m)
r = torch.sigmoid(r_1 + r_2)
return {'r': r, 'rm': r * m}
return {"r": r, "rm": r * m}
def forward(self, node):
dic = self.update_zm(node)
dic.update(self.update_r(node, zm=dic))
return dic
def tocpu(g):
src, dst = g.edges()
src = src.cpu()
......
import math
import random
import sys
from collections import deque
from optparse import OptionParser
import rdkit
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
import math, random, sys
from optparse import OptionParser
from collections import deque
import rdkit
import tqdm
from jtnn import *
from torch.utils.data import DataLoader
torch.multiprocessing.set_sharing_strategy("file_system")
torch.multiprocessing.set_sharing_strategy('file_system')
def worker_init_fn(id_):
lg = rdkit.RDLogger.logger()
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
worker_init_fn(None)
parser = OptionParser()
parser.add_option("-t", "--train", dest="train", default='train', help='Training file name')
parser.add_option("-v", "--vocab", dest="vocab", default='vocab', help='Vocab file name')
parser.add_option(
"-t", "--train", dest="train", default="train", help="Training file name"
)
parser.add_option(
"-v", "--vocab", dest="vocab", default="vocab", help="Vocab file name"
)
parser.add_option("-s", "--save_dir", dest="save_path")
parser.add_option("-m", "--model", dest="model_path", default=None)
parser.add_option("-b", "--batch", dest="batch_size", default=40)
......@@ -31,7 +39,7 @@ parser.add_option("-d", "--depth", dest="depth", default=3)
parser.add_option("-z", "--beta", dest="beta", default=1.0)
parser.add_option("-q", "--lr", dest="lr", default=1e-3)
parser.add_option("-T", "--test", dest="test", action="store_true")
opts,args = parser.parse_args()
opts, args = parser.parse_args()
dataset = JTNNDataset(data=opts.train, vocab=opts.vocab, training=True)
vocab = dataset.vocab
......@@ -55,7 +63,10 @@ else:
nn.init.xavier_normal(param)
model = cuda(model)
print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,))
print(
"Model #Params: %dK"
% (sum([x.nelement() for x in model.parameters()]) / 1000,)
)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)
......@@ -64,26 +75,28 @@ scheduler.step()
MAX_EPOCH = 100
PRINT_ITER = 20
def train():
dataset.training = True
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
collate_fn=JTNNCollator(vocab, True),
drop_last=True,
worker_init_fn=worker_init_fn)
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
collate_fn=JTNNCollator(vocab, True),
drop_last=True,
worker_init_fn=worker_init_fn,
)
for epoch in range(MAX_EPOCH):
word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
for it, batch in enumerate(tqdm.tqdm(dataloader)):
model.zero_grad()
try:
loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta)
except:
print([t.smiles for t in batch['mol_trees']])
print([t.smiles for t in batch["mol_trees"]])
raise
loss.backward()
optimizer.step()
......@@ -99,36 +112,51 @@ def train():
assm_acc = assm_acc / PRINT_ITER * 100
steo_acc = steo_acc / PRINT_ITER * 100
print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f" % (
kl_div, word_acc, topo_acc, assm_acc, steo_acc, loss.item()))
word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
print(
"KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f"
% (
kl_div,
word_acc,
topo_acc,
assm_acc,
steo_acc,
loss.item(),
)
)
word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
sys.stdout.flush()
if (it + 1) % 1500 == 0: #Fast annealing
if (it + 1) % 1500 == 0: # Fast annealing
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(),
opts.save_path + "/model.iter-%d-%d" % (epoch, it + 1))
torch.save(
model.state_dict(),
opts.save_path + "/model.iter-%d-%d" % (epoch, it + 1),
)
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])
torch.save(model.state_dict(), opts.save_path + "/model.iter-" + str(epoch))
torch.save(
model.state_dict(), opts.save_path + "/model.iter-" + str(epoch)
)
def test():
dataset.training = False
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
collate_fn=JTNNCollator(vocab, False),
drop_last=True,
worker_init_fn=worker_init_fn)
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
collate_fn=JTNNCollator(vocab, False),
drop_last=True,
worker_init_fn=worker_init_fn,
)
# Just an example of molecule decoding; in reality you may want to sample
# tree and molecule vectors.
for it, batch in enumerate(dataloader):
gt_smiles = batch['mol_trees'][0].smiles
gt_smiles = batch["mol_trees"][0].smiles
print(gt_smiles)
model.move_to_cuda(batch)
_, tree_vec, mol_vec = model.encode(batch)
......@@ -136,21 +164,28 @@ def test():
smiles = model.decode(tree_vec, mol_vec)
print(smiles)
if __name__ == '__main__':
if __name__ == "__main__":
if opts.test:
test()
else:
train()
print('# passes:', model.n_passes)
print('Total # nodes processed:', model.n_nodes_total)
print('Total # edges processed:', model.n_edges_total)
print('Total # tree nodes processed:', model.n_tree_nodes_total)
print('Graph decoder: # passes:', model.jtmpn.n_passes)
print('Graph decoder: Total # candidates processed:', model.jtmpn.n_samples_total)
print('Graph decoder: Total # nodes processed:', model.jtmpn.n_nodes_total)
print('Graph decoder: Total # edges processed:', model.jtmpn.n_edges_total)
print('Graph encoder: # passes:', model.mpn.n_passes)
print('Graph encoder: Total # candidates processed:', model.mpn.n_samples_total)
print('Graph encoder: Total # nodes processed:', model.mpn.n_nodes_total)
print('Graph encoder: Total # edges processed:', model.mpn.n_edges_total)
print("# passes:", model.n_passes)
print("Total # nodes processed:", model.n_nodes_total)
print("Total # edges processed:", model.n_edges_total)
print("Total # tree nodes processed:", model.n_tree_nodes_total)
print("Graph decoder: # passes:", model.jtmpn.n_passes)
print(
"Graph decoder: Total # candidates processed:",
model.jtmpn.n_samples_total,
)
print("Graph decoder: Total # nodes processed:", model.jtmpn.n_nodes_total)
print("Graph decoder: Total # edges processed:", model.jtmpn.n_edges_total)
print("Graph encoder: # passes:", model.mpn.n_passes)
print(
"Graph encoder: Total # candidates processed:",
model.mpn.n_samples_total,
)
print("Graph encoder: Total # nodes processed:", model.mpn.n_nodes_total)
print("Graph encoder: Total # edges processed:", model.mpn.n_edges_total)
import argparse
import torch
import dgl
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from dgl.nn import LabelPropagation
def main():
# check cuda
device = f'cuda:{args.gpu}' if torch.cuda.is_available() and args.gpu >= 0 else 'cpu'
device = (
f"cuda:{args.gpu}"
if torch.cuda.is_available() and args.gpu >= 0
else "cpu"
)
# load data
if args.dataset == 'Cora':
if args.dataset == "Cora":
dataset = CoraGraphDataset()
elif args.dataset == 'Citeseer':
elif args.dataset == "Citeseer":
dataset = CiteseerGraphDataset()
elif args.dataset == 'Pubmed':
elif args.dataset == "Pubmed":
dataset = PubmedGraphDataset()
else:
raise ValueError('Dataset {} is invalid.'.format(args.dataset))
raise ValueError("Dataset {} is invalid.".format(args.dataset))
g = dataset[0]
g = dgl.add_self_loop(g)
labels = g.ndata.pop('label').to(device).long()
labels = g.ndata.pop("label").to(device).long()
# load masks for train / test, valid is not used.
train_mask = g.ndata.pop('train_mask')
test_mask = g.ndata.pop('test_mask')
train_mask = g.ndata.pop("train_mask")
test_mask = g.ndata.pop("test_mask")
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze().to(device)
g = g.to(device)
# label propagation
lp = LabelPropagation(args.num_layers, args.alpha)
logits = lp(g, labels, mask=train_idx)
test_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
test_acc = torch.sum(
logits[test_idx].argmax(dim=1) == labels[test_idx]
).item() / len(test_idx)
print("Test Acc {:.4f}".format(test_acc))
if __name__ == '__main__':
if __name__ == "__main__":
"""
Label Propagation Hyperparameters
"""
parser = argparse.ArgumentParser(description='LP')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--dataset', type=str, default='Cora')
parser.add_argument('--num-layers', type=int, default=10)
parser.add_argument('--alpha', type=float, default=0.5)
parser = argparse.ArgumentParser(description="LP")
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--dataset", type=str, default="Cora")
parser.add_argument("--num-layers", type=int, default=10)
parser.add_argument("--alpha", type=float, default=0.5)
args = parser.parse_args()
print(args)
......
......@@ -17,49 +17,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from time import time
import matplotlib.pyplot as plt
import warnings
from time import time
import matplotlib.pyplot as plt
import numpy as np
import scipy.sparse as ss
import torch
import dgl
from dgl import function as fn
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.decomposition import NMF, LatentDirichletAllocation
from lda_model import LatentDirichletAllocation as LDAModel
from sklearn.datasets import fetch_20newsgroups
from sklearn.decomposition import NMF, LatentDirichletAllocation
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from lda_model import LatentDirichletAllocation as LDAModel
import dgl
from dgl import function as fn
n_samples = 2000
n_features = 1000
n_components = 10
n_top_words = 20
device = 'cuda'
device = "cuda"
def plot_top_words(model, feature_names, n_top_words, title):
fig, axes = plt.subplots(2, 5, figsize=(30, 15), sharex=True)
axes = axes.flatten()
for topic_idx, topic in enumerate(model.components_):
top_features_ind = topic.argsort()[:-n_top_words - 1:-1]
top_features_ind = topic.argsort()[: -n_top_words - 1 : -1]
top_features = [feature_names[i] for i in top_features_ind]
weights = topic[top_features_ind]
ax = axes[topic_idx]
ax.barh(top_features, weights, height=0.7)
ax.set_title(f'Topic {topic_idx +1}',
fontdict={'fontsize': 30})
ax.set_title(f"Topic {topic_idx +1}", fontdict={"fontsize": 30})
ax.invert_yaxis()
ax.tick_params(axis='both', which='major', labelsize=20)
for i in 'top right left'.split():
ax.tick_params(axis="both", which="major", labelsize=20)
for i in "top right left".split():
ax.spines[i].set_visible(False)
fig.suptitle(title, fontsize=40)
plt.subplots_adjust(top=0.90, bottom=0.05, wspace=0.90, hspace=0.3)
plt.show()
# Load the 20 newsgroups dataset and vectorize it. We use a few heuristics
# to filter out useless terms early on: the posts are stripped of headers,
# footers and quoted replies, and common English words, words occurring in
......@@ -67,43 +67,50 @@ def plot_top_words(model, feature_names, n_top_words, title):
print("Loading dataset...")
t0 = time()
data, _ = fetch_20newsgroups(shuffle=True, random_state=1,
remove=('headers', 'footers', 'quotes'),
return_X_y=True)
data, _ = fetch_20newsgroups(
shuffle=True,
random_state=1,
remove=("headers", "footers", "quotes"),
return_X_y=True,
)
data_samples = data[:n_samples]
data_test = data[n_samples:2*n_samples]
data_test = data[n_samples : 2 * n_samples]
print("done in %0.3fs." % (time() - t0))
# Use tf (raw term count) features for LDA.
print("Extracting tf features for LDA...")
tf_vectorizer = CountVectorizer(max_df=0.95, min_df=2,
max_features=n_features,
stop_words='english')
tf_vectorizer = CountVectorizer(
max_df=0.95, min_df=2, max_features=n_features, stop_words="english"
)
t0 = time()
tf_vectorizer.fit(data)
tf = tf_vectorizer.transform(data_samples)
tt = tf_vectorizer.transform(data_test)
tf_feature_names = tf_vectorizer.get_feature_names()
tf_uv = [(u,v)
for u,v,e in zip(tf.tocoo().row, tf.tocoo().col, tf.tocoo().data)
for _ in range(e)]
tt_uv = [(u,v)
for u,v,e in zip(tt.tocoo().row, tt.tocoo().col, tt.tocoo().data)
for _ in range(e)]
tf_uv = [
(u, v)
for u, v, e in zip(tf.tocoo().row, tf.tocoo().col, tf.tocoo().data)
for _ in range(e)
]
tt_uv = [
(u, v)
for u, v, e in zip(tt.tocoo().row, tt.tocoo().col, tt.tocoo().data)
for _ in range(e)
]
print("done in %0.3fs." % (time() - t0))
print()
print("Preparing dgl graphs...")
t0 = time()
G = dgl.heterograph({('doc','topic','word'): tf_uv}, device=device)
Gt = dgl.heterograph({('doc','topic','word'): tt_uv}, device=device)
G = dgl.heterograph({("doc", "topic", "word"): tf_uv}, device=device)
Gt = dgl.heterograph({("doc", "topic", "word"): tt_uv}, device=device)
print("done in %0.3fs." % (time() - t0))
print()
print("Training dgl-lda model...")
t0 = time()
model = LDAModel(G.num_nodes('word'), n_components)
model = LDAModel(G.num_nodes("word"), n_components)
model.fit(G)
print("done in %0.3fs." % (time() - t0))
print()
......@@ -113,20 +120,27 @@ print(f"dgl-lda testing perplexity {model.perplexity(Gt):.3f}")
word_nphi = np.vstack([nphi.tolist() for nphi in model.word_data.nphi])
plot_top_words(
type('dummy', (object,), {'components_': word_nphi}),
tf_feature_names, n_top_words, 'Topics in LDA model')
type("dummy", (object,), {"components_": word_nphi}),
tf_feature_names,
n_top_words,
"Topics in LDA model",
)
print("Training scikit-learn model...")
print('\n' * 2, "Fitting LDA models with tf features, "
"n_samples=%d and n_features=%d..."
% (n_samples, n_features))
lda = LatentDirichletAllocation(n_components=n_components, max_iter=5,
learning_method='online',
learning_offset=50.,
random_state=0,
verbose=1,
)
print(
"\n" * 2,
"Fitting LDA models with tf features, "
"n_samples=%d and n_features=%d..." % (n_samples, n_features),
)
lda = LatentDirichletAllocation(
n_components=n_components,
max_iter=5,
learning_method="online",
learning_offset=50.0,
random_state=0,
verbose=1,
)
t0 = time()
lda.fit(tf)
print("done in %0.3fs." % (time() - t0))
......
......@@ -17,8 +17,17 @@
# limitations under the License.
import os, functools, warnings, torch, collections, dgl, io
import numpy as np, scipy as sp
import collections
import functools
import io
import os
import warnings
import numpy as np
import scipy as sp
import torch
import dgl
try:
from functools import cached_property
......@@ -37,17 +46,21 @@ class EdgeData:
@property
def loglike(self):
return (self.src_data['Elog'] + self.dst_data['Elog']).logsumexp(1)
return (self.src_data["Elog"] + self.dst_data["Elog"]).logsumexp(1)
@property
def phi(self):
return (
self.src_data['Elog'] + self.dst_data['Elog'] - self.loglike.unsqueeze(1)
self.src_data["Elog"]
+ self.dst_data["Elog"]
- self.loglike.unsqueeze(1)
).exp()
@property
def expectation(self):
return (self.src_data['expectation'] * self.dst_data['expectation']).sum(1)
return (
self.src_data["expectation"] * self.dst_data["expectation"]
).sum(1)
class _Dirichlet:
......@@ -55,10 +68,13 @@ class _Dirichlet:
self.prior = prior
self.nphi = nphi
self.device = nphi.device
self._sum_by_parts = lambda map_fn: functools.reduce(torch.add, [
map_fn(slice(i, min(i+_chunksize, nphi.shape[1]))).sum(1)
for i in list(range(0, nphi.shape[1], _chunksize))
])
self._sum_by_parts = lambda map_fn: functools.reduce(
torch.add,
[
map_fn(slice(i, min(i + _chunksize, nphi.shape[1]))).sum(1)
for i in list(range(0, nphi.shape[1], _chunksize))
],
)
def _posterior(self, _ID=slice(None)):
return self.prior + self.nphi[:, _ID]
......@@ -68,14 +84,15 @@ class _Dirichlet:
return self.nphi.sum(1) + self.prior * self.nphi.shape[1]
def _Elog(self, _ID=slice(None)):
return torch.digamma(self._posterior(_ID)) - \
torch.digamma(self.posterior_sum.unsqueeze(1))
return torch.digamma(self._posterior(_ID)) - torch.digamma(
self.posterior_sum.unsqueeze(1)
)
@cached_property
def loglike(self):
neg_evid = -self._sum_by_parts(
lambda s: (self.nphi[:, s] * self._Elog(s))
)
)
prior = torch.as_tensor(self.prior).to(self.nphi)
K = self.nphi.shape[1]
......@@ -83,7 +100,7 @@ class _Dirichlet:
log_B_posterior = self._sum_by_parts(
lambda s: torch.lgamma(self._posterior(s))
) - torch.lgamma(self.posterior_sum)
) - torch.lgamma(self.posterior_sum)
return neg_evid - log_B_prior + log_B_posterior
......@@ -105,9 +122,15 @@ class _Dirichlet:
@cached_property
def Bayesian_gap(self):
return 1. - self._sum_by_parts(lambda s: self._Elog(s).exp())
return 1.0 - self._sum_by_parts(lambda s: self._Elog(s).exp())
_cached_properties = ["posterior_sum", "loglike", "n", "cdf", "Bayesian_gap"]
_cached_properties = [
"posterior_sum",
"loglike",
"n",
"cdf",
"Bayesian_gap",
]
def clear_cache(self):
for name in self._cached_properties:
......@@ -117,27 +140,29 @@ class _Dirichlet:
pass
def update(self, new, _ID=slice(None), rho=1):
""" inplace: old * (1-rho) + new * rho """
"""inplace: old * (1-rho) + new * rho"""
self.clear_cache()
mean_change = (self.nphi[:, _ID] - new).abs().mean().tolist()
self.nphi *= (1 - rho)
self.nphi *= 1 - rho
self.nphi[:, _ID] += new * rho
return mean_change
class DocData(_Dirichlet):
""" nphi (n_docs by n_topics) """
"""nphi (n_docs by n_topics)"""
def prepare_graph(self, G, key="Elog"):
G.nodes['doc'].data[key] = getattr(self, '_'+key)().to(G.device)
G.nodes["doc"].data[key] = getattr(self, "_" + key)().to(G.device)
def update_from(self, G, mult):
new = G.nodes['doc'].data['nphi'] * mult
new = G.nodes["doc"].data["nphi"] * mult
return self.update(new.to(self.device))
class _Distributed(collections.UserList):
""" split on dim=0 and store on multiple devices """
"""split on dim=0 and store on multiple devices"""
def __init__(self, prior, nphi):
self.prior = prior
self.nphi = nphi
......@@ -146,36 +171,38 @@ class _Distributed(collections.UserList):
def split_device(self, other, dim=0):
split_sections = [x.shape[0] for x in self.nphi]
out = torch.split(other, split_sections, dim)
return [y.to(x.device) for x,y in zip(self.nphi, out)]
return [y.to(x.device) for x, y in zip(self.nphi, out)]
class WordData(_Distributed):
""" distributed nphi (n_topics by n_words), transpose to/from graph nodes data """
"""distributed nphi (n_topics by n_words), transpose to/from graph nodes data"""
def prepare_graph(self, G, key="Elog"):
if '_ID' in G.nodes['word'].data:
_ID = G.nodes['word'].data['_ID']
if "_ID" in G.nodes["word"].data:
_ID = G.nodes["word"].data["_ID"]
else:
_ID = slice(None)
out = [getattr(part, '_'+key)(_ID).to(G.device) for part in self]
G.nodes['word'].data[key] = torch.cat(out).T
out = [getattr(part, "_" + key)(_ID).to(G.device) for part in self]
G.nodes["word"].data[key] = torch.cat(out).T
def update_from(self, G, mult, rho):
nphi = G.nodes['word'].data['nphi'].T * mult
nphi = G.nodes["word"].data["nphi"].T * mult
if '_ID' in G.nodes['word'].data:
_ID = G.nodes['word'].data['_ID']
if "_ID" in G.nodes["word"].data:
_ID = G.nodes["word"].data["_ID"]
else:
_ID = slice(None)
mean_change = [x.update(y, _ID, rho)
for x, y in zip(self, self.split_device(nphi))]
mean_change = [
x.update(y, _ID, rho) for x, y in zip(self, self.split_device(nphi))
]
return np.mean(mean_change)
class Gamma(collections.namedtuple('Gamma', "concentration, rate")):
""" articulate the difference between torch gamma and numpy gamma """
class Gamma(collections.namedtuple("Gamma", "concentration, rate")):
"""articulate the difference between torch gamma and numpy gamma"""
@property
def shape(self):
return self.concentration
......@@ -218,20 +245,23 @@ class LatentDirichletAllocation:
(NIPS 2010).
[2] Reactive LDA Library blogpost by Yingjie Miao for a similar Gibbs model
"""
def __init__(
self, n_words, n_components,
self,
n_words,
n_components,
prior=None,
rho=1,
mult={'doc': 1, 'word': 1},
init={'doc': (100., 100.), 'word': (100., 100.)},
device_list=['cpu'],
mult={"doc": 1, "word": 1},
init={"doc": (100.0, 100.0), "word": (100.0, 100.0)},
device_list=["cpu"],
verbose=True,
):
):
self.n_words = n_words
self.n_components = n_components
if prior is None:
prior = {'doc': 1./n_components, 'word': 1./n_components}
prior = {"doc": 1.0 / n_components, "word": 1.0 / n_components}
self.prior = prior
self.rho = rho
......@@ -239,117 +269,128 @@ class LatentDirichletAllocation:
self.init = init
assert not isinstance(device_list, str), "plz wrap devices in a list"
self.device_list = device_list[:n_components] # avoid edge cases
self.device_list = device_list[:n_components] # avoid edge cases
self.verbose = verbose
self._init_word_data()
def _init_word_data(self):
split_sections = np.diff(
np.linspace(0, self.n_components, len(self.device_list)+1).astype(int)
np.linspace(0, self.n_components, len(self.device_list) + 1).astype(
int
)
)
word_nphi = [
Gamma(*self.init['word']).sample((s, self.n_words), device)
Gamma(*self.init["word"]).sample((s, self.n_words), device)
for s, device in zip(split_sections, self.device_list)
]
self.word_data = WordData(self.prior['word'], word_nphi)
self.word_data = WordData(self.prior["word"], word_nphi)
def _init_doc_data(self, n_docs, device):
doc_nphi = Gamma(*self.init['doc']).sample(
(n_docs, self.n_components), device)
return DocData(self.prior['doc'], doc_nphi)
doc_nphi = Gamma(*self.init["doc"]).sample(
(n_docs, self.n_components), device
)
return DocData(self.prior["doc"], doc_nphi)
def save(self, f):
for w in self.word_data:
w.clear_cache()
torch.save({
'prior': self.prior,
'rho': self.rho,
'mult': self.mult,
'init': self.init,
'word_data': [part.nphi for part in self.word_data],
}, f)
torch.save(
{
"prior": self.prior,
"rho": self.rho,
"mult": self.mult,
"init": self.init,
"word_data": [part.nphi for part in self.word_data],
},
f,
)
def _prepare_graph(self, G, doc_data, key="Elog"):
doc_data.prepare_graph(G, key)
self.word_data.prepare_graph(G, key)
def _e_step(self, G, doc_data=None, mean_change_tol=1e-3, max_iters=100):
"""_e_step implements doc data sampling until convergence or max_iters
"""
"""_e_step implements doc data sampling until convergence or max_iters"""
if doc_data is None:
doc_data = self._init_doc_data(G.num_nodes('doc'), G.device)
doc_data = self._init_doc_data(G.num_nodes("doc"), G.device)
G_rev = G.reverse() # word -> doc
G_rev = G.reverse() # word -> doc
self.word_data.prepare_graph(G_rev)
for i in range(max_iters):
doc_data.prepare_graph(G_rev)
G_rev.update_all(
lambda edges: {'phi': EdgeData(edges.src, edges.dst).phi},
dgl.function.sum('phi', 'nphi')
lambda edges: {"phi": EdgeData(edges.src, edges.dst).phi},
dgl.function.sum("phi", "nphi"),
)
mean_change = doc_data.update_from(G_rev, self.mult['doc'])
mean_change = doc_data.update_from(G_rev, self.mult["doc"])
if mean_change < mean_change_tol:
break
if self.verbose:
print(f"e-step num_iters={i+1} with mean_change={mean_change:.4f}, "
f"perplexity={self.perplexity(G, doc_data):.4f}")
print(
f"e-step num_iters={i+1} with mean_change={mean_change:.4f}, "
f"perplexity={self.perplexity(G, doc_data):.4f}"
)
return doc_data
transform = _e_step
def predict(self, doc_data):
pred_scores = [
# d_exp @ w._expectation()
(lambda x: x @ w.nphi + x.sum(1, keepdims=True) * w.prior)
(d_exp / w.posterior_sum.unsqueeze(0))
(lambda x: x @ w.nphi + x.sum(1, keepdims=True) * w.prior)(
d_exp / w.posterior_sum.unsqueeze(0)
)
for (d_exp, w) in zip(
self.word_data.split_device(doc_data._expectation(), dim=1),
self.word_data)
self.word_data,
)
]
x = torch.zeros_like(pred_scores[0], device=doc_data.device)
for p in pred_scores:
x += p.to(x.device)
return x
def sample(self, doc_data, num_samples):
""" draw independent words and return the marginal probabilities,
"""draw independent words and return the marginal probabilities,
i.e., the expectations in Dirichlet distributions.
"""
def fn(cdf):
u = torch.rand(cdf.shape[0], num_samples, device=cdf.device)
return torch.searchsorted(cdf, u).to(doc_data.device)
topic_ids = fn(doc_data.cdf)
word_ids = torch.cat([fn(part.cdf) for part in self.word_data])
ids = torch.gather(word_ids, 0, topic_ids) # pick components by topic_ids
ids = torch.gather(
word_ids, 0, topic_ids
) # pick components by topic_ids
# compute expectation scores on sampled ids
src_ids = torch.arange(
ids.shape[0], dtype=ids.dtype, device=ids.device
).reshape((-1, 1)).expand(ids.shape)
unique_ids, inverse_ids = torch.unique(ids, sorted=False, return_inverse=True)
src_ids = (
torch.arange(ids.shape[0], dtype=ids.dtype, device=ids.device)
.reshape((-1, 1))
.expand(ids.shape)
)
unique_ids, inverse_ids = torch.unique(
ids, sorted=False, return_inverse=True
)
G = dgl.heterograph({('doc','','word'): (src_ids.ravel(), inverse_ids.ravel())})
G.nodes['word'].data['_ID'] = unique_ids
G = dgl.heterograph(
{("doc", "", "word"): (src_ids.ravel(), inverse_ids.ravel())}
)
G.nodes["word"].data["_ID"] = unique_ids
self._prepare_graph(G, doc_data, "expectation")
G.apply_edges(lambda e: {'expectation': EdgeData(e.src, e.dst).expectation})
expectation = G.edata.pop('expectation').reshape(ids.shape)
G.apply_edges(
lambda e: {"expectation": EdgeData(e.src, e.dst).expectation}
)
expectation = G.edata.pop("expectation").reshape(ids.shape)
return ids, expectation
def _m_step(self, G, doc_data):
"""_m_step implements word data sampling and stores word_z stats.
mean_change is in the sense of full graph with rho=1.
......@@ -357,26 +398,25 @@ class LatentDirichletAllocation:
G = G.clone()
self._prepare_graph(G, doc_data)
G.update_all(
lambda edges: {'phi': EdgeData(edges.src, edges.dst).phi},
dgl.function.sum('phi', 'nphi')
lambda edges: {"phi": EdgeData(edges.src, edges.dst).phi},
dgl.function.sum("phi", "nphi"),
)
self._last_mean_change = self.word_data.update_from(
G, self.mult['word'], self.rho)
G, self.mult["word"], self.rho
)
if self.verbose:
print(f"m-step mean_change={self._last_mean_change:.4f}, ", end="")
Bayesian_gap = np.mean([
part.Bayesian_gap.mean().tolist() for part in self.word_data
])
Bayesian_gap = np.mean(
[part.Bayesian_gap.mean().tolist() for part in self.word_data]
)
print(f"Bayesian_gap={Bayesian_gap:.4f}")
def partial_fit(self, G):
doc_data = self._e_step(G)
self._m_step(G, doc_data)
return self
def fit(self, G, mean_change_tol=1e-3, max_epochs=10):
for i in range(max_epochs):
if self.verbose:
......@@ -387,7 +427,6 @@ class LatentDirichletAllocation:
break
return self
def perplexity(self, G, doc_data=None):
"""ppl = exp{-sum[log(p(w1,...,wn|d))] / n}
Follows Eq (15) in Hoffman et al., 2010.
......@@ -398,45 +437,50 @@ class LatentDirichletAllocation:
# compute E[log p(docs | theta, beta)]
G = G.clone()
self._prepare_graph(G, doc_data)
G.apply_edges(lambda edges: {'loglike': EdgeData(edges.src, edges.dst).loglike})
edge_elbo = (G.edata['loglike'].sum() / G.num_edges()).tolist()
G.apply_edges(
lambda edges: {"loglike": EdgeData(edges.src, edges.dst).loglike}
)
edge_elbo = (G.edata["loglike"].sum() / G.num_edges()).tolist()
if self.verbose:
print(f'neg_elbo phi: {-edge_elbo:.3f}', end=' ')
print(f"neg_elbo phi: {-edge_elbo:.3f}", end=" ")
# compute E[log p(theta | alpha) - log q(theta | gamma)]
doc_elbo = (doc_data.loglike.sum() / doc_data.n.sum()).tolist()
if self.verbose:
print(f'theta: {-doc_elbo:.3f}', end=' ')
print(f"theta: {-doc_elbo:.3f}", end=" ")
# compute E[log p(beta | eta) - log q(beta | lambda)]
# The denominator n for extrapolation perplexity is undefined.
# We use the train set, whereas sklearn uses the test set.
word_elbo = (
sum([part.loglike.sum().tolist() for part in self.word_data])
/ sum([part.n.sum().tolist() for part in self.word_data])
)
word_elbo = sum(
[part.loglike.sum().tolist() for part in self.word_data]
) / sum([part.n.sum().tolist() for part in self.word_data])
if self.verbose:
print(f'beta: {-word_elbo:.3f}')
print(f"beta: {-word_elbo:.3f}")
ppl = np.exp(-edge_elbo - doc_elbo - word_elbo)
if G.num_edges()>0 and np.isnan(ppl):
if G.num_edges() > 0 and np.isnan(ppl):
warnings.warn("numerical issue in perplexity")
return ppl
def doc_subgraph(G, doc_ids):
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
_, _, (block,) = sampler.sample(G.reverse(), {'doc': torch.as_tensor(doc_ids)})
_, _, (block,) = sampler.sample(
G.reverse(), {"doc": torch.as_tensor(doc_ids)}
)
B = dgl.DGLHeteroGraph(
block._graph, ['_', 'word', 'doc', '_'], block.etypes
block._graph, ["_", "word", "doc", "_"], block.etypes
).reverse()
B.nodes['word'].data['_ID'] = block.nodes['word'].data['_ID']
B.nodes["word"].data["_ID"] = block.nodes["word"].data["_ID"]
return B
if __name__ == '__main__':
print('Testing LatentDirichletAllocation ...')
G = dgl.heterograph({('doc', '', 'word'): [(0, 0), (1, 3)]}, {'doc': 2, 'word': 5})
if __name__ == "__main__":
print("Testing LatentDirichletAllocation ...")
G = dgl.heterograph(
{("doc", "", "word"): [(0, 0), (1, 3)]}, {"doc": 2, "word": 5}
)
model = LatentDirichletAllocation(n_words=5, n_components=10, verbose=False)
model.fit(G)
model.transform(G)
......@@ -454,4 +498,4 @@ if __name__ == '__main__':
f.seek(0)
print(torch.load(f))
print('Testing LatentDirichletAllocation passed!')
print("Testing LatentDirichletAllocation passed!")
......@@ -6,11 +6,12 @@ Author's implementation: https://github.com/joanbruna/GNN_community
"""
from __future__ import division
import time
import argparse
import time
from itertools import permutations
import gnn
import numpy as np
import torch as th
import torch.nn.functional as F
......@@ -18,37 +19,51 @@ import torch.optim as optim
from torch.utils.data import DataLoader
from dgl.data import SBMMixtureDataset
import gnn
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int, help='Batch size', default=1)
parser.add_argument('--gpu', type=int, help='GPU index', default=-1)
parser.add_argument('--lr', type=float, help='Learning rate', default=0.001)
parser.add_argument('--n-communities', type=int, help='Number of communities', default=2)
parser.add_argument('--n-epochs', type=int, help='Number of epochs', default=100)
parser.add_argument('--n-features', type=int, help='Number of features', default=16)
parser.add_argument('--n-graphs', type=int, help='Number of graphs', default=10)
parser.add_argument('--n-layers', type=int, help='Number of layers', default=30)
parser.add_argument('--n-nodes', type=int, help='Number of nodes', default=10000)
parser.add_argument('--optim', type=str, help='Optimizer', default='Adam')
parser.add_argument('--radius', type=int, help='Radius', default=3)
parser.add_argument('--verbose', action='store_true')
parser.add_argument("--batch-size", type=int, help="Batch size", default=1)
parser.add_argument("--gpu", type=int, help="GPU index", default=-1)
parser.add_argument("--lr", type=float, help="Learning rate", default=0.001)
parser.add_argument(
"--n-communities", type=int, help="Number of communities", default=2
)
parser.add_argument(
"--n-epochs", type=int, help="Number of epochs", default=100
)
parser.add_argument(
"--n-features", type=int, help="Number of features", default=16
)
parser.add_argument("--n-graphs", type=int, help="Number of graphs", default=10)
parser.add_argument("--n-layers", type=int, help="Number of layers", default=30)
parser.add_argument(
"--n-nodes", type=int, help="Number of nodes", default=10000
)
parser.add_argument("--optim", type=str, help="Optimizer", default="Adam")
parser.add_argument("--radius", type=int, help="Radius", default=3)
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()
dev = th.device('cpu') if args.gpu < 0 else th.device('cuda:%d' % args.gpu)
dev = th.device("cpu") if args.gpu < 0 else th.device("cuda:%d" % args.gpu)
K = args.n_communities
training_dataset = SBMMixtureDataset(args.n_graphs, args.n_nodes, K)
training_loader = DataLoader(training_dataset, args.batch_size,
collate_fn=training_dataset.collate_fn, drop_last=True)
training_loader = DataLoader(
training_dataset,
args.batch_size,
collate_fn=training_dataset.collate_fn,
drop_last=True,
)
ones = th.ones(args.n_nodes // K)
y_list = [th.cat([x * ones for x in p]).long().to(dev) for p in permutations(range(K))]
y_list = [
th.cat([x * ones for x in p]).long().to(dev) for p in permutations(range(K))
]
feats = [1] + [args.n_features] * args.n_layers + [K]
model = gnn.GNN(feats, args.radius, K).to(dev)
optimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr)
def compute_overlap(z_list):
ybar_list = [th.max(z, 1)[1] for z in z_list]
overlap_list = []
......@@ -58,15 +73,20 @@ def compute_overlap(z_list):
overlap_list.append(overlap)
return sum(overlap_list) / len(overlap_list)
def from_np(f, *args):
def wrap(*args):
new = [th.from_numpy(x) if isinstance(x, np.ndarray) else x for x in args]
new = [
th.from_numpy(x) if isinstance(x, np.ndarray) else x for x in args
]
return f(*new)
return wrap
@from_np
def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
""" One step of training. """
"""One step of training."""
g = g.to(dev)
lg = lg.to(dev)
deg_g = deg_g.to(dev).unsqueeze(1)
......@@ -77,7 +97,10 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
t_forward = time.time() - t0
z_list = th.chunk(z, args.batch_size, 0)
loss = sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list) / args.batch_size
loss = (
sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list)
/ args.batch_size
)
overlap = compute_overlap(z_list)
optimizer.zero_grad()
......@@ -88,6 +111,7 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
return loss, overlap, t_forward, t_backward
@from_np
def inference(g, lg, deg_g, deg_lg, pm_pd):
g = g.to(dev)
......@@ -99,9 +123,11 @@ def inference(g, lg, deg_g, deg_lg, pm_pd):
z = model(g, lg, deg_g, deg_lg, pm_pd)
return z
def test():
p_list =[6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0]
q_list =[0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6]
p_list = [6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0]
q_list = [0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6]
N = 1
overlap_list = []
for p, q in zip(p_list, q_list):
......@@ -112,31 +138,38 @@ def test():
overlap_list.append(compute_overlap(th.chunk(z, N, 0)))
return overlap_list
n_iterations = args.n_graphs // args.batch_size
for i in range(args.n_epochs):
total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0
for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader):
loss, overlap, t_forward, t_backward = step(i, j, g, lg, deg_g, deg_lg, pm_pd)
loss, overlap, t_forward, t_backward = step(
i, j, g, lg, deg_g, deg_lg, pm_pd
)
total_loss += loss
total_overlap += overlap
s_forward += t_forward
s_backward += t_backward
epoch = '0' * (len(str(args.n_epochs)) - len(str(i)))
iteration = '0' * (len(str(n_iterations)) - len(str(j)))
epoch = "0" * (len(str(args.n_epochs)) - len(str(i)))
iteration = "0" * (len(str(n_iterations)) - len(str(j)))
if args.verbose:
print('[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f'
% (epoch, i, iteration, j, loss, overlap))
print(
"[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f"
% (epoch, i, iteration, j, loss, overlap)
)
epoch = '0' * (len(str(args.n_epochs)) - len(str(i)))
epoch = "0" * (len(str(args.n_epochs)) - len(str(i)))
loss = total_loss / (j + 1)
overlap = total_overlap / (j + 1)
t_forward = s_forward / (j + 1)
t_backward = s_backward / (j + 1)
print('[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs'
% (epoch, i, loss, overlap, t_forward, t_backward))
print(
"[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs"
% (epoch, i, loss, overlap, t_forward, t_backward)
)
overlap_list = test()
overlap_str = ' - '.join(['%.3f' % overlap for overlap in overlap_list])
print('[epoch %s%d]overlap: %s' % (epoch, i, overlap_str))
overlap_str = " - ".join(["%.3f" % overlap for overlap in overlap_list])
print("[epoch %s%d]overlap: %s" % (epoch, i, overlap_str))
This diff is collapsed.
......@@ -2,55 +2,55 @@ import torch as th
import torch.nn.functional as F
GCN_CONFIG = {
'extra_args': [16, 1, F.relu, 0.5],
'lr': 1e-2,
'weight_decay': 5e-4,
"extra_args": [16, 1, F.relu, 0.5],
"lr": 1e-2,
"weight_decay": 5e-4,
}
GAT_CONFIG = {
'extra_args': [8, 1, [8] * 1 + [1], F.elu, 0.6, 0.6, 0.2, False],
'lr': 0.005,
'weight_decay': 5e-4,
"extra_args": [8, 1, [8] * 1 + [1], F.elu, 0.6, 0.6, 0.2, False],
"lr": 0.005,
"weight_decay": 5e-4,
}
GRAPHSAGE_CONFIG = {
'extra_args': [16, 1, F.relu, 0.5, 'gcn'],
'lr': 1e-2,
'weight_decay': 5e-4,
"extra_args": [16, 1, F.relu, 0.5, "gcn"],
"lr": 1e-2,
"weight_decay": 5e-4,
}
APPNP_CONFIG = {
'extra_args': [64, 1, F.relu, 0.5, 0.5, 0.1, 10],
'lr': 1e-2,
'weight_decay': 5e-4,
"extra_args": [64, 1, F.relu, 0.5, 0.5, 0.1, 10],
"lr": 1e-2,
"weight_decay": 5e-4,
}
TAGCN_CONFIG = {
'extra_args': [16, 1, F.relu, 0.5],
'lr': 1e-2,
'weight_decay': 5e-4,
"extra_args": [16, 1, F.relu, 0.5],
"lr": 1e-2,
"weight_decay": 5e-4,
}
AGNN_CONFIG = {
'extra_args': [32, 2, 1.0, True, 0.5],
'lr': 1e-2,
'weight_decay': 5e-4,
"extra_args": [32, 2, 1.0, True, 0.5],
"lr": 1e-2,
"weight_decay": 5e-4,
}
SGC_CONFIG = {
'extra_args': [None, 2, False],
'lr': 0.2,
'weight_decay': 5e-6,
"extra_args": [None, 2, False],
"lr": 0.2,
"weight_decay": 5e-6,
}
GIN_CONFIG = {
'extra_args': [16, 1, 0, True],
'lr': 1e-2,
'weight_decay': 5e-6,
"extra_args": [16, 1, 0, True],
"lr": 1e-2,
"weight_decay": 5e-6,
}
CHEBNET_CONFIG = {
'extra_args': [32, 1, 2, True],
'lr': 1e-2,
'weight_decay': 5e-4,
"extra_args": [32, 1, 2, True],
"lr": 1e-2,
"weight_decay": 5e-4,
}
......@@ -31,7 +31,7 @@ def laplacian(W, normalized=True):
def rescale_L(L, lmax=2):
"""Rescale Laplacian eigenvalues to [-1,1]"""
M, M = L.shape
I = scipy.sparse.identity(M, format='csr', dtype=L.dtype)
I = scipy.sparse.identity(M, format="csr", dtype=L.dtype)
L /= lmax * 2
L -= I
return L
......@@ -39,7 +39,9 @@ def rescale_L(L, lmax=2):
def lmax_L(L):
"""Compute largest Laplacian eigenvalue"""
return scipy.sparse.linalg.eigsh(L, k=1, which='LM', return_eigenvectors=False)[0]
return scipy.sparse.linalg.eigsh(
L, k=1, which="LM", return_eigenvectors=False
)[0]
# graph coarsening with Heavy Edge Matching
......@@ -57,7 +59,11 @@ def coarsen(A, levels):
A = A.tocsr()
A.eliminate_zeros()
Mnew, Mnew = A.shape
print('Layer {0}: M_{0} = |V| = {1} nodes ({2} added), |E| = {3} edges'.format(i, Mnew, Mnew - M, A.nnz // 2))
print(
"Layer {0}: M_{0} = |V| = {1} nodes ({2} added), |E| = {3} edges".format(
i, Mnew, Mnew - M, A.nnz // 2
)
)
L = laplacian(A, normalized=True)
laplacians.append(L)
......@@ -95,7 +101,7 @@ def HEM(W, levels, rid=None):
graphs = []
graphs.append(W)
print('Heavy Edge Matching coarsening with Xavier version')
print("Heavy Edge Matching coarsening with Xavier version")
for _ in range(levels):
......@@ -183,7 +189,9 @@ def HEM_one_level(rr, cc, vv, rid, weights):
# First approach
if 2 == 1:
tval = vv[rs + jj] * (1.0 / weights[tid] + 1.0 / weights[nid])
tval = vv[rs + jj] * (
1.0 / weights[tid] + 1.0 / weights[nid]
)
# Second approach
if 1 == 1:
......@@ -192,7 +200,7 @@ def HEM_one_level(rr, cc, vv, rid, weights):
Wjj = vv[rowstart[nid]]
di = weights[tid]
dj = weights[nid]
tval = (2. * Wij + Wii + Wjj) * 1. / (di + dj + 1e-9)
tval = (2.0 * Wij + Wii + Wjj) * 1.0 / (di + dj + 1e-9)
if tval > wmax:
wmax = tval
......@@ -247,7 +255,7 @@ def compute_perm(parents):
# Sanity checks.
for i, indices_layer in enumerate(indices):
M = M_last * 2 ** i
M = M_last * 2**i
# Reduction by 2 at each layer (binary tree).
assert len(indices[0] == M)
# The new ordering does not omit an indice.
......@@ -256,8 +264,9 @@ def compute_perm(parents):
return indices[::-1]
assert (compute_perm([np.array([4, 1, 1, 2, 2, 3, 0, 0, 3]), np.array([2, 1, 0, 1, 0])])
== [[3, 4, 0, 9, 1, 2, 5, 8, 6, 7, 10, 11], [2, 4, 1, 3, 0, 5], [0, 1, 2]])
assert compute_perm(
[np.array([4, 1, 1, 2, 2, 3, 0, 0, 3]), np.array([2, 1, 0, 1, 0])]
) == [[3, 4, 0, 9, 1, 2, 5, 8, 6, 7, 10, 11], [2, 4, 1, 3, 0, 5], [0, 1, 2]]
def perm_adjacency(A, indices):
......
......@@ -2,6 +2,8 @@ import torch as th
"""Compute x,y coordinate for nodes in the graph"""
eps = 1e-8
def get_coordinates(graphs, grid_side, coarsening_levels, perm):
rst = []
for l in range(coarsening_levels + 1):
......@@ -10,21 +12,25 @@ def get_coordinates(graphs, grid_side, coarsening_levels, perm):
cnt = eps
x_accum = 0
y_accum = 0
for j in range(i * 2 ** l, (i + 1) * 2 ** l):
if perm[j] < grid_side ** 2:
x_accum += (perm[j] // grid_side)
y_accum += (perm[j] % grid_side)
for j in range(i * 2**l, (i + 1) * 2**l):
if perm[j] < grid_side**2:
x_accum += perm[j] // grid_side
y_accum += perm[j] % grid_side
cnt += 1
xs.append(x_accum / cnt)
ys.append(y_accum / cnt)
rst.append(th.cat([th.tensor(xs).view(-1, 1), th.tensor(ys).view(-1, 1)], -1))
rst.append(
th.cat([th.tensor(xs).view(-1, 1), th.tensor(ys).view(-1, 1)], -1)
)
return rst
"""Cartesian coordinate to polar coordinate"""
def z2polar(edges):
z = edges.dst['xy'] - edges.src['xy']
z = edges.dst["xy"] - edges.src["xy"]
rho = th.norm(z, dim=-1, p=2)
x, y = z.unbind(dim=-1)
phi = th.atan2(y, x)
return {'u': th.cat([rho.unsqueeze(-1), phi.unsqueeze(-1)], -1)}
return {"u": th.cat([rho.unsqueeze(-1), phi.unsqueeze(-1)], -1)}
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment