Commit 5cda368d authored by HQ's avatar HQ Committed by Minjie Wang
Browse files

[Model] SBM hotfix (#137)

* [Model]SBM hotfix

* [Model] remove backend in data
parent 02eb463a
...@@ -6,6 +6,7 @@ import networkx as nx ...@@ -6,6 +6,7 @@ import networkx as nx
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np
class GNNModule(nn.Module): class GNNModule(nn.Module):
def __init__(self, in_feats, out_feats, radius): def __init__(self, in_feats, out_feats, radius):
......
...@@ -11,6 +11,7 @@ import time ...@@ -11,6 +11,7 @@ import time
import argparse import argparse
from itertools import permutations from itertools import permutations
import numpy as np
import torch as th import torch as th
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
...@@ -57,8 +58,18 @@ def compute_overlap(z_list): ...@@ -57,8 +58,18 @@ def compute_overlap(z_list):
overlap_list.append(overlap) overlap_list.append(overlap)
return sum(overlap_list) / len(overlap_list) return sum(overlap_list) / len(overlap_list)
def from_np(f, *args):
def wrap(*args):
new = [th.from_numpy(x) if isinstance(x, np.ndarray) else x for x in args]
return f(*new)
return wrap
@from_np
def step(i, j, g, lg, deg_g, deg_lg, pm_pd): def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
""" One step of training. """ """ One step of training. """
deg_g = deg_g.to(dev)
deg_lg = deg_lg.to(dev)
pm_pd = pm_pd.to(dev)
t0 = time.time() t0 = time.time()
z = model(g, lg, deg_g, deg_lg, pm_pd) z = model(g, lg, deg_g, deg_lg, pm_pd)
t_forward = time.time() - t0 t_forward = time.time() - t0
...@@ -75,6 +86,15 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd): ...@@ -75,6 +86,15 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
return loss, overlap, t_forward, t_backward return loss, overlap, t_forward, t_backward
@from_np
def inference(g, lg, deg_g, deg_lg, pm_pd):
deg_g = deg_g.to(dev)
deg_lg = deg_lg.to(dev)
pm_pd = pm_pd.to(dev)
z = model(g, lg, deg_g, deg_lg, pm_pd)
return z
def test(): def test():
p_list =[6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0] p_list =[6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0]
q_list =[0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6] q_list =[0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6]
...@@ -84,10 +104,7 @@ def test(): ...@@ -84,10 +104,7 @@ def test():
dataset = SBMMixture(N, args.n_nodes, K, pq=[[p, q]] * N) dataset = SBMMixture(N, args.n_nodes, K, pq=[[p, q]] * N)
loader = DataLoader(dataset, N, collate_fn=dataset.collate_fn) loader = DataLoader(dataset, N, collate_fn=dataset.collate_fn)
g, lg, deg_g, deg_lg, pm_pd = next(iter(loader)) g, lg, deg_g, deg_lg, pm_pd = next(iter(loader))
deg_g = deg_g.to(dev) z = inference(g, lg, deg_g, deg_lg, pm_pd)
deg_lg = deg_lg.to(dev)
pm_pd = pm_pd.to(dev)
z = model(g, lg, deg_g, deg_lg, pm_pd)
overlap_list.append(compute_overlap(th.chunk(z, N, 0))) overlap_list.append(compute_overlap(th.chunk(z, N, 0)))
return overlap_list return overlap_list
...@@ -95,9 +112,6 @@ n_iterations = args.n_graphs // args.batch_size ...@@ -95,9 +112,6 @@ n_iterations = args.n_graphs // args.batch_size
for i in range(args.n_epochs): for i in range(args.n_epochs):
total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0 total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0
for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader): for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader):
deg_g = deg_g.to(dev)
deg_lg = deg_lg.to(dev)
pm_pd = pm_pd.to(dev)
loss, overlap, t_forward, t_backward = step(i, j, g, lg, deg_g, deg_lg, pm_pd) loss, overlap, t_forward, t_backward = step(i, j, g, lg, deg_g, deg_lg, pm_pd)
total_loss += loss total_loss += loss
......
...@@ -8,7 +8,6 @@ import numpy.random as npr ...@@ -8,7 +8,6 @@ import numpy.random as npr
import scipy as sp import scipy as sp
import networkx as nx import networkx as nx
from .. import backend as F
from ..batched_graph import batch from ..batched_graph import batch
from ..graph import DGLGraph from ..graph import DGLGraph
from ..utils import Index from ..utils import Index
...@@ -94,7 +93,7 @@ class SBMMixture: ...@@ -94,7 +93,7 @@ class SBMMixture:
g.from_scipy_sparse_matrix(adj) g.from_scipy_sparse_matrix(adj)
self._lgs = [g.line_graph(backtracking=False) for g in self._gs] self._lgs = [g.line_graph(backtracking=False) for g in self._gs]
in_degrees = lambda g: g.in_degrees( in_degrees = lambda g: g.in_degrees(
Index(F.arange(0, g.number_of_nodes()))).unsqueeze(1).float() Index(np.arange(0, g.number_of_nodes()))).unsqueeze(1).float()
self._g_degs = [in_degrees(g) for g in self._gs] self._g_degs = [in_degrees(g) for g in self._gs]
self._lg_degs = [in_degrees(lg) for lg in self._lgs] self._lg_degs = [in_degrees(lg) for lg in self._lgs]
self._pm_pds = list(zip(*[g.edges() for g in self._gs]))[0] self._pm_pds = list(zip(*[g.edges() for g in self._gs]))[0]
...@@ -118,7 +117,7 @@ class SBMMixture: ...@@ -118,7 +117,7 @@ class SBMMixture:
g, lg, deg_g, deg_lg, pm_pd = zip(*x) g, lg, deg_g, deg_lg, pm_pd = zip(*x)
g_batch = batch(g) g_batch = batch(g)
lg_batch = batch(lg) lg_batch = batch(lg)
degg_batch = F.pack(deg_g) degg_batch = np.concatenate(deg_g, axis=0)
deglg_batch = F.pack(deg_lg) deglg_batch = np.concatenate(deg_lg, axis=0)
pm_pd_batch = F.pack([x + i * self._n_nodes for i, x in enumerate(pm_pd)]) pm_pd_batch = np.concatenate([x + i * self._n_nodes for i, x in enumerate(pm_pd)], axis=0)
return g_batch, lg_batch, degg_batch, deglg_batch, pm_pd_batch return g_batch, lg_batch, degg_batch, deglg_batch, pm_pd_batch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment