Unverified Commit 65b0b9e8 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Dataset & Transform] Synthetic Datasets for Explainability and SIGNDiffusion Transform (#3982)

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Fix

* Update

* Update

* Update
parent 03024f95
......@@ -47,6 +47,11 @@ Datasets for node classification/regression tasks
FraudDataset
FraudYelpDataset
FraudAmazonDataset
BAShapeDataset
BACommunityDataset
TreeCycleDataset
TreeGridDataset
BA2MotifDataset
Edge Prediction Datasets
---------------------------------------
......
......@@ -32,3 +32,4 @@ dgl.transforms
LaplacianPE
FeatMask
RowFeatNormalizer
SIGNDiffusion
......@@ -31,6 +31,7 @@ from .fraud import FraudDataset, FraudYelpDataset, FraudAmazonDataset
from .fakenews import FakeNewsDataset
from .csv_dataset import CSVDataset
from .adapter import AsNodePredDataset, AsLinkPredDataset
from .synthetic import BAShapeDataset, BACommunityDataset, TreeCycleDataset, TreeGridDataset, BA2MotifDataset
def register_data_args(parser):
parser.add_argument(
......
......@@ -17,7 +17,7 @@ class DGLDataset(object):
1. Check whether there is a dataset cache on disk
(already processed and stored on the disk) by
invoking ``has_cache()``. If true, goto 5.
2. Call ``download()`` to download the data.
2. Call ``download()`` to download the data if ``url`` is not None.
3. Call ``process()`` to process the data.
4. Call ``save()`` to save the processed dataset on disk and goto 6.
5. Call ``load()`` to load the processed dataset from disk.
......@@ -31,7 +31,7 @@ class DGLDataset(object):
name : str
Name of the dataset
url : str
Url to download the raw dataset
Url to download the raw dataset. Default: None
raw_dir : str
Specifying the directory that will store the
downloaded data or the directory that
......@@ -313,6 +313,7 @@ class DGLBuiltinDataset(DGLDataset):
def download(self):
r""" Automatically download data and extract it.
"""
if self.url is not None:
zip_file_path = os.path.join(self.raw_dir, self.name + '.zip')
download(self.url, path=zip_file_path)
extract_archive(zip_file_path, self.raw_path)
This diff is collapsed.
......@@ -14,13 +14,14 @@
# limitations under the License.
#
"""Modules for transform"""
# pylint: disable= no-member, arguments-differ, invalid-name
# pylint: disable= no-member, arguments-differ, invalid-name, missing-function-docstring
from scipy.linalg import expm
from .. import convert
from .. import backend as F
from .. import function as fn
from ..base import DGLError
from . import functional
try:
......@@ -50,7 +51,8 @@ __all__ = [
'NodeShuffle',
'DropNode',
'DropEdge',
'AddEdge'
'AddEdge',
'SIGNDiffusion'
]
def update_graph_structure(g, data_dict, copy_edata=True):
......@@ -1492,3 +1494,181 @@ class AddEdge(BaseTransform):
dst = F.randint([num_edges_to_add], idtype, device, low=0, high=g.num_nodes(vtype))
g.add_edges(src, dst, etype=c_etype)
return g
class SIGNDiffusion(BaseTransform):
r"""The diffusion operator from `SIGN: Scalable Inception Graph Neural Networks
<https://arxiv.org/abs/2004.11198>`__
It performs node feature diffusion with :math:`TX, \cdots, T^{k}X`, where :math:`T`
is a diffusion matrix and :math:`X` is the input node features.
Specifically, this module provides four options for :math:`T`.
**raw**: raw adjacency matrix :math:`A`
**rw**: random walk (row-normalized) adjacency matrix :math:`D^{-1}A`, where
:math:`D` is the degree matrix.
**gcn**: symmetrically normalized adjacency matrix used by
`GCN <https://arxiv.org/abs/1609.02907>`__, :math:`D^{-1/2}AD^{-1/2}`
**ppr**: approximate personalized PageRank used by
`APPNP <https://arxiv.org/abs/1810.05997>`__
.. math::
H^{0} &= X
H^{l+1} &= (1-\alpha)\left(D^{-1/2}AD^{-1/2} H^{l}\right) + \alpha X
This module only works for homogeneous graphs.
Parameters
----------
k : int
The maximum number of times for node feature diffusion.
in_feat_name : str, optional
:attr:`g.ndata[{in_feat_name}]` should store the input node features. Default: 'feat'
out_feat_name : str, optional
:attr:`g.ndata[{out_feat_name}_i]` will store the result of diffusing
input node features for i times. Default: 'out_feat'
eweight_name : str, optional
Name to retrieve edge weights from :attr:`g.edata`. Default: None,
treating the graph as unweighted.
diffuse_op : str, optional
The diffusion operator to use, which can be 'raw', 'rw', 'gcn', or 'ppr'.
Default: 'raw'
alpha : float, optional
Restart probability if :attr:`diffuse_op` is :attr:`'ppr'`,
which commonly lies in :math:`[0.05, 0.2]`. Default: 0.2
Example
-------
>>> import dgl
>>> import torch
>>> from dgl import SIGNDiffusion
>>> transform = SIGNDiffusion(k=2, eweight_name='w')
>>> num_nodes = 5
>>> num_edges = 20
>>> g = dgl.rand_graph(num_nodes, num_edges)
>>> g.ndata['feat'] = torch.randn(num_nodes, 10)
>>> g.edata['w'] = torch.randn(num_edges)
>>> transform(g)
Graph(num_nodes=5, num_edges=20,
ndata_schemes={'feat': Scheme(shape=(10,), dtype=torch.float32),
'out_feat_1': Scheme(shape=(10,), dtype=torch.float32),
'out_feat_2': Scheme(shape=(10,), dtype=torch.float32)}
edata_schemes={'w': Scheme(shape=(), dtype=torch.float32)})
"""
def __init__(self,
k,
in_feat_name='feat',
out_feat_name='out_feat',
eweight_name=None,
diffuse_op='raw',
alpha=0.2):
self.k = k
self.in_feat_name = in_feat_name
self.out_feat_name = out_feat_name
self.eweight_name = eweight_name
self.diffuse_op = diffuse_op
self.alpha = alpha
if diffuse_op == 'raw':
self.diffuse = self.raw
elif diffuse_op == 'rw':
self.diffuse = self.rw
elif diffuse_op == 'gcn':
self.diffuse = self.gcn
elif diffuse_op == 'ppr':
self.diffuse = self.ppr
else:
raise DGLError("Expect diffuse_op to be from ['raw', 'rw', 'gcn', 'ppr'], \
got {}".format(diffuse_op))
def __call__(self, g):
feat_list = self.diffuse(g)
for i in range(1, self.k + 1):
g.ndata[self.out_feat_name + '_' + str(i)] = feat_list[i - 1]
def raw(self, g):
use_eweight = False
if (self.eweight_name is not None) and self.eweight_name in g.edata:
use_eweight = True
feat_list = []
with g.local_scope():
if use_eweight:
message_func = fn.u_mul_e(self.in_feat_name, self.eweight_name, 'm')
else:
message_func = fn.copy_u(self.in_feat_name, 'm')
for _ in range(self.k):
g.update_all(message_func, fn.sum('m', self.in_feat_name))
feat_list.append(g.ndata[self.in_feat_name])
return feat_list
def rw(self, g):
use_eweight = False
if (self.eweight_name is not None) and self.eweight_name in g.edata:
use_eweight = True
feat_list = []
with g.local_scope():
g.ndata['h'] = g.ndata[self.in_feat_name]
if use_eweight:
message_func = fn.u_mul_e('h', self.eweight_name, 'm')
reduce_func = fn.sum('m', 'h')
# Compute the diagonal entries of D from the weighted A
g.update_all(fn.copy_e(self.eweight_name, 'm'), fn.sum('m', 'z'))
else:
message_func = fn.copy_u('h', 'm')
reduce_func = fn.mean('m', 'h')
for _ in range(self.k):
g.update_all(message_func, reduce_func)
if use_eweight:
g.ndata['h'] = g.ndata['h'] / F.reshape(g.ndata['z'], (g.num_nodes(), 1))
feat_list.append(g.ndata['h'])
return feat_list
def gcn(self, g):
feat_list = []
with g.local_scope():
if self.eweight_name is None:
eweight_name = 'w'
if eweight_name in g.edata:
g.edata.pop(eweight_name)
else:
eweight_name = self.eweight_name
transform = GCNNorm(eweight_name=eweight_name)
transform(g)
for _ in range(self.k):
g.update_all(fn.u_mul_e(self.in_feat_name, eweight_name, 'm'),
fn.sum('m', self.in_feat_name))
feat_list.append(g.ndata[self.in_feat_name])
return feat_list
def ppr(self, g):
feat_list = []
with g.local_scope():
if self.eweight_name is None:
eweight_name = 'w'
if eweight_name in g.edata:
g.edata.pop(eweight_name)
else:
eweight_name = self.eweight_name
transform = GCNNorm(eweight_name=eweight_name)
transform(g)
in_feat = g.ndata[self.in_feat_name]
for _ in range(self.k):
g.update_all(fn.u_mul_e(self.in_feat_name, eweight_name, 'm'),
fn.sum('m', self.in_feat_name))
g.ndata[self.in_feat_name] = (1 - self.alpha) * g.ndata[self.in_feat_name] +\
self.alpha * in_feat
feat_list.append(g.ndata[self.in_feat_name])
return feat_list
......@@ -203,6 +203,64 @@ def test_reddit():
g2 = data.RedditDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_explain_syn():
dataset = data.BAShapeDataset()
assert dataset.num_classes == 4
g = dataset[0]
assert 'label' in g.ndata
assert 'feat' in g.ndata
g1 = data.BAShapeDataset(force_reload=True, seed=0)[0]
src1, dst1 = g1.edges()
g2 = data.BAShapeDataset(force_reload=True, seed=0)[0]
src2, dst2 = g2.edges()
assert F.allclose(src1, src2)
assert F.allclose(dst1, dst2)
dataset = data.BACommunityDataset()
assert dataset.num_classes == 8
g = dataset[0]
assert 'label' in g.ndata
assert 'feat' in g.ndata
g1 = data.BACommunityDataset(force_reload=True, seed=0)[0]
src1, dst1 = g1.edges()
g2 = data.BACommunityDataset(force_reload=True, seed=0)[0]
src2, dst2 = g2.edges()
assert F.allclose(src1, src2)
assert F.allclose(dst1, dst2)
dataset = data.TreeCycleDataset()
assert dataset.num_classes == 2
g = dataset[0]
assert 'label' in g.ndata
assert 'feat' in g.ndata
g1 = data.TreeCycleDataset(force_reload=True, seed=0)[0]
src1, dst1 = g1.edges()
g2 = data.TreeCycleDataset(force_reload=True, seed=0)[0]
src2, dst2 = g2.edges()
assert F.allclose(src1, src2)
assert F.allclose(dst1, dst2)
dataset = data.TreeGridDataset()
assert dataset.num_classes == 2
g = dataset[0]
assert 'label' in g.ndata
assert 'feat' in g.ndata
g1 = data.TreeGridDataset(force_reload=True, seed=0)[0]
src1, dst1 = g1.edges()
g2 = data.TreeGridDataset(force_reload=True, seed=0)[0]
src2, dst2 = g2.edges()
assert F.allclose(src1, src2)
assert F.allclose(dst1, dst2)
dataset = data.BA2MotifDataset()
assert dataset.num_classes == 2
g, label = dataset[0]
assert 'feat' in g.ndata
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_extract_archive():
......
......@@ -24,6 +24,8 @@ import dgl.partition
import backend as F
import unittest
import math
import pytest
from test_utils.graph_cases import get_cases
from utils import parametrize_dtype
from test_heterograph import create_test_heterograph3, create_test_heterograph4, create_test_heterograph5
......@@ -2350,6 +2352,75 @@ def test_module_laplacian_pe(idtype):
else:
assert F.allclose(new_g.ndata['lappe'].abs(), tgt)
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature']))
def test_module_sign(g):
import torch
ctx = F.ctx()
g = g.to(ctx)
adj = g.adj(transpose=True, scipy_fmt='coo').todense()
adj = torch.tensor(adj).float().to(ctx)
weight_adj = g.adj(transpose=True, scipy_fmt='coo').astype(float).todense()
weight_adj = torch.tensor(weight_adj).float().to(ctx)
src, dst = g.edges()
src, dst = src.long(), dst.long()
weight_adj[dst, src] = g.edata['scalar_w']
# raw
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='raw')
transform(g)
assert torch.allclose(g.ndata['out_feat_1'], torch.matmul(adj, g.ndata['h']))
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='raw')
transform(g)
assert torch.allclose(g.ndata['out_feat_1'], torch.matmul(weight_adj, g.ndata['h']))
# rw
adj_rw = torch.matmul(torch.diag(1 / adj.sum(dim=1)), adj)
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='rw')
transform(g)
assert torch.allclose(g.ndata['out_feat_1'], torch.matmul(adj_rw, g.ndata['h']))
weight_adj_rw = torch.matmul(torch.diag(1 / weight_adj.sum(dim=1)), weight_adj)
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='rw')
transform(g)
assert torch.allclose(g.ndata['out_feat_1'], torch.matmul(weight_adj_rw, g.ndata['h']))
# gcn
raw_eweight = g.edata['scalar_w']
gcn_norm = dgl.GCNNorm()
gcn_norm(g)
adj_gcn = adj.clone()
adj_gcn[dst, src] = g.edata.pop('w')
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='gcn')
transform(g)
assert torch.allclose(g.ndata['out_feat_1'], torch.matmul(adj_gcn, g.ndata['h']))
gcn_norm = dgl.GCNNorm('scalar_w')
gcn_norm(g)
weight_adj_gcn = weight_adj.clone()
weight_adj_gcn[dst, src] = g.edata['scalar_w']
g.edata['scalar_w'] = raw_eweight
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h',
eweight_name='scalar_w', diffuse_op='gcn')
transform(g)
assert torch.allclose(g.ndata['out_feat_1'], torch.matmul(weight_adj_gcn, g.ndata['h']))
# ppr
alpha = 0.2
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='ppr', alpha=alpha)
transform(g)
target = (1 - alpha) * torch.matmul(adj_gcn, g.ndata['h']) + alpha * g.ndata['h']
assert torch.allclose(g.ndata['out_feat_1'], target)
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w',
diffuse_op='ppr', alpha=alpha)
transform(g)
target = (1 - alpha) * torch.matmul(weight_adj_gcn, g.ndata['h']) + alpha * g.ndata['h']
assert torch.allclose(g.ndata['out_feat_1'], target)
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@parametrize_dtype
def test_module_row_feat_normalizer(idtype):
......@@ -2416,8 +2487,6 @@ def test_module_feat_mask(idtype):
assert g.edata['w'][('user', 'follows', 'user')].shape == (2, 5)
assert g.edata['w'][('player', 'plays', 'game')].shape == (2, 5)
if __name__ == '__main__':
test_partition_with_halo()
test_module_heat_kernel(F.int32)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment