Unverified Commit eb08ef38 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Distributed] Edge-type-specific fanouts for heterogeneous graphs (#3558)

* first commit

* second commit

* spaghetti unit tests

* rewrite test
parent 156c17f3
...@@ -382,7 +382,9 @@ COOMatrix COORowWiseSampling( ...@@ -382,7 +382,9 @@ COOMatrix COORowWiseSampling(
* // etype = [0, 0, 0, 2, 1] * // etype = [0, 0, 0, 2, 1]
* COOMatrix coo = ...; * COOMatrix coo = ...;
* IdArray rows = ... ; // [0, 3] * IdArray rows = ... ; // [0, 3]
* COOMatrix sampled = COORowWisePerEtypeSampling(coo, rows, etype, 2, FloatArray(), false); * std::vector<int64_t> num_samples = {2, 2, 2};
* COOMatrix sampled = COORowWisePerEtypeSampling(coo, rows, etype, num_samples,
* FloatArray(), false);
* // possible sampled coo matrix: * // possible sampled coo matrix:
* // sampled.num_rows = 4 * // sampled.num_rows = 4
* // sampled.num_cols = 4 * // sampled.num_cols = 4
...@@ -405,7 +407,7 @@ COOMatrix COORowWisePerEtypeSampling( ...@@ -405,7 +407,7 @@ COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, COOMatrix mat,
IdArray rows, IdArray rows,
IdArray etypes, IdArray etypes,
int64_t num_samples, const std::vector<int64_t>& num_samples,
FloatArray prob = FloatArray(), FloatArray prob = FloatArray(),
bool replace = true, bool replace = true,
bool etype_sorted = false); bool etype_sorted = false);
......
...@@ -409,7 +409,9 @@ COOMatrix CSRRowWiseSampling( ...@@ -409,7 +409,9 @@ COOMatrix CSRRowWiseSampling(
* // etype = [0, 0, 0, 2, 1] * // etype = [0, 0, 0, 2, 1]
* CSRMatrix csr = ...; * CSRMatrix csr = ...;
* IdArray rows = ... ; // [0, 3] * IdArray rows = ... ; // [0, 3]
* COOMatrix sampled = CSRRowWisePerEtypeSampling(csr, rows, etype, 2, FloatArray(), false); * std::vector<int64_t> num_samples = {2, 2, 2};
* COOMatrix sampled = CSRRowWisePerEtypeSampling(csr, rows, etype, num_samples,
* FloatArray(), false);
* // possible sampled coo matrix: * // possible sampled coo matrix:
* // sampled.num_rows = 4 * // sampled.num_rows = 4
* // sampled.num_cols = 4 * // sampled.num_cols = 4
...@@ -420,7 +422,7 @@ COOMatrix CSRRowWiseSampling( ...@@ -420,7 +422,7 @@ COOMatrix CSRRowWiseSampling(
* \param mat Input CSR matrix. * \param mat Input CSR matrix.
* \param rows Rows to sample from. * \param rows Rows to sample from.
* \param etypes Edge types of each edge. * \param etypes Edge types of each edge.
* \param num_samples Number of samples * \param num_samples Number of samples to choose per edge type.
* \param prob Unnormalized probability array. Should be of the same length as the data array. * \param prob Unnormalized probability array. Should be of the same length as the data array.
* If an empty array is provided, assume uniform. * If an empty array is provided, assume uniform.
* \param replace True if sample with replacement * \param replace True if sample with replacement
...@@ -431,7 +433,7 @@ COOMatrix CSRRowWisePerEtypeSampling( ...@@ -431,7 +433,7 @@ COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, CSRMatrix mat,
IdArray rows, IdArray rows,
IdArray etypes, IdArray etypes,
int64_t num_samples, const std::vector<int64_t>& num_samples,
FloatArray prob = FloatArray(), FloatArray prob = FloatArray(),
bool replace = true, bool replace = true,
bool etype_sorted = false); bool etype_sorted = false);
......
...@@ -132,9 +132,6 @@ class MultiLayerNeighborSampler(NeighborSamplingMixin, BlockSampler): ...@@ -132,9 +132,6 @@ class MultiLayerNeighborSampler(NeighborSamplingMixin, BlockSampler):
fanout = self.fanouts[block_id] fanout = self.fanouts[block_id]
if isinstance(g, distributed.DistGraph): if isinstance(g, distributed.DistGraph):
if len(g.etypes) > 1: # heterogeneous distributed graph if len(g.etypes) > 1: # heterogeneous distributed graph
# The edge type is stored in g.edata[dgl.ETYPE]
assert isinstance(fanout, int), "For distributed training, " \
"we can only sample same number of neighbors for each edge type"
frontier = distributed.sample_etype_neighbors( frontier = distributed.sample_etype_neighbors(
g, seed_nodes, ETYPE, fanout, replace=self.replace) g, seed_nodes, ETYPE, fanout, replace=self.replace)
else: else:
......
...@@ -450,8 +450,9 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No ...@@ -450,8 +450,9 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
one key-value pair to make this API consistent with dgl.sampling.sample_neighbors. one key-value pair to make this API consistent with dgl.sampling.sample_neighbors.
etype_field : string etype_field : string
The field in g.edata storing the edge type. The field in g.edata storing the edge type.
fanout : int fanout : int or dict[etype, int]
The number of edges to be sampled for each node per edge type. The number of edges to be sampled for each node per edge type. If an integer
is given, DGL assumes that the same fanout is applied to every edge type.
If -1 is given, all of the neighbors will be selected. If -1 is given, all of the neighbors will be selected.
edge_dir : str, optional edge_dir : str, optional
...@@ -479,6 +480,11 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No ...@@ -479,6 +480,11 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
DGLGraph DGLGraph
A sampled subgraph containing only the sampled neighboring edges. It is on CPU. A sampled subgraph containing only the sampled neighboring edges. It is on CPU.
""" """
if isinstance(fanout, int):
fanout = F.full_1d(len(g.etypes), fanout, F.int64, F.cpu())
else:
fanout = F.tensor([fanout[etype] for etype in g.etypes], dtype=F.int64)
gpb = g.get_partition_book() gpb = g.get_partition_book()
if isinstance(nodes, dict): if isinstance(nodes, dict):
homo_nids = [] homo_nids = []
......
...@@ -36,14 +36,11 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No ...@@ -36,14 +36,11 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
If a single tensor is given, the graph must only have one type of nodes. If a single tensor is given, the graph must only have one type of nodes.
etype_field : string etype_field : string
The field in g.edata storing the edge type. The field in g.edata storing the edge type.
fanout : int fanout : Tensor
The number of edges to be sampled for each node on each edge type. The number of edges to be sampled for each node per edge type. Must be a
1D tensor with the number of elements same as the number of edge types.
This argument can only take a single int. DGL will sample this number of edges for
each node for every edge type.
If -1 is given for a single edge type, all the neighboring edges with that edge If -1 is given, all of the neighbors will be selected.
type will be selected.
edge_dir : str, optional edge_dir : str, optional
Determines whether to sample inbound or outbound edges. Determines whether to sample inbound or outbound edges.
...@@ -99,15 +96,19 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No ...@@ -99,15 +96,19 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
if etype_field not in g.edata: if etype_field not in g.edata:
raise DGLError("The graph should have {} in the edge data" \ raise DGLError("The graph should have {} in the edge data" \
"representing the edge type.".format(etype_field)) "representing the edge type.".format(etype_field))
if isinstance(fanout, int) is False: # (BarclayII) because the homogenized graph no longer contains the *name* of edge
raise DGLError("The fanout should be an integer") # types, the fanout argument can no longer be a dict of etypes and ints, as opposed
if isinstance(nodes, dict) is True: # to sample_neighbors.
if not F.is_tensor(fanout):
raise DGLError("The fanout should be a tensor")
if isinstance(nodes, dict):
assert len(nodes) == 1, "The input graph should not have node types" assert len(nodes) == 1, "The input graph should not have node types"
nodes = list(nodes.values())[0] nodes = list(nodes.values())[0]
nodes = F.to_dgl_nd(utils.prepare_tensor(g, nodes, 'nodes')) nodes = F.to_dgl_nd(utils.prepare_tensor(g, nodes, 'nodes'))
# treat etypes as int32, it is much cheaper than int64 # treat etypes as int32, it is much cheaper than int64
# TODO(xiangsx): int8 can be a better choice. # TODO(xiangsx): int8 can be a better choice.
etypes = F.to_dgl_nd(F.astype(g.edata[etype_field], ty=F.int32)) etypes = F.to_dgl_nd(F.astype(g.edata[etype_field], ty=F.int32))
fanout = F.to_dgl_nd(fanout)
if prob is None: if prob is None:
prob_array = nd.array([], ctx=nd.cpu()) prob_array = nd.array([], ctx=nd.cpu())
......
...@@ -570,7 +570,8 @@ COOMatrix CSRRowWiseSampling( ...@@ -570,7 +570,8 @@ COOMatrix CSRRowWiseSampling(
COOMatrix CSRRowWisePerEtypeSampling( COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, IdArray etypes, CSRMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, FloatArray prob, bool replace, bool etype_sorted) { const std::vector<int64_t>& num_samples, FloatArray prob, bool replace,
bool etype_sorted) {
COOMatrix ret; COOMatrix ret;
ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWisePerEtypeSampling", { ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWisePerEtypeSampling", {
if (IsNullArray(prob)) { if (IsNullArray(prob)) {
...@@ -807,7 +808,8 @@ COOMatrix COORowWiseSampling( ...@@ -807,7 +808,8 @@ COOMatrix COORowWiseSampling(
COOMatrix COORowWisePerEtypeSampling( COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows, IdArray etypes, COOMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, FloatArray prob, bool replace, bool etype_sorted) { const std::vector<int64_t>& num_samples, FloatArray prob, bool replace,
bool etype_sorted) {
COOMatrix ret; COOMatrix ret;
ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWisePerEtypeSampling", { ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWisePerEtypeSampling", {
if (IsNullArray(prob)) { if (IsNullArray(prob)) {
......
...@@ -168,7 +168,8 @@ COOMatrix CSRRowWiseSampling( ...@@ -168,7 +168,8 @@ COOMatrix CSRRowWiseSampling(
template <DLDeviceType XPU, typename IdType, typename FloatType> template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWisePerEtypeSampling( COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, IdArray etypes, CSRMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, FloatArray prob, bool replace, bool etype_sorted); const std::vector<int64_t>& num_samples, FloatArray prob, bool replace,
bool etype_sorted);
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform( COOMatrix CSRRowWiseSamplingUniform(
...@@ -176,7 +177,7 @@ COOMatrix CSRRowWiseSamplingUniform( ...@@ -176,7 +177,7 @@ COOMatrix CSRRowWiseSamplingUniform(
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
COOMatrix CSRRowWisePerEtypeSamplingUniform( COOMatrix CSRRowWisePerEtypeSamplingUniform(
CSRMatrix mat, IdArray rows, IdArray etypes, int64_t num_samples, CSRMatrix mat, IdArray rows, IdArray etypes, const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted); bool replace, bool etype_sorted);
// FloatType is the type of weight data. // FloatType is the type of weight data.
...@@ -265,7 +266,7 @@ COOMatrix COORowWiseSampling( ...@@ -265,7 +266,7 @@ COOMatrix COORowWiseSampling(
template <DLDeviceType XPU, typename IdType, typename FloatType> template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWisePerEtypeSampling( COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows, IdArray etypes, COOMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, FloatArray prob, bool replace, bool etype_sorted); const std::vector<int64_t>& num_samples, FloatArray prob, bool replace, bool etype_sorted);
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
COOMatrix COORowWiseSamplingUniform( COOMatrix COORowWiseSamplingUniform(
...@@ -273,7 +274,7 @@ COOMatrix COORowWiseSamplingUniform( ...@@ -273,7 +274,7 @@ COOMatrix COORowWiseSamplingUniform(
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
COOMatrix COORowWisePerEtypeSamplingUniform( COOMatrix COORowWisePerEtypeSamplingUniform(
COOMatrix mat, IdArray rows, IdArray etypes, int64_t num_samples, COOMatrix mat, IdArray rows, IdArray etypes, const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted); bool replace, bool etype_sorted);
// FloatType is the type of weight data. // FloatType is the type of weight data.
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <dgl/runtime/parallel_for.h>
#include <functional> #include <functional>
#include <algorithm> #include <algorithm>
#include <string> #include <string>
...@@ -56,13 +57,14 @@ using PickFn = std::function<void( ...@@ -56,13 +57,14 @@ using PickFn = std::function<void(
// //
// \param off Starting offset of this row. // \param off Starting offset of this row.
// \param et_offset Starting offset of this range. // \param et_offset Starting offset of this range.
// \param cur_et The edge type.
// \param et_len Length of the range. // \param et_len Length of the range.
// \param et_idx A map from local idx to column id. // \param et_idx A map from local idx to column id.
// \param data Pointer of the data indices. // \param data Pointer of the data indices.
// \param out_idx Picked indices in [et_offset, et_offset + et_len). // \param out_idx Picked indices in [et_offset, et_offset + et_len).
template <typename IdxType> template <typename IdxType>
using RangePickFn = std::function<void( using RangePickFn = std::function<void(
IdxType off, IdxType et_offset, IdxType et_len, IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
const std::vector<IdxType> &et_idx, const IdxType* data, const std::vector<IdxType> &et_idx, const IdxType* data,
IdxType* out_idx)>; IdxType* out_idx)>;
...@@ -196,123 +198,138 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, ...@@ -196,123 +198,138 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
// OpenMP parallelization on rows because each row performs computation independently. // OpenMP parallelization on rows because each row performs computation independently.
template <typename IdxType> template <typename IdxType>
COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes, COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes,
int64_t num_picks, bool replace, bool etype_sorted, const std::vector<int64_t>& num_picks, bool replace,
RangePickFn<IdxType> pick_fn) { bool etype_sorted, RangePickFn<IdxType> pick_fn) {
using namespace aten; using namespace aten;
const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data); const IdxType* indptr = mat.indptr.Ptr<IdxType>();
const IdxType* indices = static_cast<IdxType*>(mat.indices->data); const IdxType* indices = mat.indices.Ptr<IdxType>();
const IdxType* data = CSRHasData(mat)? static_cast<IdxType*>(mat.data->data) : nullptr; const IdxType* data = CSRHasData(mat)? mat.data.Ptr<IdxType>() : nullptr;
const IdxType* rows_data = static_cast<IdxType*>(rows->data); const IdxType* rows_data = rows.Ptr<IdxType>();
const int32_t* etype_data = static_cast<int32_t*>(etypes->data); const int32_t* etype_data = etypes.Ptr<int32_t>();
const int64_t num_rows = rows->shape[0]; const int64_t num_rows = rows->shape[0];
const auto& ctx = mat.indptr->ctx; const auto& ctx = mat.indptr->ctx;
CHECK_EQ(etypes->dtype.bits / 8, sizeof(int32_t)); const int64_t num_etypes = num_picks.size();
CHECK_EQ(etypes->dtype.bits / 8, sizeof(int32_t)) << "etypes must be int32";
std::vector<IdArray> picked_rows(rows->shape[0]); std::vector<IdArray> picked_rows(rows->shape[0]);
std::vector<IdArray> picked_cols(rows->shape[0]); std::vector<IdArray> picked_cols(rows->shape[0]);
std::vector<IdArray> picked_idxs(rows->shape[0]); std::vector<IdArray> picked_idxs(rows->shape[0]);
#pragma omp parallel for // Check if the number of picks have the same value.
for (int64_t i = 0; i < num_rows; ++i) { // If so, we can potentially speed up if we have a node with total number of neighbors
const IdxType rid = rows_data[i]; // less than the given number of picks with replace=False.
CHECK_LT(rid, mat.num_rows); bool same_num_pick = true;
const IdxType off = indptr[rid]; int64_t num_pick_value = num_picks[0];
const IdxType len = indptr[rid + 1] - off; for (int64_t num_pick : num_picks) {
if (num_pick_value != num_pick) {
// do something here same_num_pick = false;
if (len == 0) { break;
picked_rows[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
picked_cols[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
picked_idxs[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
continue;
} }
}
// fast path runtime::parallel_for(0, num_rows, [&](size_t b, size_t e) {
if (len <= num_picks && !replace) { for (int64_t i = b; i < e; ++i) {
IdArray rows = Full(rid, len, sizeof(IdxType) * 8, ctx); const IdxType rid = rows_data[i];
IdArray cols = Full(-1, len, sizeof(IdxType) * 8, ctx); CHECK_LT(rid, mat.num_rows);
IdArray idx = Full(-1, len, sizeof(IdxType) * 8, ctx); const IdxType off = indptr[rid];
IdxType* cdata = static_cast<IdxType*>(cols->data); const IdxType len = indptr[rid + 1] - off;
IdxType* idata = static_cast<IdxType*>(idx->data);
for (int64_t j = 0; j < len; ++j) { // do something here
cdata[j] = indices[off + j]; if (len == 0) {
idata[j] = data ? data[off + j] : off + j; picked_rows[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
} picked_cols[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
picked_rows[i] = rows; picked_idxs[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
picked_cols[i] = cols; continue;
picked_idxs[i] = idx;
} else {
// need to do per edge type sample
std::vector<IdxType> rows;
std::vector<IdxType> cols;
std::vector<IdxType> idx;
std::vector<IdxType> et(len);
std::vector<IdxType> et_idx(len);
std::iota(et_idx.begin(), et_idx.end(), 0);
for (int64_t j = 0; j < len; ++j) {
et[j] = data ? etype_data[data[off+j]] : etype_data[off+j];
} }
if (!etype_sorted) // the edge type is sorted, not need to sort it
std::sort(et_idx.begin(), et_idx.end(), // fast path
[&et](IdxType i1, IdxType i2) {return et[i1] < et[i2];}); if (same_num_pick && len <= num_pick_value && !replace) {
IdArray rows = Full(rid, len, sizeof(IdxType) * 8, ctx);
IdxType cur_et = et[et_idx[0]]; IdArray cols = Full(-1, len, sizeof(IdxType) * 8, ctx);
int64_t et_offset = 0; IdArray idx = Full(-1, len, sizeof(IdxType) * 8, ctx);
int64_t et_len = 1; IdxType* cdata = cols.Ptr<IdxType>();
for (int64_t j = 0; j < len; ++j) { IdxType* idata = idx.Ptr<IdxType>();
if ((j+1 == len) || cur_et != et[et_idx[j+1]]) { for (int64_t j = 0; j < len; ++j) {
// 1 end of the current etype cdata[j] = indices[off + j];
// 2 end of the row idata[j] = data ? data[off + j] : off + j;
// random pick for current etype }
if (et_len <= num_picks && !replace) { picked_rows[i] = rows;
// fast path, select all picked_cols[i] = cols;
for (int64_t k = 0; k < et_len; ++k) { picked_idxs[i] = idx;
rows.push_back(rid); } else {
cols.push_back(indices[off+et_idx[et_offset+k]]); // need to do per edge type sample
if (data) std::vector<IdxType> rows;
idx.push_back(data[off+et_idx[et_offset+k]]); std::vector<IdxType> cols;
else std::vector<IdxType> idx;
idx.push_back(off+et_idx[et_offset+k]);
std::vector<IdxType> et(len);
std::vector<IdxType> et_idx(len);
std::iota(et_idx.begin(), et_idx.end(), 0);
for (int64_t j = 0; j < len; ++j) {
et[j] = data ? etype_data[data[off+j]] : etype_data[off+j];
}
if (!etype_sorted) // the edge type is sorted, not need to sort it
std::sort(et_idx.begin(), et_idx.end(),
[&et](IdxType i1, IdxType i2) {return et[i1] < et[i2];});
CHECK(et[et_idx[len - 1]] < num_etypes) <<
"etype values exceed the number of fanouts";
IdxType cur_et = et[et_idx[0]];
int64_t et_offset = 0;
int64_t et_len = 1;
for (int64_t j = 0; j < len; ++j) {
if ((j+1 == len) || cur_et != et[et_idx[j+1]]) {
// 1 end of the current etype
// 2 end of the row
// random pick for current etype
if (et_len <= num_picks[cur_et] && !replace) {
// fast path, select all
for (int64_t k = 0; k < et_len; ++k) {
rows.push_back(rid);
cols.push_back(indices[off+et_idx[et_offset+k]]);
if (data)
idx.push_back(data[off+et_idx[et_offset+k]]);
else
idx.push_back(off+et_idx[et_offset+k]);
}
} else {
IdArray picked_idx = Full(-1, num_picks[cur_et], sizeof(IdxType) * 8, ctx);
IdxType* picked_idata = static_cast<IdxType*>(picked_idx->data);
// need call random pick
pick_fn(off, et_offset, cur_et,
et_len, et_idx,
data, picked_idata);
for (int64_t k = 0; k < num_picks[cur_et]; ++k) {
const IdxType picked = picked_idata[k];
rows.push_back(rid);
cols.push_back(indices[off+et_idx[et_offset+picked]]);
if (data)
idx.push_back(data[off+et_idx[et_offset+picked]]);
else
idx.push_back(off+et_idx[et_offset+picked]);
}
} }
if (j+1 == len)
break;
// next etype
cur_et = et[et_idx[j+1]];
et_offset = j+1;
et_len = 1;
} else { } else {
IdArray picked_idx = Full(-1, num_picks, sizeof(IdxType) * 8, ctx); et_len++;
IdxType* picked_idata = static_cast<IdxType*>(picked_idx->data);
// need call random pick
pick_fn(off, et_offset,
et_len, et_idx,
data, picked_idata);
for (int64_t k = 0; k < num_picks; ++k) {
const IdxType picked = picked_idata[k];
rows.push_back(rid);
cols.push_back(indices[off+et_idx[et_offset+picked]]);
if (data)
idx.push_back(data[off+et_idx[et_offset+picked]]);
else
idx.push_back(off+et_idx[et_offset+picked]);
}
} }
if (j+1 == len)
break;
// next etype
cur_et = et[et_idx[j+1]];
et_offset = j+1;
et_len = 1;
} else {
et_len++;
} }
}
picked_rows[i] = VecToIdArray(rows, sizeof(IdxType) * 8, ctx);
picked_cols[i] = VecToIdArray(cols, sizeof(IdxType) * 8, ctx);
picked_idxs[i] = VecToIdArray(idx, sizeof(IdxType) * 8, ctx);
} // end processing one row
CHECK_EQ(picked_rows[i]->shape[0], picked_cols[i]->shape[0]); picked_rows[i] = VecToIdArray(rows, sizeof(IdxType) * 8, ctx);
CHECK_EQ(picked_rows[i]->shape[0], picked_idxs[i]->shape[0]); picked_cols[i] = VecToIdArray(cols, sizeof(IdxType) * 8, ctx);
} // end processing all rows picked_idxs[i] = VecToIdArray(idx, sizeof(IdxType) * 8, ctx);
} // end processing one row
CHECK_EQ(picked_rows[i]->shape[0], picked_cols[i]->shape[0]);
CHECK_EQ(picked_rows[i]->shape[0], picked_idxs[i]->shape[0]);
} // end processing all rows
});
IdArray picked_row = Concat(picked_rows); IdArray picked_row = Concat(picked_rows);
IdArray picked_col = Concat(picked_cols); IdArray picked_col = Concat(picked_cols);
...@@ -342,8 +359,8 @@ COOMatrix COORowWisePick(COOMatrix mat, IdArray rows, ...@@ -342,8 +359,8 @@ COOMatrix COORowWisePick(COOMatrix mat, IdArray rows,
// row-wise pick on the CSR matrix and rectifies the returned results. // row-wise pick on the CSR matrix and rectifies the returned results.
template <typename IdxType> template <typename IdxType>
COOMatrix COORowWisePerEtypePick(COOMatrix mat, IdArray rows, IdArray etypes, COOMatrix COORowWisePerEtypePick(COOMatrix mat, IdArray rows, IdArray etypes,
int64_t num_picks, bool replace, bool etype_sorted, const std::vector<int64_t>& num_picks, bool replace,
RangePickFn<IdxType> pick_fn) { bool etype_sorted, RangePickFn<IdxType> pick_fn) {
using namespace aten; using namespace aten;
const auto& csr = COOToCSR(COOSliceRows(mat, rows)); const auto& csr = COOToCSR(COOSliceRows(mat, rows));
const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx); const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
......
...@@ -46,9 +46,9 @@ inline PickFn<IdxType> GetSamplingPickFn( ...@@ -46,9 +46,9 @@ inline PickFn<IdxType> GetSamplingPickFn(
template <typename IdxType, typename FloatType> template <typename IdxType, typename FloatType>
inline RangePickFn<IdxType> GetSamplingRangePickFn( inline RangePickFn<IdxType> GetSamplingRangePickFn(
int64_t num_samples, FloatArray prob, bool replace) { const std::vector<int64_t>& num_samples, FloatArray prob, bool replace) {
RangePickFn<IdxType> pick_fn = [prob, num_samples, replace] RangePickFn<IdxType> pick_fn = [prob, num_samples, replace]
(IdxType off, IdxType et_offset, IdxType et_len, (IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
const std::vector<IdxType> &et_idx, const std::vector<IdxType> &et_idx,
const IdxType* data, IdxType* out_idx) { const IdxType* data, IdxType* out_idx) {
const FloatType* p_data = static_cast<FloatType*>(prob->data); const FloatType* p_data = static_cast<FloatType*>(prob->data);
...@@ -62,7 +62,7 @@ inline RangePickFn<IdxType> GetSamplingRangePickFn( ...@@ -62,7 +62,7 @@ inline RangePickFn<IdxType> GetSamplingRangePickFn(
} }
RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>( RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>(
num_samples, probs, out_idx, replace); num_samples[cur_et], probs, out_idx, replace);
}; };
return pick_fn; return pick_fn;
} }
...@@ -85,13 +85,13 @@ inline PickFn<IdxType> GetSamplingUniformPickFn( ...@@ -85,13 +85,13 @@ inline PickFn<IdxType> GetSamplingUniformPickFn(
template <typename IdxType> template <typename IdxType>
inline RangePickFn<IdxType> GetSamplingUniformRangePickFn( inline RangePickFn<IdxType> GetSamplingUniformRangePickFn(
int64_t num_samples, bool replace) { const std::vector<int64_t>& num_samples, bool replace) {
RangePickFn<IdxType> pick_fn = [num_samples, replace] RangePickFn<IdxType> pick_fn = [num_samples, replace]
(IdxType off, IdxType et_offset, IdxType et_len, (IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
const std::vector<IdxType> &et_idx, const std::vector<IdxType> &et_idx,
const IdxType* data, IdxType* out_idx) { const IdxType* data, IdxType* out_idx) {
RandomEngine::ThreadLocal()->UniformChoice<IdxType>( RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
num_samples, et_len, out_idx, replace); num_samples[cur_et], et_len, out_idx, replace);
}; };
return pick_fn; return pick_fn;
} }
...@@ -136,21 +136,21 @@ template COOMatrix CSRRowWiseSampling<kDLCPU, int64_t, double>( ...@@ -136,21 +136,21 @@ template COOMatrix CSRRowWiseSampling<kDLCPU, int64_t, double>(
template <DLDeviceType XPU, typename IdxType, typename FloatType> template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes, COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, FloatArray prob, const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted) { FloatArray prob, bool replace, bool etype_sorted) {
CHECK(prob.defined()); CHECK(prob.defined());
auto pick_fn = GetSamplingRangePickFn<IdxType, FloatType>(num_samples, prob, replace); auto pick_fn = GetSamplingRangePickFn<IdxType, FloatType>(num_samples, prob, replace);
return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn); return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
} }
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int32_t, float>( template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int32_t, float>(
CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool); CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int64_t, float>( template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int64_t, float>(
CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool); CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int32_t, double>( template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int32_t, double>(
CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool); CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int64_t, double>( template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int64_t, double>(
CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool); CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template <DLDeviceType XPU, typename IdxType> template <DLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows, COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows,
...@@ -166,16 +166,16 @@ template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int64_t>( ...@@ -166,16 +166,16 @@ template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int64_t>(
template <DLDeviceType XPU, typename IdxType> template <DLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray etypes, COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted) { bool replace, bool etype_sorted) {
auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace); auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn); return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
} }
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDLCPU, int32_t>( template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDLCPU, int32_t>(
CSRMatrix, IdArray, IdArray, int64_t, bool, bool); CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDLCPU, int64_t>( template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDLCPU, int64_t>(
CSRMatrix, IdArray, IdArray, int64_t, bool, bool); CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
template <DLDeviceType XPU, typename IdxType, typename FloatType> template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased( COOMatrix CSRRowWiseSamplingBiased(
...@@ -225,21 +225,21 @@ template COOMatrix COORowWiseSampling<kDLCPU, int64_t, double>( ...@@ -225,21 +225,21 @@ template COOMatrix COORowWiseSampling<kDLCPU, int64_t, double>(
template <DLDeviceType XPU, typename IdxType, typename FloatType> template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes, COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, FloatArray prob, const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted) { FloatArray prob, bool replace, bool etype_sorted) {
CHECK(prob.defined()); CHECK(prob.defined());
auto pick_fn = GetSamplingRangePickFn<IdxType, FloatType>(num_samples, prob, replace); auto pick_fn = GetSamplingRangePickFn<IdxType, FloatType>(num_samples, prob, replace);
return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn); return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
} }
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int32_t, float>( template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int32_t, float>(
COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool); COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int64_t, float>( template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int64_t, float>(
COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool); COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int32_t, double>( template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int32_t, double>(
COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool); COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int64_t, double>( template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int64_t, double>(
COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool); COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template <DLDeviceType XPU, typename IdxType> template <DLDeviceType XPU, typename IdxType>
COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows, COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows,
...@@ -255,15 +255,16 @@ template COOMatrix COORowWiseSamplingUniform<kDLCPU, int64_t>( ...@@ -255,15 +255,16 @@ template COOMatrix COORowWiseSamplingUniform<kDLCPU, int64_t>(
template <DLDeviceType XPU, typename IdxType> template <DLDeviceType XPU, typename IdxType>
COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray etypes, COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, bool replace, bool etype_sorted) { const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted) {
auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace); auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn); return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
} }
template COOMatrix COORowWisePerEtypeSamplingUniform<kDLCPU, int32_t>( template COOMatrix COORowWisePerEtypeSamplingUniform<kDLCPU, int32_t>(
COOMatrix, IdArray, IdArray, int64_t, bool, bool); COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
template COOMatrix COORowWisePerEtypeSamplingUniform<kDLCPU, int64_t>( template COOMatrix COORowWisePerEtypeSamplingUniform<kDLCPU, int64_t>(
COOMatrix, IdArray, IdArray, int64_t, bool, bool); COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/aten/macro.h>
#include <dgl/sampling/neighbor.h> #include <dgl/sampling/neighbor.h>
#include "../../../c_api_common.h" #include "../../../c_api_common.h"
#include "../../unit_graph.h" #include "../../unit_graph.h"
...@@ -156,7 +157,7 @@ HeteroSubgraph SampleNeighborsEType( ...@@ -156,7 +157,7 @@ HeteroSubgraph SampleNeighborsEType(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
const IdArray nodes, const IdArray nodes,
const IdArray etypes, const IdArray etypes,
const int64_t fanout, const std::vector<int64_t>& fanouts,
EdgeDir dir, EdgeDir dir,
const IdArray prob, const IdArray prob,
bool replace, bool replace,
...@@ -173,13 +174,23 @@ HeteroSubgraph SampleNeighborsEType( ...@@ -173,13 +174,23 @@ HeteroSubgraph SampleNeighborsEType(
dgl_type_t etype = 0; dgl_type_t etype = 0;
const dgl_type_t src_vtype = 0; const dgl_type_t src_vtype = 0;
const dgl_type_t dst_vtype = 0; const dgl_type_t dst_vtype = 0;
if (num_nodes == 0 || fanout == 0) {
bool same_fanout = true;
int64_t fanout_value = fanouts[0];
for (auto fanout : fanouts) {
if (fanout != fanout_value) {
same_fanout = false;
break;
}
}
if (num_nodes == 0 || (same_fanout && fanout_value == 0)) {
subrels[etype] = UnitGraph::Empty(1, subrels[etype] = UnitGraph::Empty(1,
hg->NumVertices(src_vtype), hg->NumVertices(src_vtype),
hg->NumVertices(dst_vtype), hg->NumVertices(dst_vtype),
hg->DataType(), hg->Context()); hg->DataType(), hg->Context());
induced_edges[etype] = aten::NullArray(); induced_edges[etype] = aten::NullArray();
} else if (fanout == -1) { } else if (same_fanout && fanout_value == -1) {
const auto &earr = (dir == EdgeDir::kOut) ? const auto &earr = (dir == EdgeDir::kOut) ?
hg->OutEdges(etype, nodes) : hg->OutEdges(etype, nodes) :
hg->InEdges(etype, nodes); hg->InEdges(etype, nodes);
...@@ -201,21 +212,21 @@ HeteroSubgraph SampleNeighborsEType( ...@@ -201,21 +212,21 @@ HeteroSubgraph SampleNeighborsEType(
if (dir == EdgeDir::kIn) { if (dir == EdgeDir::kIn) {
sampled_coo = aten::COOTranspose(aten::COORowWisePerEtypeSampling( sampled_coo = aten::COOTranspose(aten::COORowWisePerEtypeSampling(
aten::COOTranspose(hg->GetCOOMatrix(etype)), aten::COOTranspose(hg->GetCOOMatrix(etype)),
nodes, etypes, fanout, prob, replace)); nodes, etypes, fanouts, prob, replace));
} else { } else {
sampled_coo = aten::COORowWisePerEtypeSampling( sampled_coo = aten::COORowWisePerEtypeSampling(
hg->GetCOOMatrix(etype), nodes, etypes, fanout, prob, replace, etype_sorted); hg->GetCOOMatrix(etype), nodes, etypes, fanouts, prob, replace, etype_sorted);
} }
break; break;
case SparseFormat::kCSR: case SparseFormat::kCSR:
CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix."; CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix.";
sampled_coo = aten::CSRRowWisePerEtypeSampling( sampled_coo = aten::CSRRowWisePerEtypeSampling(
hg->GetCSRMatrix(etype), nodes, etypes, fanout, prob, replace, etype_sorted); hg->GetCSRMatrix(etype), nodes, etypes, fanouts, prob, replace, etype_sorted);
break; break;
case SparseFormat::kCSC: case SparseFormat::kCSC:
CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix."; CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
sampled_coo = aten::CSRRowWisePerEtypeSampling( sampled_coo = aten::CSRRowWisePerEtypeSampling(
hg->GetCSCMatrix(etype), nodes, etypes, fanout, prob, replace, etype_sorted); hg->GetCSCMatrix(etype), nodes, etypes, fanouts, prob, replace, etype_sorted);
sampled_coo = aten::COOTranspose(sampled_coo); sampled_coo = aten::COOTranspose(sampled_coo);
break; break;
default: default:
...@@ -405,7 +416,7 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType") ...@@ -405,7 +416,7 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType")
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
IdArray nodes = args[1]; IdArray nodes = args[1];
IdArray etypes = args[2]; IdArray etypes = args[2];
const int64_t fanout = args[3]; IdArray fanout = args[3];
const std::string dir_str = args[4]; const std::string dir_str = args[4];
IdArray prob = args[5]; IdArray prob = args[5];
const bool replace = args[6]; const bool replace = args[6];
...@@ -414,10 +425,12 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType") ...@@ -414,10 +425,12 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType")
CHECK(dir_str == "in" || dir_str == "out") CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\"."; << "Invalid edge direction. Must be \"in\" or \"out\".";
EdgeDir dir = (dir_str == "in")? EdgeDir::kIn : EdgeDir::kOut; EdgeDir dir = (dir_str == "in")? EdgeDir::kIn : EdgeDir::kOut;
CHECK_INT64(fanout, "fanout");
std::vector<int64_t> fanout_vec = fanout.ToVector<int64_t>();
std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph); std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
*subg = sampling::SampleNeighborsEType( *subg = sampling::SampleNeighborsEType(
hg.sptr(), nodes, etypes, fanout, dir, prob, replace, etype_sorted); hg.sptr(), nodes, etypes, fanout_vec, dir, prob, replace, etype_sorted);
*rv = HeteroSubgraphRef(subg); *rv = HeteroSubgraphRef(subg);
}); });
......
...@@ -725,106 +725,75 @@ def test_sample_neighbors_biased_bipartite(): ...@@ -725,106 +725,75 @@ def test_sample_neighbors_biased_bipartite():
check_num(subg.edges()[1], tag) check_num(subg.edges()[1], tag)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented") @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors_etype_homogeneous(): @pytest.mark.parametrize('format_', ['coo', 'csr', 'csc'])
@pytest.mark.parametrize('direction', ['in', 'out'])
@pytest.mark.parametrize('replace', [False, True])
def test_sample_neighbors_etype_homogeneous(format_, direction, replace):
num_nodes = 100 num_nodes = 100
rare_cnt = 4 rare_cnt = 4
g = create_etype_test_graph(100, 30, rare_cnt) g = create_etype_test_graph(100, 30, rare_cnt)
h_g = dgl.to_homogeneous(g) h_g = dgl.to_homogeneous(g)
seed_ntype = g.get_ntype_id("u") seed_ntype = g.get_ntype_id("u")
seeds = F.nonzero_1d(h_g.ndata[dgl.NTYPE] == seed_ntype) seeds = F.nonzero_1d(h_g.ndata[dgl.NTYPE] == seed_ntype)
fanouts = F.tensor([6, 5, 4, 3, 2], dtype=F.int64)
def check_num(nodes, replace):
nodes = F.asnumpy(nodes) def check_num(h_g, all_src, all_dst, subg, replace, fanouts, direction):
cnt = [sum(nodes == i) for i in range(num_nodes)] src, dst = subg.edges()
num_etypes = F.asnumpy(h_g.edata[dgl.ETYPE]).max()
for i in range(20): etype_array = F.asnumpy(subg.edata[dgl.ETYPE])
if i < rare_cnt: src = F.asnumpy(src)
if replace is False: dst = F.asnumpy(dst)
assert cnt[i] == 22 fanouts = F.asnumpy(fanouts)
else:
assert cnt[i] == 30 all_etype_array = F.asnumpy(h_g.edata[dgl.ETYPE])
all_src = F.asnumpy(all_src)
all_dst = F.asnumpy(all_dst)
src_per_etype = []
dst_per_etype = []
for etype in range(num_etypes):
src_per_etype.append(src[etype_array == etype])
dst_per_etype.append(dst[etype_array == etype])
if replace:
if direction == 'in':
in_degree_per_etype = [np.bincount(d) for d in dst_per_etype]
for in_degree, fanout in zip(in_degree_per_etype, fanouts):
assert np.all(in_degree == fanout)
else: else:
if replace is False: out_degree_per_etype = [np.bincount(s) for s in src_per_etype]
assert cnt[i] == 12 for out_degree, fanout in zip(out_degree_per_etype, fanouts):
else: assert np.all(out_degree == fanout)
assert cnt[i] == 20 else:
if direction == 'in':
# graph with coo format for v in set(dst):
coo_g = h_g.formats('coo') u = src[dst == v]
for _ in range(5): et = etype_array[dst == v]
subg = dgl.sampling.sample_etype_neighbors(coo_g, seeds, dgl.ETYPE, 10, replace=False) all_u = all_src[all_dst == v]
check_num(subg.edges()[1], False) all_et = all_etype_array[all_dst == v]
for etype in set(et):
for _ in range(5): u_etype = set(u[et == etype])
subg = dgl.sampling.sample_etype_neighbors(coo_g, seeds, dgl.ETYPE, 10, replace=True) all_u_etype = set(all_u[all_et == etype])
check_num(subg.edges()[1], True) assert (len(u_etype) == fanouts[etype]) or (u_etype == all_u_etype)
# graph with csr format
csr_g = h_g.formats('csr')
csr_g = csr_g.formats(['csr','csc','coo'])
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(csr_g, seeds, dgl.ETYPE, 10, replace=False)
check_num(subg.edges()[1], False)
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(csr_g, seeds, dgl.ETYPE, 10, replace=True)
check_num(subg.edges()[1], True)
# graph with csc format
csc_g = h_g.formats('csc')
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(csc_g, seeds, dgl.ETYPE, 10, replace=False)
check_num(subg.edges()[1], False)
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(csc_g, seeds, dgl.ETYPE, 10, replace=True)
check_num(subg.edges()[1], True)
def check_num2(nodes, replace):
nodes = F.asnumpy(nodes)
cnt = [sum(nodes == i) for i in range(num_nodes)]
for i in range(20):
if replace is False:
assert cnt[i] == 7
else: else:
assert cnt[i] == 10 for u in set(src):
v = dst[src == u]
# edge dir out et = etype_array[src == u]
# graph with coo format all_v = all_dst[all_src == u]
coo_g = h_g.formats('coo') all_et = all_etype_array[all_src == u]
for _ in range(5): for etype in set(et):
subg = dgl.sampling.sample_etype_neighbors( v_etype = set(v[et == etype])
coo_g, seeds, dgl.ETYPE, 5, edge_dir='out', replace=False) all_v_etype = set(all_v[all_et == etype])
check_num2(subg.edges()[0], False) assert (len(v_etype) == fanouts[etype]) or (v_etype == all_v_etype)
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors( all_src, all_dst = h_g.edges()
coo_g, seeds, dgl.ETYPE, 5, edge_dir='out', replace=True) h_g = h_g.formats(format_)
check_num2(subg.edges()[0], True) if (direction, format_) in [('in', 'csr'), ('out', 'csc')]:
# graph with csr format h_g = h_g.formats(['csc', 'csr', 'coo'])
csr_g = h_g.formats('csr')
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(
csr_g, seeds, dgl.ETYPE, 5, edge_dir='out', replace=False)
check_num2(subg.edges()[0], False)
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(
csr_g, seeds, dgl.ETYPE, 5, edge_dir='out', replace=True)
check_num2(subg.edges()[0], True)
# graph with csc format
csc_g = h_g.formats('csc')
csc_g = csc_g.formats(['csc','csr','coo'])
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(
csc_g, seeds, dgl.ETYPE, 5, edge_dir='out', replace=False)
check_num2(subg.edges()[0], False)
for _ in range(5): for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors( subg = dgl.sampling.sample_etype_neighbors(
csc_g, seeds, dgl.ETYPE, 5, edge_dir='out', replace=True) h_g, seeds, dgl.ETYPE, fanouts, replace=replace, edge_dir=direction)
check_num2(subg.edges()[0], True) check_num(h_g, all_src, all_dst, subg, replace, fanouts, direction)
@pytest.mark.parametrize('dtype', ['int32', 'int64']) @pytest.mark.parametrize('dtype', ['int32', 'int64'])
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented") @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
...@@ -920,7 +889,9 @@ def test_sample_neighbors_exclude_edges_homoG(dtype): ...@@ -920,7 +889,9 @@ def test_sample_neighbors_exclude_edges_homoG(dtype):
if __name__ == '__main__': if __name__ == '__main__':
test_sample_neighbors_etype_homogeneous() from itertools import product
for args in product(['coo', 'csr', 'csc'], ['in', 'out'], [False, True]):
test_sample_neighbors_etype_homogeneous(*args)
test_random_walk() test_random_walk()
test_pack_traces() test_pack_traces()
test_pinsage_sampling() test_pinsage_sampling()
......
...@@ -264,11 +264,11 @@ void _TestCSRPerEtypeSampling(bool has_data) { ...@@ -264,11 +264,11 @@ void _TestCSRPerEtypeSampling(bool has_data) {
NDArray::FromVector(std::vector<int32_t>({3, 1, 3, 3, 2, 3, 0})) : NDArray::FromVector(std::vector<int32_t>({3, 1, 3, 3, 2, 3, 0})) :
NDArray::FromVector(std::vector<int32_t>({3, 3, 3, 0, 3, 1, 2})); NDArray::FromVector(std::vector<int32_t>({3, 3, 3, 0, 3, 1, 2}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true); auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
} }
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false); auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
if (has_data) { if (has_data) {
...@@ -316,7 +316,7 @@ void _TestCSRPerEtypeSampling(bool has_data) { ...@@ -316,7 +316,7 @@ void _TestCSRPerEtypeSampling(bool has_data) {
NDArray::FromVector( NDArray::FromVector(
std::vector<FloatType>({.0, .5, .0, .5, .0, .5, .5})); std::vector<FloatType>({.0, .5, .0, .5, .0, .5, .5}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true); auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
if (has_data) { if (has_data) {
...@@ -339,11 +339,11 @@ void _TestCSRPerEtypeSamplingSorted(bool has_data, bool etype_sorted) { ...@@ -339,11 +339,11 @@ void _TestCSRPerEtypeSamplingSorted(bool has_data, bool etype_sorted) {
NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) : NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) :
NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2})); NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true, etype_sorted); auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true, etype_sorted);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
} }
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false, etype_sorted); auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false, etype_sorted);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
if (has_data) { if (has_data) {
...@@ -391,7 +391,7 @@ void _TestCSRPerEtypeSamplingSorted(bool has_data, bool etype_sorted) { ...@@ -391,7 +391,7 @@ void _TestCSRPerEtypeSamplingSorted(bool has_data, bool etype_sorted) {
NDArray::FromVector( NDArray::FromVector(
std::vector<FloatType>({.0, .5, .0, .5, .0, .5, .5})); std::vector<FloatType>({.0, .5, .0, .5, .0, .5, .5}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true, etype_sorted); auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true, etype_sorted);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
if (has_data) { if (has_data) {
...@@ -440,12 +440,12 @@ void _TestCSRPerEtypeSamplingUniform(bool has_data) { ...@@ -440,12 +440,12 @@ void _TestCSRPerEtypeSamplingUniform(bool has_data) {
NDArray::FromVector(std::vector<int32_t>({3, 1, 3, 3, 2, 3, 0})) : NDArray::FromVector(std::vector<int32_t>({3, 1, 3, 3, 2, 3, 0})) :
NDArray::FromVector(std::vector<int32_t>({3, 3, 3, 0, 3, 1, 2})); NDArray::FromVector(std::vector<int32_t>({3, 3, 3, 0, 3, 1, 2}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true); auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
} }
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false); auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
if (has_data) { if (has_data) {
...@@ -497,12 +497,12 @@ void _TestCSRPerEtypeSamplingUniformSorted(bool has_data, bool etype_sorted) { ...@@ -497,12 +497,12 @@ void _TestCSRPerEtypeSamplingUniformSorted(bool has_data, bool etype_sorted) {
NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) : NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) :
NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2})); NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true, etype_sorted); auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true, etype_sorted);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
} }
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false, etype_sorted); auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false, etype_sorted);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
if (has_data) { if (has_data) {
...@@ -674,11 +674,11 @@ void _TestCOOerEtypeSampling(bool has_data) { ...@@ -674,11 +674,11 @@ void _TestCOOerEtypeSampling(bool has_data) {
NDArray::FromVector(std::vector<int32_t>({3, 1, 3, 3, 2, 3, 0})) : NDArray::FromVector(std::vector<int32_t>({3, 1, 3, 3, 2, 3, 0})) :
NDArray::FromVector(std::vector<int32_t>({3, 3, 3, 0, 3, 1, 2})); NDArray::FromVector(std::vector<int32_t>({3, 3, 3, 0, 3, 1, 2}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true); auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
} }
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false); auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
if (has_data) { if (has_data) {
...@@ -726,7 +726,7 @@ void _TestCOOerEtypeSampling(bool has_data) { ...@@ -726,7 +726,7 @@ void _TestCOOerEtypeSampling(bool has_data) {
NDArray::FromVector( NDArray::FromVector(
std::vector<FloatType>({.0, .5, .0, .5, .0, .5, .5})); std::vector<FloatType>({.0, .5, .0, .5, .0, .5, .5}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true); auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
if (has_data) { if (has_data) {
...@@ -749,11 +749,11 @@ void _TestCOOerEtypeSamplingSorted(bool has_data, bool etype_sorted) { ...@@ -749,11 +749,11 @@ void _TestCOOerEtypeSamplingSorted(bool has_data, bool etype_sorted) {
NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) : NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) :
NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2})); NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true, etype_sorted); auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true, etype_sorted);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
} }
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false, etype_sorted); auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false, etype_sorted);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
if (has_data) { if (has_data) {
...@@ -801,7 +801,7 @@ void _TestCOOerEtypeSamplingSorted(bool has_data, bool etype_sorted) { ...@@ -801,7 +801,7 @@ void _TestCOOerEtypeSamplingSorted(bool has_data, bool etype_sorted) {
NDArray::FromVector( NDArray::FromVector(
std::vector<FloatType>({.0, .5, .0, .5, .0, .5, .5})); std::vector<FloatType>({.0, .5, .0, .5, .0, .5, .5}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true, etype_sorted); auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true, etype_sorted);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
if (has_data) { if (has_data) {
...@@ -850,12 +850,12 @@ void _TestCOOPerEtypeSamplingUniform(bool has_data) { ...@@ -850,12 +850,12 @@ void _TestCOOPerEtypeSamplingUniform(bool has_data) {
NDArray::FromVector(std::vector<int32_t>({3, 1, 3, 3, 2, 3, 0})) : NDArray::FromVector(std::vector<int32_t>({3, 1, 3, 3, 2, 3, 0})) :
NDArray::FromVector(std::vector<int32_t>({3, 3, 3, 0, 3, 1, 2})); NDArray::FromVector(std::vector<int32_t>({3, 3, 3, 0, 3, 1, 2}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true); auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
} }
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false); auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
if (has_data) { if (has_data) {
...@@ -907,12 +907,12 @@ void _TestCOOPerEtypeSamplingUniformSorted(bool has_data, bool etype_sorted) { ...@@ -907,12 +907,12 @@ void _TestCOOPerEtypeSamplingUniformSorted(bool has_data, bool etype_sorted) {
NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) : NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) :
NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2})); NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true, etype_sorted); auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true, etype_sorted);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
} }
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false, etype_sorted); auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false, etype_sorted);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
if (has_data) { if (has_data) {
...@@ -1134,4 +1134,4 @@ TEST(RowwiseTest, TestCSRSamplingBiased) { ...@@ -1134,4 +1134,4 @@ TEST(RowwiseTest, TestCSRSamplingBiased) {
_TestCSRSamplingBiased<int32_t, double>(false); _TestCSRSamplingBiased<int32_t, double>(false);
_TestCSRSamplingBiased<int64_t, double>(true); _TestCSRSamplingBiased<int64_t, double>(true);
_TestCSRSamplingBiased<int64_t, double>(false); _TestCSRSamplingBiased<int64_t, double>(false);
} }
\ No newline at end of file
...@@ -278,7 +278,10 @@ def start_node_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_ ...@@ -278,7 +278,10 @@ def start_node_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i) part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
# Create sampler # Create sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10]) sampler = dgl.dataloading.MultiLayerNeighborSampler([
# test dict for hetero
{etype: 5 for etype in dist_graph.etypes} if len(dist_graph.etypes) > 1 else 5,
10]) # test int for hetero
# We need to test creating DistDataLoader multiple times. # We need to test creating DistDataLoader multiple times.
for i in range(2): for i in range(2):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment