"...text-generation-inference.git" did not exist on "8094ecfc9ef22c838fa7d49db4af8301539619e3"
Commit 2bff8339 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by Zihao Ye
Browse files

[Test] Provid a frame agnostic API to test nn modules on both CPU and CUDA side. (#775)

* upd

* fig edgebatch edges

* add test

* trigger

* Update README.md for pytorch PinSage example.

Add noting that the PinSage model example under
example/pytorch/recommendation only work with Python 3.6+
as its dataset loader depends on stanfordnlp package
which work only with Python 3.6+.

* Provid a frame agnostic API to test nn modules on both CPU and CUDA side.

1. make dgl.nn.xxx frame agnostic
2. make test.backend include dgl.nn modules
3. modify test_edge_softmax of test/mxnet/test_nn.py and
    test/pytorch/test_nn.py work on both CPU and GPU

* Fix style

* Delete unused code

* Make agnostic test only related to tests/backend

1. clear all agnostic related code in dgl.nn
2. make test_graph_conv agnostic to cpu/gpu

* Fix code style

* fix

* doc

* Make all test code under tests.mxnet/pytorch.test_nn.py
work on both CPU and GPU.

* Fix syntex

* Remove rand
parent be936da8
"""MXNet modules for graph global pooling.""" """MXNet modules for graph global pooling."""
# pylint: disable= no-member, arguments-differ, C0103, W0235 # pylint: disable= no-member, arguments-differ, C0103, W0235
import mxnet as mx
from mxnet import gluon, nd from mxnet import gluon, nd
from mxnet.gluon import nn from mxnet.gluon import nn
...@@ -191,13 +190,6 @@ class GlobalAttentionPooling(nn.Block): ...@@ -191,13 +190,6 @@ class GlobalAttentionPooling(nn.Block):
self.gate_nn = gate_nn self.gate_nn = gate_nn
self.feat_nn = feat_nn self.feat_nn = feat_nn
self._reset_parameters()
def _reset_parameters(self):
self.gate_nn.initialize(mx.init.Xavier())
if self.feat_nn:
self.feat_nn.initialize(mx.init.Xavier())
def forward(self, feat, graph): def forward(self, feat, graph):
r"""Compute global attention pooling. r"""Compute global attention pooling.
...@@ -265,10 +257,6 @@ class Set2Set(nn.Block): ...@@ -265,10 +257,6 @@ class Set2Set(nn.Block):
with self.name_scope(): with self.name_scope():
self.lstm = gluon.rnn.LSTM( self.lstm = gluon.rnn.LSTM(
self.input_dim, num_layers=n_layers, input_size=self.output_dim) self.input_dim, num_layers=n_layers, input_size=self.output_dim)
self._reset_parameters()
def _reset_parameters(self):
self.lstm.initialize(mx.init.Xavier())
def forward(self, feat, graph): def forward(self, feat, graph):
r"""Compute set2set pooling. r"""Compute set2set pooling.
......
...@@ -178,9 +178,10 @@ class GlobalAttentionPooling(nn.Module): ...@@ -178,9 +178,10 @@ class GlobalAttentionPooling(nn.Module):
super(GlobalAttentionPooling, self).__init__() super(GlobalAttentionPooling, self).__init__()
self.gate_nn = gate_nn self.gate_nn = gate_nn
self.feat_nn = feat_nn self.feat_nn = feat_nn
self._reset_parameters() self.reset_parameters()
def _reset_parameters(self): def reset_parameters(self):
"""Reinitialize learnable parameters."""
for p in self.gate_nn.parameters(): for p in self.gate_nn.parameters():
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
...@@ -256,12 +257,11 @@ class Set2Set(nn.Module): ...@@ -256,12 +257,11 @@ class Set2Set(nn.Module):
self.n_iters = n_iters self.n_iters = n_iters
self.n_layers = n_layers self.n_layers = n_layers
self.lstm = th.nn.LSTM(self.output_dim, self.input_dim, n_layers) self.lstm = th.nn.LSTM(self.output_dim, self.input_dim, n_layers)
self._reset_parameters() self.reset_parameters()
def _reset_parameters(self): def reset_parameters(self):
for p in self.lstm.parameters(): """Reinitialize learnable parameters."""
if p.dim() > 1: self.lstm.reset_parameters()
nn.init.xavier_uniform_(p)
def forward(self, feat, graph): def forward(self, feat, graph):
r"""Compute set2set pooling. r"""Compute set2set pooling.
...@@ -342,9 +342,10 @@ class MultiHeadAttention(nn.Module): ...@@ -342,9 +342,10 @@ class MultiHeadAttention(nn.Module):
self.dropa = nn.Dropout(dropouta) self.dropa = nn.Dropout(dropouta)
self.norm_in = nn.LayerNorm(d_model) self.norm_in = nn.LayerNorm(d_model)
self.norm_inter = nn.LayerNorm(d_model) self.norm_inter = nn.LayerNorm(d_model)
self._reset_parameters() self.reset_parameters()
def _reset_parameters(self): def reset_parameters(self):
"""Reinitialize learnable parameters."""
for p in self.parameters(): for p in self.parameters():
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
...@@ -441,9 +442,10 @@ class InducedSetAttentionBlock(nn.Module): ...@@ -441,9 +442,10 @@ class InducedSetAttentionBlock(nn.Module):
self.mha = nn.ModuleList([ self.mha = nn.ModuleList([
MultiHeadAttention(d_model, num_heads, d_head, d_ff, MultiHeadAttention(d_model, num_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta) for _ in range(2)]) dropouth=dropouth, dropouta=dropouta) for _ in range(2)])
self._reset_parameters() self.reset_parameters()
def _reset_parameters(self): def reset_parameters(self):
"""Reinitialize learnable parameters."""
nn.init.xavier_uniform_(self.inducing_points) nn.init.xavier_uniform_(self.inducing_points)
def forward(self, feat, lengths): def forward(self, feat, lengths):
...@@ -492,9 +494,10 @@ class PMALayer(nn.Module): ...@@ -492,9 +494,10 @@ class PMALayer(nn.Module):
nn.Dropout(dropouth), nn.Dropout(dropouth),
nn.Linear(d_ff, d_model) nn.Linear(d_ff, d_model)
) )
self._reset_parameters() self.reset_parameters()
def _reset_parameters(self): def reset_parameters(self):
"""Reinitialize learnable parameters."""
nn.init.xavier_uniform_(self.seed_vectors) nn.init.xavier_uniform_(self.seed_vectors)
def forward(self, feat, lengths): def forward(self, feat, lengths):
......
from dgl.backend import * from dgl.backend import *
from dgl.nn import *
from . import backend_unittest from . import backend_unittest
import os import os
import importlib import importlib
...@@ -34,6 +35,12 @@ _context_dict = { ...@@ -34,6 +35,12 @@ _context_dict = {
} }
_default_context = _context_dict[_default_context_str] _default_context = _context_dict[_default_context_str]
def ctx():
return _default_context
def gpu_ctx():
return (_default_context_str == 'gpu')
def zeros(shape, dtype=float32, ctx=_default_context): def zeros(shape, dtype=float32, ctx=_default_context):
return _zeros(shape, dtype, ctx) return _zeros(shape, dtype, ctx)
......
...@@ -3,6 +3,7 @@ import networkx as nx ...@@ -3,6 +3,7 @@ import networkx as nx
import numpy as np import numpy as np
import dgl import dgl
import dgl.nn.mxnet as nn import dgl.nn.mxnet as nn
import backend as F
from mxnet import autograd, gluon from mxnet import autograd, gluon
def check_close(a, b): def check_close(a, b):
...@@ -15,19 +16,19 @@ def _AXWb(A, X, W, b): ...@@ -15,19 +16,19 @@ def _AXWb(A, X, W, b):
def test_graph_conv(): def test_graph_conv():
g = dgl.DGLGraph(nx.path_graph(3)) g = dgl.DGLGraph(nx.path_graph(3))
adj = g.adjacency_matrix() ctx = F.ctx()
ctx = mx.cpu(0) adj = g.adjacency_matrix(ctx=ctx)
conv = nn.GraphConv(5, 2, norm=False, bias=True) conv = nn.GraphConv(5, 2, norm=False, bias=True)
conv.initialize(ctx=ctx) conv.initialize(ctx=ctx)
# test#1: basic # test#1: basic
h0 = mx.nd.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias)) check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
# test#2: more-dim # test#2: more-dim
h0 = mx.nd.ones((3, 5, 5)) h0 = F.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
...@@ -37,12 +38,12 @@ def test_graph_conv(): ...@@ -37,12 +38,12 @@ def test_graph_conv():
conv.initialize(ctx=ctx) conv.initialize(ctx=ctx)
# test#3: basic # test#3: basic
h0 = mx.nd.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
# test#4: basic # test#4: basic
h0 = mx.nd.ones((3, 5, 5)) h0 = F.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
...@@ -52,54 +53,58 @@ def test_graph_conv(): ...@@ -52,54 +53,58 @@ def test_graph_conv():
with autograd.train_mode(): with autograd.train_mode():
# test#3: basic # test#3: basic
h0 = mx.nd.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
# test#4: basic # test#4: basic
h0 = mx.nd.ones((3, 5, 5)) h0 = F.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
# test not override features # test not override features
g.ndata["h"] = 2 * mx.nd.ones((3, 1)) g.ndata["h"] = 2 * F.ones((3, 1))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 1 assert len(g.ndata) == 1
assert len(g.edata) == 0 assert len(g.edata) == 0
assert "h" in g.ndata assert "h" in g.ndata
check_close(g.ndata['h'], 2 * mx.nd.ones((3, 1))) check_close(g.ndata['h'], 2 * F.ones((3, 1)))
def test_set2set(): def test_set2set():
g = dgl.DGLGraph(nx.path_graph(10)) g = dgl.DGLGraph(nx.path_graph(10))
ctx = F.ctx()
s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
s2s.initialize(ctx=ctx)
print(s2s) print(s2s)
# test#1: basic # test#1: basic
h0 = mx.nd.random.randn(g.number_of_nodes(), 5) h0 = F.randn((g.number_of_nodes(), 5))
h1 = s2s(h0, g) h1 = s2s(h0, g)
assert h1.shape[0] == 10 and h1.ndim == 1 assert h1.shape[0] == 10 and h1.ndim == 1
# test#2: batched graph # test#2: batched graph
bg = dgl.batch([g, g, g]) bg = dgl.batch([g, g, g])
h0 = mx.nd.random.randn(bg.number_of_nodes(), 5) h0 = F.randn((bg.number_of_nodes(), 5))
h1 = s2s(h0, bg) h1 = s2s(h0, bg)
assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2 assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2
def test_glob_att_pool(): def test_glob_att_pool():
g = dgl.DGLGraph(nx.path_graph(10)) g = dgl.DGLGraph(nx.path_graph(10))
ctx = F.ctx()
gap = nn.GlobalAttentionPooling(gluon.nn.Dense(1), gluon.nn.Dense(10)) gap = nn.GlobalAttentionPooling(gluon.nn.Dense(1), gluon.nn.Dense(10))
gap.initialize(ctx=ctx)
print(gap) print(gap)
# test#1: basic # test#1: basic
h0 = mx.nd.random.randn(g.number_of_nodes(), 5) h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(h0, g) h1 = gap(h0, g)
assert h1.shape[0] == 10 and h1.ndim == 1 assert h1.shape[0] == 10 and h1.ndim == 1
# test#2: batched graph # test#2: batched graph
bg = dgl.batch([g, g, g, g]) bg = dgl.batch([g, g, g, g])
h0 = mx.nd.random.randn(bg.number_of_nodes(), 5) h0 = F.randn((bg.number_of_nodes(), 5))
h1 = gap(h0, bg) h1 = gap(h0, bg)
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2 assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2
...@@ -113,48 +118,47 @@ def test_simple_pool(): ...@@ -113,48 +118,47 @@ def test_simple_pool():
print(sum_pool, avg_pool, max_pool, sort_pool) print(sum_pool, avg_pool, max_pool, sort_pool)
# test#1: basic # test#1: basic
h0 = mx.nd.random.randn(g.number_of_nodes(), 5) h0 = F.randn((g.number_of_nodes(), 5))
h1 = sum_pool(h0, g) h1 = sum_pool(h0, g)
check_close(h1, mx.nd.sum(h0, 0)) check_close(h1, F.sum(h0, 0))
h1 = avg_pool(h0, g) h1 = avg_pool(h0, g)
check_close(h1, mx.nd.mean(h0, 0)) check_close(h1, F.mean(h0, 0))
h1 = max_pool(h0, g) h1 = max_pool(h0, g)
check_close(h1, mx.nd.max(h0, 0)) check_close(h1, F.max(h0, 0))
h1 = sort_pool(h0, g) h1 = sort_pool(h0, g)
assert h1.shape[0] == 10 * 5 and h1.ndim == 1 assert h1.shape[0] == 10 * 5 and h1.ndim == 1
# test#2: batched graph # test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5)) g_ = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g_, g, g_, g]) bg = dgl.batch([g, g_, g, g_, g])
h0 = mx.nd.random.randn(bg.number_of_nodes(), 5) h0 = F.randn((bg.number_of_nodes(), 5))
h1 = sum_pool(h0, bg) h1 = sum_pool(h0, bg)
truth = mx.nd.stack(mx.nd.sum(h0[:15], 0), truth = mx.nd.stack(F.sum(h0[:15], 0),
mx.nd.sum(h0[15:20], 0), F.sum(h0[15:20], 0),
mx.nd.sum(h0[20:35], 0), F.sum(h0[20:35], 0),
mx.nd.sum(h0[35:40], 0), F.sum(h0[35:40], 0),
mx.nd.sum(h0[40:55], 0), axis=0) F.sum(h0[40:55], 0), axis=0)
check_close(h1, truth) check_close(h1, truth)
h1 = avg_pool(h0, bg) h1 = avg_pool(h0, bg)
truth = mx.nd.stack(mx.nd.mean(h0[:15], 0), truth = mx.nd.stack(F.mean(h0[:15], 0),
mx.nd.mean(h0[15:20], 0), F.mean(h0[15:20], 0),
mx.nd.mean(h0[20:35], 0), F.mean(h0[20:35], 0),
mx.nd.mean(h0[35:40], 0), F.mean(h0[35:40], 0),
mx.nd.mean(h0[40:55], 0), axis=0) F.mean(h0[40:55], 0), axis=0)
check_close(h1, truth) check_close(h1, truth)
h1 = max_pool(h0, bg) h1 = max_pool(h0, bg)
truth = mx.nd.stack(mx.nd.max(h0[:15], 0), truth = mx.nd.stack(F.max(h0[:15], 0),
mx.nd.max(h0[15:20], 0), F.max(h0[15:20], 0),
mx.nd.max(h0[20:35], 0), F.max(h0[20:35], 0),
mx.nd.max(h0[35:40], 0), F.max(h0[35:40], 0),
mx.nd.max(h0[40:55], 0), axis=0) F.max(h0[40:55], 0), axis=0)
check_close(h1, truth) check_close(h1, truth)
h1 = sort_pool(h0, bg) h1 = sort_pool(h0, bg)
assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2 assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2
def uniform_attention(g, shape): def uniform_attention(g, shape):
a = mx.nd.ones(shape) a = mx.nd.ones(shape)
target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1) target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
...@@ -163,7 +167,7 @@ def uniform_attention(g, shape): ...@@ -163,7 +167,7 @@ def uniform_attention(g, shape):
def test_edge_softmax(): def test_edge_softmax():
# Basic # Basic
g = dgl.DGLGraph(nx.path_graph(3)) g = dgl.DGLGraph(nx.path_graph(3))
edata = mx.nd.ones((g.number_of_edges(), 1)) edata = F.ones((g.number_of_edges(), 1))
a = nn.edge_softmax(g, edata) a = nn.edge_softmax(g, edata)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
...@@ -171,7 +175,7 @@ def test_edge_softmax(): ...@@ -171,7 +175,7 @@ def test_edge_softmax():
1e-4, 1e-4) 1e-4, 1e-4)
# Test higher dimension case # Test higher dimension case
edata = mx.nd.ones((g.number_of_edges(), 3, 1)) edata = F.ones((g.number_of_edges(), 3, 1))
a = nn.edge_softmax(g, edata) a = nn.edge_softmax(g, edata)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
......
...@@ -2,6 +2,7 @@ import torch as th ...@@ -2,6 +2,7 @@ import torch as th
import networkx as nx import networkx as nx
import dgl import dgl
import dgl.nn.pytorch as nn import dgl.nn.pytorch as nn
import backend as F
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
...@@ -14,43 +15,50 @@ def _AXWb(A, X, W, b): ...@@ -14,43 +15,50 @@ def _AXWb(A, X, W, b):
def test_graph_conv(): def test_graph_conv():
g = dgl.DGLGraph(nx.path_graph(3)) g = dgl.DGLGraph(nx.path_graph(3))
adj = g.adjacency_matrix() ctx = F.ctx()
adj = g.adjacency_matrix(ctx=ctx)
conv = nn.GraphConv(5, 2, norm=False, bias=True) conv = nn.GraphConv(5, 2, norm=False, bias=True)
if F.gpu_ctx():
conv.cuda()
print(conv) print(conv)
# test#1: basic # test#1: basic
h0 = th.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
assert th.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias)) assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
# test#2: more-dim # test#2: more-dim
h0 = th.ones((3, 5, 5)) h0 = F.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
assert th.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias)) assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))
conv = nn.GraphConv(5, 2) conv = nn.GraphConv(5, 2)
if F.gpu_ctx():
conv.cuda()
# test#3: basic # test#3: basic
h0 = th.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
# test#4: basic # test#4: basic
h0 = th.ones((3, 5, 5)) h0 = F.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
conv = nn.GraphConv(5, 2) conv = nn.GraphConv(5, 2)
if F.gpu_ctx():
conv.cuda()
# test#3: basic # test#3: basic
h0 = th.ones((3, 5)) h0 = F.ones((3, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
# test#4: basic # test#4: basic
h0 = th.ones((3, 5, 5)) h0 = F.ones((3, 5, 5))
h1 = conv(h0, g) h1 = conv(h0, g)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
...@@ -59,16 +67,18 @@ def test_graph_conv(): ...@@ -59,16 +67,18 @@ def test_graph_conv():
old_weight = deepcopy(conv.weight.data) old_weight = deepcopy(conv.weight.data)
conv.reset_parameters() conv.reset_parameters()
new_weight = conv.weight.data new_weight = conv.weight.data
assert not th.allclose(old_weight, new_weight) assert not F.allclose(old_weight, new_weight)
def test_set2set(): def test_set2set():
g = dgl.DGLGraph(nx.path_graph(10)) g = dgl.DGLGraph(nx.path_graph(10))
s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
if F.gpu_ctx():
s2s.cuda()
print(s2s) print(s2s)
# test#1: basic # test#1: basic
h0 = th.rand(g.number_of_nodes(), 5) h0 = F.randn((g.number_of_nodes(), 5))
h1 = s2s(h0, g) h1 = s2s(h0, g)
assert h1.shape[0] == 10 and h1.dim() == 1 assert h1.shape[0] == 10 and h1.dim() == 1
...@@ -76,7 +86,7 @@ def test_set2set(): ...@@ -76,7 +86,7 @@ def test_set2set():
g1 = dgl.DGLGraph(nx.path_graph(11)) g1 = dgl.DGLGraph(nx.path_graph(11))
g2 = dgl.DGLGraph(nx.path_graph(5)) g2 = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g1, g2]) bg = dgl.batch([g, g1, g2])
h0 = th.rand(bg.number_of_nodes(), 5) h0 = F.randn((bg.number_of_nodes(), 5))
h1 = s2s(h0, bg) h1 = s2s(h0, bg)
assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2 assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2
...@@ -84,16 +94,18 @@ def test_glob_att_pool(): ...@@ -84,16 +94,18 @@ def test_glob_att_pool():
g = dgl.DGLGraph(nx.path_graph(10)) g = dgl.DGLGraph(nx.path_graph(10))
gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10)) gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))
if F.gpu_ctx():
gap.cuda()
print(gap) print(gap)
# test#1: basic # test#1: basic
h0 = th.rand(g.number_of_nodes(), 5) h0 = F.randn((g.number_of_nodes(), 5))
h1 = gap(h0, g) h1 = gap(h0, g)
assert h1.shape[0] == 10 and h1.dim() == 1 assert h1.shape[0] == 10 and h1.dim() == 1
# test#2: batched graph # test#2: batched graph
bg = dgl.batch([g, g, g, g]) bg = dgl.batch([g, g, g, g])
h0 = th.rand(bg.number_of_nodes(), 5) h0 = F.randn((bg.number_of_nodes(), 5))
h1 = gap(h0, bg) h1 = gap(h0, bg)
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2 assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2
...@@ -107,44 +119,44 @@ def test_simple_pool(): ...@@ -107,44 +119,44 @@ def test_simple_pool():
print(sum_pool, avg_pool, max_pool, sort_pool) print(sum_pool, avg_pool, max_pool, sort_pool)
# test#1: basic # test#1: basic
h0 = th.rand(g.number_of_nodes(), 5) h0 = F.randn((g.number_of_nodes(), 5))
h1 = sum_pool(h0, g) h1 = sum_pool(h0, g)
assert th.allclose(h1, th.sum(h0, 0)) assert F.allclose(h1, F.sum(h0, 0))
h1 = avg_pool(h0, g) h1 = avg_pool(h0, g)
assert th.allclose(h1, th.mean(h0, 0)) assert F.allclose(h1, F.mean(h0, 0))
h1 = max_pool(h0, g) h1 = max_pool(h0, g)
assert th.allclose(h1, th.max(h0, 0)[0]) assert F.allclose(h1, F.max(h0, 0))
h1 = sort_pool(h0, g) h1 = sort_pool(h0, g)
assert h1.shape[0] == 10 * 5 and h1.dim() == 1 assert h1.shape[0] == 10 * 5 and h1.dim() == 1
# test#2: batched graph # test#2: batched graph
g_ = dgl.DGLGraph(nx.path_graph(5)) g_ = dgl.DGLGraph(nx.path_graph(5))
bg = dgl.batch([g, g_, g, g_, g]) bg = dgl.batch([g, g_, g, g_, g])
h0 = th.rand(bg.number_of_nodes(), 5) h0 = F.randn((bg.number_of_nodes(), 5))
h1 = sum_pool(h0, bg) h1 = sum_pool(h0, bg)
truth = th.stack([th.sum(h0[:15], 0), truth = th.stack([F.sum(h0[:15], 0),
th.sum(h0[15:20], 0), F.sum(h0[15:20], 0),
th.sum(h0[20:35], 0), F.sum(h0[20:35], 0),
th.sum(h0[35:40], 0), F.sum(h0[35:40], 0),
th.sum(h0[40:55], 0)], 0) F.sum(h0[40:55], 0)], 0)
assert th.allclose(h1, truth) assert F.allclose(h1, truth)
h1 = avg_pool(h0, bg) h1 = avg_pool(h0, bg)
truth = th.stack([th.mean(h0[:15], 0), truth = th.stack([F.mean(h0[:15], 0),
th.mean(h0[15:20], 0), F.mean(h0[15:20], 0),
th.mean(h0[20:35], 0), F.mean(h0[20:35], 0),
th.mean(h0[35:40], 0), F.mean(h0[35:40], 0),
th.mean(h0[40:55], 0)], 0) F.mean(h0[40:55], 0)], 0)
assert th.allclose(h1, truth) assert F.allclose(h1, truth)
h1 = max_pool(h0, bg) h1 = max_pool(h0, bg)
truth = th.stack([th.max(h0[:15], 0)[0], truth = th.stack([F.max(h0[:15], 0),
th.max(h0[15:20], 0)[0], F.max(h0[15:20], 0),
th.max(h0[20:35], 0)[0], F.max(h0[20:35], 0),
th.max(h0[35:40], 0)[0], F.max(h0[35:40], 0),
th.max(h0[40:55], 0)[0]], 0) F.max(h0[40:55], 0)], 0)
assert th.allclose(h1, truth) assert F.allclose(h1, truth)
h1 = sort_pool(h0, bg) h1 = sort_pool(h0, bg)
assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2 assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2
...@@ -155,10 +167,14 @@ def test_set_trans(): ...@@ -155,10 +167,14 @@ def test_set_trans():
st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab') st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'sab')
st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3) st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, 'isab', 3)
st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4) st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
if F.gpu_ctx():
st_enc_0.cuda()
st_enc_1.cuda()
st_dec.cuda()
print(st_enc_0, st_enc_1, st_dec) print(st_enc_0, st_enc_1, st_dec)
# test#1: basic # test#1: basic
h0 = th.rand(g.number_of_nodes(), 50) h0 = F.randn((g.number_of_nodes(), 50))
h1 = st_enc_0(h0, g) h1 = st_enc_0(h0, g)
assert h1.shape == h0.shape assert h1.shape == h0.shape
h1 = st_enc_1(h0, g) h1 = st_enc_1(h0, g)
...@@ -170,7 +186,7 @@ def test_set_trans(): ...@@ -170,7 +186,7 @@ def test_set_trans():
g1 = dgl.DGLGraph(nx.path_graph(5)) g1 = dgl.DGLGraph(nx.path_graph(5))
g2 = dgl.DGLGraph(nx.path_graph(10)) g2 = dgl.DGLGraph(nx.path_graph(10))
bg = dgl.batch([g, g1, g2]) bg = dgl.batch([g, g1, g2])
h0 = th.rand(bg.number_of_nodes(), 50) h0 = F.randn((bg.number_of_nodes(), 50))
h1 = st_enc_0(h0, bg) h1 = st_enc_0(h0, bg)
assert h1.shape == h0.shape assert h1.shape == h0.shape
h1 = st_enc_1(h0, bg) h1 = st_enc_1(h0, bg)
...@@ -187,18 +203,18 @@ def uniform_attention(g, shape): ...@@ -187,18 +203,18 @@ def uniform_attention(g, shape):
def test_edge_softmax(): def test_edge_softmax():
# Basic # Basic
g = dgl.DGLGraph(nx.path_graph(3)) g = dgl.DGLGraph(nx.path_graph(3))
edata = th.ones(g.number_of_edges(), 1) edata = F.ones((g.number_of_edges(), 1))
a = nn.edge_softmax(g, edata) a = nn.edge_softmax(g, edata)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
assert th.allclose(a, uniform_attention(g, a.shape)) assert F.allclose(a, uniform_attention(g, a.shape))
# Test higher dimension case # Test higher dimension case
edata = th.ones(g.number_of_edges(), 3, 1) edata = F.ones((g.number_of_edges(), 3, 1))
a = nn.edge_softmax(g, edata) a = nn.edge_softmax(g, edata)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
assert th.allclose(a, uniform_attention(g, a.shape)) assert F.allclose(a, uniform_attention(g, a.shape))
# Test both forward and backward with PyTorch built-in softmax. # Test both forward and backward with PyTorch built-in softmax.
g = dgl.DGLGraph() g = dgl.DGLGraph()
...@@ -208,10 +224,10 @@ def test_edge_softmax(): ...@@ -208,10 +224,10 @@ def test_edge_softmax():
for j in range(30): for j in range(30):
g.add_edge(i, j) g.add_edge(i, j)
score = th.rand(900, 1) score = F.randn((900, 1))
score.requires_grad_() score.requires_grad_()
grad = th.rand(900, 1) grad = F.randn((900, 1))
y = th.softmax(score.view(30, 30), dim=0).view(-1, 1) y = F.softmax(score.view(30, 30), dim=0).view(-1, 1)
y.backward(grad) y.backward(grad)
grad_score = score.grad grad_score = score.grad
score.grad.zero_() score.grad.zero_()
...@@ -219,10 +235,10 @@ def test_edge_softmax(): ...@@ -219,10 +235,10 @@ def test_edge_softmax():
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 0 assert len(g.edata) == 0
# check forward # check forward
assert th.allclose(y_dgl, y) assert F.allclose(y_dgl, y)
y_dgl.backward(grad) y_dgl.backward(grad)
# checkout gradient # checkout gradient
assert th.allclose(score.grad, grad_score) assert F.allclose(score.grad, grad_score)
print(score.grad[:10], grad_score[:10]) print(score.grad[:10], grad_score[:10])
# Test 2 # Test 2
...@@ -231,10 +247,10 @@ def test_edge_softmax(): ...@@ -231,10 +247,10 @@ def test_edge_softmax():
return dgl.DGLGraph(arr, readonly=True) return dgl.DGLGraph(arr, readonly=True)
g = generate_rand_graph(50) g = generate_rand_graph(50)
a1 = th.randn(g.number_of_edges(), 1).requires_grad_() a1 = F.randn((g.number_of_edges(), 1)).requires_grad_()
a2 = a1.clone().detach().requires_grad_() a2 = a1.clone().detach().requires_grad_()
g.edata['s'] = a1 g.edata['s'] = a1
g.group_apply_edges('dst', lambda edges: {'ss':th.softmax(edges.data['s'], 1)}) g.group_apply_edges('dst', lambda edges: {'ss':F.softmax(edges.data['s'], 1)})
g.edata['ss'].sum().backward() g.edata['ss'].sum().backward()
builtin_sm = nn.edge_softmax(g, a2) builtin_sm = nn.edge_softmax(g, a2)
...@@ -242,7 +258,7 @@ def test_edge_softmax(): ...@@ -242,7 +258,7 @@ def test_edge_softmax():
print(a1.grad - a2.grad) print(a1.grad - a2.grad)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
assert len(g.edata) == 2 assert len(g.edata) == 2
assert th.allclose(a1.grad, a2.grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend assert F.allclose(a1.grad, a2.grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment