"test/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "1c3eedc49f02fcd457ee4fac2de97d284bc6bae4"
Unverified Commit 230b886e authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Bug fix] Misc Fix for Transforms and NN Modules (#4038)

* Update module.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update module.py

* Update

* Update

* Update
parent 744896e2
...@@ -403,8 +403,8 @@ class LabelPropagation(nn.Module): ...@@ -403,8 +403,8 @@ class LabelPropagation(nn.Module):
Propagation <http://mlg.eng.cam.ac.uk/zoubin/papers/CMU-CALD-02-107.pdf>`__ Propagation <http://mlg.eng.cam.ac.uk/zoubin/papers/CMU-CALD-02-107.pdf>`__
.. math:: .. math::
\mathbf{Y}^{(t+1)} = \alpha \cdot \tilde{A} \mathbf{Y}^{(t)} + (1 - \alpha)
\mathbf{Y}^{(0)}, \mathbf{Y}^{(t+1)} = \alpha \tilde{A} \mathbf{Y}^{(t)} + (1 - \alpha) \mathbf{Y}^{(0)}
where unlabeled data is initially set to zero and inferred from labeled data via where unlabeled data is initially set to zero and inferred from labeled data via
propagation. :math:`\alpha` is a weight parameter for balancing between updated labels propagation. :math:`\alpha` is a weight parameter for balancing between updated labels
......
...@@ -121,11 +121,11 @@ class RowFeatNormalizer(BaseTransform): ...@@ -121,11 +121,11 @@ class RowFeatNormalizer(BaseTransform):
Subtraction will make all values non-negative. If all values are negative, after Subtraction will make all values non-negative. If all values are negative, after
normalisation, the sum of each row of the feature tensor will be 1. normalisation, the sum of each row of the feature tensor will be 1.
node_feat_names : list[str], optional node_feat_names : list[str], optional
The names of the node features to be normalized. Default: `None`. The names of the node feature tensors to be row-normalized. Default: `None`, which will
If None, all node features will be normalized. not normalize any node feature tensor.
edge_feat_names : list[str], optional edge_feat_names : list[str], optional
The names of the edge features to be normalized. Default: `None`. The names of the edge feature tensors to be row-normalized. Default: `None`, which will
If None, all edge features will be normalized. not normalize any edge feature tensor.
Example Example
------- -------
...@@ -138,51 +138,40 @@ class RowFeatNormalizer(BaseTransform): ...@@ -138,51 +138,40 @@ class RowFeatNormalizer(BaseTransform):
Case1: Row normalize features of a homogeneous graph. Case1: Row normalize features of a homogeneous graph.
>>> transform = RowFeatNormalizer() >>> transform = RowFeatNormalizer(subtract_min=True,
... node_feat_names=['h'], edge_feat_names=['w'])
>>> g = dgl.rand_graph(5, 20) >>> g = dgl.rand_graph(5, 20)
>>> g.ndata['h'] = torch.randn((g.num_nodes(), 5)) >>> g.ndata['h'] = torch.randn((g.num_nodes(), 5))
>>> print(g.ndata['h'].sum(1)) >>> g.edata['w'] = torch.randn((g.num_edges(), 5))
tensor([-3.0586, 3.4974, 3.1509, 1.5805, 0.4890])
>>> g.edata['w'] = torch.randn((g.num_edges, 5))
>>> print(g.edata['w'].sum(1))
tensor([ 2.9284, -3.8341, 1.5087, 0.8673, -2.5115, 2.4751, 0.1427, 1.7180,
2.7705, 1.0600, -0.7126, 0.1072, -0.8159, 1.2082, 2.0327, -2.3323,
-1.2495, 1.9458, -1.4240, -1.3575])
>>> g = transform(g) >>> g = transform(g)
>>> print(g.ndata['h'].sum(1)) >>> print(g.ndata['h'].sum(1))
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000]) tensor([1., 1., 1., 1., 1.])
>>> print(g.edata['w'].sum(1)) >>> print(g.edata['w'].sum(1))
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, tensor([1., 1., 1., 1., 1., 1., 1., 1., 1.,
1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1.0000, 1.0000]) 1., 1.])
Case2: Row normalize features of a heterogeneous graph. Case2: Row normalize features of a heterogeneous graph.
>>> transform = RowFeatNormalizer()
>>> g = dgl.heterograph({ >>> g = dgl.heterograph({
... ('user', 'follows', 'user'): (torch.tensor([1, 2]), torch.tensor([3, 4])), ... ('user', 'follows', 'user'): (torch.tensor([1, 2]), torch.tensor([3, 4])),
... ('player', 'plays', 'game'): (torch.tensor([2, 2]), torch.tensor([1, 1])) ... ('player', 'plays', 'game'): (torch.tensor([2, 2]), torch.tensor([1, 1]))
... }) ... })
>>> g.ndata['h'] = {'game': torch.randn(2, 5), 'player': torch.randn(3, 5)} >>> g.ndata['h'] = {'game': torch.randn(2, 5), 'player': torch.randn(3, 5)}
>>> print(g.ndata['h']['game'].sum(1), g.ndata['h']['player'].sum(1))
tensor([ 4.4201, -4.0683]) tensor([-2.2460, -1.1204, -2.0254])
>>> g.edata['w'] = { >>> g.edata['w'] = {
... ('user', 'follows', 'user'): torch.randn(2, 5), ... ('user', 'follows', 'user'): torch.randn(2, 5),
... ('player', 'plays', 'game'): torch.randn(2, 5) ... ('player', 'plays', 'game'): torch.randn(2, 5)
... } ... }
>>> print(g.edata['w'][('user', 'follows', 'user')].sum(1),
... g.edata['w'][('player', 'plays', 'game')].sum(1))
tensor([-1.2663, 3.3789]) tensor([4.1371, 1.4743])
>>> g = transform(g) >>> g = transform(g)
>>> print(g.ndata['h']['game'].sum(1), g.ndata['h']['player'].sum(1)) >>> print(g.ndata['h']['game'].sum(1), g.ndata['h']['player'].sum(1))
tensor([1.0000, 1.0000]) tensor([1.0000, 1.0000, 1.0000]) tensor([1., 1.]) tensor([1., 1., 1.])
>>> print(g.edata['w'][('user', 'follows', 'user')].sum(1), >>> print(g.edata['w'][('user', 'follows', 'user')].sum(1),
... g.edata['w'][('player', 'plays', 'game')].sum(1)) ... g.edata['w'][('player', 'plays', 'game')].sum(1))
tensor([1.0000, 1.0000]) tensor([1.0000, 1.0000]) tensor([1., 1.]) tensor([1., 1.])
""" """
def __init__(self, subtract_min=False, node_feat_names=None, edge_feat_names=None): def __init__(self, subtract_min=False, node_feat_names=None, edge_feat_names=None):
self.node_feat_names = node_feat_names self.node_feat_names = [] if node_feat_names is None else node_feat_names
self.edge_feat_names = edge_feat_names self.edge_feat_names = [] if edge_feat_names is None else edge_feat_names
self.subtract_min = subtract_min self.subtract_min = subtract_min
def row_normalize(self, feat): def row_normalize(self, feat):
...@@ -208,12 +197,6 @@ class RowFeatNormalizer(BaseTransform): ...@@ -208,12 +197,6 @@ class RowFeatNormalizer(BaseTransform):
return feat return feat
def __call__(self, g): def __call__(self, g):
if self.node_feat_names is None:
self.node_feat_names = g.ndata.keys()
if self.edge_feat_names is None:
self.edge_feat_names = g.edata.keys()
for node_feat_name in self.node_feat_names: for node_feat_name in self.node_feat_names:
if isinstance(g.ndata[node_feat_name], torch.Tensor): if isinstance(g.ndata[node_feat_name], torch.Tensor):
g.ndata[node_feat_name] = self.row_normalize(g.ndata[node_feat_name]) g.ndata[node_feat_name] = self.row_normalize(g.ndata[node_feat_name])
...@@ -233,21 +216,19 @@ class RowFeatNormalizer(BaseTransform): ...@@ -233,21 +216,19 @@ class RowFeatNormalizer(BaseTransform):
return g return g
class FeatMask(BaseTransform): class FeatMask(BaseTransform):
r"""Randomly mask columns of the node and edge feature tensors, as described in `An Empirical r"""Randomly mask columns of the node and edge feature tensors, as described in `Graph
Study of Graph Contrastive Learning <https://arxiv.org/abs/2109.01116>`__. Contrastive Learning with Augmentations <https://arxiv.org/abs/2010.13902>`__.
Parameters Parameters
---------- ----------
p : float, optional p : float, optional
Probability of masking a column of a feature tensor. Default: `0.5`. Probability of masking a column of a feature tensor. Default: `0.5`.
node_feat_names : list[str], optional node_feat_names : list[str], optional
The names of the node feature tensors to be masked. Default: `None`. The names of the node feature tensors to be masked. Default: `None`, which will
If None, all node feature tensors will be randomly mask some columns according to not mask any node feature tensor.
probability :attr:`p`.
edge_feat_names : list[str], optional edge_feat_names : list[str], optional
The names of the edge features to be masked. Default: `None`. The names of the edge features to be masked. Default: `None`, which will not mask
If None, all edge feature tensors will be randomly mask some columns according to any edge feature tensor.
probability :attr:`p`.
Example Example
------- -------
...@@ -260,7 +241,7 @@ class FeatMask(BaseTransform): ...@@ -260,7 +241,7 @@ class FeatMask(BaseTransform):
Case1 : Mask node and edge feature tensors of a homogeneous graph. Case1 : Mask node and edge feature tensors of a homogeneous graph.
>>> transform = FeatMask() >>> transform = FeatMask(node_feat_names=['h'], edge_feat_names=['w'])
>>> g = dgl.rand_graph(5, 10) >>> g = dgl.rand_graph(5, 10)
>>> g.ndata['h'] = torch.ones((g.num_nodes(), 10)) >>> g.ndata['h'] = torch.ones((g.num_nodes(), 10))
>>> g.edata['w'] = torch.ones((g.num_edges(), 10)) >>> g.edata['w'] = torch.ones((g.num_edges(), 10))
...@@ -268,53 +249,48 @@ class FeatMask(BaseTransform): ...@@ -268,53 +249,48 @@ class FeatMask(BaseTransform):
>>> g = transform(g) >>> g = transform(g)
>>> print(g.ndata['h']) >>> print(g.ndata['h'])
tensor([[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.], tensor([[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.], [0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.], [0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.], [0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.]]) [0., 0., 1., 1., 0., 0., 1., 1., 1., 0.]])
>>> print(g.edata['w']) >>> print(g.edata['w'])
tensor([[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.], tensor([[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.], [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.], [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.], [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.], [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.], [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.], [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.], [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.], [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.]]) [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.]])
Case2 : Mask node and edge feature tensors of a heterogeneous graph. Case2 : Mask node and edge feature tensors of a heterogeneous graph.
>>> transform = FeatMask()
>>> g = dgl.heterograph({ >>> g = dgl.heterograph({
... ('user', 'follows', 'user'): (torch.tensor([1, 2]), torch.tensor([3, 4])), ... ('user', 'follows', 'user'): (torch.tensor([1, 2]), torch.tensor([3, 4])),
... ('player', 'plays', 'game'): (torch.tensor([2, 2]), torch.tensor([1, 1])) ... ('player', 'plays', 'game'): (torch.tensor([2, 2]), torch.tensor([1, 1]))
... }) ... })
>>> g.ndata['h'] = {'game': torch.ones(2, 5), 'player': torch.ones(3, 5)} >>> g.ndata['h'] = {'game': torch.ones(2, 5), 'player': torch.ones(3, 5)}
>>> g.edata['w'] = {('user', 'follows', 'user'): torch.ones(2, 5)} >>> g.edata['w'] = {('user', 'follows', 'user'): torch.ones(2, 5)}
>>> print(g.ndata['h']['game'], g.ndata['h']['player']) >>> print(g.ndata['h']['game'])
tensor([[1., 1., 1., 1., 1.], tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]) tensor([[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]])
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
>>> print(g.edata['w'][('user', 'follows', 'user')]) >>> print(g.edata['w'][('user', 'follows', 'user')])
tensor([[1., 1., 1., 1., 1.], tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]) [1., 1., 1., 1., 1.]])
>>> g = transform(g) >>> g = transform(g)
>>> print(g.ndata['h']['game'], g.ndata['h']['player']) >>> print(g.ndata['h']['game'])
tensor([[1., 1., 0., 1., 0.], tensor([[1., 1., 0., 1., 0.],
[1., 1., 0., 1., 0.]]) tensor([[0., 0., 0., 0., 1.], [1., 1., 0., 1., 0.]])
[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.]])
>>> print(g.edata['w'][('user', 'follows', 'user')]) >>> print(g.edata['w'][('user', 'follows', 'user')])
tensor([[0., 1., 0., 1., 0.], tensor([[0., 1., 0., 1., 0.],
[0., 1., 0., 1., 0.]]) [0., 1., 0., 1., 0.]])
""" """
def __init__(self, p=0.5, node_feat_names=None, edge_feat_names=None): def __init__(self, p=0.5, node_feat_names=None, edge_feat_names=None):
self.p = p self.p = p
self.node_feat_names = node_feat_names self.node_feat_names = [] if node_feat_names is None else node_feat_names
self.edge_feat_names = edge_feat_names self.edge_feat_names = [] if edge_feat_names is None else edge_feat_names
self.dist = Bernoulli(p) self.dist = Bernoulli(p)
def __call__(self, g): def __call__(self, g):
...@@ -322,12 +298,6 @@ class FeatMask(BaseTransform): ...@@ -322,12 +298,6 @@ class FeatMask(BaseTransform):
if self.p == 0: if self.p == 0:
return g return g
if self.node_feat_names is None:
self.node_feat_names = g.ndata.keys()
if self.edge_feat_names is None:
self.edge_feat_names = g.edata.keys()
for node_feat_name in self.node_feat_names: for node_feat_name in self.node_feat_names:
if isinstance(g.ndata[node_feat_name], torch.Tensor): if isinstance(g.ndata[node_feat_name], torch.Tensor):
feat_mask = self.dist.sample(torch.Size([g.ndata[node_feat_name].shape[-1], ])) feat_mask = self.dist.sample(torch.Size([g.ndata[node_feat_name].shape[-1], ]))
...@@ -1592,6 +1562,7 @@ class SIGNDiffusion(BaseTransform): ...@@ -1592,6 +1562,7 @@ class SIGNDiffusion(BaseTransform):
for i in range(1, self.k + 1): for i in range(1, self.k + 1):
g.ndata[self.out_feat_name + '_' + str(i)] = feat_list[i - 1] g.ndata[self.out_feat_name + '_' + str(i)] = feat_list[i - 1]
return g
def raw(self, g): def raw(self, g):
use_eweight = False use_eweight = False
......
...@@ -2370,54 +2370,56 @@ def test_module_sign(g): ...@@ -2370,54 +2370,56 @@ def test_module_sign(g):
# raw # raw
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='raw') transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='raw')
transform(g) g = transform(g)
assert torch.allclose(g.ndata['out_feat_1'], torch.matmul(adj, g.ndata['h'])) assert torch.allclose(g.ndata['out_feat_1'], torch.matmul(adj, g.ndata['h']))
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='raw') transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='raw')
transform(g) g = transform(g)
assert torch.allclose(g.ndata['out_feat_1'], torch.matmul(weight_adj, g.ndata['h'])) assert torch.allclose(g.ndata['out_feat_1'], torch.matmul(weight_adj, g.ndata['h']))
# rw # rw
adj_rw = torch.matmul(torch.diag(1 / adj.sum(dim=1)), adj) adj_rw = torch.matmul(torch.diag(1 / adj.sum(dim=1)), adj)
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='rw') transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='rw')
transform(g) g = transform(g)
assert torch.allclose(g.ndata['out_feat_1'], torch.matmul(adj_rw, g.ndata['h'])) assert torch.allclose(g.ndata['out_feat_1'], torch.matmul(adj_rw, g.ndata['h']))
weight_adj_rw = torch.matmul(torch.diag(1 / weight_adj.sum(dim=1)), weight_adj) weight_adj_rw = torch.matmul(torch.diag(1 / weight_adj.sum(dim=1)), weight_adj)
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='rw') transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='rw')
transform(g) g = transform(g)
assert torch.allclose(g.ndata['out_feat_1'], torch.matmul(weight_adj_rw, g.ndata['h'])) assert torch.allclose(g.ndata['out_feat_1'], torch.matmul(weight_adj_rw, g.ndata['h']))
# gcn # gcn
raw_eweight = g.edata['scalar_w'] raw_eweight = g.edata['scalar_w']
gcn_norm = dgl.GCNNorm() gcn_norm = dgl.GCNNorm()
gcn_norm(g) g = gcn_norm(g)
adj_gcn = adj.clone() adj_gcn = adj.clone()
adj_gcn[dst, src] = g.edata.pop('w') adj_gcn[dst, src] = g.edata.pop('w')
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='gcn') transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='gcn')
transform(g) g = transform(g)
assert torch.allclose(g.ndata['out_feat_1'], torch.matmul(adj_gcn, g.ndata['h'])) target = torch.matmul(adj_gcn, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target)
gcn_norm = dgl.GCNNorm('scalar_w') gcn_norm = dgl.GCNNorm('scalar_w')
gcn_norm(g) g = gcn_norm(g)
weight_adj_gcn = weight_adj.clone() weight_adj_gcn = weight_adj.clone()
weight_adj_gcn[dst, src] = g.edata['scalar_w'] weight_adj_gcn[dst, src] = g.edata['scalar_w']
g.edata['scalar_w'] = raw_eweight g.edata['scalar_w'] = raw_eweight
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', transform = dgl.SIGNDiffusion(k=1, in_feat_name='h',
eweight_name='scalar_w', diffuse_op='gcn') eweight_name='scalar_w', diffuse_op='gcn')
transform(g) g = transform(g)
assert torch.allclose(g.ndata['out_feat_1'], torch.matmul(weight_adj_gcn, g.ndata['h'])) target = torch.matmul(weight_adj_gcn, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target)
# ppr # ppr
alpha = 0.2 alpha = 0.2
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='ppr', alpha=alpha) transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='ppr', alpha=alpha)
transform(g) g = transform(g)
target = (1 - alpha) * torch.matmul(adj_gcn, g.ndata['h']) + alpha * g.ndata['h'] target = (1 - alpha) * torch.matmul(adj_gcn, g.ndata['h']) + alpha * g.ndata['h']
assert torch.allclose(g.ndata['out_feat_1'], target) assert torch.allclose(g.ndata['out_feat_1'], target)
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w',
diffuse_op='ppr', alpha=alpha) diffuse_op='ppr', alpha=alpha)
transform(g) g = transform(g)
target = (1 - alpha) * torch.matmul(weight_adj_gcn, g.ndata['h']) + alpha * g.ndata['h'] target = (1 - alpha) * torch.matmul(weight_adj_gcn, g.ndata['h']) + alpha * g.ndata['h']
assert torch.allclose(g.ndata['out_feat_1'], target) assert torch.allclose(g.ndata['out_feat_1'], target)
...@@ -2425,7 +2427,8 @@ def test_module_sign(g): ...@@ -2425,7 +2427,8 @@ def test_module_sign(g):
@parametrize_idtype @parametrize_idtype
def test_module_row_feat_normalizer(idtype): def test_module_row_feat_normalizer(idtype):
# Case1: Normalize features of a homogeneous graph. # Case1: Normalize features of a homogeneous graph.
transform = dgl.RowFeatNormalizer(subtract_min=True) transform = dgl.RowFeatNormalizer(subtract_min=True,
node_feat_names=['h'], edge_feat_names=['w'])
g = dgl.rand_graph(5, 5, idtype=idtype, device=F.ctx()) g = dgl.rand_graph(5, 5, idtype=idtype, device=F.ctx())
g.ndata['h'] = F.randn((g.num_nodes(), 128)) g.ndata['h'] = F.randn((g.num_nodes(), 128))
g.edata['w'] = F.randn((g.num_edges(), 128)) g.edata['w'] = F.randn((g.num_edges(), 128))
...@@ -2436,7 +2439,8 @@ def test_module_row_feat_normalizer(idtype): ...@@ -2436,7 +2439,8 @@ def test_module_row_feat_normalizer(idtype):
assert F.allclose(g.edata['w'].sum(1), F.tensor([1.0, 1.0, 1.0, 1.0, 1.0])) assert F.allclose(g.edata['w'].sum(1), F.tensor([1.0, 1.0, 1.0, 1.0, 1.0]))
# Case2: Normalize features of a heterogeneous graph. # Case2: Normalize features of a heterogeneous graph.
transform = dgl.RowFeatNormalizer(subtract_min=True) transform = dgl.RowFeatNormalizer(subtract_min=True,
node_feat_names=['h', 'h2'], edge_feat_names=['w'])
g = dgl.heterograph({ g = dgl.heterograph({
('user', 'follows', 'user'): (F.tensor([1, 2]), F.tensor([3, 4])), ('user', 'follows', 'user'): (F.tensor([1, 2]), F.tensor([3, 4])),
('player', 'plays', 'game'): (F.tensor([2, 2]), F.tensor([1, 1])) ('player', 'plays', 'game'): (F.tensor([2, 2]), F.tensor([1, 1]))
...@@ -2460,7 +2464,7 @@ def test_module_row_feat_normalizer(idtype): ...@@ -2460,7 +2464,7 @@ def test_module_row_feat_normalizer(idtype):
@parametrize_idtype @parametrize_idtype
def test_module_feat_mask(idtype): def test_module_feat_mask(idtype):
# Case1: Mask node and edge feature tensors of a homogeneous graph. # Case1: Mask node and edge feature tensors of a homogeneous graph.
transform = dgl.FeatMask() transform = dgl.FeatMask(node_feat_names=['h'], edge_feat_names=['w'])
g = dgl.rand_graph(5, 20, idtype=idtype, device=F.ctx()) g = dgl.rand_graph(5, 20, idtype=idtype, device=F.ctx())
g.ndata['h'] = F.ones((g.num_nodes(), 10)) g.ndata['h'] = F.ones((g.num_nodes(), 10))
g.edata['w'] = F.ones((g.num_edges(), 20)) g.edata['w'] = F.ones((g.num_edges(), 20))
...@@ -2471,7 +2475,6 @@ def test_module_feat_mask(idtype): ...@@ -2471,7 +2475,6 @@ def test_module_feat_mask(idtype):
assert g.edata['w'].shape == (g.num_edges(), 20) assert g.edata['w'].shape == (g.num_edges(), 20)
# Case2: Mask node and edge feature tensors of a heterogeneous graph. # Case2: Mask node and edge feature tensors of a heterogeneous graph.
transform = dgl.FeatMask()
g = dgl.heterograph({ g = dgl.heterograph({
('user', 'follows', 'user'): (F.tensor([1, 2]), F.tensor([3, 4])), ('user', 'follows', 'user'): (F.tensor([1, 2]), F.tensor([3, 4])),
('player', 'plays', 'game'): (F.tensor([2, 2]), F.tensor([1, 1])) ('player', 'plays', 'game'): (F.tensor([2, 2]), F.tensor([1, 1]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment