"vscode:/vscode.git/clone" did not exist on "55d6453fce312e3858155bf604e291c150f707a6"
Commit e3921d5d authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by Minjie Wang
Browse files

[Sampler] Metapath sampler for metapath2vec (#861)

* metapath sampler

* lint & fixes

* lint x2

* lint x3

* fix windows

* remove max_cycle argument

* add todo note
parent 1db697ec
import os
import torch as th
import torch.nn as nn
import tqdm
class AminerDataset:
class PBar(object):
def __enter__(self):
self.t = None
return self
def __call__(self, blockno, readsize, totalsize):
if self.t is None:
self.t = tqdm.tqdm(total=totalsize)
self.t.update(readsize)
def __exit__(self, exc_type, exc_value, traceback):
self.t.close()
class AminerDataset(object):
"""
Download Aminer Dataset from Amazon S3 bucket.
"""
......@@ -11,28 +26,21 @@ class AminerDataset:
self.url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/aminer.zip'
if not os.path.exists(os.path.join(path, 'aminer')):
if not os.path.exists(os.path.join(path, 'aminer.txt')):
print('File not found. Downloading from', self.url)
self._download_and_extract(path, 'aminer.zip')
self.fn = os.path.join(path, 'aminer.txt')
def _download_and_extract(self, path, filename):
import shutil, zipfile, zlib
from tqdm import tqdm
import requests
import urllib.request
fn = os.path.join(path, filename)
if os.path.exists(path):
shutil.rmtree(path, ignore_errors=True)
os.makedirs(path)
f_remote = requests.get(self.url, stream=True)
assert f_remote.status_code == 200, 'fail to open {}'.format(self.url)
with open(fn, 'wb') as writer:
for chunk in tqdm(f_remote.iter_content(chunk_size=1024*1024*3)):
writer.write(chunk)
with PBar() as pb:
urllib.request.urlretrieve(self.url, fn, pb)
print('Download finished. Unzipping the file...')
with zipfile.ZipFile(fn) as zf:
zf.extractall(path)
print('Unzip finished.')
self.fn = fn
import numpy as np
import torch
import torchvision
from torch.autograd import Variable
import random
import time
import tqdm
import dgl
import sys
import os
Metapath = "Conference-Paper-Author-Paper-Conference"
num_walks_per_node = 1000
walk_length = 100
path = sys.argv[1]
#construct mapping from text, could be changed to DGL later
def construct_id_dict():
id_to_paper = {}
id_to_author = {}
id_to_conf = {}
f_3 = open(".../id_author.txt", encoding="ISO-8859-1")
f_4 = open(".../id_conf.txt", encoding="ISO-8859-1")
f_5 = open(".../paper.txt", encoding="ISO-8859-1")
def construct_graph():
paper_ids = []
paper_names = []
author_ids = []
author_names = []
conf_ids = []
conf_names = []
f_3 = open(os.path.join(path, "id_author.txt"), encoding="ISO-8859-1")
f_4 = open(os.path.join(path, "id_conf.txt"), encoding="ISO-8859-1")
f_5 = open(os.path.join(path, "paper.txt"), encoding="ISO-8859-1")
while True:
z = f_3.readline()
if not z:
break
z = z.split('\t')
z = z.strip().split()
identity = int(z[0])
id_to_author[identity] = z[1].strip("\n")
author_ids.append(identity)
author_names.append(z[1])
while True:
w = f_4.readline()
if not w:
break;
w = w.split('\t')
w = w.strip().split()
identity = int(w[0])
id_to_conf[identity] = w[1].strip("\n")
conf_ids.append(identity)
conf_names.append(w[1])
while True:
v = f_5.readline()
if not v:
break;
v = v.split(' ')
v = v.strip().split()
identity = int(v[0])
paper_name = ""
for s in range(5, len(v)):
paper_name += v[s]
paper_name = 'p' + paper_name
id_to_paper[identity] = paper_name.strip('\n')
paper_name = 'p' + ''.join(v[1:])
paper_ids.append(identity)
paper_names.append(paper_name)
f_3.close()
f_4.close()
f_5.close()
return id_to_paper, id_to_author, id_to_conf
#construct mapping from text, could be changed to DGL later
def construct_types_mappings():
paper_to_author = {}
author_to_paper = {}
paper_to_conf = {}
conf_to_paper = {}
f_1 = open(".../paper_author.txt", "r")
f_2 = open(".../paper_conf.txt", "r")
author_ids_invmap = {x: i for i, x in enumerate(author_ids)}
conf_ids_invmap = {x: i for i, x in enumerate(conf_ids)}
paper_ids_invmap = {x: i for i, x in enumerate(paper_ids)}
paper_author_src = []
paper_author_dst = []
paper_conf_src = []
paper_conf_dst = []
f_1 = open(os.path.join(path, "paper_author.txt"), "r")
f_2 = open(os.path.join(path, "paper_conf.txt"), "r")
for x in f_1:
x = x.split('\t')
x[0] = int(x[0])
x[1] = int(x[1].strip('\n'))
if x[0] in paper_to_author:
paper_to_author[x[0]].append(x[1])
else:
paper_to_author[x[0]] = []
paper_to_author[x[0]].append(x[1])
if x[1] in author_to_paper:
author_to_paper[x[1]].append(x[0])
else:
author_to_paper[x[1]] = []
author_to_paper[x[1]].append(x[0])
paper_author_src.append(paper_ids_invmap[x[0]])
paper_author_dst.append(author_ids_invmap[x[1]])
for y in f_2:
y = y.split('\t')
y[0] = int(y[0])
y[1] = int(y[1].strip('\n'))
if y[0] in paper_to_conf:
paper_to_conf[y[0]].append(y[1])
else:
paper_to_conf[y[0]] = []
paper_to_conf[y[0]].append(y[1])
if y[1] in conf_to_paper:
conf_to_paper[y[1]].append(y[0])
else:
conf_to_paper[y[1]] = []
conf_to_paper[y[1]].append(y[0])
paper_conf_src.append(paper_ids_invmap[y[0]])
paper_conf_dst.append(conf_ids_invmap[y[1]])
f_1.close()
f_2.close()
return paper_to_author, author_to_paper, paper_to_conf, conf_to_paper
pa = dgl.bipartite((paper_author_src, paper_author_dst), 'paper', 'pa', 'author')
ap = dgl.bipartite((paper_author_dst, paper_author_src), 'author', 'ap', 'paper')
pc = dgl.bipartite((paper_conf_src, paper_conf_dst), 'paper', 'pc', 'conf')
cp = dgl.bipartite((paper_conf_dst, paper_conf_src), 'conf', 'cp', 'paper')
hg = dgl.hetero_from_relations([pa, ap, pc, cp])
return hg, author_names, conf_names, paper_names
#"conference - paper - Author - paper - conference" metapath sampling
def generate_metapath():
output_path = open(".../output_path.txt", "w")
id_to_paper, id_to_author, id_to_conf = construct_id_dict()
paper_to_author, author_to_paper, paper_to_conf, conf_to_paper = construct_types_mappings()
output_path = open(os.path.join(path, "output_path.txt"), "w")
count = 0
#loop all conferences
for conf_id in conf_to_paper.keys():
start_time = time.time()
print("sampling" + str(count))
conf = id_to_conf[conf_id]
conf0 = conf
#for each conference, simulate num_walks_per_node walks
for i in range(num_walks_per_node):
outline = conf0
# each walk with length walk_length
for j in range(walk_length):
# C - P
paper_list_1 = conf_to_paper[conf_id]
# check whether the paper nodes link to any author node
connections_1 = False
available_paper_1 = []
for k in range(len(paper_list_1)):
if paper_list_1[k] in paper_to_author:
available_paper_1.append(paper_list_1[k])
num_p_1 = len(available_paper_1)
if num_p_1 != 0:
connections_1 = True
paper_1_index = random.randrange(num_p_1)
#paper_id_1 = paper_list_1[paper_1_index]
paper_id_1 = available_paper_1[paper_1_index]
paper_1 = id_to_paper[paper_id_1]
outline += " " + paper_1
else:
break
# C - P - A
author_list = paper_to_author[paper_id_1]
num_a = len(author_list)
# No need to check
author_index = random.randrange(num_a)
author_id = author_list[author_index]
author = id_to_author[author_id]
outline += " " + author
# C - P - A - P
paper_list_2 = author_to_paper[author_id]
#check whether paper node links to any conference node
connections_2 = False
available_paper_2 = []
for m in range(len(paper_list_2)):
if paper_list_2[m] in paper_to_conf:
available_paper_2.append(paper_list_2[m])
num_p_2 = len(available_paper_2)
if num_p_2 != 0:
connections_2 = True
paper_2_index = random.randrange(num_p_2)
paper_id_2 = available_paper_2[paper_2_index]
paper_2 = id_to_paper[paper_id_2]
outline += " " + paper_2
else:
break
# C - P - A - P - C
conf_list = paper_to_conf[paper_id_2]
num_c = len(conf_list)
conf_index = random.randrange(num_c)
conf_id = conf_list[conf_index]
conf = id_to_conf[conf_id]
outline += " " + conf
if connections_1 and connections_2:
output_path.write(outline + "\n")
else:
break
# Note that the original mapping text has type indicator in front of each node just like "cVLDB"
# So the sampling sequence looks like "cconference ppaper aauthor ppaper cconference"
count += 1
print("--- %s seconds ---" % (time.time() - start_time))
hg, author_names, conf_names, paper_names = construct_graph()
for conf_idx in tqdm.trange(hg.number_of_nodes('conf')):
traces = dgl.contrib.sampling.metapath_random_walk(
hg, ['cp', 'pa', 'ap', 'pc'] * walk_length, [conf_idx], num_walks_per_node)
traces = traces[0]
for trace in traces:
tr = np.insert(trace.numpy(), 0, conf_idx)
outline = ' '.join(
(conf_names if i % 4 == 0 else author_names)[tr[i]]
for i in range(0, len(tr), 2)) # skip paper
print(outline, file=output_path)
output_path.close()
......
......@@ -17,15 +17,6 @@ namespace dgl {
class ImmutableGraph;
struct RandomWalkTraces {
/*! \brief number of traces generated for each seed */
IdArray trace_counts;
/*! \brief length of each trace, concatenated */
IdArray trace_lengths;
/*! \brief the vertices, concatenated */
IdArray vertices;
};
class SamplerOp {
public:
/*!
......@@ -65,76 +56,6 @@ class SamplerOp {
IdArray layer_sizes);
};
/*!
* \brief Batch-generate random walk traces
* \param seeds The array of starting vertex IDs
* \param num_traces The number of traces to generate for each seed
* \param num_hops The number of hops for each trace
* \return a flat ID array with shape (num_seeds, num_traces, num_hops + 1)
*/
IdArray RandomWalk(const GraphInterface *gptr,
IdArray seeds,
int num_traces,
int num_hops);
/*!
* \brief Batch-generate random walk traces with restart
*
* Stop generating traces if max_frequrent_visited_nodes nodes are visited more than
* max_visit_counts times.
*
* \param seeds The array of starting vertex IDs
* \param restart_prob The restart probability
* \param visit_threshold_per_seed Stop generating more traces once the number of nodes
* visited for a seed exceeds this number. (Algorithm 1 in [1])
* \param max_visit_counts Alternatively, stop generating traces for a seed if no less
* than \c max_frequent_visited_nodes are visited no less than \c max_visit_counts
* times. (Algorithm 2 in [1])
* \param max_frequent_visited_nodes See \c max_visit_counts
* \return A RandomWalkTraces instance.
*
* \sa [1] Eksombatchai et al., 2017 https://arxiv.org/abs/1711.07601
*/
RandomWalkTraces RandomWalkWithRestart(
const GraphInterface *gptr,
IdArray seeds,
double restart_prob,
uint64_t visit_threshold_per_seed,
uint64_t max_visit_counts,
uint64_t max_frequent_visited_nodes);
/*
* \brief Batch-generate random walk traces with restart on a bipartite graph, walking two
* hops at a time.
*
* Since it is walking on a bipartite graph, the vertices of a trace will always stay on the
* same side.
*
* Stop generating traces if max_frequrent_visited_nodes nodes are visited more than
* max_visit_counts times.
*
* \param seeds The array of starting vertex IDs
* \param restart_prob The restart probability
* \param visit_threshold_per_seed Stop generating more traces once the number of nodes
* visited for a seed exceeds this number. (Algorithm 1 in [1])
* \param max_visit_counts Alternatively, stop generating traces for a seed if no less
* than \c max_frequent_visited_nodes are visited no less than \c max_visit_counts
* times. (Algorithm 2 in [1])
* \param max_frequent_visited_nodes See \c max_visit_counts
* \return A RandomWalkTraces instance.
*
* \note Doesn't verify whether the graph is indeed a bipartite graph
*
* \sa [1] Eksombatchai et al., 2017 https://arxiv.org/abs/1711.07601
*/
RandomWalkTraces BipartiteSingleSidedRandomWalkWithRestart(
const GraphInterface *gptr,
IdArray seeds,
double restart_prob,
uint64_t visit_threshold_per_seed,
uint64_t max_visit_counts,
uint64_t max_frequent_visited_nodes);
} // namespace dgl
#endif // DGL_SAMPLER_H_
import numpy as np
from ... import utils
from ... import backend as F
from ..._ffi.function import _init_api
from ..._ffi.object import register_object, ObjectBase
from ... import ndarray
__all__ = ['random_walk',
'random_walk_with_restart',
'bipartite_single_sided_random_walk_with_restart',
'metapath_random_walk',
]
@register_object('sampler.RandomWalkTraces')
class RandomWalkTraces(ObjectBase):
pass
def random_walk(g, seeds, num_traces, num_hops):
"""Batch-generate random walk traces on given graph with the same length.
......@@ -46,28 +53,25 @@ def _split_traces(traces):
Parameters
----------
traces : PackedFunc object of RandomWalkTraces structure
traces : RandomWalkTraces
Returns
-------
traces : list[list[Tensor]]
traces[i][j] is the j-th trace generated for i-th seed.
"""
trace_counts = F.zerocopy_to_numpy(
F.zerocopy_from_dlpack(traces(0).to_dlpack())).tolist()
trace_lengths = F.zerocopy_from_dlpack(traces(1).to_dlpack())
trace_vertices = F.zerocopy_from_dlpack(traces(2).to_dlpack())
trace_counts = traces.trace_counts.asnumpy().tolist()
trace_vertices = F.zerocopy_from_dgl_ndarray(traces.vertices)
trace_vertices = F.split(
trace_vertices, F.zerocopy_to_numpy(trace_lengths).tolist(), 0)
trace_vertices, traces.trace_lengths.asnumpy().tolist(), 0)
traces = []
results = []
s = 0
for c in trace_counts:
traces.append(trace_vertices[s:s+c])
results.append(trace_vertices[s:s+c])
s += c
return traces
return results
def random_walk_with_restart(
......@@ -165,4 +169,49 @@ def bipartite_single_sided_random_walk_with_restart(
int(max_visit_counts), int(max_frequent_visited_nodes))
return _split_traces(traces)
_init_api('dgl.randomwalk', __name__)
def metapath_random_walk(hg, etypes, seeds, num_traces):
"""Generate random walk traces from an array of seed nodes (or starting nodes),
based on the given metapath.
For a single seed node, ``num_traces`` traces would be generated. A trace would
1. Start from the given seed and set ``t`` to 0.
2. Pick and traverse along edge type ``etypes[t % len(etypes)]`` from the current node.
3. If no edge can be found, halt. Otherwise, increment ``t`` and go to step 2.
Parameters
----------
hg : DGLHeteroGraph
The heterogeneous graph.
etypes : list[str or tuple of str]
Metapath, specified as a list of edge types.
The beginning and ending node type must be the same.
seeds : Tensor
The seed nodes. Node type is the same as the beginning node type of metapath.
num_traces : int
The number of traces
Returns
-------
traces : list[list[Tensor]]
traces[i][j] is the j-th trace generated for i-th seed.
traces[i][j][k] would have node type the same as the destination node type of edge
type ``etypes[k % len(etypes)]``
Notes
-----
The traces does **not** include the seed nodes themselves.
"""
if len(etypes) == 0:
raise ValueError('empty metapath')
if hg.to_canonical_etype(etypes[0])[0] != hg.to_canonical_etype(etypes[-1])[2]:
raise ValueError('beginning and ending node type mismatch')
if len(seeds) == 0:
return []
etype_array = ndarray.array(np.array([hg.get_etype_id(et) for et in etypes], dtype='int64'))
seed_array = utils.toindex(seeds).todgltensor()
traces = _CAPI_DGLMetapathRandomWalk(hg._graph, etype_array, seed_array, num_traces)
return _split_traces(traces)
_init_api('dgl.sampler.randomwalk', __name__)
......@@ -76,8 +76,8 @@ class HeteroGraphIndex(ObjectBase):
Returns
-------
HeteroGraphIndex
The unitgraph graph.
FlattenedHeteroGraph
A flattened heterograph object
"""
return _CAPI_DGLHeteroGetFlattenedGraph(self, etypes)
......@@ -1006,4 +1006,8 @@ def create_heterograph_from_relations(metagraph, rel_graphs):
"""
return _CAPI_DGLHeteroCreateHeteroGraph(metagraph, rel_graphs)
@register_object("graph.FlattenedHeteroGraph")
class FlattenedHeteroGraph(ObjectBase):
"""FlattenedHeteroGraph object class in C++ backend."""
_init_api("dgl.heterograph_index")
/*!
* Copyright (c) 2019 by Contributors
* \file graph/sampler/metapath.cc
* \brief Metapath sampling
*/
#include <dgl/array.h>
#include <dgl/random.h>
#include <dgl/packed_func_ext.h>
#include "../../c_api_common.h"
#include "randomwalk.h"
using namespace dgl::runtime;
using namespace dgl::aten;
namespace dgl {
namespace sampling {
namespace {
/*!
* \brief Random walk based on the given metapath.
*
* \param hg The heterograph
* \param etypes The metapath as an array of edge type IDs
* \param seeds The array of starting vertices for random walks
* \param num_traces Number of traces to generate for each starting vertex
* \note The metapath should have the same starting and ending node type.
*/
RandomWalkTracesPtr MetapathRandomWalk(
const HeteroGraphPtr hg,
const IdArray etypes,
const IdArray seeds,
int num_traces) {
const auto metagraph = hg->meta_graph();
uint64_t num_etypes = etypes->shape[0];
uint64_t num_seeds = seeds->shape[0];
const dgl_type_t *etype_data = static_cast<dgl_type_t *>(etypes->data);
const dgl_id_t *seed_data = static_cast<dgl_id_t *>(seeds->data);
std::vector<dgl_id_t> vertices;
std::vector<size_t> trace_lengths, trace_counts;
// TODO(quan): use omp to parallelize this loop
for (uint64_t seed_id = 0; seed_id < num_seeds; ++seed_id) {
int curr_num_traces = 0;
for (; curr_num_traces < num_traces; ++curr_num_traces) {
dgl_id_t curr = seed_data[seed_id];
size_t trace_length = 0;
for (size_t i = 0; i < num_etypes; ++i) {
const auto &succ = hg->SuccVec(etype_data[i], curr);
if (succ.size() == 0)
break;
curr = succ[RandomEngine::ThreadLocal()->RandInt(succ.size())];
vertices.push_back(curr);
++trace_length;
}
trace_lengths.push_back(trace_length);
}
trace_counts.push_back(curr_num_traces);
}
RandomWalkTraces *tl = new RandomWalkTraces;
tl->vertices = VecToIdArray(vertices);
tl->trace_lengths = VecToIdArray(trace_lengths);
tl->trace_counts = VecToIdArray(trace_counts);
return RandomWalkTracesPtr(tl);
}
}; // namespace
DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLMetapathRandomWalk")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
const HeteroGraphRef hg = args[0];
const IdArray etypes = args[1];
const IdArray seeds = args[2];
int num_traces = args[3];
const auto tl = MetapathRandomWalk(hg.sptr(), etypes, seeds, num_traces);
*rv = RandomWalkTracesRef(tl);
});
}; // namespace sampling
}; // namespace dgl
......@@ -4,7 +4,6 @@
* \brief DGL sampler implementation
*/
#include <dgl/sampler.h>
#include <dmlc/omp.h>
#include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
......@@ -14,12 +13,16 @@
#include <cmath>
#include <numeric>
#include <functional>
#include "../c_api_common.h"
#include <vector>
#include "randomwalk.h"
#include "../../c_api_common.h"
using namespace dgl::runtime;
namespace dgl {
namespace sampling {
using Walker = std::function<dgl_id_t(const GraphInterface *, dgl_id_t)>;
namespace {
......@@ -92,7 +95,7 @@ IdArray GenericRandomWalk(
return traces;
}
RandomWalkTraces GenericRandomWalkWithRestart(
RandomWalkTracesPtr GenericRandomWalkWithRestart(
const GraphInterface *gptr,
IdArray seeds,
double restart_prob,
......@@ -144,38 +147,33 @@ RandomWalkTraces GenericRandomWalkWithRestart(
trace_counts.push_back(num_traces);
}
RandomWalkTraces traces;
traces.trace_counts = IdArray::Empty(
RandomWalkTraces *traces = new RandomWalkTraces;
traces->trace_counts = IdArray::Empty(
{static_cast<int64_t>(trace_counts.size())},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0});
traces.trace_lengths = IdArray::Empty(
traces->trace_lengths = IdArray::Empty(
{static_cast<int64_t>(trace_lengths.size())},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0});
traces.vertices = IdArray::Empty(
traces->vertices = IdArray::Empty(
{static_cast<int64_t>(vertices.size())},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0});
dgl_id_t *trace_counts_data = static_cast<dgl_id_t *>(traces.trace_counts->data);
dgl_id_t *trace_lengths_data = static_cast<dgl_id_t *>(traces.trace_lengths->data);
dgl_id_t *vertices_data = static_cast<dgl_id_t *>(traces.vertices->data);
dgl_id_t *trace_counts_data = static_cast<dgl_id_t *>(traces->trace_counts->data);
dgl_id_t *trace_lengths_data = static_cast<dgl_id_t *>(traces->trace_lengths->data);
dgl_id_t *vertices_data = static_cast<dgl_id_t *>(traces->vertices->data);
std::copy(trace_counts.begin(), trace_counts.end(), trace_counts_data);
std::copy(trace_lengths.begin(), trace_lengths.end(), trace_lengths_data);
std::copy(vertices.begin(), vertices.end(), vertices_data);
return traces;
return RandomWalkTracesPtr(traces);
}
}; // namespace
PackedFunc ConvertRandomWalkTracesToPackedFunc(const RandomWalkTraces &t) {
return ConvertNDArrayVectorToPackedFunc({
t.trace_counts, t.trace_lengths, t.vertices});
}
IdArray RandomWalk(
const GraphInterface *gptr,
IdArray seeds,
......@@ -184,7 +182,7 @@ IdArray RandomWalk(
return GenericRandomWalk(gptr, seeds, num_traces, num_hops, WalkMultipleHops<1>);
}
RandomWalkTraces RandomWalkWithRestart(
RandomWalkTracesPtr RandomWalkWithRestart(
const GraphInterface *gptr,
IdArray seeds,
double restart_prob,
......@@ -196,7 +194,7 @@ RandomWalkTraces RandomWalkWithRestart(
max_frequent_visited_nodes, WalkMultipleHops<1>);
}
RandomWalkTraces BipartiteSingleSidedRandomWalkWithRestart(
RandomWalkTracesPtr BipartiteSingleSidedRandomWalkWithRestart(
const GraphInterface *gptr,
IdArray seeds,
double restart_prob,
......@@ -208,7 +206,7 @@ RandomWalkTraces BipartiteSingleSidedRandomWalkWithRestart(
max_frequent_visited_nodes, WalkMultipleHops<2>);
}
DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalk")
DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLRandomWalk")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray seeds = args[1];
......@@ -218,7 +216,7 @@ DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalk")
*rv = RandomWalk(g.sptr().get(), seeds, num_traces, num_hops);
});
DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalkWithRestart")
DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLRandomWalkWithRestart")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray seeds = args[1];
......@@ -227,12 +225,12 @@ DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalkWithRestart")
const uint64_t max_visit_counts = args[4];
const uint64_t max_frequent_visited_nodes = args[5];
*rv = ConvertRandomWalkTracesToPackedFunc(
*rv = RandomWalkTracesRef(
RandomWalkWithRestart(g.sptr().get(), seeds, restart_prob, visit_threshold_per_seed,
max_visit_counts, max_frequent_visited_nodes));
});
DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLBipartiteSingleSidedRandomWalkWithRestart")
DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLBipartiteSingleSidedRandomWalkWithRestart")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray seeds = args[1];
......@@ -241,10 +239,12 @@ DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLBipartiteSingleSidedRandomWalkWithResta
const uint64_t max_visit_counts = args[4];
const uint64_t max_frequent_visited_nodes = args[5];
*rv = ConvertRandomWalkTracesToPackedFunc(
*rv = RandomWalkTracesRef(
BipartiteSingleSidedRandomWalkWithRestart(
g.sptr().get(), seeds, restart_prob, visit_threshold_per_seed,
max_visit_counts, max_frequent_visited_nodes));
});
}; // namespace sampling
}; // namespace dgl
/*!
* Copyright (c) 2018 by Contributors
* \file dgl/sampler.h
* \brief DGL sampler header.
*/
#ifndef DGL_GRAPH_SAMPLER_RANDOMWALK_H_
#define DGL_GRAPH_SAMPLER_RANDOMWALK_H_
#include <dgl/runtime/object.h>
#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <memory>
namespace dgl {
namespace sampling {
/*! \brief Structure of multiple random walk traces */
struct RandomWalkTraces : public runtime::Object {
/*! \brief number of traces generated for each seed */
IdArray trace_counts;
/*! \brief length of each trace, concatenated */
IdArray trace_lengths;
/*! \brief the vertices, concatenated */
IdArray vertices;
void VisitAttrs(runtime::AttrVisitor *v) final {
v->Visit("vertices", &vertices);
v->Visit("trace_lengths", &trace_lengths);
v->Visit("trace_counts", &trace_counts);
}
static constexpr const char *_type_key = "sampler.RandomWalkTraces";
DGL_DECLARE_OBJECT_TYPE_INFO(RandomWalkTraces, runtime::Object);
};
typedef std::shared_ptr<RandomWalkTraces> RandomWalkTracesPtr;
DGL_DEFINE_OBJECT_REF(RandomWalkTracesRef, RandomWalkTraces);
/*!
* \brief Batch-generate random walk traces
* \param seeds The array of starting vertex IDs
* \param num_traces The number of traces to generate for each seed
* \param num_hops The number of hops for each trace
* \return a flat ID array with shape (num_seeds, num_traces, num_hops + 1)
*/
IdArray RandomWalk(const GraphInterface *gptr,
IdArray seeds,
int num_traces,
int num_hops);
/*!
* \brief Batch-generate random walk traces with restart
*
* Stop generating traces if max_frequrent_visited_nodes nodes are visited more than
* max_visit_counts times.
*
* \param seeds The array of starting vertex IDs
* \param restart_prob The restart probability
* \param visit_threshold_per_seed Stop generating more traces once the number of nodes
* visited for a seed exceeds this number. (Algorithm 1 in [1])
* \param max_visit_counts Alternatively, stop generating traces for a seed if no less
* than \c max_frequent_visited_nodes are visited no less than \c max_visit_counts
* times. (Algorithm 2 in [1])
* \param max_frequent_visited_nodes See \c max_visit_counts
* \return A RandomWalkTraces pointer.
*
* \sa [1] Eksombatchai et al., 2017 https://arxiv.org/abs/1711.07601
*/
RandomWalkTracesPtr RandomWalkWithRestart(
const GraphInterface *gptr,
IdArray seeds,
double restart_prob,
uint64_t visit_threshold_per_seed,
uint64_t max_visit_counts,
uint64_t max_frequent_visited_nodes);
/*
* \brief Batch-generate random walk traces with restart on a bipartite graph, walking two
* hops at a time.
*
* Since it is walking on a bipartite graph, the vertices of a trace will always stay on the
* same side.
*
* Stop generating traces if max_frequrent_visited_nodes nodes are visited more than
* max_visit_counts times.
*
* \param seeds The array of starting vertex IDs
* \param restart_prob The restart probability
* \param visit_threshold_per_seed Stop generating more traces once the number of nodes
* visited for a seed exceeds this number. (Algorithm 1 in [1])
* \param max_visit_counts Alternatively, stop generating traces for a seed if no less
* than \c max_frequent_visited_nodes are visited no less than \c max_visit_counts
* times. (Algorithm 2 in [1])
* \param max_frequent_visited_nodes See \c max_visit_counts
* \return A RandomWalkTraces instance.
*
* \note Doesn't verify whether the graph is indeed a bipartite graph
*
* \sa [1] Eksombatchai et al., 2017 https://arxiv.org/abs/1711.07601
*/
RandomWalkTracesPtr BipartiteSingleSidedRandomWalkWithRestart(
const GraphInterface *gptr,
IdArray seeds,
double restart_prob,
uint64_t visit_threshold_per_seed,
uint64_t max_visit_counts,
uint64_t max_frequent_visited_nodes);
}; // namespace sampling
}; // namespace dgl
#endif // DGL_GRAPH_SAMPLER_RANDOMWALK_H_
......@@ -58,5 +58,21 @@ def test_random_walk_with_restart():
trace_diff = np.diff(F.zerocopy_to_numpy(t), axis=-1)
assert (trace_diff % 2 == 0).all()
def test_metapath_random_walk():
g1 = dgl.bipartite(([0, 1, 2, 3], [0, 1, 2, 3]), 'a', 'ab', 'b')
g2 = dgl.bipartite(([0, 0, 1, 1, 2, 2, 3, 3], [1, 3, 2, 0, 3, 1, 0, 2]), 'b', 'ba', 'a')
G = dgl.hetero_from_relations([g1, g2])
seeds = [0, 1]
traces = dgl.contrib.sampling.metapath_random_walk(G, ['ab', 'ba'] * 4, seeds, 3)
for seed, traces_per_seed in zip(seeds, traces):
assert len(traces_per_seed) == 3
for trace in traces_per_seed:
assert len(trace) == 8
trace = np.insert(F.asnumpy(trace), 0, seed)
for i in range(4):
assert g1.has_edge_between(trace[2 * i], trace[2 * i + 1])
assert g2.has_edge_between(trace[2 * i + 1], trace[2 * i + 2])
if __name__ == '__main__':
test_random_walk()
test_metapath_random_walk()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment