Unverified Commit 9e358dfe authored by dddg617's avatar dddg617 Committed by GitHub
Browse files

[NN] HeteroLinear and HeteroEmbedding (#3678)



* modify hetero

* modify rst document

* update hetero

* update hetero

* update hetero

* update hetero

* Update

* Update

* Update

* Update

* 20220216

* Update

* Update

* Fix
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
Co-authored-by: default avatarShelkerX <925089962@qq.com>
parent e9c3c0e8
...@@ -299,12 +299,24 @@ Heterogeneous Graph Convolution Module ...@@ -299,12 +299,24 @@ Heterogeneous Graph Convolution Module
---------------------------------------- ----------------------------------------
HeteroGraphConv HeteroGraphConv
~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: dgl.nn.pytorch.HeteroGraphConv .. autoclass:: dgl.nn.pytorch.HeteroGraphConv
:members: :members:
:show-inheritance: :show-inheritance:
HeteroLinear
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: dgl.nn.pytorch.HeteroLinear
:members:
:show-inheritance:
HeteroEmbedding
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: dgl.nn.pytorch.HeteroEmbedding
:members:
:show-inheritance:
.. _apinn-pytorch-util: .. _apinn-pytorch-util:
Utility Modules Utility Modules
......
...@@ -4,7 +4,7 @@ import torch as th ...@@ -4,7 +4,7 @@ import torch as th
import torch.nn as nn import torch.nn as nn
from ...base import DGLError from ...base import DGLError
__all__ = ['HeteroGraphConv'] __all__ = ['HeteroGraphConv', 'HeteroLinear', 'HeteroEmbedding']
class HeteroGraphConv(nn.Module): class HeteroGraphConv(nn.Module):
r"""A generic module for computing convolution on heterogeneous graphs. r"""A generic module for computing convolution on heterogeneous graphs.
...@@ -250,3 +250,129 @@ def get_aggregate_fn(agg): ...@@ -250,3 +250,129 @@ def get_aggregate_fn(agg):
return _stack_agg_func return _stack_agg_func
else: else:
return partial(_agg_func, fn=fn) return partial(_agg_func, fn=fn)
class HeteroLinear(nn.Module):
"""Apply linear transformations on heterogeneous inputs.
Parameters
----------
in_size : dict[key, int]
Input feature size for heterogeneous inputs. A key can be a string or a tuple of strings.
out_size : int
Output feature size.
Examples
--------
>>> import dgl
>>> import torch
>>> from dgl.nn import HeteroLinear
>>> layer = HeteroLinear({'user': 1, ('user', 'follows', 'user'): 2}, 3)
>>> in_feats = {'user': torch.randn(2, 1), ('user', 'follows', 'user'): torch.randn(3, 2)}
>>> out_feats = layer(in_feats)
>>> print(out_feats['user'].shape)
torch.Size([2, 3])
>>> print(out_feats[('user', 'follows', 'user')].shape)
torch.Size([3, 3])
"""
def __init__(self, in_size, out_size):
super(HeteroLinear, self).__init__()
self.linears = nn.ModuleDict()
for typ, typ_in_size in in_size.items():
self.linears[str(typ)] = nn.Linear(typ_in_size, out_size)
def forward(self, feat):
"""Forward function
Parameters
----------
feat : dict[key, Tensor]
Heterogeneous input features. It maps keys to features.
Returns
-------
dict[key, Tensor]
Transformed features.
"""
out_feat = dict()
for typ, typ_feat in feat.items():
out_feat[typ] = self.linears[str(typ)](typ_feat)
return out_feat
class HeteroEmbedding(nn.Module):
"""Create a heterogeneous embedding table.
It internally contains multiple ``torch.nn.Embedding`` with different dictionary sizes.
Parameters
----------
num_embeddings : dict[key, int]
Size of the dictionaries. A key can be a string or a tuple of strings.
embedding_dim : int
Size of each embedding vector.
Examples
--------
>>> import dgl
>>> import torch
>>> from dgl.nn import HeteroEmbedding
>>> layer = HeteroEmbedding({'user': 2, ('user', 'follows', 'user'): 3}, 4)
>>> # Get the heterogeneous embedding table
>>> embeds = layer.weight
>>> print(embeds['user'].shape)
torch.Size([2, 4])
>>> print(embeds[('user', 'follows', 'user')].shape)
torch.Size([3, 4])
>>> # Get the embeddings for a subset
>>> input_ids = {'user': torch.LongTensor([0]),
... ('user', 'follows', 'user'): torch.LongTensor([0, 2])}
>>> embeds = layer(input_ids)
>>> print(embeds['user'].shape)
torch.Size([1, 4])
>>> print(embeds[('user', 'follows', 'user')].shape)
torch.Size([2, 4])
"""
def __init__(self, num_embeddings, embedding_dim):
super(HeteroEmbedding, self).__init__()
self.embeds = nn.ModuleDict()
self.raw_keys = dict()
for typ, typ_num_rows in num_embeddings.items():
self.embeds[str(typ)] = nn.Embedding(typ_num_rows, embedding_dim)
self.raw_keys[str(typ)] = typ
@property
def weight(self):
"""Get the heterogeneous embedding table
Returns
-------
dict[key, Tensor]
Heterogeneous embedding table
"""
return {self.raw_keys[typ]: emb.weight for typ, emb in self.embeds.items()}
def forward(self, input_ids):
"""Forward function
Parameters
----------
input_ids : dict[key, Tensor]
The row IDs to retrieve embeddings. It maps a key to key-specific IDs.
Returns
-------
dict[key, Tensor]
The retrieved embeddings.
"""
embeds = dict()
for typ, typ_ids in input_ids.items():
embeds[typ] = self.embeds[str(typ)](typ_ids)
return embeds
...@@ -788,7 +788,7 @@ def test_gin_conv(g, idtype, aggregator_type): ...@@ -788,7 +788,7 @@ def test_gin_conv(g, idtype, aggregator_type):
th.save(gin, tmp_buffer) th.save(gin, tmp_buffer)
assert h.shape == (g.number_of_dst_nodes(), 12) assert h.shape == (g.number_of_dst_nodes(), 12)
gin = nn.GINConv(None, aggregator_type) gin = nn.GINConv(None, aggregator_type)
th.save(gin, tmp_buffer) th.save(gin, tmp_buffer)
gin = gin.to(ctx) gin = gin.to(ctx)
...@@ -1246,6 +1246,35 @@ def test_hetero_conv(agg, idtype): ...@@ -1246,6 +1246,35 @@ def test_hetero_conv(agg, idtype):
{'user': uf, 'game': gf, 'store': sf[0:0]})) {'user': uf, 'game': gf, 'store': sf[0:0]}))
assert set(h.keys()) == {'user', 'game'} assert set(h.keys()) == {'user', 'game'}
@pytest.mark.parametrize('out_dim', [1, 2, 100])
def test_hetero_linear(out_dim):
in_feats = {
'user': F.randn((2, 1)),
('user', 'follows', 'user'): F.randn((3, 2))
}
layer = nn.HeteroLinear({'user': 1, ('user', 'follows', 'user'): 2}, out_dim)
layer = layer.to(F.ctx())
out_feats = layer(in_feats)
assert out_feats['user'].shape == (2, out_dim)
assert out_feats[('user', 'follows', 'user')].shape == (3, out_dim)
@pytest.mark.parametrize('out_dim', [1, 2, 100])
def test_hetero_embedding(out_dim):
layer = nn.HeteroEmbedding({'user': 2, ('user', 'follows', 'user'): 3}, out_dim)
layer = layer.to(F.ctx())
embeds = layer.weight
assert embeds['user'].shape == (2, out_dim)
assert embeds[('user', 'follows', 'user')].shape == (3, out_dim)
embeds = layer({
'user': F.tensor([0], dtype=F.int64),
('user', 'follows', 'user'): F.tensor([0, 2], dtype=F.int64)
})
assert embeds['user'].shape == (1, out_dim)
assert embeds[('user', 'follows', 'user')].shape == (2, out_dim)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2]) @pytest.mark.parametrize('out_dim', [1, 2])
...@@ -1348,13 +1377,13 @@ def test_ke_score_funcs(): ...@@ -1348,13 +1377,13 @@ def test_ke_score_funcs():
score_func(h_src, h_dst, rels).shape == (num_edges) score_func(h_src, h_dst, rels).shape == (num_edges)
def test_twirls(): def test_twirls():
g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
feat = th.ones(6, 10) feat = th.ones(6, 10)
conv = nn.TWIRLSConv(10, 2, 128, prop_step = 64) conv = nn.TWIRLSConv(10, 2, 128, prop_step = 64)
res = conv(g , feat) res = conv(g , feat)
assert ( res.size() == (6,2) ) assert ( res.size() == (6,2) )
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment