"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "0e188b61a69cb998f65a1fbd0ec240c4186b177f"
Commit e3921d5d authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by Minjie Wang
Browse files

[Sampler] Metapath sampler for metapath2vec (#861)

* metapath sampler

* lint & fixes

* lint x2

* lint x3

* fix windows

* remove max_cycle argument

* add todo note
parent 1db697ec
import os import os
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import tqdm
class AminerDataset: class PBar(object):
def __enter__(self):
self.t = None
return self
def __call__(self, blockno, readsize, totalsize):
if self.t is None:
self.t = tqdm.tqdm(total=totalsize)
self.t.update(readsize)
def __exit__(self, exc_type, exc_value, traceback):
self.t.close()
class AminerDataset(object):
""" """
Download Aminer Dataset from Amazon S3 bucket. Download Aminer Dataset from Amazon S3 bucket.
""" """
...@@ -11,28 +26,21 @@ class AminerDataset: ...@@ -11,28 +26,21 @@ class AminerDataset:
self.url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/aminer.zip' self.url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/aminer.zip'
if not os.path.exists(os.path.join(path, 'aminer')): if not os.path.exists(os.path.join(path, 'aminer.txt')):
print('File not found. Downloading from', self.url) print('File not found. Downloading from', self.url)
self._download_and_extract(path, 'aminer.zip') self._download_and_extract(path, 'aminer.zip')
self.fn = os.path.join(path, 'aminer.txt')
def _download_and_extract(self, path, filename): def _download_and_extract(self, path, filename):
import shutil, zipfile, zlib import shutil, zipfile, zlib
from tqdm import tqdm from tqdm import tqdm
import requests import urllib.request
fn = os.path.join(path, filename) fn = os.path.join(path, filename)
with PBar() as pb:
if os.path.exists(path): urllib.request.urlretrieve(self.url, fn, pb)
shutil.rmtree(path, ignore_errors=True)
os.makedirs(path)
f_remote = requests.get(self.url, stream=True)
assert f_remote.status_code == 200, 'fail to open {}'.format(self.url)
with open(fn, 'wb') as writer:
for chunk in tqdm(f_remote.iter_content(chunk_size=1024*1024*3)):
writer.write(chunk)
print('Download finished. Unzipping the file...') print('Download finished. Unzipping the file...')
with zipfile.ZipFile(fn) as zf: with zipfile.ZipFile(fn) as zf:
zf.extractall(path) zf.extractall(path)
print('Unzip finished.') print('Unzip finished.')
self.fn = fn
import numpy as np import numpy as np
import torch
import torchvision
from torch.autograd import Variable
import random import random
import time import time
import tqdm
import dgl
import sys
import os
Metapath = "Conference-Paper-Author-Paper-Conference"
num_walks_per_node = 1000 num_walks_per_node = 1000
walk_length = 100 walk_length = 100
path = sys.argv[1]
#construct mapping from text, could be changed to DGL later def construct_graph():
def construct_id_dict(): paper_ids = []
id_to_paper = {} paper_names = []
id_to_author = {} author_ids = []
id_to_conf = {} author_names = []
f_3 = open(".../id_author.txt", encoding="ISO-8859-1") conf_ids = []
f_4 = open(".../id_conf.txt", encoding="ISO-8859-1") conf_names = []
f_5 = open(".../paper.txt", encoding="ISO-8859-1") f_3 = open(os.path.join(path, "id_author.txt"), encoding="ISO-8859-1")
f_4 = open(os.path.join(path, "id_conf.txt"), encoding="ISO-8859-1")
f_5 = open(os.path.join(path, "paper.txt"), encoding="ISO-8859-1")
while True: while True:
z = f_3.readline() z = f_3.readline()
if not z: if not z:
break break
z = z.split('\t') z = z.strip().split()
identity = int(z[0]) identity = int(z[0])
id_to_author[identity] = z[1].strip("\n") author_ids.append(identity)
author_names.append(z[1])
while True: while True:
w = f_4.readline() w = f_4.readline()
if not w: if not w:
break; break;
w = w.split('\t') w = w.strip().split()
identity = int(w[0]) identity = int(w[0])
id_to_conf[identity] = w[1].strip("\n") conf_ids.append(identity)
conf_names.append(w[1])
while True: while True:
v = f_5.readline() v = f_5.readline()
if not v: if not v:
break; break;
v = v.split(' ') v = v.strip().split()
identity = int(v[0]) identity = int(v[0])
paper_name = "" paper_name = 'p' + ''.join(v[1:])
for s in range(5, len(v)): paper_ids.append(identity)
paper_name += v[s] paper_names.append(paper_name)
paper_name = 'p' + paper_name
id_to_paper[identity] = paper_name.strip('\n')
f_3.close() f_3.close()
f_4.close() f_4.close()
f_5.close() f_5.close()
return id_to_paper, id_to_author, id_to_conf
#construct mapping from text, could be changed to DGL later author_ids_invmap = {x: i for i, x in enumerate(author_ids)}
def construct_types_mappings(): conf_ids_invmap = {x: i for i, x in enumerate(conf_ids)}
paper_to_author = {} paper_ids_invmap = {x: i for i, x in enumerate(paper_ids)}
author_to_paper = {}
paper_to_conf = {} paper_author_src = []
conf_to_paper = {} paper_author_dst = []
f_1 = open(".../paper_author.txt", "r") paper_conf_src = []
f_2 = open(".../paper_conf.txt", "r") paper_conf_dst = []
f_1 = open(os.path.join(path, "paper_author.txt"), "r")
f_2 = open(os.path.join(path, "paper_conf.txt"), "r")
for x in f_1: for x in f_1:
x = x.split('\t') x = x.split('\t')
x[0] = int(x[0]) x[0] = int(x[0])
x[1] = int(x[1].strip('\n')) x[1] = int(x[1].strip('\n'))
if x[0] in paper_to_author: paper_author_src.append(paper_ids_invmap[x[0]])
paper_to_author[x[0]].append(x[1]) paper_author_dst.append(author_ids_invmap[x[1]])
else:
paper_to_author[x[0]] = []
paper_to_author[x[0]].append(x[1])
if x[1] in author_to_paper:
author_to_paper[x[1]].append(x[0])
else:
author_to_paper[x[1]] = []
author_to_paper[x[1]].append(x[0])
for y in f_2: for y in f_2:
y = y.split('\t') y = y.split('\t')
y[0] = int(y[0]) y[0] = int(y[0])
y[1] = int(y[1].strip('\n')) y[1] = int(y[1].strip('\n'))
if y[0] in paper_to_conf: paper_conf_src.append(paper_ids_invmap[y[0]])
paper_to_conf[y[0]].append(y[1]) paper_conf_dst.append(conf_ids_invmap[y[1]])
else:
paper_to_conf[y[0]] = []
paper_to_conf[y[0]].append(y[1])
if y[1] in conf_to_paper:
conf_to_paper[y[1]].append(y[0])
else:
conf_to_paper[y[1]] = []
conf_to_paper[y[1]].append(y[0])
f_1.close() f_1.close()
f_2.close() f_2.close()
return paper_to_author, author_to_paper, paper_to_conf, conf_to_paper
pa = dgl.bipartite((paper_author_src, paper_author_dst), 'paper', 'pa', 'author')
ap = dgl.bipartite((paper_author_dst, paper_author_src), 'author', 'ap', 'paper')
pc = dgl.bipartite((paper_conf_src, paper_conf_dst), 'paper', 'pc', 'conf')
cp = dgl.bipartite((paper_conf_dst, paper_conf_src), 'conf', 'cp', 'paper')
hg = dgl.hetero_from_relations([pa, ap, pc, cp])
return hg, author_names, conf_names, paper_names
#"conference - paper - Author - paper - conference" metapath sampling #"conference - paper - Author - paper - conference" metapath sampling
def generate_metapath(): def generate_metapath():
output_path = open(".../output_path.txt", "w") output_path = open(os.path.join(path, "output_path.txt"), "w")
id_to_paper, id_to_author, id_to_conf = construct_id_dict()
paper_to_author, author_to_paper, paper_to_conf, conf_to_paper = construct_types_mappings()
count = 0 count = 0
#loop all conferences
for conf_id in conf_to_paper.keys(): hg, author_names, conf_names, paper_names = construct_graph()
start_time = time.time()
print("sampling" + str(count)) for conf_idx in tqdm.trange(hg.number_of_nodes('conf')):
conf = id_to_conf[conf_id] traces = dgl.contrib.sampling.metapath_random_walk(
conf0 = conf hg, ['cp', 'pa', 'ap', 'pc'] * walk_length, [conf_idx], num_walks_per_node)
#for each conference, simulate num_walks_per_node walks traces = traces[0]
for i in range(num_walks_per_node): for trace in traces:
outline = conf0 tr = np.insert(trace.numpy(), 0, conf_idx)
# each walk with length walk_length outline = ' '.join(
for j in range(walk_length): (conf_names if i % 4 == 0 else author_names)[tr[i]]
# C - P for i in range(0, len(tr), 2)) # skip paper
paper_list_1 = conf_to_paper[conf_id] print(outline, file=output_path)
# check whether the paper nodes link to any author node
connections_1 = False
available_paper_1 = []
for k in range(len(paper_list_1)):
if paper_list_1[k] in paper_to_author:
available_paper_1.append(paper_list_1[k])
num_p_1 = len(available_paper_1)
if num_p_1 != 0:
connections_1 = True
paper_1_index = random.randrange(num_p_1)
#paper_id_1 = paper_list_1[paper_1_index]
paper_id_1 = available_paper_1[paper_1_index]
paper_1 = id_to_paper[paper_id_1]
outline += " " + paper_1
else:
break
# C - P - A
author_list = paper_to_author[paper_id_1]
num_a = len(author_list)
# No need to check
author_index = random.randrange(num_a)
author_id = author_list[author_index]
author = id_to_author[author_id]
outline += " " + author
# C - P - A - P
paper_list_2 = author_to_paper[author_id]
#check whether paper node links to any conference node
connections_2 = False
available_paper_2 = []
for m in range(len(paper_list_2)):
if paper_list_2[m] in paper_to_conf:
available_paper_2.append(paper_list_2[m])
num_p_2 = len(available_paper_2)
if num_p_2 != 0:
connections_2 = True
paper_2_index = random.randrange(num_p_2)
paper_id_2 = available_paper_2[paper_2_index]
paper_2 = id_to_paper[paper_id_2]
outline += " " + paper_2
else:
break
# C - P - A - P - C
conf_list = paper_to_conf[paper_id_2]
num_c = len(conf_list)
conf_index = random.randrange(num_c)
conf_id = conf_list[conf_index]
conf = id_to_conf[conf_id]
outline += " " + conf
if connections_1 and connections_2:
output_path.write(outline + "\n")
else:
break
# Note that the original mapping text has type indicator in front of each node just like "cVLDB"
# So the sampling sequence looks like "cconference ppaper aauthor ppaper cconference"
count += 1
print("--- %s seconds ---" % (time.time() - start_time))
output_path.close() output_path.close()
......
...@@ -17,15 +17,6 @@ namespace dgl { ...@@ -17,15 +17,6 @@ namespace dgl {
class ImmutableGraph; class ImmutableGraph;
struct RandomWalkTraces {
/*! \brief number of traces generated for each seed */
IdArray trace_counts;
/*! \brief length of each trace, concatenated */
IdArray trace_lengths;
/*! \brief the vertices, concatenated */
IdArray vertices;
};
class SamplerOp { class SamplerOp {
public: public:
/*! /*!
...@@ -65,76 +56,6 @@ class SamplerOp { ...@@ -65,76 +56,6 @@ class SamplerOp {
IdArray layer_sizes); IdArray layer_sizes);
}; };
/*!
* \brief Batch-generate random walk traces
* \param seeds The array of starting vertex IDs
* \param num_traces The number of traces to generate for each seed
* \param num_hops The number of hops for each trace
* \return a flat ID array with shape (num_seeds, num_traces, num_hops + 1)
*/
IdArray RandomWalk(const GraphInterface *gptr,
IdArray seeds,
int num_traces,
int num_hops);
/*!
* \brief Batch-generate random walk traces with restart
*
* Stop generating traces if max_frequrent_visited_nodes nodes are visited more than
* max_visit_counts times.
*
* \param seeds The array of starting vertex IDs
* \param restart_prob The restart probability
* \param visit_threshold_per_seed Stop generating more traces once the number of nodes
* visited for a seed exceeds this number. (Algorithm 1 in [1])
* \param max_visit_counts Alternatively, stop generating traces for a seed if no less
* than \c max_frequent_visited_nodes are visited no less than \c max_visit_counts
* times. (Algorithm 2 in [1])
* \param max_frequent_visited_nodes See \c max_visit_counts
* \return A RandomWalkTraces instance.
*
* \sa [1] Eksombatchai et al., 2017 https://arxiv.org/abs/1711.07601
*/
RandomWalkTraces RandomWalkWithRestart(
const GraphInterface *gptr,
IdArray seeds,
double restart_prob,
uint64_t visit_threshold_per_seed,
uint64_t max_visit_counts,
uint64_t max_frequent_visited_nodes);
/*
* \brief Batch-generate random walk traces with restart on a bipartite graph, walking two
* hops at a time.
*
* Since it is walking on a bipartite graph, the vertices of a trace will always stay on the
* same side.
*
* Stop generating traces if max_frequrent_visited_nodes nodes are visited more than
* max_visit_counts times.
*
* \param seeds The array of starting vertex IDs
* \param restart_prob The restart probability
* \param visit_threshold_per_seed Stop generating more traces once the number of nodes
* visited for a seed exceeds this number. (Algorithm 1 in [1])
* \param max_visit_counts Alternatively, stop generating traces for a seed if no less
* than \c max_frequent_visited_nodes are visited no less than \c max_visit_counts
* times. (Algorithm 2 in [1])
* \param max_frequent_visited_nodes See \c max_visit_counts
* \return A RandomWalkTraces instance.
*
* \note Doesn't verify whether the graph is indeed a bipartite graph
*
* \sa [1] Eksombatchai et al., 2017 https://arxiv.org/abs/1711.07601
*/
RandomWalkTraces BipartiteSingleSidedRandomWalkWithRestart(
const GraphInterface *gptr,
IdArray seeds,
double restart_prob,
uint64_t visit_threshold_per_seed,
uint64_t max_visit_counts,
uint64_t max_frequent_visited_nodes);
} // namespace dgl } // namespace dgl
#endif // DGL_SAMPLER_H_ #endif // DGL_SAMPLER_H_
import numpy as np
from ... import utils from ... import utils
from ... import backend as F from ... import backend as F
from ..._ffi.function import _init_api from ..._ffi.function import _init_api
from ..._ffi.object import register_object, ObjectBase
from ... import ndarray
__all__ = ['random_walk', __all__ = ['random_walk',
'random_walk_with_restart', 'random_walk_with_restart',
'bipartite_single_sided_random_walk_with_restart', 'bipartite_single_sided_random_walk_with_restart',
'metapath_random_walk',
] ]
@register_object('sampler.RandomWalkTraces')
class RandomWalkTraces(ObjectBase):
pass
def random_walk(g, seeds, num_traces, num_hops): def random_walk(g, seeds, num_traces, num_hops):
"""Batch-generate random walk traces on given graph with the same length. """Batch-generate random walk traces on given graph with the same length.
...@@ -46,28 +53,25 @@ def _split_traces(traces): ...@@ -46,28 +53,25 @@ def _split_traces(traces):
Parameters Parameters
---------- ----------
traces : PackedFunc object of RandomWalkTraces structure traces : RandomWalkTraces
Returns Returns
------- -------
traces : list[list[Tensor]] traces : list[list[Tensor]]
traces[i][j] is the j-th trace generated for i-th seed. traces[i][j] is the j-th trace generated for i-th seed.
""" """
trace_counts = F.zerocopy_to_numpy( trace_counts = traces.trace_counts.asnumpy().tolist()
F.zerocopy_from_dlpack(traces(0).to_dlpack())).tolist() trace_vertices = F.zerocopy_from_dgl_ndarray(traces.vertices)
trace_lengths = F.zerocopy_from_dlpack(traces(1).to_dlpack())
trace_vertices = F.zerocopy_from_dlpack(traces(2).to_dlpack())
trace_vertices = F.split( trace_vertices = F.split(
trace_vertices, F.zerocopy_to_numpy(trace_lengths).tolist(), 0) trace_vertices, traces.trace_lengths.asnumpy().tolist(), 0)
traces = [] results = []
s = 0 s = 0
for c in trace_counts: for c in trace_counts:
traces.append(trace_vertices[s:s+c]) results.append(trace_vertices[s:s+c])
s += c s += c
return traces return results
def random_walk_with_restart( def random_walk_with_restart(
...@@ -165,4 +169,49 @@ def bipartite_single_sided_random_walk_with_restart( ...@@ -165,4 +169,49 @@ def bipartite_single_sided_random_walk_with_restart(
int(max_visit_counts), int(max_frequent_visited_nodes)) int(max_visit_counts), int(max_frequent_visited_nodes))
return _split_traces(traces) return _split_traces(traces)
_init_api('dgl.randomwalk', __name__)
def metapath_random_walk(hg, etypes, seeds, num_traces):
"""Generate random walk traces from an array of seed nodes (or starting nodes),
based on the given metapath.
For a single seed node, ``num_traces`` traces would be generated. A trace would
1. Start from the given seed and set ``t`` to 0.
2. Pick and traverse along edge type ``etypes[t % len(etypes)]`` from the current node.
3. If no edge can be found, halt. Otherwise, increment ``t`` and go to step 2.
Parameters
----------
hg : DGLHeteroGraph
The heterogeneous graph.
etypes : list[str or tuple of str]
Metapath, specified as a list of edge types.
The beginning and ending node type must be the same.
seeds : Tensor
The seed nodes. Node type is the same as the beginning node type of metapath.
num_traces : int
The number of traces
Returns
-------
traces : list[list[Tensor]]
traces[i][j] is the j-th trace generated for i-th seed.
traces[i][j][k] would have node type the same as the destination node type of edge
type ``etypes[k % len(etypes)]``
Notes
-----
The traces does **not** include the seed nodes themselves.
"""
if len(etypes) == 0:
raise ValueError('empty metapath')
if hg.to_canonical_etype(etypes[0])[0] != hg.to_canonical_etype(etypes[-1])[2]:
raise ValueError('beginning and ending node type mismatch')
if len(seeds) == 0:
return []
etype_array = ndarray.array(np.array([hg.get_etype_id(et) for et in etypes], dtype='int64'))
seed_array = utils.toindex(seeds).todgltensor()
traces = _CAPI_DGLMetapathRandomWalk(hg._graph, etype_array, seed_array, num_traces)
return _split_traces(traces)
_init_api('dgl.sampler.randomwalk', __name__)
...@@ -76,8 +76,8 @@ class HeteroGraphIndex(ObjectBase): ...@@ -76,8 +76,8 @@ class HeteroGraphIndex(ObjectBase):
Returns Returns
------- -------
HeteroGraphIndex FlattenedHeteroGraph
The unitgraph graph. A flattened heterograph object
""" """
return _CAPI_DGLHeteroGetFlattenedGraph(self, etypes) return _CAPI_DGLHeteroGetFlattenedGraph(self, etypes)
...@@ -1006,4 +1006,8 @@ def create_heterograph_from_relations(metagraph, rel_graphs): ...@@ -1006,4 +1006,8 @@ def create_heterograph_from_relations(metagraph, rel_graphs):
""" """
return _CAPI_DGLHeteroCreateHeteroGraph(metagraph, rel_graphs) return _CAPI_DGLHeteroCreateHeteroGraph(metagraph, rel_graphs)
@register_object("graph.FlattenedHeteroGraph")
class FlattenedHeteroGraph(ObjectBase):
"""FlattenedHeteroGraph object class in C++ backend."""
_init_api("dgl.heterograph_index") _init_api("dgl.heterograph_index")
/*!
* Copyright (c) 2019 by Contributors
* \file graph/sampler/metapath.cc
* \brief Metapath sampling
*/
#include <dgl/array.h>
#include <dgl/random.h>
#include <dgl/packed_func_ext.h>
#include "../../c_api_common.h"
#include "randomwalk.h"
using namespace dgl::runtime;
using namespace dgl::aten;
namespace dgl {
namespace sampling {
namespace {
/*!
* \brief Random walk based on the given metapath.
*
* \param hg The heterograph
* \param etypes The metapath as an array of edge type IDs
* \param seeds The array of starting vertices for random walks
* \param num_traces Number of traces to generate for each starting vertex
* \note The metapath should have the same starting and ending node type.
*/
RandomWalkTracesPtr MetapathRandomWalk(
const HeteroGraphPtr hg,
const IdArray etypes,
const IdArray seeds,
int num_traces) {
const auto metagraph = hg->meta_graph();
uint64_t num_etypes = etypes->shape[0];
uint64_t num_seeds = seeds->shape[0];
const dgl_type_t *etype_data = static_cast<dgl_type_t *>(etypes->data);
const dgl_id_t *seed_data = static_cast<dgl_id_t *>(seeds->data);
std::vector<dgl_id_t> vertices;
std::vector<size_t> trace_lengths, trace_counts;
// TODO(quan): use omp to parallelize this loop
for (uint64_t seed_id = 0; seed_id < num_seeds; ++seed_id) {
int curr_num_traces = 0;
for (; curr_num_traces < num_traces; ++curr_num_traces) {
dgl_id_t curr = seed_data[seed_id];
size_t trace_length = 0;
for (size_t i = 0; i < num_etypes; ++i) {
const auto &succ = hg->SuccVec(etype_data[i], curr);
if (succ.size() == 0)
break;
curr = succ[RandomEngine::ThreadLocal()->RandInt(succ.size())];
vertices.push_back(curr);
++trace_length;
}
trace_lengths.push_back(trace_length);
}
trace_counts.push_back(curr_num_traces);
}
RandomWalkTraces *tl = new RandomWalkTraces;
tl->vertices = VecToIdArray(vertices);
tl->trace_lengths = VecToIdArray(trace_lengths);
tl->trace_counts = VecToIdArray(trace_counts);
return RandomWalkTracesPtr(tl);
}
}; // namespace
DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLMetapathRandomWalk")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
const HeteroGraphRef hg = args[0];
const IdArray etypes = args[1];
const IdArray seeds = args[2];
int num_traces = args[3];
const auto tl = MetapathRandomWalk(hg.sptr(), etypes, seeds, num_traces);
*rv = RandomWalkTracesRef(tl);
});
}; // namespace sampling
}; // namespace dgl
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
* \brief DGL sampler implementation * \brief DGL sampler implementation
*/ */
#include <dgl/sampler.h>
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
...@@ -14,12 +13,16 @@ ...@@ -14,12 +13,16 @@
#include <cmath> #include <cmath>
#include <numeric> #include <numeric>
#include <functional> #include <functional>
#include "../c_api_common.h" #include <vector>
#include "randomwalk.h"
#include "../../c_api_common.h"
using namespace dgl::runtime; using namespace dgl::runtime;
namespace dgl { namespace dgl {
namespace sampling {
using Walker = std::function<dgl_id_t(const GraphInterface *, dgl_id_t)>; using Walker = std::function<dgl_id_t(const GraphInterface *, dgl_id_t)>;
namespace { namespace {
...@@ -92,7 +95,7 @@ IdArray GenericRandomWalk( ...@@ -92,7 +95,7 @@ IdArray GenericRandomWalk(
return traces; return traces;
} }
RandomWalkTraces GenericRandomWalkWithRestart( RandomWalkTracesPtr GenericRandomWalkWithRestart(
const GraphInterface *gptr, const GraphInterface *gptr,
IdArray seeds, IdArray seeds,
double restart_prob, double restart_prob,
...@@ -144,38 +147,33 @@ RandomWalkTraces GenericRandomWalkWithRestart( ...@@ -144,38 +147,33 @@ RandomWalkTraces GenericRandomWalkWithRestart(
trace_counts.push_back(num_traces); trace_counts.push_back(num_traces);
} }
RandomWalkTraces traces; RandomWalkTraces *traces = new RandomWalkTraces;
traces.trace_counts = IdArray::Empty( traces->trace_counts = IdArray::Empty(
{static_cast<int64_t>(trace_counts.size())}, {static_cast<int64_t>(trace_counts.size())},
DLDataType{kDLInt, 64, 1}, DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0}); DLContext{kDLCPU, 0});
traces.trace_lengths = IdArray::Empty( traces->trace_lengths = IdArray::Empty(
{static_cast<int64_t>(trace_lengths.size())}, {static_cast<int64_t>(trace_lengths.size())},
DLDataType{kDLInt, 64, 1}, DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0}); DLContext{kDLCPU, 0});
traces.vertices = IdArray::Empty( traces->vertices = IdArray::Empty(
{static_cast<int64_t>(vertices.size())}, {static_cast<int64_t>(vertices.size())},
DLDataType{kDLInt, 64, 1}, DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0}); DLContext{kDLCPU, 0});
dgl_id_t *trace_counts_data = static_cast<dgl_id_t *>(traces.trace_counts->data); dgl_id_t *trace_counts_data = static_cast<dgl_id_t *>(traces->trace_counts->data);
dgl_id_t *trace_lengths_data = static_cast<dgl_id_t *>(traces.trace_lengths->data); dgl_id_t *trace_lengths_data = static_cast<dgl_id_t *>(traces->trace_lengths->data);
dgl_id_t *vertices_data = static_cast<dgl_id_t *>(traces.vertices->data); dgl_id_t *vertices_data = static_cast<dgl_id_t *>(traces->vertices->data);
std::copy(trace_counts.begin(), trace_counts.end(), trace_counts_data); std::copy(trace_counts.begin(), trace_counts.end(), trace_counts_data);
std::copy(trace_lengths.begin(), trace_lengths.end(), trace_lengths_data); std::copy(trace_lengths.begin(), trace_lengths.end(), trace_lengths_data);
std::copy(vertices.begin(), vertices.end(), vertices_data); std::copy(vertices.begin(), vertices.end(), vertices_data);
return traces; return RandomWalkTracesPtr(traces);
} }
}; // namespace }; // namespace
PackedFunc ConvertRandomWalkTracesToPackedFunc(const RandomWalkTraces &t) {
return ConvertNDArrayVectorToPackedFunc({
t.trace_counts, t.trace_lengths, t.vertices});
}
IdArray RandomWalk( IdArray RandomWalk(
const GraphInterface *gptr, const GraphInterface *gptr,
IdArray seeds, IdArray seeds,
...@@ -184,7 +182,7 @@ IdArray RandomWalk( ...@@ -184,7 +182,7 @@ IdArray RandomWalk(
return GenericRandomWalk(gptr, seeds, num_traces, num_hops, WalkMultipleHops<1>); return GenericRandomWalk(gptr, seeds, num_traces, num_hops, WalkMultipleHops<1>);
} }
RandomWalkTraces RandomWalkWithRestart( RandomWalkTracesPtr RandomWalkWithRestart(
const GraphInterface *gptr, const GraphInterface *gptr,
IdArray seeds, IdArray seeds,
double restart_prob, double restart_prob,
...@@ -196,7 +194,7 @@ RandomWalkTraces RandomWalkWithRestart( ...@@ -196,7 +194,7 @@ RandomWalkTraces RandomWalkWithRestart(
max_frequent_visited_nodes, WalkMultipleHops<1>); max_frequent_visited_nodes, WalkMultipleHops<1>);
} }
RandomWalkTraces BipartiteSingleSidedRandomWalkWithRestart( RandomWalkTracesPtr BipartiteSingleSidedRandomWalkWithRestart(
const GraphInterface *gptr, const GraphInterface *gptr,
IdArray seeds, IdArray seeds,
double restart_prob, double restart_prob,
...@@ -208,7 +206,7 @@ RandomWalkTraces BipartiteSingleSidedRandomWalkWithRestart( ...@@ -208,7 +206,7 @@ RandomWalkTraces BipartiteSingleSidedRandomWalkWithRestart(
max_frequent_visited_nodes, WalkMultipleHops<2>); max_frequent_visited_nodes, WalkMultipleHops<2>);
} }
DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalk") DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLRandomWalk")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray seeds = args[1]; const IdArray seeds = args[1];
...@@ -218,7 +216,7 @@ DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalk") ...@@ -218,7 +216,7 @@ DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalk")
*rv = RandomWalk(g.sptr().get(), seeds, num_traces, num_hops); *rv = RandomWalk(g.sptr().get(), seeds, num_traces, num_hops);
}); });
DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalkWithRestart") DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLRandomWalkWithRestart")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray seeds = args[1]; const IdArray seeds = args[1];
...@@ -227,12 +225,12 @@ DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalkWithRestart") ...@@ -227,12 +225,12 @@ DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLRandomWalkWithRestart")
const uint64_t max_visit_counts = args[4]; const uint64_t max_visit_counts = args[4];
const uint64_t max_frequent_visited_nodes = args[5]; const uint64_t max_frequent_visited_nodes = args[5];
*rv = ConvertRandomWalkTracesToPackedFunc( *rv = RandomWalkTracesRef(
RandomWalkWithRestart(g.sptr().get(), seeds, restart_prob, visit_threshold_per_seed, RandomWalkWithRestart(g.sptr().get(), seeds, restart_prob, visit_threshold_per_seed,
max_visit_counts, max_frequent_visited_nodes)); max_visit_counts, max_frequent_visited_nodes));
}); });
DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLBipartiteSingleSidedRandomWalkWithRestart") DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLBipartiteSingleSidedRandomWalkWithRestart")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray seeds = args[1]; const IdArray seeds = args[1];
...@@ -241,10 +239,12 @@ DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLBipartiteSingleSidedRandomWalkWithResta ...@@ -241,10 +239,12 @@ DGL_REGISTER_GLOBAL("randomwalk._CAPI_DGLBipartiteSingleSidedRandomWalkWithResta
const uint64_t max_visit_counts = args[4]; const uint64_t max_visit_counts = args[4];
const uint64_t max_frequent_visited_nodes = args[5]; const uint64_t max_frequent_visited_nodes = args[5];
*rv = ConvertRandomWalkTracesToPackedFunc( *rv = RandomWalkTracesRef(
BipartiteSingleSidedRandomWalkWithRestart( BipartiteSingleSidedRandomWalkWithRestart(
g.sptr().get(), seeds, restart_prob, visit_threshold_per_seed, g.sptr().get(), seeds, restart_prob, visit_threshold_per_seed,
max_visit_counts, max_frequent_visited_nodes)); max_visit_counts, max_frequent_visited_nodes));
}); });
}; // namespace sampling
}; // namespace dgl }; // namespace dgl
/*!
* Copyright (c) 2018 by Contributors
* \file dgl/sampler.h
* \brief DGL sampler header.
*/
#ifndef DGL_GRAPH_SAMPLER_RANDOMWALK_H_
#define DGL_GRAPH_SAMPLER_RANDOMWALK_H_
#include <dgl/runtime/object.h>
#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <memory>
namespace dgl {
namespace sampling {
/*! \brief Structure of multiple random walk traces */
struct RandomWalkTraces : public runtime::Object {
/*! \brief number of traces generated for each seed */
IdArray trace_counts;
/*! \brief length of each trace, concatenated */
IdArray trace_lengths;
/*! \brief the vertices, concatenated */
IdArray vertices;
void VisitAttrs(runtime::AttrVisitor *v) final {
v->Visit("vertices", &vertices);
v->Visit("trace_lengths", &trace_lengths);
v->Visit("trace_counts", &trace_counts);
}
static constexpr const char *_type_key = "sampler.RandomWalkTraces";
DGL_DECLARE_OBJECT_TYPE_INFO(RandomWalkTraces, runtime::Object);
};
typedef std::shared_ptr<RandomWalkTraces> RandomWalkTracesPtr;
DGL_DEFINE_OBJECT_REF(RandomWalkTracesRef, RandomWalkTraces);
/*!
* \brief Batch-generate random walk traces
* \param seeds The array of starting vertex IDs
* \param num_traces The number of traces to generate for each seed
* \param num_hops The number of hops for each trace
* \return a flat ID array with shape (num_seeds, num_traces, num_hops + 1)
*/
IdArray RandomWalk(const GraphInterface *gptr,
IdArray seeds,
int num_traces,
int num_hops);
/*!
* \brief Batch-generate random walk traces with restart
*
* Stop generating traces if max_frequrent_visited_nodes nodes are visited more than
* max_visit_counts times.
*
* \param seeds The array of starting vertex IDs
* \param restart_prob The restart probability
* \param visit_threshold_per_seed Stop generating more traces once the number of nodes
* visited for a seed exceeds this number. (Algorithm 1 in [1])
* \param max_visit_counts Alternatively, stop generating traces for a seed if no less
* than \c max_frequent_visited_nodes are visited no less than \c max_visit_counts
* times. (Algorithm 2 in [1])
* \param max_frequent_visited_nodes See \c max_visit_counts
* \return A RandomWalkTraces pointer.
*
* \sa [1] Eksombatchai et al., 2017 https://arxiv.org/abs/1711.07601
*/
RandomWalkTracesPtr RandomWalkWithRestart(
const GraphInterface *gptr,
IdArray seeds,
double restart_prob,
uint64_t visit_threshold_per_seed,
uint64_t max_visit_counts,
uint64_t max_frequent_visited_nodes);
/*
* \brief Batch-generate random walk traces with restart on a bipartite graph, walking two
* hops at a time.
*
* Since it is walking on a bipartite graph, the vertices of a trace will always stay on the
* same side.
*
* Stop generating traces if max_frequrent_visited_nodes nodes are visited more than
* max_visit_counts times.
*
* \param seeds The array of starting vertex IDs
* \param restart_prob The restart probability
* \param visit_threshold_per_seed Stop generating more traces once the number of nodes
* visited for a seed exceeds this number. (Algorithm 1 in [1])
* \param max_visit_counts Alternatively, stop generating traces for a seed if no less
* than \c max_frequent_visited_nodes are visited no less than \c max_visit_counts
* times. (Algorithm 2 in [1])
* \param max_frequent_visited_nodes See \c max_visit_counts
* \return A RandomWalkTraces instance.
*
* \note Doesn't verify whether the graph is indeed a bipartite graph
*
* \sa [1] Eksombatchai et al., 2017 https://arxiv.org/abs/1711.07601
*/
RandomWalkTracesPtr BipartiteSingleSidedRandomWalkWithRestart(
const GraphInterface *gptr,
IdArray seeds,
double restart_prob,
uint64_t visit_threshold_per_seed,
uint64_t max_visit_counts,
uint64_t max_frequent_visited_nodes);
}; // namespace sampling
}; // namespace dgl
#endif // DGL_GRAPH_SAMPLER_RANDOMWALK_H_
...@@ -58,5 +58,21 @@ def test_random_walk_with_restart(): ...@@ -58,5 +58,21 @@ def test_random_walk_with_restart():
trace_diff = np.diff(F.zerocopy_to_numpy(t), axis=-1) trace_diff = np.diff(F.zerocopy_to_numpy(t), axis=-1)
assert (trace_diff % 2 == 0).all() assert (trace_diff % 2 == 0).all()
def test_metapath_random_walk():
g1 = dgl.bipartite(([0, 1, 2, 3], [0, 1, 2, 3]), 'a', 'ab', 'b')
g2 = dgl.bipartite(([0, 0, 1, 1, 2, 2, 3, 3], [1, 3, 2, 0, 3, 1, 0, 2]), 'b', 'ba', 'a')
G = dgl.hetero_from_relations([g1, g2])
seeds = [0, 1]
traces = dgl.contrib.sampling.metapath_random_walk(G, ['ab', 'ba'] * 4, seeds, 3)
for seed, traces_per_seed in zip(seeds, traces):
assert len(traces_per_seed) == 3
for trace in traces_per_seed:
assert len(trace) == 8
trace = np.insert(F.asnumpy(trace), 0, seed)
for i in range(4):
assert g1.has_edge_between(trace[2 * i], trace[2 * i + 1])
assert g2.has_edge_between(trace[2 * i + 1], trace[2 * i + 2])
if __name__ == '__main__': if __name__ == '__main__':
test_random_walk() test_random_walk()
test_metapath_random_walk()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment