Unverified Commit 9fe5092c authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

lintrunner (#5326)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 9836f78e
import dgl import dgl
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
class DGLRoutingLayer(nn.Module): class DGLRoutingLayer(nn.Module):
def __init__(self, in_nodes, out_nodes, f_size, batch_size=0, device="cpu"): def __init__(self, in_nodes, out_nodes, f_size, batch_size=0, device="cpu"):
super(DGLRoutingLayer, self).__init__() super(DGLRoutingLayer, self).__init__()
self.batch_size = batch_size self.batch_size = batch_size
self.g = init_graph(in_nodes, out_nodes, f_size, device=device) self.g = init_graph(in_nodes, out_nodes, f_size, device=device)
self.in_nodes = in_nodes self.in_nodes = in_nodes
self.out_nodes = out_nodes self.out_nodes = out_nodes
self.in_indx = list(range(in_nodes)) self.in_indx = list(range(in_nodes))
self.out_indx = list(range(in_nodes, in_nodes + out_nodes)) self.out_indx = list(range(in_nodes, in_nodes + out_nodes))
self.device = device self.device = device
def forward(self, u_hat, routing_num=1): def forward(self, u_hat, routing_num=1):
self.g.edata["u_hat"] = u_hat self.g.edata["u_hat"] = u_hat
batch_size = self.batch_size batch_size = self.batch_size
# step 2 (line 5) # step 2 (line 5)
def cap_message(edges): def cap_message(edges):
if batch_size: if batch_size:
return {"m": edges.data["c"].unsqueeze(1) * edges.data["u_hat"]} return {"m": edges.data["c"].unsqueeze(1) * edges.data["u_hat"]}
else: else:
return {"m": edges.data["c"] * edges.data["u_hat"]} return {"m": edges.data["c"] * edges.data["u_hat"]}
def cap_reduce(nodes): def cap_reduce(nodes):
return {"s": th.sum(nodes.mailbox["m"], dim=1)} return {"s": th.sum(nodes.mailbox["m"], dim=1)}
for r in range(routing_num): for r in range(routing_num):
# step 1 (line 4): normalize over out edges # step 1 (line 4): normalize over out edges
edges_b = self.g.edata["b"].view(self.in_nodes, self.out_nodes) edges_b = self.g.edata["b"].view(self.in_nodes, self.out_nodes)
self.g.edata["c"] = F.softmax(edges_b, dim=1).view(-1, 1) self.g.edata["c"] = F.softmax(edges_b, dim=1).view(-1, 1)
# Execute step 1 & 2 # Execute step 1 & 2
self.g.update_all(message_func=cap_message, reduce_func=cap_reduce) self.g.update_all(message_func=cap_message, reduce_func=cap_reduce)
# step 3 (line 6) # step 3 (line 6)
if self.batch_size: if self.batch_size:
self.g.nodes[self.out_indx].data["v"] = squash( self.g.nodes[self.out_indx].data["v"] = squash(
self.g.nodes[self.out_indx].data["s"], dim=2 self.g.nodes[self.out_indx].data["s"], dim=2
) )
else: else:
self.g.nodes[self.out_indx].data["v"] = squash( self.g.nodes[self.out_indx].data["v"] = squash(
self.g.nodes[self.out_indx].data["s"], dim=1 self.g.nodes[self.out_indx].data["s"], dim=1
) )
# step 4 (line 7)
# step 4 (line 7) v = th.cat(
v = th.cat( [self.g.nodes[self.out_indx].data["v"]] * self.in_nodes, dim=0
[self.g.nodes[self.out_indx].data["v"]] * self.in_nodes, dim=0 )
) if self.batch_size:
if self.batch_size: self.g.edata["b"] = self.g.edata["b"] + (
self.g.edata["b"] = self.g.edata["b"] + ( self.g.edata["u_hat"] * v
self.g.edata["u_hat"] * v ).mean(dim=1).sum(dim=1, keepdim=True)
).mean(dim=1).sum(dim=1, keepdim=True) else:
else: self.g.edata["b"] = self.g.edata["b"] + (
self.g.edata["b"] = self.g.edata["b"] + ( self.g.edata["u_hat"] * v
self.g.edata["u_hat"] * v ).sum(dim=1, keepdim=True)
).sum(dim=1, keepdim=True)
def squash(s, dim=1):
def squash(s, dim=1): sq = th.sum(s**2, dim=dim, keepdim=True)
sq = th.sum(s**2, dim=dim, keepdim=True) s_norm = th.sqrt(sq)
s_norm = th.sqrt(sq) s = (sq / (1.0 + sq)) * (s / s_norm)
s = (sq / (1.0 + sq)) * (s / s_norm) return s
return s
def init_graph(in_nodes, out_nodes, f_size, device="cpu"):
def init_graph(in_nodes, out_nodes, f_size, device="cpu"): g = dgl.DGLGraph()
g = dgl.DGLGraph() g.set_n_initializer(dgl.frame.zero_initializer)
g.set_n_initializer(dgl.frame.zero_initializer) all_nodes = in_nodes + out_nodes
all_nodes = in_nodes + out_nodes g.add_nodes(all_nodes)
g.add_nodes(all_nodes) in_indx = list(range(in_nodes))
in_indx = list(range(in_nodes)) out_indx = list(range(in_nodes, in_nodes + out_nodes))
out_indx = list(range(in_nodes, in_nodes + out_nodes)) # add edges use edge broadcasting
# add edges use edge broadcasting for u in in_indx:
for u in in_indx: g.add_edges(u, out_indx)
g.add_edges(u, out_indx) g = g.to(device)
g.edata["b"] = th.zeros(in_nodes * out_nodes, 1).to(device)
g = g.to(device) return g
g.edata["b"] = th.zeros(in_nodes * out_nodes, 1).to(device)
return g
import dgl import dgl
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
from DGLRoutingLayer import DGLRoutingLayer from DGLRoutingLayer import DGLRoutingLayer
from torch.nn import functional as F from torch.nn import functional as F
g = dgl.DGLGraph() g = dgl.DGLGraph()
g.graph_data = {} g.graph_data = {}
in_nodes = 20 in_nodes = 20
out_nodes = 10 out_nodes = 10
g.graph_data["in_nodes"] = in_nodes g.graph_data["in_nodes"] = in_nodes
g.graph_data["out_nodes"] = out_nodes g.graph_data["out_nodes"] = out_nodes
all_nodes = in_nodes + out_nodes all_nodes = in_nodes + out_nodes
g.add_nodes(all_nodes) g.add_nodes(all_nodes)
in_indx = list(range(in_nodes)) in_indx = list(range(in_nodes))
out_indx = list(range(in_nodes, in_nodes + out_nodes)) out_indx = list(range(in_nodes, in_nodes + out_nodes))
g.graph_data["in_indx"] = in_indx g.graph_data["in_indx"] = in_indx
g.graph_data["out_indx"] = out_indx g.graph_data["out_indx"] = out_indx
# add edges use edge broadcasting # add edges use edge broadcasting
for u in out_indx: for u in out_indx:
g.add_edges(in_indx, u) g.add_edges(in_indx, u)
# init states
# init states f_size = 4
f_size = 4 g.ndata["v"] = th.zeros(all_nodes, f_size)
g.ndata["v"] = th.zeros(all_nodes, f_size) g.edata["u_hat"] = th.randn(in_nodes * out_nodes, f_size)
g.edata["u_hat"] = th.randn(in_nodes * out_nodes, f_size) g.edata["b"] = th.randn(in_nodes * out_nodes, 1)
g.edata["b"] = th.randn(in_nodes * out_nodes, 1)
routing_layer = DGLRoutingLayer(g)
routing_layer = DGLRoutingLayer(g)
entropy_list = []
entropy_list = [] for i in range(15):
for i in range(15): routing_layer()
routing_layer() dist_matrix = g.edata["c"].view(in_nodes, out_nodes)
dist_matrix = g.edata["c"].view(in_nodes, out_nodes) entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=0)
entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=0) entropy_list.append(entropy.data.numpy())
entropy_list.append(entropy.data.numpy()) std = dist_matrix.std(dim=0)
std = dist_matrix.std(dim=0)
import dgl import dgl
import torch import torch
from dgl.data import CiteseerGraphDataset, CoraGraphDataset from dgl.data import CiteseerGraphDataset, CoraGraphDataset
def load_data(args): def load_data(args):
if args.dataset == "cora": if args.dataset == "cora":
data = CoraGraphDataset() data = CoraGraphDataset()
elif args.dataset == "citeseer": elif args.dataset == "citeseer":
data = CiteseerGraphDataset() data = CiteseerGraphDataset()
else: else:
raise ValueError("Unknown dataset: {}".format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0] g = data[0]
if args.gpu < 0: if args.gpu < 0:
cuda = False cuda = False
else: else:
cuda = True cuda = True
g = g.int().to(args.gpu) g = g.int().to(args.gpu)
features = g.ndata["feat"] features = g.ndata["feat"]
labels = g.ndata["label"] labels = g.ndata["label"]
train_mask = g.ndata["train_mask"] train_mask = g.ndata["train_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
g = dgl.add_self_loop(g) g = dgl.add_self_loop(g)
return g, features, labels, train_mask, test_mask, data.num_classes, cuda return g, features, labels, train_mask, test_mask, data.num_classes, cuda
def svd_feature(features, d=200): def svd_feature(features, d=200):
"""Get 200-dimensional node features, to avoid curse of dimensionality""" """Get 200-dimensional node features, to avoid curse of dimensionality"""
if features.shape[1] <= d: if features.shape[1] <= d:
return features return features
U, S, VT = torch.svd(features) U, S, VT = torch.svd(features)
res = torch.mm(U[:, 0:d], torch.diag(S[0:d])) res = torch.mm(U[:, 0:d], torch.diag(S[0:d]))
return res return res
def process_classids(labels_temp): def process_classids(labels_temp):
"""Reorder the remaining classes with unseen classes removed. """Reorder the remaining classes with unseen classes removed.
Input: the label only removing unseen classes Input: the label only removing unseen classes
Output: the label with reordered classes Output: the label with reordered classes
""" """
labeldict = {} labeldict = {}
num = 0 num = 0
for i in labels_temp: for i in labels_temp:
labeldict[int(i)] = 1 labeldict[int(i)] = 1
labellist = sorted(labeldict) labellist = sorted(labeldict)
for label in labellist: for label in labellist:
labeldict[int(label)] = num labeldict[int(label)] = num
num = num + 1 num = num + 1
for i in range(labels_temp.numel()): for i in range(labels_temp.numel()):
labels_temp[i] = labeldict[int(labels_temp[i])] labels_temp[i] = labeldict[int(labels_temp[i])]
return labels_temp return labels_temp
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment