Commit 0a6d720c authored by GaiYu0's avatar GaiYu0
Browse files

fix build failure

parent 885f5e73
......@@ -40,11 +40,11 @@ class GNNModule(nn.Module):
def aggregate(self, g, z):
z_list = []
g.set_n_repr(z)
g.update_all(fn.copy_src(), fn.sum(), batchable=True)
g.update_all(fn.copy_src(), fn.sum())
z_list.append(g.get_n_repr())
for i in range(self.radius - 1):
for j in range(2 ** i):
g.update_all(fn.copy_src(), fn.sum(), batchable=True)
g.update_all(fn.copy_src(), fn.sum())
z_list.append(g.get_n_repr())
return z_list
......@@ -54,7 +54,7 @@ class GNNModule(nn.Module):
x_list = [theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x))]
g.set_e_repr(y)
g.update_all(fn.copy_edge(), fn.sum(), batchable=True)
g.update_all(fn.copy_edge(), fn.sum())
yx = g.get_n_repr()
x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum(x_list) + self.theta_y(yx)
......@@ -62,7 +62,7 @@ class GNNModule(nn.Module):
y_list = [gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y))]
lg.set_n_repr(xy)
lg.update_all(fn.copy_src(), fn.sum(), batchable=True)
lg.update_all(fn.copy_src(), fn.sum())
xy = lg.get_n_repr()
y = self.gamma_y(y) + self.gamma_deg(deg_lg * y) + sum(y_list) + self.gamma_x(xy)
y = self.bn_y(y[:, :self.out_feats] + F.relu(y[:, self.out_feats:]))
......
......@@ -16,7 +16,7 @@ import utils
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int,
help='Batch size', default=4)
help='Batch size', default=1)
parser.add_argument('--gpu', type=int,
help='GPU', default=-1)
parser.add_argument('--n-communities', type=int,
......
......@@ -9,7 +9,7 @@ import networkx as nx
from torch.utils.data import Dataset
from .. import backend as F
from ..batch import batch
from ..batched_graph import batch
from ..graph import DGLGraph
from ..utils import Index
......
......@@ -1263,16 +1263,6 @@ class DGLGraph(object):
"""
return DGLGraph(self._graph.line_graph(backtracking, sorted))
def _line_graph(self, backtracking=True, sorted=False):
"""Return the line graph of this graph.
Returns
-------
DGLGraph
The line graph of this graph.
"""
return DGLGraph(self._graph._line_graph(backtracking, sorted))
def _get_repr(attr_dict):
if len(attr_dict) == 1 and __REPR__ in attr_dict:
return attr_dict[__REPR__]
......
......@@ -536,31 +536,6 @@ class GraphIndex(object):
GraphIndex
The line graph of this graph.
"""
m = self.number_of_edges()
ctx = F.get_context(F.ones(1))
inc = F.to_scipy_sparse(self.incidence_matrix(True, sorted).get(ctx))
adj = inc.transpose().dot(inc).tocoo()
if backtracking:
adj.data[adj.data >= 0] = 0
else:
adj.data[adj.data != -1] = 0
adj.eliminate_zeros()
u, v, _ = self.edges(sorted=sorted)
u = u.tousertensor()
v = v.tousertensor()
src = F.gather_row(v, F.tensor(adj.row, dtype=F.int64))
dst = F.gather_row(u, F.tensor(adj.col, dtype=F.int64))
dat = F.tensor(adj.data)
dat[src != dst] = 0
adj.data = dat.numpy()
adj.eliminate_zeros()
lg = create_graph_index()
lg.from_scipy_sparse_matrix(adj)
return lg
def _line_graph(self, backtracking=True, sorted=False):
handle = _CAPI_DGLGraphLineGraph(self._handle, backtracking)
return GraphIndex(handle)
......
import argparse
import time
import dgl
import dgl.backend as F
import igraph
import networkx as nx
import numpy as np
import scipy.sparse as sp
parser = argparse.ArgumentParser()
parser.add_argument('--bt', action='store_true', help='BackTracking')
parser.add_argument('--n-nodes', type=int, help='Number of NODES')
args = parser.parse_args()
n = args.n_nodes
a = sp.random(n, n, 1 / n, data_rvs=lambda n: np.ones(n))
b = sp.triu(a, 1) + sp.triu(a, 1).transpose()
g = dgl.DGLGraph()
g.from_scipy_sparse_matrix(b)
N = 10
t0 = time.time()
for i in range(N):
lg_sparse = g.line_graph(args.bt)
g._graph._cache.clear()
t = (time.time() - t0) / N
print('dgl.DGLGraph.line_graph: %f' % t)
t0 = time.time()
for i in range(N):
lg = g._line_graph(args.bt)
g._graph._cache.clear()
t = (time.time() - t0) / N
print('dgl.DGLGraph._line_graph: %f' % t)
g = igraph.Graph()
g.add_vertices(n)
g.add_edges(list(zip(a.row.tolist(), a.col.tolist())))
'''
t0 = time.time()
for i in range(N):
lg = g.linegraph()
t = (time.time() - t0) / N
print('igraph.Graph._line_graph: %f' % t)
'''
t = 0
for i in range(N):
g = igraph.Graph()
g.add_vertices(n)
g.add_edges(list(zip(a.row.tolist(), a.col.tolist())))
t0 = time.time()
lg = g.linegraph()
t += time.time() - t0
t /= N
print('igraph.Graph.linegraph: %f' % t)
import dgl
import dgl.backend as F
import networkx as nx
import numpy as np
import scipy as sp
N = 1000
a = sp.sparse.random(N, N, 1 / N, data_rvs=lambda n: np.ones(n))
b = sp.sparse.triu(a) + sp.sparse.triu(a, 1).transpose()
g_nx = nx.from_scipy_sparse_matrix(b, create_using=nx.DiGraph())
g_dgl = dgl.DGLGraph()
g_dgl.from_scipy_sparse_matrix(b)
h_nx = g_dgl.to_networkx()
g_nodes = set(g_nx.nodes)
h_nodes = set(h_nx.nodes)
assert h_nodes.issubset(g_nodes)
assert all(g_nx.in_degree(x) == g_nx.out_degree(x) == 0
for x in g_nodes.difference(h_nodes))
assert g_nx.edges == h_nx.edges
nx_adj = nx.adjacency_matrix(g_nx)
nx_inc = nx.incidence_matrix(g_nx, edgelist=sorted(g_nx.edges()))
nx_oriented = nx.incidence_matrix(g_nx, edgelist=sorted(g_nx.edges()), oriented=True)
ctx = F.get_context(F.ones((1,)))
dgl_adj = F.to_scipy_sparse(g_dgl.adjacency_matrix().get(ctx)).transpose()
dgl_inc = F.to_scipy_sparse(g_dgl.incidence_matrix().get(ctx))
dgl_oriented = F.to_scipy_sparse(g_dgl.incidence_matrix(oriented=True).get(ctx))
assert abs(nx_adj - dgl_adj).max() == 0
assert abs(nx_inc - dgl_inc).max() == 0
assert abs(nx_oriented - dgl_oriented).max() == 0
import dgl
import dgl.backend as F
import networkx as nx
import numpy as np
import scipy as sp
N = 10000
a = sp.sparse.random(N, N, 1 / N, data_rvs=lambda n: np.ones(n))
b = sp.sparse.triu(a, 1) + sp.sparse.triu(a, 1).transpose()
g = dgl.DGLGraph()
g.from_scipy_sparse_matrix(b)
backtracking = True
lg_sparse = g.line_graph(backtracking)
lg_cpp = g._line_graph(backtracking)
assert lg_sparse.number_of_nodes() == lg_cpp.number_of_nodes()
assert lg_sparse.number_of_edges() == lg_cpp.number_of_edges()
src_sparse, dst_sparse, _ = lg_sparse.edges(sorted=True)
src_cpp, dst_cpp, _ = lg_cpp.edges(sorted=True)
assert (src_sparse == src_cpp).all()
assert (dst_sparse == dst_cpp).all()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment