Unverified Commit 6066fee9 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Model][Feature] PinSage & Random Walk with Restart (#453)

* random walk traces generation

* remove outdated comments

* oops put in the wrong place

* explicit inline

* moving rand_r to util

* pinsage-like model on movielens

* the code runs now

* support cuda

* using readonly graph

* moving random walk to public function

* per-thread seed and openmp support

* pinsage-like model on movielens

* the code runs now

* support cuda

* using readonly graph

* using C random walk

* removing profile decorators

* param initialization

* no grad

* leaky relu fixes everything

* train and save

* WIP

* WIP

* WIP

* seems to work

* evaluation output

* swapping order of val/test and train

* debug

* hyperparam tuning

* prior/training dataset split changes

* random walk reorg

* random walk with restart

* signed comparison fix

* migrating random walk to nodeflow

* Revert "migrating random walk to nodeflow"

This reverts commit f2565347cced7c912a58a529b257c033d9f375b7.

* add README and remove dataset

* new endpoint

* lint

* lint x2

* oops forgot test

* including bpr - better for baseline

* addressing fixes

* throwing random walks out from SamplerOp class

* forgot to move RandomWalk; why did this even work?

* removing legacy garbage

* add todo

* address comments

* stupid bug fix

* call ndarrayvector converter to handle traces
parent 2dff1aba
# PinSage model
NOTE: this version is not using NodeFlow yet.
First, download and extract from https://dgl.ai.s3.us-east-2.amazonaws.com/dataset/ml-1m.tar.gz
One can then run the following to train PinSage on MovieLens-1M:
```bash
python3 train.py --opt Adam --lr 1e-3 --sched none --sgd-switch 25
```
One can also incorporate user and movie features into training:
```bash
python3 train.py --opt Adam --lr 1e-3 --sched none --sgd-switch 25 --use-feature
```
Currently, performance of PinSage on MovieLens-1M has the best mean reciprocal rank of
0.032298±0.048078 on validation (and 0.033695±0.051963 on test set for the same model).
The Implicit Factorization Model from Spotlight has a 0.034572±0.041653 on the test set.
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import tqdm
from rec.model.pinsage import PinSage
from rec.datasets.movielens import MovieLens
from rec.utils import cuda
from rec.adabound import AdaBound
from dgl import DGLGraph
import argparse
import pickle
import os
parser = argparse.ArgumentParser()
parser.add_argument('--opt', type=str, default='SGD')
parser.add_argument('--lr', type=float, default=1)
parser.add_argument('--sched', type=str, default='none')
parser.add_argument('--layers', type=int, default=2)
parser.add_argument('--use-feature', action='store_true')
parser.add_argument('--sgd-switch', type=int, default=-1)
parser.add_argument('--n-negs', type=int, default=1)
parser.add_argument('--loss', type=str, default='hinge')
parser.add_argument('--hard-neg-prob', type=float, default=0)
args = parser.parse_args()
print(args)
cache_file = 'ml.pkl'
if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
ml = pickle.load(f)
else:
ml = MovieLens('./ml-1m')
with open(cache_file, 'wb') as f:
pickle.dump(ml, f)
g = ml.g
neighbors = ml.user_neighbors + ml.movie_neighbors
n_hidden = 100
n_layers = args.layers
batch_size = 256
margin = 0.9
n_negs = args.n_negs
hard_neg_prob = args.hard_neg_prob
sched_lambda = {
'none': lambda epoch: 1,
'decay': lambda epoch: max(0.98 ** epoch, 1e-4),
}
loss_func = {
'hinge': lambda diff: (diff + margin).clamp(min=0).mean(),
'bpr': lambda diff: (1 - torch.sigmoid(-diff)).mean(),
}
model = cuda(PinSage(
g.number_of_nodes(),
[n_hidden] * (n_layers + 1),
20,
0.5,
10,
use_feature=args.use_feature,
G=g,
))
opt = getattr(torch.optim, args.opt)(model.parameters(), lr=args.lr)
sched = torch.optim.lr_scheduler.LambdaLR(opt, sched_lambda[args.sched])
def forward(model, g_prior, nodeset, train=True):
if train:
return model(g_prior, nodeset)
else:
with torch.no_grad():
return model(g_prior, nodeset)
def filter_nid(nids, nid_from):
nids = [nid.numpy() for nid in nids]
nid_from = nid_from.numpy()
np_mask = np.logical_and(*[np.isin(nid, nid_from) for nid in nids])
return [torch.from_numpy(nid[np_mask]) for nid in nids]
def runtrain(g_prior_edges, g_train_edges, train):
global opt
if train:
model.train()
else:
model.eval()
g_prior_src, g_prior_dst = g.find_edges(g_prior_edges)
g_prior = DGLGraph()
g_prior.add_nodes(g.number_of_nodes())
g_prior.add_edges(g_prior_src, g_prior_dst)
g_prior.ndata.update({k: cuda(v) for k, v in g.ndata.items()})
edge_batches = g_train_edges[torch.randperm(g_train_edges.shape[0])].split(batch_size)
with tqdm.tqdm(edge_batches) as tq:
sum_loss = 0
sum_acc = 0
count = 0
for batch_id, batch in enumerate(tq):
count += batch.shape[0]
src, dst = g.find_edges(batch)
dst_neg = []
for i in range(len(dst)):
if np.random.rand() < hard_neg_prob:
nb = torch.LongTensor(neighbors[dst[i].item()])
mask = ~(g.has_edges_between(nb, src[i].item()).byte())
dst_neg.append(np.random.choice(nb[mask].numpy(), n_negs))
else:
dst_neg.append(np.random.randint(
len(ml.user_ids), len(ml.user_ids) + len(ml.movie_ids), n_negs))
dst_neg = torch.LongTensor(dst_neg)
dst = dst.view(-1, 1).expand_as(dst_neg).flatten()
src = src.view(-1, 1).expand_as(dst_neg).flatten()
dst_neg = dst_neg.flatten()
mask = (g_prior.in_degrees(dst_neg) > 0) & \
(g_prior.in_degrees(dst) > 0) & \
(g_prior.in_degrees(src) > 0)
src = src[mask]
dst = dst[mask]
dst_neg = dst_neg[mask]
if len(src) == 0:
continue
nodeset = cuda(torch.cat([src, dst, dst_neg]))
src_size, dst_size, dst_neg_size = \
src.shape[0], dst.shape[0], dst_neg.shape[0]
h_src, h_dst, h_dst_neg = (
forward(model, g_prior, nodeset, train)
.split([src_size, dst_size, dst_neg_size]))
diff = (h_src * (h_dst_neg - h_dst)).sum(1)
loss = loss_func[args.loss](diff)
acc = (diff < 0).sum()
assert loss.item() == loss.item()
grad_sqr_norm = 0
if train:
opt.zero_grad()
loss.backward()
for name, p in model.named_parameters():
assert (p.grad != p.grad).sum() == 0
grad_sqr_norm += p.grad.norm().item() ** 2
opt.step()
sum_loss += loss.item()
sum_acc += acc.item() / n_negs
avg_loss = sum_loss / (batch_id + 1)
avg_acc = sum_acc / count
tq.set_postfix({'loss': '%.6f' % loss.item(),
'avg_loss': '%.3f' % avg_loss,
'avg_acc': '%.3f' % avg_acc,
'grad_norm': '%.6f' % np.sqrt(grad_sqr_norm)})
return avg_loss, avg_acc
def runtest(g_prior_edges, validation=True):
model.eval()
n_users = len(ml.users.index)
n_items = len(ml.movies.index)
g_prior_src, g_prior_dst = g.find_edges(g_prior_edges)
g_prior = DGLGraph()
g_prior.add_nodes(g.number_of_nodes())
g_prior.add_edges(g_prior_src, g_prior_dst)
g_prior.ndata.update({k: cuda(v) for k, v in g.ndata.items()})
hs = []
with torch.no_grad():
with tqdm.trange(n_users + n_items) as tq:
for node_id in tq:
nodeset = cuda(torch.LongTensor([node_id]))
h = forward(model, g_prior, nodeset, False)
hs.append(h)
h = torch.cat(hs, 0)
rr = []
with torch.no_grad():
with tqdm.trange(n_users) as tq:
for u_nid in tq:
uid = ml.user_ids[u_nid]
pids_exclude = ml.ratings[
(ml.ratings['user_id'] == uid) &
(ml.ratings['train'] | ml.ratings['test' if validation else 'valid'])
]['movie_id'].values
pids_candidate = ml.ratings[
(ml.ratings['user_id'] == uid) &
ml.ratings['valid' if validation else 'test']
]['movie_id'].values
pids = np.setdiff1d(ml.movie_ids, pids_exclude)
p_nids = np.array([ml.movie_ids_invmap[pid] for pid in pids])
p_nids_candidate = np.array([ml.movie_ids_invmap[pid] for pid in pids_candidate])
dst = torch.from_numpy(p_nids) + n_users
src = torch.zeros_like(dst).fill_(u_nid)
h_dst = h[dst]
h_src = h[src]
score = (h_src * h_dst).sum(1)
score_sort_idx = score.sort(descending=True)[1].cpu().numpy()
rank_map = {v: i for i, v in enumerate(p_nids[score_sort_idx])}
rank_candidates = np.array([rank_map[p_nid] for p_nid in p_nids_candidate])
rank = 1 / (rank_candidates + 1)
rr.append(rank.mean())
tq.set_postfix({'rank': rank.mean()})
return np.array(rr)
def train():
global opt, sched
best_mrr = 0
for epoch in range(500):
ml.refresh_mask()
g_prior_edges = g.filter_edges(lambda edges: edges.data['prior'])
g_train_edges = g.filter_edges(lambda edges: edges.data['train'] & ~edges.data['inv'])
g_prior_train_edges = g.filter_edges(
lambda edges: edges.data['prior'] | edges.data['train'])
print('Epoch %d validation' % epoch)
with torch.no_grad():
valid_mrr = runtest(g_prior_train_edges, True)
if best_mrr < valid_mrr.mean():
best_mrr = valid_mrr.mean()
torch.save(model.state_dict(), 'model.pt')
print(pd.Series(valid_mrr).describe())
print('Epoch %d test' % epoch)
with torch.no_grad():
test_mrr = runtest(g_prior_train_edges, False)
print(pd.Series(test_mrr).describe())
print('Epoch %d train' % epoch)
runtrain(g_prior_edges, g_train_edges, True)
if epoch == args.sgd_switch:
opt = torch.optim.SGD(model.parameters(), lr=0.6)
sched = torch.optim.lr_scheduler.LambdaLR(opt, sched_lambda['decay'])
elif epoch < args.sgd_switch:
sched.step()
if __name__ == '__main__':
train()
import pandas as pd
import dgl
import os
import torch
import numpy as np
import scipy.sparse as sp
import time
from functools import partial
from .. import randomwalk
import stanfordnlp
import re
import tqdm
import string
class MovieLens(object):
def __init__(self, directory):
'''
directory: path to movielens directory which should have the three
files:
users.dat
movies.dat
ratings.dat
'''
self.directory = directory
users = []
movies = []
ratings = []
# read users
with open(os.path.join(directory, 'users.dat')) as f:
for l in f:
id_, gender, age, occupation, zip_ = l.strip().split('::')
users.append({
'id': int(id_),
'gender': gender,
'age': age,
'occupation': occupation,
'zip': zip_,
})
self.users = pd.DataFrame(users).set_index('id').astype('category')
# read movies
with open(os.path.join(directory, 'movies.dat'), encoding='latin1') as f:
for l in f:
id_, title, genres = l.strip().split('::')
genres_set = set(genres.split('|'))
# extract year
assert re.match(r'.*\([0-9]{4}\)$', title)
year = title[-5:-1]
title = title[:-6].strip()
data = {'id': int(id_), 'title': title, 'year': year}
for g in genres_set:
data[g] = True
movies.append(data)
self.movies = (
pd.DataFrame(movies)
.set_index('id')
.fillna(False)
.astype({'year': 'category'}))
self.genres = self.movies.columns[self.movies.dtypes == bool]
# read ratings
with open(os.path.join(directory, 'ratings.dat')) as f:
for l in f:
user_id, movie_id, rating, timestamp = [int(_) for _ in l.split('::')]
ratings.append({
'user_id': user_id,
'movie_id': movie_id,
'rating': rating,
'timestamp': timestamp,
})
ratings = pd.DataFrame(ratings)
movie_count = ratings['movie_id'].value_counts()
movie_count.name = 'movie_count'
ratings = ratings.join(movie_count, on='movie_id')
self.ratings = ratings
# drop users and movies which do not exist in ratings
self.users = self.users[self.users.index.isin(self.ratings['user_id'])]
self.movies = self.movies[self.movies.index.isin(self.ratings['movie_id'])]
self.data_split()
self.build_graph()
self.find_neighbors(0.2, 2000, 1000)
def split_user(self, df, filter_counts=False):
df_new = df.copy()
df_new['prob'] = 0
if filter_counts:
df_new_sub = (df_new['movie_count'] >= 10).nonzero()[0]
else:
df_new_sub = df_new['train'].nonzero()[0]
prob = np.linspace(0, 1, df_new_sub.shape[0], endpoint=False)
np.random.shuffle(prob)
df_new['prob'].iloc[df_new_sub] = prob
return df_new
def data_split(self):
self.ratings = self.ratings.groupby('user_id', group_keys=False).apply(
partial(self.split_user, filter_counts=True))
self.ratings['train'] = self.ratings['prob'] <= 0.8
self.ratings['valid'] = (self.ratings['prob'] > 0.8) & (self.ratings['prob'] <= 0.9)
self.ratings['test'] = self.ratings['prob'] > 0.9
self.ratings.drop(['prob'], axis=1, inplace=True)
def build_graph(self):
user_ids = list(self.users.index)
movie_ids = list(self.movies.index)
user_ids_invmap = {id_: i for i, id_ in enumerate(user_ids)}
movie_ids_invmap = {id_: i for i, id_ in enumerate(movie_ids)}
self.user_ids = user_ids
self.movie_ids = movie_ids
self.user_ids_invmap = user_ids_invmap
self.movie_ids_invmap = movie_ids_invmap
g = dgl.DGLGraph()
g.add_nodes(len(user_ids) + len(movie_ids))
# user features
for user_column in self.users.columns:
udata = torch.zeros(g.number_of_nodes(), dtype=torch.int64)
# 0 for padding
udata[:len(user_ids)] = \
torch.LongTensor(self.users[user_column].cat.codes.values.astype('int64') + 1)
g.ndata[user_column] = udata
# movie genre
movie_genres = torch.from_numpy(self.movies[self.genres].values.astype('float32'))
g.ndata['genre'] = torch.zeros(g.number_of_nodes(), len(self.genres))
g.ndata['genre'][len(user_ids):len(user_ids) + len(movie_ids)] = movie_genres
# movie year
g.ndata['year'] = torch.zeros(g.number_of_nodes(), dtype=torch.int64)
# 0 for padding
g.ndata['year'][len(user_ids):len(user_ids) + len(movie_ids)] = \
torch.LongTensor(self.movies['year'].cat.codes.values.astype('int64') + 1)
# movie title
nlp = stanfordnlp.Pipeline(use_gpu=False, processors='tokenize,lemma')
vocab = set()
title_words = []
for t in tqdm.tqdm(self.movies['title'].values):
doc = nlp(t)
words = set()
for s in doc.sentences:
words.update(w.lemma.lower() for w in s.words
if not re.fullmatch(r'['+string.punctuation+']+', w.lemma))
vocab.update(words)
title_words.append(words)
vocab = list(vocab)
vocab_invmap = {w: i for i, w in enumerate(vocab)}
# bag-of-words
g.ndata['title'] = torch.zeros(g.number_of_nodes(), len(vocab))
for i, tw in enumerate(tqdm.tqdm(title_words)):
g.ndata['title'][i, [vocab_invmap[w] for w in tw]] = 1
self.vocab = vocab
self.vocab_invmap = vocab_invmap
rating_user_vertices = [user_ids_invmap[id_] for id_ in self.ratings['user_id'].values]
rating_movie_vertices = [movie_ids_invmap[id_] + len(user_ids)
for id_ in self.ratings['movie_id'].values]
self.rating_user_vertices = rating_user_vertices
self.rating_movie_vertices = rating_movie_vertices
g.add_edges(
rating_user_vertices,
rating_movie_vertices,
data={'inv': torch.zeros(self.ratings.shape[0], dtype=torch.uint8)})
g.add_edges(
rating_movie_vertices,
rating_user_vertices,
data={'inv': torch.ones(self.ratings.shape[0], dtype=torch.uint8)})
self.g = g
def find_neighbors(self, restart_prob, max_nodes, top_T):
# TODO: replace with more efficient PPR estimation
neighbor_probs, neighbors = randomwalk.random_walk_distribution_topt(
self.g, self.g.nodes(), restart_prob, max_nodes, top_T)
self.user_neighbors = []
for i in range(len(self.user_ids)):
user_neighbor = neighbors[i]
self.user_neighbors.append(user_neighbor.tolist())
self.movie_neighbors = []
for i in range(len(self.user_ids), len(self.user_ids) + len(self.movie_ids)):
movie_neighbor = neighbors[i]
self.movie_neighbors.append(movie_neighbor.tolist())
def generate_mask(self):
while True:
ratings = self.ratings.groupby('user_id', group_keys=False).apply(self.split_user)
prior_prob = ratings['prob'].values
for i in range(5):
train_mask = (prior_prob >= 0.2 * i) & (prior_prob < 0.2 * (i + 1))
prior_mask = ~train_mask
train_mask &= ratings['train'].values
prior_mask &= ratings['train'].values
yield prior_mask, train_mask
def refresh_mask(self):
if not hasattr(self, 'masks'):
self.masks = self.generate_mask()
prior_mask, train_mask = next(self.masks)
valid_tensor = torch.from_numpy(self.ratings['valid'].values.astype('uint8'))
test_tensor = torch.from_numpy(self.ratings['test'].values.astype('uint8'))
train_tensor = torch.from_numpy(train_mask.astype('uint8'))
prior_tensor = torch.from_numpy(prior_mask.astype('uint8'))
edge_data = {
'prior': prior_tensor,
'valid': valid_tensor,
'test': test_tensor,
'train': train_tensor,
}
self.g.edges[self.rating_user_vertices, self.rating_movie_vertices].data.update(edge_data)
self.g.edges[self.rating_movie_vertices, self.rating_user_vertices].data.update(edge_data)
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from .. import randomwalk
from ..utils import cuda
def create_embeddings(n_nodes, n_features):
return nn.Parameter(torch.randn(n_nodes, n_features))
def mix_embeddings(h, ndata, emb, proj):
'''Combine node-specific trainable embedding ``h`` with categorical inputs
(projected by ``emb``) and numeric inputs (projected by ``proj``).
'''
e = []
for key, value in ndata.items():
if value.dtype == torch.int64:
e.append(emb[key](value))
elif value.dtype == torch.float32:
e.append(proj[key](value))
return h + torch.stack(e, 0).sum(0)
def get_embeddings(h, nodeset):
return h[nodeset]
def put_embeddings(h, nodeset, new_embeddings):
n_nodes = nodeset.shape[0]
n_features = h.shape[1]
return h.scatter(0, nodeset[:, None].expand(n_nodes, n_features), new_embeddings)
def safediv(a, b):
b = torch.where(b == 0, torch.ones_like(b), b)
return a / b
def init_weight(w, func_name, nonlinearity):
getattr(nn.init, func_name)(w, gain=nn.init.calculate_gain(nonlinearity))
def init_bias(w):
nn.init.constant_(w, 0)
class PinSageConv(nn.Module):
def __init__(self, in_features, out_features, hidden_features):
super(PinSageConv, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.hidden_features = hidden_features
self.Q = nn.Linear(in_features, hidden_features)
self.W = nn.Linear(in_features + hidden_features, out_features)
init_weight(self.Q.weight, 'xavier_uniform_', 'leaky_relu')
init_weight(self.W.weight, 'xavier_uniform_', 'leaky_relu')
init_bias(self.Q.bias)
init_bias(self.W.bias)
def forward(self, h, nodeset, nb_nodes, nb_weights):
'''
h: node embeddings (num_total_nodes, in_features), or a container
of the node embeddings (for distributed computing)
nodeset: node IDs in this minibatch (num_nodes,)
nb_nodes: neighbor node IDs of each node in nodeset (num_nodes, num_neighbors)
nb_weights: weight of each neighbor node (num_nodes, num_neighbors)
return: new node embeddings (num_nodes, out_features)
'''
n_nodes, T = nb_nodes.shape
h_nodeset = get_embeddings(h, nodeset) # (n_nodes, in_features)
h_neighbors = get_embeddings(h, nb_nodes.view(-1)).view(n_nodes, T, self.in_features)
h_neighbors = F.leaky_relu(self.Q(h_neighbors))
h_agg = safediv(
(nb_weights[:, :, None] * h_neighbors).sum(1),
nb_weights.sum(1, keepdim=True))
h_concat = torch.cat([h_nodeset, h_agg], 1)
h_new = F.leaky_relu(self.W(h_concat))
h_new = safediv(h_new, h_new.norm(dim=1, keepdim=True))
return h_new
class PinSage(nn.Module):
'''
Completes a multi-layer PinSage convolution
G: DGLGraph
feature_sizes: the dimensionality of input/hidden/output features
T: number of neighbors we pick for each node
restart_prob: restart probability
max_nodes: max number of nodes visited for each seed
'''
def __init__(self, num_nodes, feature_sizes, T, restart_prob, max_nodes,
use_feature=False, G=None):
super(PinSage, self).__init__()
self.T = T
self.restart_prob = restart_prob
self.max_nodes = max_nodes
self.in_features = feature_sizes[0]
self.out_features = feature_sizes[-1]
self.n_layers = len(feature_sizes) - 1
self.convs = nn.ModuleList()
for i in range(self.n_layers):
self.convs.append(PinSageConv(
feature_sizes[i], feature_sizes[i+1], feature_sizes[i+1]))
self.h = create_embeddings(num_nodes, self.in_features)
self.use_feature = use_feature
if use_feature:
self.emb = nn.ModuleDict()
self.proj = nn.ModuleDict()
for key, scheme in G.node_attr_schemes().items():
if scheme.dtype == torch.int64:
self.emb[key] = nn.Embedding(
G.ndata[key].max().item() + 1,
self.in_features,
padding_idx=0)
elif scheme.dtype == torch.float32:
self.proj[key] = nn.Sequential(
nn.Linear(scheme.shape[0], self.in_features),
nn.LeakyReLU(),
)
def forward(self, G, nodeset):
'''
Given a complete embedding matrix h and a list of node IDs, return
the output embeddings of these node IDs.
nodeset: node IDs in this minibatch (num_nodes,)
return: new node embeddings (num_nodes, out_features)
'''
if self.use_feature:
h = mix_embeddings(self.h, G.ndata, self.emb, self.proj)
else:
h = self.h
nodeflow = randomwalk.random_walk_nodeflow(
G, nodeset, self.n_layers, self.restart_prob, self.max_nodes, self.T)
for i, (nodeset, nb_weights, nb_nodes) in enumerate(nodeflow):
new_embeddings = self.convs[i](h, nodeset, nb_nodes, nb_weights)
h = put_embeddings(h, nodeset, new_embeddings)
h_new = get_embeddings(h, nodeset)
return h_new
import torch
import dgl
from ..utils import cuda
from collections import Counter
def random_walk_sampler(G, nodeset, restart_prob, max_nodes):
'''
G: DGLGraph
nodeset: 1D CPU Tensor of node IDs
restart_prob: float
max_nodes: int
return: list[list[Tensor]]
'''
traces = dgl.contrib.sampling.bipartite_single_sided_random_walk_with_restart(
G, nodeset, restart_prob, max_nodes)
return traces
# Note: this function is not friendly to giant graphs since we use a matrix
# with size (num_nodes_in_nodeset, num_nodes_in_graph).
def random_walk_distribution(G, nodeset, restart_prob, max_nodes):
n_nodes = nodeset.shape[0]
n_available_nodes = G.number_of_nodes()
traces = random_walk_sampler(G, nodeset, restart_prob, max_nodes)
visited_counts = torch.zeros(n_nodes, n_available_nodes)
for i in range(n_nodes):
visited_nodes = torch.cat(traces[i])
visited_counts[i].scatter_add_(0, visited_nodes, torch.ones_like(visited_nodes, dtype=torch.float32))
return visited_counts
def random_walk_distribution_topt(G, nodeset, restart_prob, max_nodes, top_T):
'''
returns the top T important neighbors of each node in nodeset, as well as
the weights of the neighbors.
'''
visited_prob = random_walk_distribution(G, nodeset, restart_prob, max_nodes)
weights, nodes = visited_prob.topk(top_T, 1)
weights = weights / weights.sum(1, keepdim=True)
return weights, nodes
def random_walk_nodeflow(G, nodeset, n_layers, restart_prob, max_nodes, top_T):
'''
returns a list of triplets (
"active" node IDs whose embeddings are computed at the i-th layer (num_nodes,)
weight of each neighboring node of each "active" node on the i-th layer (num_nodes, top_T)
neighboring node IDs for each "active" node on the i-th layer (num_nodes, top_T)
)
'''
dev = nodeset.device
nodeset = nodeset.cpu()
nodeflow = []
cur_nodeset = nodeset
for i in reversed(range(n_layers)):
nb_weights, nb_nodes = random_walk_distribution_topt(G, cur_nodeset, restart_prob, max_nodes, top_T)
nodeflow.insert(0, (cur_nodeset.to(dev), nb_weights.to(dev), nb_nodes.to(dev)))
cur_nodeset = torch.cat([nb_nodes.view(-1), cur_nodeset]).unique()
return nodeflow
import torch
def cuda(x):
if torch.cuda.is_available():
return x.cuda()
else:
return x
......@@ -23,6 +23,8 @@ typedef dgl::runtime::NDArray FloatArray;
struct Subgraph;
struct NodeFlow;
const dgl_id_t DGL_INVALID_ID = static_cast<dgl_id_t>(-1);
/*!
* \brief This class references data in std::vector.
*
......
......@@ -8,14 +8,6 @@
#include "c_runtime_api.h"
#ifdef _MSC_VER
// rand in MS compiler works well in multi-threading.
static inline int rand_r(unsigned *seed) {
return rand();
}
#define _CRT_RAND_S
#endif
namespace dgl {
namespace runtime {
......
......@@ -8,13 +8,42 @@
#include <vector>
#include <string>
#include <cstdlib>
#include <ctime>
#include "graph_interface.h"
#include "nodeflow.h"
#ifdef _MSC_VER
// rand in MS compiler works well in multi-threading.
inline int rand_r(unsigned *seed) {
return rand();
}
inline unsigned int randseed() {
unsigned int seed = time(nullptr);
srand(seed); // need to set seed manually since there's no rand_r
return seed;
}
#define _CRT_RAND_S
#else
inline unsigned int randseed() {
return time(nullptr);
}
#endif
namespace dgl {
class ImmutableGraph;
struct RandomWalkTraces {
/*! \brief number of traces generated for each seed */
IdArray trace_counts;
/*! \brief length of each trace, concatenated */
IdArray trace_lengths;
/*! \brief the vertices, concatenated */
IdArray vertices;
};
class SamplerOp {
public:
/*!
......@@ -49,19 +78,77 @@ class SamplerOp {
const std::vector<dgl_id_t>& seeds,
const std::string &neigh_type,
IdArray layer_sizes);
};
/*!
/*!
* \brief Batch-generate random walk traces
* \param seeds The array of starting vertex IDs
* \param num_traces The number of traces to generate for each seed
* \param num_hops The number of hops for each trace
* \return a flat ID array with shape (num_seeds, num_traces, num_hops + 1)
*/
static IdArray RandomWalk(const GraphInterface *gptr,
IdArray RandomWalk(const GraphInterface *gptr,
IdArray seeds,
int num_traces,
int num_hops);
};
/*!
* \brief Batch-generate random walk traces with restart
*
* Stop generating traces if max_frequrent_visited_nodes nodes are visited more than
* max_visit_counts times.
*
* \param seeds The array of starting vertex IDs
* \param restart_prob The restart probability
* \param visit_threshold_per_seed Stop generating more traces once the number of nodes
* visited for a seed exceeds this number. (Algorithm 1 in [1])
* \param max_visit_counts Alternatively, stop generating traces for a seed if no less
* than \c max_frequent_visited_nodes are visited no less than \c max_visit_counts
* times. (Algorithm 2 in [1])
* \param max_frequent_visited_nodes See \c max_visit_counts
* \return A RandomWalkTraces instance.
*
* \sa [1] Eksombatchai et al., 2017 https://arxiv.org/abs/1711.07601
*/
RandomWalkTraces RandomWalkWithRestart(
const GraphInterface *gptr,
IdArray seeds,
double restart_prob,
uint64_t visit_threshold_per_seed,
uint64_t max_visit_counts,
uint64_t max_frequent_visited_nodes);
/*
* \brief Batch-generate random walk traces with restart on a bipartite graph, walking two
* hops at a time.
*
* Since it is walking on a bipartite graph, the vertices of a trace will always stay on the
* same side.
*
* Stop generating traces if max_frequrent_visited_nodes nodes are visited more than
* max_visit_counts times.
*
* \param seeds The array of starting vertex IDs
* \param restart_prob The restart probability
* \param visit_threshold_per_seed Stop generating more traces once the number of nodes
* visited for a seed exceeds this number. (Algorithm 1 in [1])
* \param max_visit_counts Alternatively, stop generating traces for a seed if no less
* than \c max_frequent_visited_nodes are visited no less than \c max_visit_counts
* times. (Algorithm 2 in [1])
* \param max_frequent_visited_nodes See \c max_visit_counts
* \return A RandomWalkTraces instance.
*
* \note Doesn't verify whether the graph is indeed a bipartite graph
*
* \sa [1] Eksombatchai et al., 2017 https://arxiv.org/abs/1711.07601
*/
RandomWalkTraces BipartiteSingleSidedRandomWalkWithRestart(
const GraphInterface *gptr,
IdArray seeds,
double restart_prob,
uint64_t visit_threshold_per_seed,
uint64_t max_visit_counts,
uint64_t max_frequent_visited_nodes);
} // namespace dgl
......
from ... import utils
from ... import backend as F
from ..._ffi.function import _init_api
__all__ = ['random_walk']
__all__ = ['random_walk',
'random_walk_with_restart',
'bipartite_single_sided_random_walk_with_restart',
]
def random_walk(g, seeds, num_traces, num_hops):
......@@ -11,7 +15,7 @@ def random_walk(g, seeds, num_traces, num_hops):
Parameters
----------
g : DGLGraph
The graph. Must be readonly.
The graph.
seeds : Tensor
The node ID tensor from which the random walk traces starts.
num_traces : int
......@@ -28,4 +32,136 @@ def random_walk(g, seeds, num_traces, num_hops):
traces[i, j, 0] are always starting nodes (i.e. seed[i]).
"""
return g._graph.random_walk(utils.toindex(seeds), num_traces, num_hops)
if len(seeds) == 0:
return utils.toindex([]).tousertensor()
seeds = utils.toindex(seeds).todgltensor()
traces = _CAPI_DGLRandomWalk(g._graph._handle, seeds, num_traces, num_hops)
return F.zerocopy_from_dlpack(traces.to_dlpack())
def _split_traces(traces):
"""Splits the flattened RandomWalkTraces structure into list of list
of tensors.
Parameters
----------
traces : PackedFunc object of RandomWalkTraces structure
Returns
-------
traces : list[list[Tensor]]
traces[i][j] is the j-th trace generated for i-th seed.
"""
trace_counts = F.zerocopy_to_numpy(
F.zerocopy_from_dlpack(traces(0).to_dlpack())).tolist()
trace_lengths = F.zerocopy_from_dlpack(traces(1).to_dlpack())
trace_vertices = F.zerocopy_from_dlpack(traces(2).to_dlpack())
trace_vertices = F.split(
trace_vertices, F.zerocopy_to_numpy(trace_lengths).tolist(), 0)
traces = []
s = 0
for c in trace_counts:
traces.append(trace_vertices[s:s+c])
s += c
return traces
def random_walk_with_restart(
g, seeds, restart_prob, max_nodes_per_seed,
max_visit_counts=0, max_frequent_visited_nodes=0):
"""Batch-generate random walk traces on given graph with restart probability.
Parameters
----------
g : DGLGraph
The graph.
seeds : Tensor
The node ID tensor from which the random walk traces starts.
restart_prob : float
Probability to stop a random walk after each step.
max_nodes_per_seed : int
Stop generating traces for a seed if the total number of nodes
visited exceeds this number. [1]
max_visit_counts : int, optional
max_frequent_visited_nodes : int, optional
Alternatively, stop generating traces for a seed if no less than
``max_frequent_visited_nodes`` are visited no less than
``max_visit_counts`` times. [1]
Returns
-------
traces : list[list[Tensor]]
traces[i][j] is the j-th trace generated for i-th seed.
Notes
-----
The traces does **not** include the seed nodes themselves.
Reference
---------
[1] Eksombatchai et al., 2017 https://arxiv.org/abs/1711.07601
"""
if len(seeds) == 0:
return []
seeds = utils.toindex(seeds).todgltensor()
traces = _CAPI_DGLRandomWalkWithRestart(
g._graph._handle, seeds, restart_prob, max_nodes_per_seed,
max_visit_counts, max_frequent_visited_nodes)
return _split_traces(traces)
def bipartite_single_sided_random_walk_with_restart(
g, seeds, restart_prob, max_nodes_per_seed,
max_visit_counts=0, max_frequent_visited_nodes=0):
"""Batch-generate random walk traces on given graph with restart probability.
The graph must be a bipartite graph.
A single random walk step involves two normal steps, so that the "visited"
nodes always stay on the same side. [1]
Parameters
----------
g : DGLGraph
The graph.
seeds : Tensor
The node ID tensor from which the random walk traces starts.
restart_prob : float
Probability to stop a random walk after each step.
max_nodes_per_seed : int
Stop generating traces for a seed if the total number of nodes
visited exceeds this number. [1]
max_visit_counts : int, optional
max_frequent_visited_nodes : int, optional
Alternatively, stop generating traces for a seed if no less than
``max_frequent_visited_nodes`` are visited no less than
``max_visit_counts`` times. [1]
Returns
-------
traces : list[list[Tensor]]
traces[i][j] is the j-th trace generated for i-th seed.
Notes
-----
The current implementation does not ensure that the graph is a bipartite
graph.
The traces does **not** include the seed nodes themselves.
Reference
---------
[1] Eksombatchai et al., 2017 https://arxiv.org/abs/1711.07601
"""
if len(seeds) == 0:
return []
seeds = utils.toindex(seeds).todgltensor()
traces = _CAPI_DGLBipartiteSingleSidedRandomWalkWithRestart(
g._graph._handle, seeds, restart_prob, max_nodes_per_seed,
max_visit_counts, max_frequent_visited_nodes)
return _split_traces(traces)
_init_api('dgl.randomwalk', __name__)
......@@ -723,20 +723,6 @@ class GraphIndex(object):
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
return inc, shuffle_idx
def random_walk(self, seeds, num_traces, num_hops):
"""Random walk sampling.
Returns a user Tensor of random walk traces with shape
(num_seeds, num_traces, num_hops + 1)
"""
if len(seeds) == 0:
return utils.toindex([])
seeds = seeds.todgltensor()
traces = _CAPI_DGLGraphRandomWalk(self._handle, seeds, num_traces, num_hops)
return F.zerocopy_from_dlpack(traces.to_dlpack())
def to_networkx(self):
"""Convert to networkx graph.
......
......@@ -433,15 +433,4 @@ DGL_REGISTER_GLOBAL("nodeflow._CAPI_NodeFlowGetBlockAdj")
*rv = ConvertAdjToPackedFunc(res);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphRandomWalk")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const IdArray seeds = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const int num_traces = args[2];
const int num_hops = args[3];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
*rv = SamplerOp::RandomWalk(ptr, seeds, num_traces, num_hops);
});
} // namespace dgl
/*!
* Copyright (c) 2018 by Contributors
* \file graph/sampler.cc
* \brief DGL sampler implementation
*/
#include <dgl/sampler.h>
#include <dmlc/omp.h>
#include <dgl/immutable_graph.h>
#include <algorithm>
#include <cstdlib>
#include <cmath>
#include <numeric>
#include <functional>
#include "../c_api_common.h"
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
namespace dgl {
using Walker = std::function<dgl_id_t(
const GraphInterface *, unsigned int *, dgl_id_t)>;
namespace {
/*!
* \brief Randomly select a single direct successor given the current vertex
* \return Whether such a successor could be found
*/
dgl_id_t WalkOneHop(
const GraphInterface *gptr,
unsigned int *random_seed,
dgl_id_t cur) {
const auto succ = gptr->SuccVec(cur);
const size_t size = succ.size();
if (size == 0)
return DGL_INVALID_ID;
return succ[rand_r(random_seed) % size];
}
/*!
* \brief Randomly select a single direct successor after \c hops hops given the current vertex
* \return Whether such a successor could be found
*/
template<int hops>
dgl_id_t WalkMultipleHops(
const GraphInterface *gptr,
unsigned int *random_seed,
dgl_id_t cur) {
dgl_id_t next;
for (int i = 0; i < hops; ++i) {
if ((next = WalkOneHop(gptr, random_seed, cur)) == DGL_INVALID_ID)
return DGL_INVALID_ID;
cur = next;
}
return cur;
}
IdArray GenericRandomWalk(
const GraphInterface *gptr,
IdArray seeds,
int num_traces,
int num_hops,
Walker walker) {
const int64_t num_nodes = seeds->shape[0];
const dgl_id_t *seed_ids = static_cast<dgl_id_t *>(seeds->data);
IdArray traces = IdArray::Empty(
{num_nodes, num_traces, num_hops + 1},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0});
dgl_id_t *trace_data = static_cast<dgl_id_t *>(traces->data);
// FIXME: does OpenMP work with exceptions? Especially without throwing SIGABRT?
unsigned int random_seed = randseed();
dgl_id_t next;
for (int64_t i = 0; i < num_nodes; ++i) {
const dgl_id_t seed_id = seed_ids[i];
for (int j = 0; j < num_traces; ++j) {
dgl_id_t cur = seed_id;
const int kmax = num_hops + 1;
for (int k = 0; k < kmax; ++k) {
const int64_t offset = (i * num_traces + j) * kmax + k;
trace_data[offset] = cur;
if ((next = walker(gptr, &random_seed, cur)) == DGL_INVALID_ID)
LOG(FATAL) << "no successors from vertex " << cur;
cur = next;
}
}
}
return traces;
}
RandomWalkTraces GenericRandomWalkWithRestart(
const GraphInterface *gptr,
IdArray seeds,
double restart_prob,
uint64_t visit_threshold_per_seed,
uint64_t max_visit_counts,
uint64_t max_frequent_visited_nodes,
Walker walker) {
std::vector<dgl_id_t> vertices;
std::vector<size_t> trace_lengths, trace_counts, visit_counts;
const dgl_id_t *seed_ids = static_cast<dgl_id_t *>(seeds->data);
const uint64_t num_nodes = seeds->shape[0];
int64_t restart_bound = static_cast<int64_t>(restart_prob * RAND_MAX);
visit_counts.resize(gptr->NumVertices());
unsigned int random_seed = randseed();
for (uint64_t i = 0; i < num_nodes; ++i) {
int stop = 0;
size_t total_trace_length = 0;
size_t num_traces = 0;
uint64_t num_frequent_visited_nodes = 0;
std::fill(visit_counts.begin(), visit_counts.end(), 0);
while (1) {
dgl_id_t cur = seed_ids[i], next;
size_t trace_length = 0;
for (; ; ++trace_length) {
if ((trace_length > 0) &&
(++visit_counts[cur] == max_visit_counts) &&
(++num_frequent_visited_nodes == max_frequent_visited_nodes))
stop = 1;
if ((trace_length > 0) && (rand_r(&random_seed) < restart_bound))
break;
if ((next = walker(gptr, &random_seed, cur)) == DGL_INVALID_ID)
LOG(FATAL) << "no successors from vertex " << cur;
cur = next;
vertices.push_back(cur);
}
total_trace_length += trace_length;
++num_traces;
trace_lengths.push_back(trace_length);
if ((total_trace_length >= visit_threshold_per_seed) || stop)
break;
}
trace_counts.push_back(num_traces);
}
RandomWalkTraces traces;
traces.trace_counts = IdArray::Empty(
{static_cast<int64_t>(trace_counts.size())},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0});
traces.trace_lengths = IdArray::Empty(
{static_cast<int64_t>(trace_lengths.size())},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0});
traces.vertices = IdArray::Empty(
{static_cast<int64_t>(vertices.size())},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0});
dgl_id_t *trace_counts_data = static_cast<dgl_id_t *>(traces.trace_counts->data);
dgl_id_t *trace_lengths_data = static_cast<dgl_id_t *>(traces.trace_lengths->data);
dgl_id_t *vertices_data = static_cast<dgl_id_t *>(traces.vertices->data);
std::copy(trace_counts.begin(), trace_counts.end(), trace_counts_data);
std::copy(trace_lengths.begin(), trace_lengths.end(), trace_lengths_data);
std::copy(vertices.begin(), vertices.end(), vertices_data);
return traces;
}
}; // namespace
PackedFunc ConvertRandomWalkTracesToPackedFunc(const RandomWalkTraces &t) {
return ConvertNDArrayVectorToPackedFunc({
t.trace_counts, t.trace_lengths, t.vertices});
}
IdArray RandomWalk(
const GraphInterface *gptr,
IdArray seeds,
int num_traces,
int num_hops) {
return GenericRandomWalk(gptr, seeds, num_traces, num_hops, WalkMultipleHops<1>);
}
RandomWalkTraces RandomWalkWithRestart(
const GraphInterface *gptr,
IdArray seeds,
double restart_prob,
uint64_t visit_threshold_per_seed,
uint64_t max_visit_counts,
uint64_t max_frequent_visited_nodes) {
return GenericRandomWalkWithRestart(
gptr, seeds, restart_prob, visit_threshold_per_seed, max_visit_counts,
max_frequent_visited_nodes, WalkMultipleHops<1>);
}
RandomWalkTraces BipartiteSingleSidedRandomWalkWithRestart(
const GraphInterface *gptr,
IdArray seeds,
double restart_prob,
uint64_t visit_threshold_per_seed,
uint64_t max_visit_counts,
uint64_t max_frequent_visited_nodes) {
return GenericRandomWalkWithRestart(
gptr, seeds, restart_prob, visit_threshold_per_seed, max_visit_counts,
max_frequent_visited_nodes, WalkMultipleHops<2>);
}
DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalk")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const IdArray seeds = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const int num_traces = args[2];
const int num_hops = args[3];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
*rv = RandomWalk(ptr, seeds, num_traces, num_hops);
});
DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalkWithRestart")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const IdArray seeds = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const double restart_prob = args[2];
const uint64_t visit_threshold_per_seed = args[3];
const uint64_t max_visit_counts = args[4];
const uint64_t max_frequent_visited_nodes = args[5];
const GraphInterface *gptr = static_cast<const GraphInterface *>(ghandle);
*rv = ConvertRandomWalkTracesToPackedFunc(
RandomWalkWithRestart(gptr, seeds, restart_prob, visit_threshold_per_seed,
max_visit_counts, max_frequent_visited_nodes));
});
DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLBipartiteSingleSidedRandomWalkWithRestart")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const IdArray seeds = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const double restart_prob = args[2];
const uint64_t visit_threshold_per_seed = args[3];
const uint64_t max_visit_counts = args[4];
const uint64_t max_frequent_visited_nodes = args[5];
const GraphInterface *gptr = static_cast<const GraphInterface *>(ghandle);
*rv = ConvertRandomWalkTracesToPackedFunc(
BipartiteSingleSidedRandomWalkWithRestart(
gptr, seeds, restart_prob, visit_threshold_per_seed,
max_visit_counts, max_frequent_visited_nodes));
});
}; // namespace dgl
......@@ -13,13 +13,6 @@
#include <numeric>
#include "../c_api_common.h"
#ifdef _MSC_VER
// rand in MS compiler works well in multi-threading.
int rand_r(unsigned *seed) {
return rand();
}
#endif
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
......@@ -385,7 +378,7 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
int num_hops,
size_t num_neighbor,
const bool add_self_loop) {
unsigned int time_seed = time(nullptr);
unsigned int time_seed = randseed();
const size_t num_seeds = seeds.size();
auto orig_csr = edge_type == "in" ? graph->GetInCSR() : graph->GetOutCSR();
const dgl_id_t* val_list = orig_csr->edge_ids.data();
......@@ -540,47 +533,6 @@ NodeFlow SamplerOp::NeighborUniformSample(const ImmutableGraph *graph,
add_self_loop);
}
IdArray SamplerOp::RandomWalk(
const GraphInterface *gptr,
IdArray seeds,
int num_traces,
int num_hops) {
const int num_nodes = seeds->shape[0];
const dgl_id_t *seed_ids = static_cast<dgl_id_t *>(seeds->data);
IdArray traces = IdArray::Empty(
{num_nodes, num_traces, num_hops + 1},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0});
dgl_id_t *trace_data = static_cast<dgl_id_t *>(traces->data);
// FIXME: does OpenMP work with exceptions? Especially without throwing SIGABRT?
unsigned int random_seed = time(nullptr);
for (int i = 0; i < num_nodes; ++i) {
const dgl_id_t seed_id = seed_ids[i];
for (int j = 0; j < num_traces; ++j) {
dgl_id_t cur = seed_id;
const int kmax = num_hops + 1;
for (int k = 0; k < kmax; ++k) {
const size_t offset = ((size_t)i * num_traces + j) * kmax + k;
trace_data[offset] = cur;
const auto succ = gptr->SuccVec(cur);
const size_t size = succ.size();
if (size == 0) {
LOG(FATAL) << "no successors from vertex " << cur;
return traces;
}
cur = succ[rand_r(&random_seed) % size];
}
}
}
return traces;
}
namespace {
void ConstructLayers(const int64_t *indptr,
const dgl_id_t *indices,
......@@ -603,7 +555,7 @@ namespace {
size_t curr = 0;
size_t next = node_mapping->size();
unsigned int rand_seed = time(nullptr);
unsigned int rand_seed = randseed();
for (int64_t i = num_layers - 1; i >= 0; --i) {
const int64_t layer_size = layer_sizes_data[i];
std::unordered_set<dgl_id_t> candidate_set;
......
import dgl
from dgl import utils
import backend as F
import numpy as np
def test_random_walk():
edge_list = [(0, 1), (1, 2), (2, 3), (3, 4),
(4, 3), (3, 2), (2, 1), (1, 0)]
seeds = [0, 1]
n_traces = 3
n_hops = 4
g = dgl.DGLGraph(edge_list, readonly=True)
traces = dgl.contrib.sampling.random_walk(g, seeds, n_traces, n_hops)
traces = F.zerocopy_to_numpy(traces)
assert traces.shape == (len(seeds), n_traces, n_hops + 1)
for i, seed in enumerate(seeds):
assert (traces[i, :, 0] == seeds[i]).all()
trace_diff = np.diff(traces, axis=-1)
# only nodes with adjacent IDs are connected
assert (np.abs(trace_diff) == 1).all()
def test_random_walk_with_restart():
edge_list = [(0, 1), (1, 2), (2, 3), (3, 4),
(4, 3), (3, 2), (2, 1), (1, 0)]
seeds = [0, 1]
max_nodes = 10
g = dgl.DGLGraph(edge_list)
# test normal RWR
traces = dgl.contrib.sampling.random_walk_with_restart(g, seeds, 0.2, max_nodes)
assert len(traces) == len(seeds)
for traces_per_seed in traces:
total_nodes = 0
for t in traces_per_seed:
total_nodes += len(t)
trace_diff = np.diff(F.zerocopy_to_numpy(t), axis=-1)
assert (np.abs(trace_diff) == 1).all()
assert total_nodes >= max_nodes
# test RWR with early stopping
traces = dgl.contrib.sampling.random_walk_with_restart(
g, seeds, 1, 100, max_nodes, 1)
assert len(traces) == len(seeds)
for traces_per_seed in traces:
assert sum(len(t) for t in traces_per_seed) < 100
# test bipartite RWR
traces = dgl.contrib.sampling.bipartite_single_sided_random_walk_with_restart(
g, seeds, 0.2, max_nodes)
assert len(traces) == len(seeds)
for traces_per_seed in traces:
for t in traces_per_seed:
trace_diff = np.diff(F.zerocopy_to_numpy(t), axis=-1)
assert (trace_diff % 2 == 0).all()
if __name__ == '__main__':
test_random_walk()
......@@ -149,26 +149,6 @@ def test_layer_sampler(prefetch=False):
sub_m = sub_g.number_of_edges()
assert sum(F.shape(sub_g.block_eid(i))[0] for i in range(n_blocks)) == sub_m
def test_random_walk():
edge_list = [(0, 1), (1, 2), (2, 3), (3, 4),
(4, 3), (3, 2), (2, 1), (1, 0)]
seeds = [0, 1]
n_traces = 3
n_hops = 4
g = dgl.DGLGraph(edge_list, readonly=True)
traces = dgl.contrib.sampling.random_walk(g, seeds, n_traces, n_hops)
traces = F.zerocopy_to_numpy(traces)
assert traces.shape == (len(seeds), n_traces, n_hops + 1)
for i, seed in enumerate(seeds):
assert (traces[i, :, 0] == seeds[i]).all()
trace_diff = np.diff(traces, axis=-1)
# only nodes with adjacent IDs are connected
assert (np.abs(trace_diff) == 1).all()
if __name__ == '__main__':
test_create_full()
test_1neighbor_sampler_all()
......@@ -177,4 +157,3 @@ if __name__ == '__main__':
test_10neighbor_sampler()
test_layer_sampler()
test_layer_sampler(prefetch=True)
test_random_walk()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment