Unverified Commit 7e6a6b4a authored by Dewvin's avatar Dewvin Committed by GitHub
Browse files

[Feature] Add CUDA Weighted Randomwalk Sampling (#4243)



* [Feature] Add CUDA Weighted Randomwalk Sampling

* [Feature] Add CUDA Weighted Randomwalk Sampling

* [Feature] Add CUDA Weighted Randomwalk Sampling

* [Feature] Add CUDA Weighted Randomwalk Sampling

* fix empty prob array && enable non-uniform for restart && enable unit tests

* update doc and guide for randomwalk and pinsage

* update comments
Co-authored-by: default avatarzhenliangqiu <ubuntu@ip-172-31-24-245.ap-southeast-1.compute.internal>
Co-authored-by: default avatarxiny <xiny@nvidia.com>
parent 7cd531c4
...@@ -99,6 +99,38 @@ especially for multi-GPU training. ...@@ -99,6 +99,38 @@ especially for multi-GPU training.
Refer to our `GraphSAGE example <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/multi_gpu_node_classification.py>`_ for more details. Refer to our `GraphSAGE example <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/multi_gpu_node_classification.py>`_ for more details.
UVA and GPU support for PinSAGESampler/RandomWalkNeighborSampler
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
PinSAGESampler and RandomWalkNeighborSampler support UVA and GPU sampling.
You can enable them via:
* Pin the graph (for UVA sampling) or put the graph onto GPU (for GPU sampling).
* Put the ``train_nid`` onto GPU.
.. code:: python
g = dgl.heterograph({
('item', 'bought-by', 'user'): ([0, 0, 1, 1, 2, 2, 3, 3], [0, 1, 0, 1, 2, 3, 2, 3]),
('user', 'bought', 'item'): ([0, 1, 0, 1, 2, 3, 2, 3], [0, 0, 1, 1, 2, 2, 3, 3])})
# UVA setup
# g.create_formats_()
# g.pin_memory_()
# GPU setup
device = torch.device('cuda:0')
g = g.to(device)
sampler1 = dgl.sampling.PinSAGESampler(g, 'item', 'user', 4, 0.5, 3, 2)
sampler2 = dgl.sampling.RandomWalkNeighborSampler(g, 4, 0.5, 3, 2, ['bought-by', 'bought'])
train_nid = torch.tensor([0, 2], dtype=g.idtype, device=device)
sampler1(train_nid)
sampler2(train_nid)
Using GPU-based neighbor sampling with DGL functions Using GPU-based neighbor sampling with DGL functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -109,6 +141,8 @@ operating on GPU: ...@@ -109,6 +141,8 @@ operating on GPU:
* Only has support for uniform sampling; non-uniform sampling can only run on CPU. * Only has support for uniform sampling; non-uniform sampling can only run on CPU.
* :func:`dgl.sampling.random_walk`
Subgraph extraction ops: Subgraph extraction ops:
* :func:`dgl.node_subgraph` * :func:`dgl.node_subgraph`
......
...@@ -38,6 +38,9 @@ class RandomWalkNeighborSampler(object): ...@@ -38,6 +38,9 @@ class RandomWalkNeighborSampler(object):
This is a generalization of PinSAGE sampler which only works on bidirectional bipartite This is a generalization of PinSAGE sampler which only works on bidirectional bipartite
graphs. graphs.
UVA and GPU sampling is supported for this sampler.
Refer to :ref:`guide-minibatch-gpu-sampling` for more details.
Parameters Parameters
---------- ----------
G : DGLGraph G : DGLGraph
...@@ -104,13 +107,14 @@ class RandomWalkNeighborSampler(object): ...@@ -104,13 +107,14 @@ class RandomWalkNeighborSampler(object):
A tensor of given node IDs of node type ``ntype`` to generate neighbors from. The A tensor of given node IDs of node type ``ntype`` to generate neighbors from. The
node type ``ntype`` is the beginning and ending node type of the given metapath. node type ``ntype`` is the beginning and ending node type of the given metapath.
It must be on CPU and have the same dtype as the ID type of the graph. It must be on the same device as the graph and have the same dtype
as the ID type of the graph.
Returns Returns
------- -------
g : DGLGraph g : DGLGraph
A homogeneous graph constructed by selecting neighbors for each given node according A homogeneous graph constructed by selecting neighbors for each given node according
to the algorithm above. The returned graph is on CPU. to the algorithm above.
""" """
seed_nodes = utils.prepare_tensor(self.G, seed_nodes, 'seed_nodes') seed_nodes = utils.prepare_tensor(self.G, seed_nodes, 'seed_nodes')
self.restart_prob = F.copy_to(self.restart_prob, F.context(seed_nodes)) self.restart_prob = F.copy_to(self.restart_prob, F.context(seed_nodes))
...@@ -147,6 +151,9 @@ class PinSAGESampler(RandomWalkNeighborSampler): ...@@ -147,6 +151,9 @@ class PinSAGESampler(RandomWalkNeighborSampler):
The edges of the returned homogeneous graph will connect to the given nodes from their most The edges of the returned homogeneous graph will connect to the given nodes from their most
commonly visited nodes, with a feature indicating the number of visits. commonly visited nodes, with a feature indicating the number of visits.
UVA and GPU sampling is supported for this sampler.
Refer to :ref:`guide-minibatch-gpu-sampling` for more details.
Parameters Parameters
---------- ----------
G : DGLGraph G : DGLGraph
......
...@@ -30,6 +30,8 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob ...@@ -30,6 +30,8 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
If a random walk stops in advance, DGL pads the trace with -1 to have the same If a random walk stops in advance, DGL pads the trace with -1 to have the same
length. length.
This function supports the graph on GPU.
Parameters Parameters
---------- ----------
g : DGLGraph g : DGLGraph
...@@ -37,7 +39,7 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob ...@@ -37,7 +39,7 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
nodes : Tensor nodes : Tensor
Node ID tensor from which the random walk traces starts. Node ID tensor from which the random walk traces starts.
The tensor must have the same dtype as the ID type The tensor must be on the same device as the graph and have the same dtype as the ID type
of the graph. of the graph.
metapath : list[str or tuple of str], optional metapath : list[str or tuple of str], optional
Metapath, specified as a list of edge types. Metapath, specified as a list of edge types.
...@@ -60,12 +62,14 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob ...@@ -60,12 +62,14 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
must be positive for the outbound edges of all nodes (although they don't have must be positive for the outbound edges of all nodes (although they don't have
to sum up to one). The result will be undefined otherwise. to sum up to one). The result will be undefined otherwise.
The feature tensor must be on the same device as the graph.
If omitted, DGL assumes that the neighbors are picked uniformly. If omitted, DGL assumes that the neighbors are picked uniformly.
restart_prob : float or Tensor, optional restart_prob : float or Tensor, optional
Probability to terminate the current trace before each transition. Probability to terminate the current trace before each transition.
If a tensor is given, :attr:`restart_prob` should have the same length as If a tensor is given, :attr:`restart_prob` should be on the same device as the graph
:attr:`metapath` or :attr:`length`. and have the same length as :attr:`metapath` or :attr:`length`.
return_eids : bool, optional return_eids : bool, optional
If True, additionally return the edge IDs traversed. If True, additionally return the edge IDs traversed.
......
/*! /*!
* Copyright (c) 2021 by Contributors * Copyright (c) 2021-2022 by Contributors
* \file graph/sampling/randomwalk_gpu.cu * \file graph/sampling/randomwalk_gpu.cu
* \brief DGL sampler * \brief CUDA random walk sampleing
*/ */
#include <dgl/array.h> #include <dgl/array.h>
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <utility> #include <utility>
#include <tuple> #include <tuple>
#include "../../../array/cuda/dgl_cub.cuh"
#include "../../../runtime/cuda/cuda_common.h" #include "../../../runtime/cuda/cuda_common.h"
#include "frequency_hashmap.cuh" #include "frequency_hashmap.cuh"
...@@ -89,6 +90,81 @@ __global__ void _RandomWalkKernel( ...@@ -89,6 +90,81 @@ __global__ void _RandomWalkKernel(
} }
} }
template <typename IdType, typename FloatType, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _RandomWalkBiasedKernel(
const uint64_t rand_seed,
const IdType *seed_data,
const int64_t num_seeds,
const IdType *metapath_data,
const uint64_t max_num_steps,
const GraphKernelData<IdType> *graphs,
const FloatType **probs,
const FloatType **prob_sums,
const FloatType *restart_prob_data,
const int64_t restart_prob_size,
const int64_t max_nodes,
IdType *out_traces_data,
IdType *out_eids_data) {
assert(BLOCK_SIZE == blockDim.x);
int64_t idx = blockIdx.x * TILE_SIZE + threadIdx.x;
int64_t last_idx = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_seeds);
int64_t trace_length = (max_num_steps + 1);
curandState rng;
// reference:
// https://docs.nvidia.com/cuda/curand/device-api-overview.html#performance-notes
curand_init(rand_seed + idx, 0, 0, &rng);
while (idx < last_idx) {
IdType curr = seed_data[idx];
assert(curr < max_nodes);
IdType *traces_data_ptr = &out_traces_data[idx * trace_length];
IdType *eids_data_ptr = &out_eids_data[idx * max_num_steps];
*(traces_data_ptr++) = curr;
int64_t step_idx;
for (step_idx = 0; step_idx < max_num_steps; ++step_idx) {
IdType metapath_id = metapath_data[step_idx];
const GraphKernelData<IdType> &graph = graphs[metapath_id];
const int64_t in_row_start = graph.in_ptr[curr];
const int64_t deg = graph.in_ptr[curr + 1] - graph.in_ptr[curr];
if (deg == 0) { // the degree is zero
break;
}
// randomly select by weight
const FloatType *prob_sum = prob_sums[metapath_id];
const FloatType *prob = probs[metapath_id];
int64_t num;
if (prob == nullptr) {
num = curand(&rng) % deg;
} else {
auto rnd_sum_w = prob_sum[curr] * curand_uniform(&rng);
FloatType sum_w{0.};
for (num = 0; num < deg; ++num) {
sum_w += prob[in_row_start + num];
if (sum_w >= rnd_sum_w) break;
}
}
IdType pick = graph.in_cols[in_row_start + num];
IdType eid = (graph.data? graph.data[in_row_start + num] : in_row_start + num);
*traces_data_ptr = pick;
*eids_data_ptr = eid;
if ((restart_prob_size > 1) && (curand_uniform(&rng) < restart_prob_data[step_idx])) {
break;
} else if ((restart_prob_size == 1) && (curand_uniform(&rng) < restart_prob_data[0])) {
break;
}
++traces_data_ptr; ++eids_data_ptr;
curr = pick;
}
for (; step_idx < max_num_steps; ++step_idx) {
*(traces_data_ptr++) = -1;
*(eids_data_ptr++) = -1;
}
idx += BLOCK_SIZE;
}
}
} // namespace } // namespace
// random walk for uniform choice // random walk for uniform choice
...@@ -167,6 +243,143 @@ std::pair<IdArray, IdArray> RandomWalkUniform( ...@@ -167,6 +243,143 @@ std::pair<IdArray, IdArray> RandomWalkUniform(
return std::make_pair(traces, eids); return std::make_pair(traces, eids);
} }
/**
* \brief Random walk for biased choice. We use inverse transform sampling to
* choose the next step.
*/
template <DLDeviceType XPU, typename FloatType, typename IdType>
std::pair<IdArray, IdArray> RandomWalkBiased(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob,
FloatArray restart_prob) {
const int64_t max_num_steps = metapath->shape[0];
const IdType *metapath_data = static_cast<IdType *>(metapath->data);
const int64_t begin_ntype = hg->meta_graph()->FindEdge(metapath_data[0]).first;
const int64_t max_nodes = hg->NumVertices(begin_ntype);
int64_t num_etypes = hg->NumEdgeTypes();
auto ctx = seeds->ctx;
const IdType *seed_data = static_cast<const IdType*>(seeds->data);
CHECK(seeds->ndim == 1) << "seeds shape is not one dimension.";
const int64_t num_seeds = seeds->shape[0];
int64_t trace_length = max_num_steps + 1;
IdArray traces = IdArray::Empty({num_seeds, trace_length}, seeds->dtype, ctx);
IdArray eids = IdArray::Empty({num_seeds, max_num_steps}, seeds->dtype, ctx);
IdType *traces_data = traces.Ptr<IdType>();
IdType *eids_data = eids.Ptr<IdType>();
cudaStream_t stream = 0;
auto device = DeviceAPI::Get(ctx);
// new probs and prob sums pointers
assert(num_etypes == static_cast<int64_t>(prob.size()));
std::unique_ptr<FloatType *[]> probs(new FloatType *[prob.size()]);
std::unique_ptr<FloatType *[]> prob_sums(new FloatType *[prob.size()]);
std::vector<FloatArray> prob_sums_arr;
prob_sums_arr.reserve(prob.size());
// graphs
std::vector<GraphKernelData<IdType>> h_graphs(num_etypes);
for (int64_t etype = 0; etype < num_etypes; ++etype) {
const CSRMatrix &csr = hg->GetCSRMatrix(etype);
h_graphs[etype].in_ptr = static_cast<const IdType*>(csr.indptr->data);
h_graphs[etype].in_cols = static_cast<const IdType*>(csr.indices->data);
h_graphs[etype].data = (CSRHasData(csr) ? static_cast<const IdType*>(csr.data->data) : nullptr);
int64_t num_segments = csr.indptr->shape[0] - 1;
// will handle empty probs in the kernel
if (IsNullArray(prob[etype])) {
probs[etype] = nullptr;
prob_sums[etype] = nullptr;
continue;
}
probs[etype] = prob[etype].Ptr<FloatType>();
prob_sums_arr.push_back(FloatArray::Empty({num_segments}, prob[etype]->dtype, ctx));
prob_sums[etype] = prob_sums_arr[etype].Ptr<FloatType>();
// calculate the sum of the neighbor weights
const IdType *d_offsets = static_cast<const IdType*>(csr.indptr->data);
size_t temp_storage_size = 0;
CUDA_CALL(cub::DeviceSegmentedReduce::Sum(nullptr, temp_storage_size,
probs[etype],
prob_sums[etype],
num_segments,
d_offsets,
d_offsets + 1));
void *temp_storage = device->AllocWorkspace(ctx, temp_storage_size);
CUDA_CALL(cub::DeviceSegmentedReduce::Sum(temp_storage, temp_storage_size,
probs[etype],
prob_sums[etype],
num_segments,
d_offsets,
d_offsets + 1));
device->FreeWorkspace(ctx, temp_storage);
}
// copy graph metadata pointers to GPU
auto d_graphs = static_cast<GraphKernelData<IdType>*>(
device->AllocWorkspace(ctx, (num_etypes) * sizeof(GraphKernelData<IdType>)));
device->CopyDataFromTo(h_graphs.data(), 0, d_graphs, 0,
(num_etypes) * sizeof(GraphKernelData<IdType>),
DGLContext{kDLCPU, 0},
ctx,
hg->GetCSRMatrix(0).indptr->dtype,
stream);
// copy probs pointers to GPU
const FloatType **probs_dev = static_cast<const FloatType **>(
device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *)));
device->CopyDataFromTo(probs.get(), 0, probs_dev, 0,
(num_etypes) * sizeof(FloatType *),
DGLContext{kDLCPU, 0},
ctx,
prob[0]->dtype,
stream);
// copy probs_sum pointers to GPU
const FloatType **prob_sums_dev = static_cast<const FloatType **>(
device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *)));
device->CopyDataFromTo(prob_sums.get(), 0, prob_sums_dev, 0,
(num_etypes) * sizeof(FloatType *),
DGLContext{kDLCPU, 0},
ctx,
prob[0]->dtype,
stream);
// copy metapath to GPU
auto d_metapath = metapath.CopyTo(ctx);
const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data);
constexpr int BLOCK_SIZE = 256;
constexpr int TILE_SIZE = BLOCK_SIZE * 4;
dim3 block(256);
dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE);
const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
CHECK(restart_prob->ctx.device_type == kDLGPU) << "restart prob should be in GPU.";
CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1.";
const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>();
const int64_t restart_prob_size = restart_prob->shape[0];
CUDA_KERNEL_CALL(
(_RandomWalkBiasedKernel<IdType, FloatType, BLOCK_SIZE, TILE_SIZE>),
grid, block, 0, stream,
random_seed,
seed_data,
num_seeds,
d_metapath_data,
max_num_steps,
d_graphs,
probs_dev,
prob_sums_dev,
restart_prob_data,
restart_prob_size,
max_nodes,
traces_data,
eids_data);
device->FreeWorkspace(ctx, d_graphs);
device->FreeWorkspace(ctx, probs_dev);
device->FreeWorkspace(ctx, prob_sums_dev);
return std::make_pair(traces, eids);
}
template<DLDeviceType XPU, typename IdType> template<DLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> RandomWalk( std::pair<IdArray, IdArray> RandomWalk(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
...@@ -174,16 +387,26 @@ std::pair<IdArray, IdArray> RandomWalk( ...@@ -174,16 +387,26 @@ std::pair<IdArray, IdArray> RandomWalk(
const TypeArray metapath, const TypeArray metapath,
const std::vector<FloatArray> &prob) { const std::vector<FloatArray> &prob) {
// not support no-uniform choice now bool isUniform = true;
for (const auto &etype_prob : prob) { for (const auto &etype_prob : prob) {
if (!IsNullArray(etype_prob)) { if (!IsNullArray(etype_prob)) {
LOG(FATAL) << "Non-uniform choice is not supported in GPU."; isUniform = false;
break;
} }
} }
auto restart_prob = NDArray::Empty( auto restart_prob = NDArray::Empty(
{0}, DLDataType{kDLFloat, 32, 1}, DGLContext{XPU, 0}); {0}, DLDataType{kDLFloat, 32, 1}, DGLContext{XPU, 0});
if (!isUniform) {
std::pair<IdArray, IdArray> ret;
ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
CHECK(prob[0]->ctx.device_type == kDLGPU) << "prob should be in GPU.";
ret = RandomWalkBiased<XPU, FloatType, IdType>(hg, seeds, metapath, prob, restart_prob);
});
return ret;
} else {
return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob); return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob);
}
} }
template<DLDeviceType XPU, typename IdType> template<DLDeviceType XPU, typename IdType>
...@@ -194,12 +417,14 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart( ...@@ -194,12 +417,14 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart(
const std::vector<FloatArray> &prob, const std::vector<FloatArray> &prob,
double restart_prob) { double restart_prob) {
// not support no-uniform choice now bool isUniform = true;
for (const auto &etype_prob : prob) { for (const auto &etype_prob : prob) {
if (!IsNullArray(etype_prob)) { if (!IsNullArray(etype_prob)) {
LOG(FATAL) << "Non-uniform choice is not supported in GPU."; isUniform = false;
break;
} }
} }
auto device_ctx = seeds->ctx; auto device_ctx = seeds->ctx;
auto restart_prob_array = NDArray::Empty( auto restart_prob_array = NDArray::Empty(
{1}, DLDataType{kDLFloat, 64, 1}, device_ctx); {1}, DLDataType{kDLFloat, 64, 1}, device_ctx);
...@@ -214,7 +439,17 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart( ...@@ -214,7 +439,17 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart(
restart_prob_array->dtype, stream); restart_prob_array->dtype, stream);
device->StreamSync(device_ctx, stream); device->StreamSync(device_ctx, stream);
if (!isUniform) {
std::pair<IdArray, IdArray> ret;
ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
CHECK(prob[0]->ctx.device_type == kDLGPU) << "prob should be in GPU.";
ret = RandomWalkBiased<XPU, FloatType, IdType>(
hg, seeds, metapath, prob, restart_prob_array);
});
return ret;
} else {
return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob_array); return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob_array);
}
} }
template<DLDeviceType XPU, typename IdType> template<DLDeviceType XPU, typename IdType>
...@@ -225,14 +460,24 @@ std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart( ...@@ -225,14 +460,24 @@ std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(
const std::vector<FloatArray> &prob, const std::vector<FloatArray> &prob,
FloatArray restart_prob) { FloatArray restart_prob) {
// not support no-uniform choice now bool isUniform = true;
for (const auto &etype_prob : prob) { for (const auto &etype_prob : prob) {
if (!IsNullArray(etype_prob)) { if (!IsNullArray(etype_prob)) {
LOG(FATAL) << "Non-uniform choice is not supported in GPU."; isUniform = false;
break;
} }
} }
if (!isUniform) {
std::pair<IdArray, IdArray> ret;
ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
CHECK(prob[0]->ctx.device_type == kDLGPU) << "prob should be in GPU.";
ret = RandomWalkBiased<XPU, FloatType, IdType>(hg, seeds, metapath, prob, restart_prob);
});
return ret;
} else {
return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob); return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob);
}
} }
template<DLDeviceType XPU, typename IdxType> template<DLDeviceType XPU, typename IdxType>
......
...@@ -24,7 +24,6 @@ def check_random_walk(g, metapath, traces, ntypes, prob=None, trace_eids=None): ...@@ -24,7 +24,6 @@ def check_random_walk(g, metapath, traces, ntypes, prob=None, trace_eids=None):
u, v = g.find_edges(trace_eids[i, j], etype=metapath[j]) u, v = g.find_edges(trace_eids[i, j], etype=metapath[j])
assert (u == traces[i, j]) and (v == traces[i, j + 1]) assert (u == traces[i, j]) and (v == traces[i, j + 1])
@unittest.skipIf(F._default_context_str == 'gpu', reason="Random walk with non-uniform prob is not supported in GPU.")
def test_non_uniform_random_walk(): def test_non_uniform_random_walk():
g2 = dgl.heterograph({ g2 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]) ('user', 'follow', 'user'): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0])
...@@ -61,7 +60,7 @@ def test_non_uniform_random_walk(): ...@@ -61,7 +60,7 @@ def test_non_uniform_random_walk():
check_random_walk(g4, metapath, traces, ntypes, 'p', trace_eids=eids) check_random_walk(g4, metapath, traces, ntypes, 'p', trace_eids=eids)
traces, eids, ntypes = dgl.sampling.random_walk( traces, eids, ntypes = dgl.sampling.random_walk(
g4, [0, 1, 2, 3, 0, 1, 2, 3], metapath=metapath, prob='p', g4, [0, 1, 2, 3, 0, 1, 2, 3], metapath=metapath, prob='p',
restart_prob=F.zeros((6,), F.float32, F.cpu()), return_eids=True) restart_prob=F.zeros((6,), F.float32, F.ctx()), return_eids=True)
check_random_walk(g4, metapath, traces, ntypes, 'p', trace_eids=eids) check_random_walk(g4, metapath, traces, ntypes, 'p', trace_eids=eids)
traces, eids, ntypes = dgl.sampling.random_walk( traces, eids, ntypes = dgl.sampling.random_walk(
g4, [0, 1, 2, 3, 0, 1, 2, 3], metapath=metapath + ['follow'], prob='p', g4, [0, 1, 2, 3, 0, 1, 2, 3], metapath=metapath + ['follow'], prob='p',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment