Unverified Commit f846d902 authored by LuckyLiuM's avatar LuckyLiuM Committed by GitHub
Browse files

Add Metapath2vec module (#4660)



* metapath2vec package

* fix bugs --metapath2vec package

* add unittest and fix bugs

* fix pyling messages

* del init.py

* fix bugs

* modify metapath2vec and add deepwalk

* metapath2vec module

* Update

* Update

* rollback to initial metapath2vec

* Update

* Update

* Update

* Update
Co-authored-by: default avatarRhett Ying <85214957+Rhett-Ying@users.noreply.github.com>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-9-26.ap-northeast-1.compute.internal>
parent 2b983869
...@@ -123,3 +123,4 @@ Network Embedding Modules ...@@ -123,3 +123,4 @@ Network Embedding Modules
:template: classtemplate.rst :template: classtemplate.rst
~dgl.nn.pytorch.DeepWalk ~dgl.nn.pytorch.DeepWalk
~dgl.nn.pytorch.MetaPath2Vec
"""Network Embedding NN Modules""" """Network Embedding NN Modules"""
# pylint: disable= invalid-name # pylint: disable= invalid-name
import random import random
import torch import torch
from torch import nn from torch import nn
from torch.nn import init from torch.nn import init
import torch.nn.functional as F import torch.nn.functional as F
import tqdm
from ...base import NID
from ...convert import to_homogeneous, to_heterogeneous
from ...random import choice from ...random import choice
from ...sampling import random_walk from ...sampling import random_walk
__all__ = ['DeepWalk'] __all__ = ['DeepWalk', 'MetaPath2Vec']
class DeepWalk(nn.Module): class DeepWalk(nn.Module):
"""DeepWalk module from `DeepWalk: Online Learning of Social Representations """DeepWalk module from `DeepWalk: Online Learning of Social Representations
...@@ -33,7 +37,7 @@ class DeepWalk(nn.Module): ...@@ -33,7 +37,7 @@ class DeepWalk(nn.Module):
neg_weight : float, optional neg_weight : float, optional
Weight of the loss term for negative samples in the total loss. Default: 1.0 Weight of the loss term for negative samples in the total loss. Default: 1.0
negative_size : int, optional negative_size : int, optional
Number of negative samples to use for each positive sample in an iteration. Default: 1 Number of negative samples to use for each positive sample. Default: 1
fast_neg : bool, optional fast_neg : bool, optional
If True, it samples negative node pairs within a batch of random walks. Default: True If True, it samples negative node pairs within a batch of random walks. Default: True
sparse : bool, optional sparse : bool, optional
...@@ -58,15 +62,18 @@ class DeepWalk(nn.Module): ...@@ -58,15 +62,18 @@ class DeepWalk(nn.Module):
>>> dataset = CoraGraphDataset() >>> dataset = CoraGraphDataset()
>>> g = dataset[0] >>> g = dataset[0]
>>> model = DeepWalk(g) >>> model = DeepWalk(g)
>>> optimizer = SparseAdam(model.parameters(), lr=0.01)
>>> dataloader = DataLoader(torch.arange(g.num_nodes()), batch_size=128, >>> dataloader = DataLoader(torch.arange(g.num_nodes()), batch_size=128,
... shuffle=True, collate_fn=model.sample) ... shuffle=True, collate_fn=model.sample)
>>> optimizer = SparseAdam(model.parameters(), lr=0.01)
>>> num_epochs = 5 >>> num_epochs = 5
>>> for epoch in range(num_epochs): >>> for epoch in range(num_epochs):
... for batch_walk in dataloader: ... for batch_walk in dataloader:
... loss = model(batch_walk) ... loss = model(batch_walk)
... optimizer.zero_grad()
... loss.backward() ... loss.backward()
... optimizer.step() ... optimizer.step()
>>> train_mask = g.ndata['train_mask'] >>> train_mask = g.ndata['train_mask']
>>> test_mask = g.ndata['test_mask'] >>> test_mask = g.ndata['test_mask']
>>> X = model.node_embed.weight.detach() >>> X = model.node_embed.weight.detach()
...@@ -100,14 +107,13 @@ class DeepWalk(nn.Module): ...@@ -100,14 +107,13 @@ class DeepWalk(nn.Module):
# center node embedding # center node embedding
self.node_embed = nn.Embedding(num_nodes, emb_dim, sparse=sparse) self.node_embed = nn.Embedding(num_nodes, emb_dim, sparse=sparse)
# context embedding
self.context_embed = nn.Embedding(num_nodes, emb_dim, sparse=sparse) self.context_embed = nn.Embedding(num_nodes, emb_dim, sparse=sparse)
self.reset_parameters() self.reset_parameters()
if not fast_neg: if not fast_neg:
node_degree = g.out_degrees().pow(0.75) neg_prob = g.out_degrees().pow(0.75)
# categorical distribution for true negative sampling # categorical distribution for true negative sampling
self.neg_prob = node_degree / node_degree.sum() self.neg_prob = neg_prob / neg_prob.sum()
# Get list index pairs for positive samples. # Get list index pairs for positive samples.
# Given i, positive index pairs are (i - window_size, i), ... , # Given i, positive index pairs are (i - window_size, i), ... ,
...@@ -203,3 +209,199 @@ class DeepWalk(nn.Module): ...@@ -203,3 +209,199 @@ class DeepWalk(nn.Module):
neg_score = torch.mean(-F.logsigmoid(-neg_score)) * self.negative_size * self.neg_weight neg_score = torch.mean(-F.logsigmoid(-neg_score)) * self.negative_size * self.neg_weight
return torch.mean(pos_score + neg_score) return torch.mean(pos_score + neg_score)
class MetaPath2Vec(nn.Module):
r"""metapath2vec module from `metapath2vec: Scalable Representation Learning for
Heterogeneous Networks <https://dl.acm.org/doi/pdf/10.1145/3097983.3098036>`__
To achieve efficient optimization, we leverage the negative sampling technique for the
training process. Repeatedly for each node in meta-path, we treat it as the center node
and sample nearby positive nodes within context size and draw negative samples among all
types of nodes from all meta-paths. Then we can use the center-context paired nodes and
context-negative paired nodes to update the network.
Parameters
----------
g : DGLGraph
Graph for learning node embeddings. Two different canonical edge types
:attr:`(utype, etype, vtype)` are not allowed to have same :attr:`etype`.
metapath : list[str]
A sequence of edge types in the form of a string. It defines a new edge type by composing
multiple edge types in order. Note that the start node type and the end one are commonly
the same.
window_size : int
In a random walk :attr:`w`, a node :attr:`w[j]` is considered close to a node
:attr:`w[i]` if :attr:`i - window_size <= j <= i + window_size`.
emb_dim : int, optional
Size of each embedding vector. Default: 128
negative_size : int, optional
Number of negative samples to use for each positive sample. Default: 5
sparse : bool, optional
If True, gradients with respect to the learnable weights will be sparse.
Default: True
Attributes
----------
node_embed : nn.Embedding
Embedding table of all nodes
local_to_global_nid : dict[str, list]
Mapping from type-specific node IDs to global node IDs
Examples
--------
>>> import torch
>>> import dgl
>>> from torch.optim import SparseAdam
>>> from torch.utils.data import DataLoader
>>> from dgl.nn.pytorch import MetaPath2Vec
>>> # Define a model
>>> g = dgl.heterograph({
... ('user', 'uc', 'company'): dgl.rand_graph(100, 1000).edges(),
... ('company', 'cp', 'product'): dgl.rand_graph(100, 1000).edges(),
... ('company', 'cu', 'user'): dgl.rand_graph(100, 1000).edges(),
... ('product', 'pc', 'company'): dgl.rand_graph(100, 1000).edges()
... })
>>> model = MetaPath2Vec(g, ['uc', 'cu'], window_size=1)
>>> # Use the source node type of etype 'uc'
>>> dataloader = DataLoader(torch.arange(g.num_nodes('user')), batch_size=128,
... shuffle=True, collate_fn=model.sample)
>>> optimizer = SparseAdam(model.parameters(), lr=0.025)
>>> for (pos_u, pos_v, neg_v) in dataloader:
... loss = model(pos_u, pos_v, neg_v)
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
>>> # Get the embeddings of all user nodes
>>> user_nids = torch.LongTensor(model.local_to_global_nid['user'])
>>> user_emb = model.node_embed(user_nids)
"""
def __init__(self,
g,
metapath,
window_size,
emb_dim=128,
negative_size=5,
sparse=True):
super().__init__()
assert len(metapath) + 1 >= window_size, \
f'Expect len(metapath) >= window_size - 1, got {metapath} and {window_size}'
self.hg = g
self.emb_dim = emb_dim
self.metapath = metapath
self.window_size = window_size
self.negative_size = negative_size
# convert edge metapath to node metapath
# get initial source node type
src_type, _, _ = g.to_canonical_etype(metapath[0])
node_metapath = [src_type]
for etype in metapath:
_, _, dst_type = g.to_canonical_etype(etype)
node_metapath.append(dst_type)
self.node_metapath = node_metapath
# Convert the graph into a homogeneous one for global to local node ID mapping
g = to_homogeneous(g)
# Convert it back to the hetero one for local to global node ID mapping
hg = to_heterogeneous(g, self.hg.ntypes, self.hg.etypes)
local_to_global_nid = hg.ndata[NID]
for key, val in local_to_global_nid.items():
local_to_global_nid[key] = list(val.cpu().numpy())
self.local_to_global_nid = local_to_global_nid
num_nodes_total = hg.num_nodes()
node_frequency = torch.zeros(num_nodes_total)
# random walk
for idx in tqdm.trange(hg.num_nodes(node_metapath[0])):
traces, _ = random_walk(g=hg, nodes=[idx], metapath=metapath)
for tr in traces.cpu().numpy():
tr_nids = [
self.local_to_global_nid[node_metapath[i]][tr[i]] for i in range(len(tr))]
node_frequency[torch.LongTensor(tr_nids)] += 1
neg_prob = node_frequency.pow(0.75)
self.neg_prob = neg_prob / neg_prob.sum()
# center node embedding
self.node_embed = nn.Embedding(num_nodes_total, self.emb_dim, sparse=sparse)
self.context_embed = nn.Embedding(num_nodes_total, self.emb_dim, sparse=sparse)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters"""
init_range = 1.0 / self.emb_dim
init.uniform_(self.node_embed.weight.data, -init_range, init_range)
init.constant_(self.context_embed.weight.data, 0)
def sample(self, indices):
"""Sample positive and negative samples
Parameters
----------
indices : torch.Tensor
Node IDs of the source node type from which we perform random walks
Returns
-------
torch.Tensor
Positive center nodes
torch.Tensor
Positive context nodes
torch.Tensor
Negative context nodes
"""
traces, _ = random_walk(g=self.hg, nodes=indices, metapath=self.metapath)
u_list = []
v_list = []
for tr in traces.cpu().numpy():
tr_nids = [
self.local_to_global_nid[self.node_metapath[i]][tr[i]] for i in range(len(tr))]
for i, u in enumerate(tr_nids):
for j, v in enumerate(tr_nids[max(i - self.window_size, 0):i + self.window_size]):
if i == j:
continue
u_list.append(u)
v_list.append(v)
neg_v = choice(self.hg.num_nodes(), size=len(u_list) * self.negative_size,
prob=self.neg_prob).reshape(len(u_list), self.negative_size)
return torch.LongTensor(u_list), torch.LongTensor(v_list), neg_v
def forward(self, pos_u, pos_v, neg_v):
r"""Compute the loss for the batch of positive and negative samples
Parameters
----------
pos_u : torch.Tensor
Positive center nodes
pos_v : torch.Tensor
Positive context nodes
neg_v : torch.Tensor
Negative context nodes
Returns
-------
torch.Tensor
Loss value
"""
emb_u = self.node_embed(pos_u)
emb_v = self.context_embed(pos_v)
emb_neg_v = self.context_embed(neg_v)
score = torch.sum(torch.mul(emb_u, emb_v), dim=1)
score = torch.clamp(score, max=10, min=-10)
score = -F.logsigmoid(score)
neg_score = torch.bmm(emb_neg_v, emb_u.unsqueeze(2)).squeeze()
neg_score = torch.clamp(neg_score, max=10, min=-10)
neg_score = -torch.sum(F.logsigmoid(-neg_score), dim=1)
return torch.mean(score + neg_score)
...@@ -1605,18 +1605,15 @@ def test_label_prop(k, alpha, norm_type, clamp, normalize, reset): ...@@ -1605,18 +1605,15 @@ def test_label_prop(k, alpha, norm_type, clamp, normalize, reset):
# multi-label case # multi-label case
model(g, ml_labels, mask) model(g, ml_labels, mask)
@pytest.mark.parametrize('in_size', [16, 32]) @pytest.mark.parametrize('in_size', [16])
@pytest.mark.parametrize('out_size', [16, 32]) @pytest.mark.parametrize('out_size', [16, 32])
@pytest.mark.parametrize('aggregators', @pytest.mark.parametrize('aggregators',
[['mean', 'max', 'dir2-av'], ['min', 'std', 'dir1-dx'], ['moment3', 'moment4', 'dir3-av']]) [['mean', 'max', 'dir2-av'], ['min', 'std', 'dir1-dx']])
@pytest.mark.parametrize('scalers', [['identity'], ['amplification', 'attenuation']]) @pytest.mark.parametrize('scalers', [['amplification', 'attenuation']])
@pytest.mark.parametrize('delta', [2.5, 7.4]) @pytest.mark.parametrize('delta', [2.5])
@pytest.mark.parametrize('dropout', [0., 0.1])
@pytest.mark.parametrize('num_towers', [1, 4])
@pytest.mark.parametrize('edge_feat_size', [16, 0]) @pytest.mark.parametrize('edge_feat_size', [16, 0])
@pytest.mark.parametrize('residual', [True, False])
def test_dgn_conv(in_size, out_size, aggregators, scalers, delta, def test_dgn_conv(in_size, out_size, aggregators, scalers, delta,
dropout, num_towers, edge_feat_size, residual): edge_feat_size):
dev = F.ctx() dev = F.ctx()
num_nodes = 5 num_nodes = 5
num_edges = 20 num_edges = 20
...@@ -1626,13 +1623,13 @@ def test_dgn_conv(in_size, out_size, aggregators, scalers, delta, ...@@ -1626,13 +1623,13 @@ def test_dgn_conv(in_size, out_size, aggregators, scalers, delta,
transform = dgl.LaplacianPE(k=3, feat_name='eig') transform = dgl.LaplacianPE(k=3, feat_name='eig')
g = transform(g) g = transform(g)
eig = g.ndata['eig'] eig = g.ndata['eig']
model = nn.DGNConv(in_size, out_size, aggregators, scalers, delta, dropout, model = nn.DGNConv(in_size, out_size, aggregators, scalers, delta,
num_towers, edge_feat_size, residual).to(dev) edge_feat_size=edge_feat_size).to(dev)
model(g, h, edge_feat=e, eig_vec=eig) model(g, h, edge_feat=e, eig_vec=eig)
aggregators_non_eig = [aggr for aggr in aggregators if not aggr.startswith('dir')] aggregators_non_eig = [aggr for aggr in aggregators if not aggr.startswith('dir')]
model = nn.DGNConv(in_size, out_size, aggregators_non_eig, scalers, delta, dropout, model = nn.DGNConv(in_size, out_size, aggregators_non_eig, scalers, delta,
num_towers, edge_feat_size, residual).to(dev) edge_feat_size=edge_feat_size).to(dev)
model(g, h, edge_feat=e) model(g, h, edge_feat=e)
def test_DeepWalk(): def test_DeepWalk():
...@@ -1655,3 +1652,17 @@ def test_DeepWalk(): ...@@ -1655,3 +1652,17 @@ def test_DeepWalk():
loss = model(walk) loss = model(walk)
loss.backward() loss.backward()
optim.step() optim.step()
@parametrize_idtype
def test_MetaPath2Vec(idtype):
dev = F.ctx()
g = dgl.heterograph({
('user', 'uc', 'company'): ([0, 0, 2, 1, 3], [1, 2, 1, 3, 0]),
('company', 'cp', 'product'): ([0, 0, 0, 1, 2, 3], [0, 2, 3, 0, 2, 1]),
('company', 'cu', 'user'): ([1, 2, 1, 3, 0], [0, 0, 2, 1, 3]),
('product', 'pc', 'company'): ([0, 2, 3, 0, 2, 1], [0, 0, 0, 1, 2, 3])
}, idtype=idtype, device=dev)
model = nn.MetaPath2Vec(g, ['uc', 'cu'], window_size=1)
model = model.to(dev)
embeds = model.node_embed.weight
assert embeds.shape[0] == g.num_nodes()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment