Unverified Commit 366cc7eb authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Pickle] Fix HeteroGraphConv pickle problem (#2761)

* fix pickle problem

* lint

* add pickle tests

* fix

* fix

* fix

* fix

* fix for windows
parent 337b1559
"""Heterograph NN modules""" """Heterograph NN modules"""
from functools import partial
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
from ...base import DGLError
__all__ = ['HeteroGraphConv'] __all__ = ['HeteroGraphConv']
...@@ -196,6 +198,29 @@ class HeteroGraphConv(nn.Module): ...@@ -196,6 +198,29 @@ class HeteroGraphConv(nn.Module):
rsts[nty] = self.agg_fn(alist, nty) rsts[nty] = self.agg_fn(alist, nty)
return rsts return rsts
def _max_reduce_func(inputs, dim):
return th.max(inputs, dim=dim)[0]
def _min_reduce_func(inputs, dim):
return th.min(inputs, dim=dim)[0]
def _sum_reduce_func(inputs, dim):
return th.sum(inputs, dim=dim)
def _mean_reduce_func(inputs, dim):
return th.mean(inputs, dim=dim)
def _stack_agg_func(inputs, dsttype): # pylint: disable=unused-argument
if len(inputs) == 0:
return None
return th.stack(inputs, dim=1)
def _agg_func(inputs, dsttype, fn): # pylint: disable=unused-argument
if len(inputs) == 0:
return None
stacked = th.stack(inputs, dim=0)
return fn(stacked, dim=0)
def get_aggregate_fn(agg): def get_aggregate_fn(agg):
"""Internal function to get the aggregation function for node data """Internal function to get the aggregation function for node data
generated from different relations. generated from different relations.
...@@ -213,28 +238,19 @@ def get_aggregate_fn(agg): ...@@ -213,28 +238,19 @@ def get_aggregate_fn(agg):
and returns one aggregated tensor. and returns one aggregated tensor.
""" """
if agg == 'sum': if agg == 'sum':
fn = th.sum fn = _sum_reduce_func
elif agg == 'max': elif agg == 'max':
fn = lambda inputs, dim: th.max(inputs, dim=dim)[0] fn = _max_reduce_func
elif agg == 'min': elif agg == 'min':
fn = lambda inputs, dim: th.min(inputs, dim=dim)[0] fn = _min_reduce_func
elif agg == 'mean': elif agg == 'mean':
fn = th.mean fn = _mean_reduce_func
elif agg == 'stack': elif agg == 'stack':
fn = None # will not be called fn = None # will not be called
else: else:
raise DGLError('Invalid cross type aggregator. Must be one of ' raise DGLError('Invalid cross type aggregator. Must be one of '
'"sum", "max", "min", "mean" or "stack". But got "%s"' % agg) '"sum", "max", "min", "mean" or "stack". But got "%s"' % agg)
if agg == 'stack': if agg == 'stack':
def stack_agg(inputs, dsttype): # pylint: disable=unused-argument return _stack_agg_func
if len(inputs) == 0:
return None
return th.stack(inputs, dim=1)
return stack_agg
else: else:
def aggfn(inputs, dsttype): # pylint: disable=unused-argument return partial(_agg_func, fn=fn)
if len(inputs) == 0:
return None
stacked = th.stack(inputs, dim=0)
return fn(stacked, dim=0)
return aggfn
import io
import torch as th import torch as th
import networkx as nx import networkx as nx
import dgl import dgl
...@@ -8,9 +9,12 @@ import pytest ...@@ -8,9 +9,12 @@ import pytest
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_dtype from test_utils import parametrize_dtype
from copy import deepcopy from copy import deepcopy
import pickle
import scipy as sp import scipy as sp
tmp_buffer = io.BytesIO()
def _AXWb(A, X, W, b): def _AXWb(A, X, W, b):
X = th.matmul(X, W) X = th.matmul(X, W)
Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X) Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
...@@ -25,6 +29,11 @@ def test_graph_conv0(out_dim): ...@@ -25,6 +29,11 @@ def test_graph_conv0(out_dim):
conv = nn.GraphConv(5, out_dim, norm='none', bias=True) conv = nn.GraphConv(5, out_dim, norm='none', bias=True)
conv = conv.to(ctx) conv = conv.to(ctx)
print(conv) print(conv)
# test pickle
th.save(conv, tmp_buffer)
# test#1: basic # test#1: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(g, h0) h1 = conv(g, h0)
...@@ -119,6 +128,10 @@ def test_graph_conv_e_weight(idtype, g, norm, weight, bias, out_dim): ...@@ -119,6 +128,10 @@ def test_graph_conv_e_weight(idtype, g, norm, weight, bias, out_dim):
def test_graph_conv_e_weight_norm(idtype, g, norm, weight, bias, out_dim): def test_graph_conv_e_weight_norm(idtype, g, norm, weight, bias, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx()) conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx())
# test pickle
th.save(conv, tmp_buffer)
ext_w = F.randn((5, out_dim)).to(F.ctx()) ext_w = F.randn((5, out_dim)).to(F.ctx())
nsrc = g.number_of_src_nodes() nsrc = g.number_of_src_nodes()
ndst = g.number_of_dst_nodes() ndst = g.number_of_dst_nodes()
...@@ -141,6 +154,10 @@ def test_graph_conv_bi(idtype, g, norm, weight, bias, out_dim): ...@@ -141,6 +154,10 @@ def test_graph_conv_bi(idtype, g, norm, weight, bias, out_dim):
# Test a pair of tensor inputs # Test a pair of tensor inputs
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx()) conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(F.ctx())
# test pickle
th.save(conv, tmp_buffer)
ext_w = F.randn((5, out_dim)).to(F.ctx()) ext_w = F.randn((5, out_dim)).to(F.ctx())
nsrc = g.number_of_src_nodes() nsrc = g.number_of_src_nodes()
ndst = g.number_of_dst_nodes() ndst = g.number_of_dst_nodes()
...@@ -175,6 +192,9 @@ def test_tagconv(out_dim): ...@@ -175,6 +192,9 @@ def test_tagconv(out_dim):
conv = nn.TAGConv(5, out_dim, bias=True) conv = nn.TAGConv(5, out_dim, bias=True)
conv = conv.to(ctx) conv = conv.to(ctx)
print(conv) print(conv)
# test pickle
th.save(conv, tmp_buffer)
# test#1: basic # test#1: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
...@@ -231,6 +251,9 @@ def test_glob_att_pool(): ...@@ -231,6 +251,9 @@ def test_glob_att_pool():
gap = gap.to(ctx) gap = gap.to(ctx)
print(gap) print(gap)
# test pickle
th.save(gap, tmp_buffer)
# test#1: basic # test#1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(g, h0) h1 = gap(g, h0)
...@@ -347,6 +370,10 @@ def test_rgcn(O): ...@@ -347,6 +370,10 @@ def test_rgcn(O):
I = 10 I = 10
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx) rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
# test pickle
th.save(rgc_basis, tmp_buffer)
rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx) rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
rgc_basis_low.weight = rgc_basis.weight rgc_basis_low.weight = rgc_basis.weight
rgc_basis_low.w_comp = rgc_basis.w_comp rgc_basis_low.w_comp = rgc_basis.w_comp
...@@ -509,6 +536,10 @@ def test_gat_conv(g, idtype, out_dim, num_heads): ...@@ -509,6 +536,10 @@ def test_gat_conv(g, idtype, out_dim, num_heads):
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_nodes(), 5))
gat = gat.to(ctx) gat = gat.to(ctx)
h = gat(g, feat) h = gat(g, feat)
# test pickle
th.save(gat, tmp_buffer)
assert h.shape == (g.number_of_nodes(), num_heads, out_dim) assert h.shape == (g.number_of_nodes(), num_heads, out_dim)
_, a = gat(g, feat, get_attention=True) _, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), num_heads, 1) assert a.shape == (g.number_of_edges(), num_heads, 1)
...@@ -536,6 +567,8 @@ def test_sage_conv(idtype, g, aggre_type): ...@@ -536,6 +567,8 @@ def test_sage_conv(idtype, g, aggre_type):
sage = nn.SAGEConv(5, 10, aggre_type) sage = nn.SAGEConv(5, 10, aggre_type)
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_nodes(), 5))
sage = sage.to(F.ctx()) sage = sage.to(F.ctx())
# test pickle
th.save(sage, tmp_buffer)
h = sage(g, feat) h = sage(g, feat)
assert h.shape[-1] == 10 assert h.shape[-1] == 10
...@@ -583,6 +616,10 @@ def test_sgc_conv(g, idtype, out_dim): ...@@ -583,6 +616,10 @@ def test_sgc_conv(g, idtype, out_dim):
g = g.astype(idtype).to(ctx) g = g.astype(idtype).to(ctx)
# not cached # not cached
sgc = nn.SGConv(5, out_dim, 3) sgc = nn.SGConv(5, out_dim, 3)
# test pickle
th.save(sgc, tmp_buffer)
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_nodes(), 5))
sgc = sgc.to(ctx) sgc = sgc.to(ctx)
...@@ -605,6 +642,9 @@ def test_appnp_conv(g, idtype): ...@@ -605,6 +642,9 @@ def test_appnp_conv(g, idtype):
appnp = nn.APPNPConv(10, 0.1) appnp = nn.APPNPConv(10, 0.1)
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_nodes(), 5))
appnp = appnp.to(ctx) appnp = appnp.to(ctx)
# test pickle
th.save(appnp, tmp_buffer)
h = appnp(g, feat) h = appnp(g, feat)
assert h.shape[-1] == 5 assert h.shape[-1] == 5
...@@ -622,6 +662,10 @@ def test_gin_conv(g, idtype, aggregator_type): ...@@ -622,6 +662,10 @@ def test_gin_conv(g, idtype, aggregator_type):
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_nodes(), 5))
gin = gin.to(ctx) gin = gin.to(ctx)
h = gin(g, feat) h = gin(g, feat)
# test pickle
th.save(h, tmp_buffer)
assert h.shape == (g.number_of_nodes(), 12) assert h.shape == (g.number_of_nodes(), 12)
@parametrize_dtype @parametrize_dtype
...@@ -784,6 +828,10 @@ def test_edge_conv(g, idtype, out_dim): ...@@ -784,6 +828,10 @@ def test_edge_conv(g, idtype, out_dim):
ctx = F.ctx() ctx = F.ctx()
edge_conv = nn.EdgeConv(5, out_dim).to(ctx) edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
print(edge_conv) print(edge_conv)
# test pickle
th.save(edge_conv, tmp_buffer)
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
h1 = edge_conv(g, h0) h1 = edge_conv(g, h0)
assert h1.shape == (g.number_of_nodes(), out_dim) assert h1.shape == (g.number_of_nodes(), out_dim)
...@@ -811,6 +859,10 @@ def test_dotgat_conv(g, idtype, out_dim, num_heads): ...@@ -811,6 +859,10 @@ def test_dotgat_conv(g, idtype, out_dim, num_heads):
dotgat = nn.DotGatConv(5, out_dim, num_heads) dotgat = nn.DotGatConv(5, out_dim, num_heads)
feat = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_nodes(), 5))
dotgat = dotgat.to(ctx) dotgat = dotgat.to(ctx)
# test pickle
th.save(dotgat, tmp_buffer)
h = dotgat(g, feat) h = dotgat(g, feat)
assert h.shape == (g.number_of_nodes(), num_heads, out_dim) assert h.shape == (g.number_of_nodes(), num_heads, out_dim)
_, a = dotgat(g, feat, get_attention=True) _, a = dotgat(g, feat, get_attention=True)
...@@ -919,6 +971,7 @@ def test_atomic_conv(g, idtype): ...@@ -919,6 +971,7 @@ def test_atomic_conv(g, idtype):
dist = F.randn((g.number_of_edges(), 1)) dist = F.randn((g.number_of_edges(), 1))
h = aconv(g, feat, dist) h = aconv(g, feat, dist)
# current we only do shape check # current we only do shape check
assert h.shape[-1] == 4 assert h.shape[-1] == 4
...@@ -968,6 +1021,10 @@ def test_hetero_conv(agg, idtype): ...@@ -968,6 +1021,10 @@ def test_hetero_conv(agg, idtype):
'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)}, 'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
agg) agg)
conv = conv.to(F.ctx()) conv = conv.to(F.ctx())
# test pickle
th.save(conv, tmp_buffer)
uf = F.randn((4, 2)) uf = F.randn((4, 2))
gf = F.randn((4, 4)) gf = F.randn((4, 4))
sf = F.randn((2, 3)) sf = F.randn((2, 3))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment