Unverified Commit eb08ef38 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Distributed] Edge-type-specific fanouts for heterogeneous graphs (#3558)

* first commit

* second commit

* spaghetti unit tests

* rewrite test
parent 156c17f3
......@@ -382,7 +382,9 @@ COOMatrix COORowWiseSampling(
* // etype = [0, 0, 0, 2, 1]
* COOMatrix coo = ...;
* IdArray rows = ... ; // [0, 3]
* COOMatrix sampled = COORowWisePerEtypeSampling(coo, rows, etype, 2, FloatArray(), false);
* std::vector<int64_t> num_samples = {2, 2, 2};
* COOMatrix sampled = COORowWisePerEtypeSampling(coo, rows, etype, num_samples,
* FloatArray(), false);
* // possible sampled coo matrix:
* // sampled.num_rows = 4
* // sampled.num_cols = 4
......@@ -405,7 +407,7 @@ COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat,
IdArray rows,
IdArray etypes,
int64_t num_samples,
const std::vector<int64_t>& num_samples,
FloatArray prob = FloatArray(),
bool replace = true,
bool etype_sorted = false);
......
......@@ -409,7 +409,9 @@ COOMatrix CSRRowWiseSampling(
* // etype = [0, 0, 0, 2, 1]
* CSRMatrix csr = ...;
* IdArray rows = ... ; // [0, 3]
* COOMatrix sampled = CSRRowWisePerEtypeSampling(csr, rows, etype, 2, FloatArray(), false);
* std::vector<int64_t> num_samples = {2, 2, 2};
* COOMatrix sampled = CSRRowWisePerEtypeSampling(csr, rows, etype, num_samples,
* FloatArray(), false);
* // possible sampled coo matrix:
* // sampled.num_rows = 4
* // sampled.num_cols = 4
......@@ -420,7 +422,7 @@ COOMatrix CSRRowWiseSampling(
* \param mat Input CSR matrix.
* \param rows Rows to sample from.
* \param etypes Edge types of each edge.
* \param num_samples Number of samples
* \param num_samples Number of samples to choose per edge type.
* \param prob Unnormalized probability array. Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* \param replace True if sample with replacement
......@@ -431,7 +433,7 @@ COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat,
IdArray rows,
IdArray etypes,
int64_t num_samples,
const std::vector<int64_t>& num_samples,
FloatArray prob = FloatArray(),
bool replace = true,
bool etype_sorted = false);
......
......@@ -132,9 +132,6 @@ class MultiLayerNeighborSampler(NeighborSamplingMixin, BlockSampler):
fanout = self.fanouts[block_id]
if isinstance(g, distributed.DistGraph):
if len(g.etypes) > 1: # heterogeneous distributed graph
# The edge type is stored in g.edata[dgl.ETYPE]
assert isinstance(fanout, int), "For distributed training, " \
"we can only sample same number of neighbors for each edge type"
frontier = distributed.sample_etype_neighbors(
g, seed_nodes, ETYPE, fanout, replace=self.replace)
else:
......
......@@ -450,8 +450,9 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
one key-value pair to make this API consistent with dgl.sampling.sample_neighbors.
etype_field : string
The field in g.edata storing the edge type.
fanout : int
The number of edges to be sampled for each node per edge type.
fanout : int or dict[etype, int]
The number of edges to be sampled for each node per edge type. If an integer
is given, DGL assumes that the same fanout is applied to every edge type.
If -1 is given, all of the neighbors will be selected.
edge_dir : str, optional
......@@ -479,6 +480,11 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
DGLGraph
A sampled subgraph containing only the sampled neighboring edges. It is on CPU.
"""
if isinstance(fanout, int):
fanout = F.full_1d(len(g.etypes), fanout, F.int64, F.cpu())
else:
fanout = F.tensor([fanout[etype] for etype in g.etypes], dtype=F.int64)
gpb = g.get_partition_book()
if isinstance(nodes, dict):
homo_nids = []
......
......@@ -36,14 +36,11 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
If a single tensor is given, the graph must only have one type of nodes.
etype_field : string
The field in g.edata storing the edge type.
fanout : int
The number of edges to be sampled for each node on each edge type.
This argument can only take a single int. DGL will sample this number of edges for
each node for every edge type.
fanout : Tensor
The number of edges to be sampled for each node per edge type. Must be a
1D tensor with the number of elements same as the number of edge types.
If -1 is given for a single edge type, all the neighboring edges with that edge
type will be selected.
If -1 is given, all of the neighbors will be selected.
edge_dir : str, optional
Determines whether to sample inbound or outbound edges.
......@@ -99,15 +96,19 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
if etype_field not in g.edata:
raise DGLError("The graph should have {} in the edge data" \
"representing the edge type.".format(etype_field))
if isinstance(fanout, int) is False:
raise DGLError("The fanout should be an integer")
if isinstance(nodes, dict) is True:
# (BarclayII) because the homogenized graph no longer contains the *name* of edge
# types, the fanout argument can no longer be a dict of etypes and ints, as opposed
# to sample_neighbors.
if not F.is_tensor(fanout):
raise DGLError("The fanout should be a tensor")
if isinstance(nodes, dict):
assert len(nodes) == 1, "The input graph should not have node types"
nodes = list(nodes.values())[0]
nodes = F.to_dgl_nd(utils.prepare_tensor(g, nodes, 'nodes'))
# treat etypes as int32, it is much cheaper than int64
# TODO(xiangsx): int8 can be a better choice.
etypes = F.to_dgl_nd(F.astype(g.edata[etype_field], ty=F.int32))
fanout = F.to_dgl_nd(fanout)
if prob is None:
prob_array = nd.array([], ctx=nd.cpu())
......
......@@ -570,7 +570,8 @@ COOMatrix CSRRowWiseSampling(
COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, FloatArray prob, bool replace, bool etype_sorted) {
const std::vector<int64_t>& num_samples, FloatArray prob, bool replace,
bool etype_sorted) {
COOMatrix ret;
ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWisePerEtypeSampling", {
if (IsNullArray(prob)) {
......@@ -807,7 +808,8 @@ COOMatrix COORowWiseSampling(
COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, FloatArray prob, bool replace, bool etype_sorted) {
const std::vector<int64_t>& num_samples, FloatArray prob, bool replace,
bool etype_sorted) {
COOMatrix ret;
ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWisePerEtypeSampling", {
if (IsNullArray(prob)) {
......
......@@ -168,7 +168,8 @@ COOMatrix CSRRowWiseSampling(
template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, FloatArray prob, bool replace, bool etype_sorted);
const std::vector<int64_t>& num_samples, FloatArray prob, bool replace,
bool etype_sorted);
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform(
......@@ -176,7 +177,7 @@ COOMatrix CSRRowWiseSamplingUniform(
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRRowWisePerEtypeSamplingUniform(
CSRMatrix mat, IdArray rows, IdArray etypes, int64_t num_samples,
CSRMatrix mat, IdArray rows, IdArray etypes, const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted);
// FloatType is the type of weight data.
......@@ -265,7 +266,7 @@ COOMatrix COORowWiseSampling(
template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, FloatArray prob, bool replace, bool etype_sorted);
const std::vector<int64_t>& num_samples, FloatArray prob, bool replace, bool etype_sorted);
template <DLDeviceType XPU, typename IdType>
COOMatrix COORowWiseSamplingUniform(
......@@ -273,7 +274,7 @@ COOMatrix COORowWiseSamplingUniform(
template <DLDeviceType XPU, typename IdType>
COOMatrix COORowWisePerEtypeSamplingUniform(
COOMatrix mat, IdArray rows, IdArray etypes, int64_t num_samples,
COOMatrix mat, IdArray rows, IdArray etypes, const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted);
// FloatType is the type of weight data.
......
......@@ -8,6 +8,7 @@
#include <dgl/array.h>
#include <dmlc/omp.h>
#include <dgl/runtime/parallel_for.h>
#include <functional>
#include <algorithm>
#include <string>
......@@ -56,13 +57,14 @@ using PickFn = std::function<void(
//
// \param off Starting offset of this row.
// \param et_offset Starting offset of this range.
// \param cur_et The edge type.
// \param et_len Length of the range.
// \param et_idx A map from local idx to column id.
// \param data Pointer of the data indices.
// \param out_idx Picked indices in [et_offset, et_offset + et_len).
template <typename IdxType>
using RangePickFn = std::function<void(
IdxType off, IdxType et_offset, IdxType et_len,
IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
const std::vector<IdxType> &et_idx, const IdxType* data,
IdxType* out_idx)>;
......@@ -196,123 +198,138 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
// OpenMP parallelization on rows because each row performs computation independently.
template <typename IdxType>
COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes,
int64_t num_picks, bool replace, bool etype_sorted,
RangePickFn<IdxType> pick_fn) {
const std::vector<int64_t>& num_picks, bool replace,
bool etype_sorted, RangePickFn<IdxType> pick_fn) {
using namespace aten;
const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data);
const IdxType* indices = static_cast<IdxType*>(mat.indices->data);
const IdxType* data = CSRHasData(mat)? static_cast<IdxType*>(mat.data->data) : nullptr;
const IdxType* rows_data = static_cast<IdxType*>(rows->data);
const int32_t* etype_data = static_cast<int32_t*>(etypes->data);
const IdxType* indptr = mat.indptr.Ptr<IdxType>();
const IdxType* indices = mat.indices.Ptr<IdxType>();
const IdxType* data = CSRHasData(mat)? mat.data.Ptr<IdxType>() : nullptr;
const IdxType* rows_data = rows.Ptr<IdxType>();
const int32_t* etype_data = etypes.Ptr<int32_t>();
const int64_t num_rows = rows->shape[0];
const auto& ctx = mat.indptr->ctx;
CHECK_EQ(etypes->dtype.bits / 8, sizeof(int32_t));
const int64_t num_etypes = num_picks.size();
CHECK_EQ(etypes->dtype.bits / 8, sizeof(int32_t)) << "etypes must be int32";
std::vector<IdArray> picked_rows(rows->shape[0]);
std::vector<IdArray> picked_cols(rows->shape[0]);
std::vector<IdArray> picked_idxs(rows->shape[0]);
#pragma omp parallel for
for (int64_t i = 0; i < num_rows; ++i) {
const IdxType rid = rows_data[i];
CHECK_LT(rid, mat.num_rows);
const IdxType off = indptr[rid];
const IdxType len = indptr[rid + 1] - off;
// do something here
if (len == 0) {
picked_rows[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
picked_cols[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
picked_idxs[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
continue;
// Check if the number of picks have the same value.
// If so, we can potentially speed up if we have a node with total number of neighbors
// less than the given number of picks with replace=False.
bool same_num_pick = true;
int64_t num_pick_value = num_picks[0];
for (int64_t num_pick : num_picks) {
if (num_pick_value != num_pick) {
same_num_pick = false;
break;
}
}
// fast path
if (len <= num_picks && !replace) {
IdArray rows = Full(rid, len, sizeof(IdxType) * 8, ctx);
IdArray cols = Full(-1, len, sizeof(IdxType) * 8, ctx);
IdArray idx = Full(-1, len, sizeof(IdxType) * 8, ctx);
IdxType* cdata = static_cast<IdxType*>(cols->data);
IdxType* idata = static_cast<IdxType*>(idx->data);
for (int64_t j = 0; j < len; ++j) {
cdata[j] = indices[off + j];
idata[j] = data ? data[off + j] : off + j;
}
picked_rows[i] = rows;
picked_cols[i] = cols;
picked_idxs[i] = idx;
} else {
// need to do per edge type sample
std::vector<IdxType> rows;
std::vector<IdxType> cols;
std::vector<IdxType> idx;
std::vector<IdxType> et(len);
std::vector<IdxType> et_idx(len);
std::iota(et_idx.begin(), et_idx.end(), 0);
for (int64_t j = 0; j < len; ++j) {
et[j] = data ? etype_data[data[off+j]] : etype_data[off+j];
runtime::parallel_for(0, num_rows, [&](size_t b, size_t e) {
for (int64_t i = b; i < e; ++i) {
const IdxType rid = rows_data[i];
CHECK_LT(rid, mat.num_rows);
const IdxType off = indptr[rid];
const IdxType len = indptr[rid + 1] - off;
// do something here
if (len == 0) {
picked_rows[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
picked_cols[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
picked_idxs[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
continue;
}
if (!etype_sorted) // the edge type is sorted, not need to sort it
std::sort(et_idx.begin(), et_idx.end(),
[&et](IdxType i1, IdxType i2) {return et[i1] < et[i2];});
IdxType cur_et = et[et_idx[0]];
int64_t et_offset = 0;
int64_t et_len = 1;
for (int64_t j = 0; j < len; ++j) {
if ((j+1 == len) || cur_et != et[et_idx[j+1]]) {
// 1 end of the current etype
// 2 end of the row
// random pick for current etype
if (et_len <= num_picks && !replace) {
// fast path, select all
for (int64_t k = 0; k < et_len; ++k) {
rows.push_back(rid);
cols.push_back(indices[off+et_idx[et_offset+k]]);
if (data)
idx.push_back(data[off+et_idx[et_offset+k]]);
else
idx.push_back(off+et_idx[et_offset+k]);
// fast path
if (same_num_pick && len <= num_pick_value && !replace) {
IdArray rows = Full(rid, len, sizeof(IdxType) * 8, ctx);
IdArray cols = Full(-1, len, sizeof(IdxType) * 8, ctx);
IdArray idx = Full(-1, len, sizeof(IdxType) * 8, ctx);
IdxType* cdata = cols.Ptr<IdxType>();
IdxType* idata = idx.Ptr<IdxType>();
for (int64_t j = 0; j < len; ++j) {
cdata[j] = indices[off + j];
idata[j] = data ? data[off + j] : off + j;
}
picked_rows[i] = rows;
picked_cols[i] = cols;
picked_idxs[i] = idx;
} else {
// need to do per edge type sample
std::vector<IdxType> rows;
std::vector<IdxType> cols;
std::vector<IdxType> idx;
std::vector<IdxType> et(len);
std::vector<IdxType> et_idx(len);
std::iota(et_idx.begin(), et_idx.end(), 0);
for (int64_t j = 0; j < len; ++j) {
et[j] = data ? etype_data[data[off+j]] : etype_data[off+j];
}
if (!etype_sorted) // the edge type is sorted, not need to sort it
std::sort(et_idx.begin(), et_idx.end(),
[&et](IdxType i1, IdxType i2) {return et[i1] < et[i2];});
CHECK(et[et_idx[len - 1]] < num_etypes) <<
"etype values exceed the number of fanouts";
IdxType cur_et = et[et_idx[0]];
int64_t et_offset = 0;
int64_t et_len = 1;
for (int64_t j = 0; j < len; ++j) {
if ((j+1 == len) || cur_et != et[et_idx[j+1]]) {
// 1 end of the current etype
// 2 end of the row
// random pick for current etype
if (et_len <= num_picks[cur_et] && !replace) {
// fast path, select all
for (int64_t k = 0; k < et_len; ++k) {
rows.push_back(rid);
cols.push_back(indices[off+et_idx[et_offset+k]]);
if (data)
idx.push_back(data[off+et_idx[et_offset+k]]);
else
idx.push_back(off+et_idx[et_offset+k]);
}
} else {
IdArray picked_idx = Full(-1, num_picks[cur_et], sizeof(IdxType) * 8, ctx);
IdxType* picked_idata = static_cast<IdxType*>(picked_idx->data);
// need call random pick
pick_fn(off, et_offset, cur_et,
et_len, et_idx,
data, picked_idata);
for (int64_t k = 0; k < num_picks[cur_et]; ++k) {
const IdxType picked = picked_idata[k];
rows.push_back(rid);
cols.push_back(indices[off+et_idx[et_offset+picked]]);
if (data)
idx.push_back(data[off+et_idx[et_offset+picked]]);
else
idx.push_back(off+et_idx[et_offset+picked]);
}
}
if (j+1 == len)
break;
// next etype
cur_et = et[et_idx[j+1]];
et_offset = j+1;
et_len = 1;
} else {
IdArray picked_idx = Full(-1, num_picks, sizeof(IdxType) * 8, ctx);
IdxType* picked_idata = static_cast<IdxType*>(picked_idx->data);
// need call random pick
pick_fn(off, et_offset,
et_len, et_idx,
data, picked_idata);
for (int64_t k = 0; k < num_picks; ++k) {
const IdxType picked = picked_idata[k];
rows.push_back(rid);
cols.push_back(indices[off+et_idx[et_offset+picked]]);
if (data)
idx.push_back(data[off+et_idx[et_offset+picked]]);
else
idx.push_back(off+et_idx[et_offset+picked]);
}
et_len++;
}
if (j+1 == len)
break;
// next etype
cur_et = et[et_idx[j+1]];
et_offset = j+1;
et_len = 1;
} else {
et_len++;
}
}
picked_rows[i] = VecToIdArray(rows, sizeof(IdxType) * 8, ctx);
picked_cols[i] = VecToIdArray(cols, sizeof(IdxType) * 8, ctx);
picked_idxs[i] = VecToIdArray(idx, sizeof(IdxType) * 8, ctx);
} // end processing one row
CHECK_EQ(picked_rows[i]->shape[0], picked_cols[i]->shape[0]);
CHECK_EQ(picked_rows[i]->shape[0], picked_idxs[i]->shape[0]);
} // end processing all rows
picked_rows[i] = VecToIdArray(rows, sizeof(IdxType) * 8, ctx);
picked_cols[i] = VecToIdArray(cols, sizeof(IdxType) * 8, ctx);
picked_idxs[i] = VecToIdArray(idx, sizeof(IdxType) * 8, ctx);
} // end processing one row
CHECK_EQ(picked_rows[i]->shape[0], picked_cols[i]->shape[0]);
CHECK_EQ(picked_rows[i]->shape[0], picked_idxs[i]->shape[0]);
} // end processing all rows
});
IdArray picked_row = Concat(picked_rows);
IdArray picked_col = Concat(picked_cols);
......@@ -342,8 +359,8 @@ COOMatrix COORowWisePick(COOMatrix mat, IdArray rows,
// row-wise pick on the CSR matrix and rectifies the returned results.
template <typename IdxType>
COOMatrix COORowWisePerEtypePick(COOMatrix mat, IdArray rows, IdArray etypes,
int64_t num_picks, bool replace, bool etype_sorted,
RangePickFn<IdxType> pick_fn) {
const std::vector<int64_t>& num_picks, bool replace,
bool etype_sorted, RangePickFn<IdxType> pick_fn) {
using namespace aten;
const auto& csr = COOToCSR(COOSliceRows(mat, rows));
const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
......
......@@ -46,9 +46,9 @@ inline PickFn<IdxType> GetSamplingPickFn(
template <typename IdxType, typename FloatType>
inline RangePickFn<IdxType> GetSamplingRangePickFn(
int64_t num_samples, FloatArray prob, bool replace) {
const std::vector<int64_t>& num_samples, FloatArray prob, bool replace) {
RangePickFn<IdxType> pick_fn = [prob, num_samples, replace]
(IdxType off, IdxType et_offset, IdxType et_len,
(IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
const std::vector<IdxType> &et_idx,
const IdxType* data, IdxType* out_idx) {
const FloatType* p_data = static_cast<FloatType*>(prob->data);
......@@ -62,7 +62,7 @@ inline RangePickFn<IdxType> GetSamplingRangePickFn(
}
RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>(
num_samples, probs, out_idx, replace);
num_samples[cur_et], probs, out_idx, replace);
};
return pick_fn;
}
......@@ -85,13 +85,13 @@ inline PickFn<IdxType> GetSamplingUniformPickFn(
template <typename IdxType>
inline RangePickFn<IdxType> GetSamplingUniformRangePickFn(
int64_t num_samples, bool replace) {
const std::vector<int64_t>& num_samples, bool replace) {
RangePickFn<IdxType> pick_fn = [num_samples, replace]
(IdxType off, IdxType et_offset, IdxType et_len,
(IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
const std::vector<IdxType> &et_idx,
const IdxType* data, IdxType* out_idx) {
RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
num_samples, et_len, out_idx, replace);
num_samples[cur_et], et_len, out_idx, replace);
};
return pick_fn;
}
......@@ -136,21 +136,21 @@ template COOMatrix CSRRowWiseSampling<kDLCPU, int64_t, double>(
template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, FloatArray prob,
bool replace, bool etype_sorted) {
const std::vector<int64_t>& num_samples,
FloatArray prob, bool replace, bool etype_sorted) {
CHECK(prob.defined());
auto pick_fn = GetSamplingRangePickFn<IdxType, FloatType>(num_samples, prob, replace);
return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
}
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int32_t, float>(
CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int64_t, float>(
CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int32_t, double>(
CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int64_t, double>(
CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template <DLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows,
......@@ -166,16 +166,16 @@ template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int64_t>(
template <DLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples,
const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted) {
auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
}
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDLCPU, int32_t>(
CSRMatrix, IdArray, IdArray, int64_t, bool, bool);
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDLCPU, int64_t>(
CSRMatrix, IdArray, IdArray, int64_t, bool, bool);
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased(
......@@ -225,21 +225,21 @@ template COOMatrix COORowWiseSampling<kDLCPU, int64_t, double>(
template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, FloatArray prob,
bool replace, bool etype_sorted) {
const std::vector<int64_t>& num_samples,
FloatArray prob, bool replace, bool etype_sorted) {
CHECK(prob.defined());
auto pick_fn = GetSamplingRangePickFn<IdxType, FloatType>(num_samples, prob, replace);
return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
}
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int32_t, float>(
COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int64_t, float>(
COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int32_t, double>(
COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int64_t, double>(
COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template <DLDeviceType XPU, typename IdxType>
COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows,
......@@ -255,15 +255,16 @@ template COOMatrix COORowWiseSamplingUniform<kDLCPU, int64_t>(
template <DLDeviceType XPU, typename IdxType>
COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, bool replace, bool etype_sorted) {
const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted) {
auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
}
template COOMatrix COORowWisePerEtypeSamplingUniform<kDLCPU, int32_t>(
COOMatrix, IdArray, IdArray, int64_t, bool, bool);
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
template COOMatrix COORowWisePerEtypeSamplingUniform<kDLCPU, int64_t>(
COOMatrix, IdArray, IdArray, int64_t, bool, bool);
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
} // namespace impl
} // namespace aten
......
......@@ -7,6 +7,7 @@
#include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h>
#include <dgl/array.h>
#include <dgl/aten/macro.h>
#include <dgl/sampling/neighbor.h>
#include "../../../c_api_common.h"
#include "../../unit_graph.h"
......@@ -156,7 +157,7 @@ HeteroSubgraph SampleNeighborsEType(
const HeteroGraphPtr hg,
const IdArray nodes,
const IdArray etypes,
const int64_t fanout,
const std::vector<int64_t>& fanouts,
EdgeDir dir,
const IdArray prob,
bool replace,
......@@ -173,13 +174,23 @@ HeteroSubgraph SampleNeighborsEType(
dgl_type_t etype = 0;
const dgl_type_t src_vtype = 0;
const dgl_type_t dst_vtype = 0;
if (num_nodes == 0 || fanout == 0) {
bool same_fanout = true;
int64_t fanout_value = fanouts[0];
for (auto fanout : fanouts) {
if (fanout != fanout_value) {
same_fanout = false;
break;
}
}
if (num_nodes == 0 || (same_fanout && fanout_value == 0)) {
subrels[etype] = UnitGraph::Empty(1,
hg->NumVertices(src_vtype),
hg->NumVertices(dst_vtype),
hg->DataType(), hg->Context());
induced_edges[etype] = aten::NullArray();
} else if (fanout == -1) {
} else if (same_fanout && fanout_value == -1) {
const auto &earr = (dir == EdgeDir::kOut) ?
hg->OutEdges(etype, nodes) :
hg->InEdges(etype, nodes);
......@@ -201,21 +212,21 @@ HeteroSubgraph SampleNeighborsEType(
if (dir == EdgeDir::kIn) {
sampled_coo = aten::COOTranspose(aten::COORowWisePerEtypeSampling(
aten::COOTranspose(hg->GetCOOMatrix(etype)),
nodes, etypes, fanout, prob, replace));
nodes, etypes, fanouts, prob, replace));
} else {
sampled_coo = aten::COORowWisePerEtypeSampling(
hg->GetCOOMatrix(etype), nodes, etypes, fanout, prob, replace, etype_sorted);
hg->GetCOOMatrix(etype), nodes, etypes, fanouts, prob, replace, etype_sorted);
}
break;
case SparseFormat::kCSR:
CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix.";
sampled_coo = aten::CSRRowWisePerEtypeSampling(
hg->GetCSRMatrix(etype), nodes, etypes, fanout, prob, replace, etype_sorted);
hg->GetCSRMatrix(etype), nodes, etypes, fanouts, prob, replace, etype_sorted);
break;
case SparseFormat::kCSC:
CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
sampled_coo = aten::CSRRowWisePerEtypeSampling(
hg->GetCSCMatrix(etype), nodes, etypes, fanout, prob, replace, etype_sorted);
hg->GetCSCMatrix(etype), nodes, etypes, fanouts, prob, replace, etype_sorted);
sampled_coo = aten::COOTranspose(sampled_coo);
break;
default:
......@@ -405,7 +416,7 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType")
HeteroGraphRef hg = args[0];
IdArray nodes = args[1];
IdArray etypes = args[2];
const int64_t fanout = args[3];
IdArray fanout = args[3];
const std::string dir_str = args[4];
IdArray prob = args[5];
const bool replace = args[6];
......@@ -414,10 +425,12 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType")
CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\".";
EdgeDir dir = (dir_str == "in")? EdgeDir::kIn : EdgeDir::kOut;
CHECK_INT64(fanout, "fanout");
std::vector<int64_t> fanout_vec = fanout.ToVector<int64_t>();
std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
*subg = sampling::SampleNeighborsEType(
hg.sptr(), nodes, etypes, fanout, dir, prob, replace, etype_sorted);
hg.sptr(), nodes, etypes, fanout_vec, dir, prob, replace, etype_sorted);
*rv = HeteroSubgraphRef(subg);
});
......
......@@ -725,106 +725,75 @@ def test_sample_neighbors_biased_bipartite():
check_num(subg.edges()[1], tag)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors_etype_homogeneous():
@pytest.mark.parametrize('format_', ['coo', 'csr', 'csc'])
@pytest.mark.parametrize('direction', ['in', 'out'])
@pytest.mark.parametrize('replace', [False, True])
def test_sample_neighbors_etype_homogeneous(format_, direction, replace):
num_nodes = 100
rare_cnt = 4
g = create_etype_test_graph(100, 30, rare_cnt)
h_g = dgl.to_homogeneous(g)
seed_ntype = g.get_ntype_id("u")
seeds = F.nonzero_1d(h_g.ndata[dgl.NTYPE] == seed_ntype)
def check_num(nodes, replace):
nodes = F.asnumpy(nodes)
cnt = [sum(nodes == i) for i in range(num_nodes)]
for i in range(20):
if i < rare_cnt:
if replace is False:
assert cnt[i] == 22
else:
assert cnt[i] == 30
fanouts = F.tensor([6, 5, 4, 3, 2], dtype=F.int64)
def check_num(h_g, all_src, all_dst, subg, replace, fanouts, direction):
src, dst = subg.edges()
num_etypes = F.asnumpy(h_g.edata[dgl.ETYPE]).max()
etype_array = F.asnumpy(subg.edata[dgl.ETYPE])
src = F.asnumpy(src)
dst = F.asnumpy(dst)
fanouts = F.asnumpy(fanouts)
all_etype_array = F.asnumpy(h_g.edata[dgl.ETYPE])
all_src = F.asnumpy(all_src)
all_dst = F.asnumpy(all_dst)
src_per_etype = []
dst_per_etype = []
for etype in range(num_etypes):
src_per_etype.append(src[etype_array == etype])
dst_per_etype.append(dst[etype_array == etype])
if replace:
if direction == 'in':
in_degree_per_etype = [np.bincount(d) for d in dst_per_etype]
for in_degree, fanout in zip(in_degree_per_etype, fanouts):
assert np.all(in_degree == fanout)
else:
if replace is False:
assert cnt[i] == 12
else:
assert cnt[i] == 20
# graph with coo format
coo_g = h_g.formats('coo')
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(coo_g, seeds, dgl.ETYPE, 10, replace=False)
check_num(subg.edges()[1], False)
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(coo_g, seeds, dgl.ETYPE, 10, replace=True)
check_num(subg.edges()[1], True)
# graph with csr format
csr_g = h_g.formats('csr')
csr_g = csr_g.formats(['csr','csc','coo'])
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(csr_g, seeds, dgl.ETYPE, 10, replace=False)
check_num(subg.edges()[1], False)
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(csr_g, seeds, dgl.ETYPE, 10, replace=True)
check_num(subg.edges()[1], True)
# graph with csc format
csc_g = h_g.formats('csc')
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(csc_g, seeds, dgl.ETYPE, 10, replace=False)
check_num(subg.edges()[1], False)
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(csc_g, seeds, dgl.ETYPE, 10, replace=True)
check_num(subg.edges()[1], True)
def check_num2(nodes, replace):
nodes = F.asnumpy(nodes)
cnt = [sum(nodes == i) for i in range(num_nodes)]
for i in range(20):
if replace is False:
assert cnt[i] == 7
out_degree_per_etype = [np.bincount(s) for s in src_per_etype]
for out_degree, fanout in zip(out_degree_per_etype, fanouts):
assert np.all(out_degree == fanout)
else:
if direction == 'in':
for v in set(dst):
u = src[dst == v]
et = etype_array[dst == v]
all_u = all_src[all_dst == v]
all_et = all_etype_array[all_dst == v]
for etype in set(et):
u_etype = set(u[et == etype])
all_u_etype = set(all_u[all_et == etype])
assert (len(u_etype) == fanouts[etype]) or (u_etype == all_u_etype)
else:
assert cnt[i] == 10
# edge dir out
# graph with coo format
coo_g = h_g.formats('coo')
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(
coo_g, seeds, dgl.ETYPE, 5, edge_dir='out', replace=False)
check_num2(subg.edges()[0], False)
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(
coo_g, seeds, dgl.ETYPE, 5, edge_dir='out', replace=True)
check_num2(subg.edges()[0], True)
# graph with csr format
csr_g = h_g.formats('csr')
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(
csr_g, seeds, dgl.ETYPE, 5, edge_dir='out', replace=False)
check_num2(subg.edges()[0], False)
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(
csr_g, seeds, dgl.ETYPE, 5, edge_dir='out', replace=True)
check_num2(subg.edges()[0], True)
# graph with csc format
csc_g = h_g.formats('csc')
csc_g = csc_g.formats(['csc','csr','coo'])
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(
csc_g, seeds, dgl.ETYPE, 5, edge_dir='out', replace=False)
check_num2(subg.edges()[0], False)
for u in set(src):
v = dst[src == u]
et = etype_array[src == u]
all_v = all_dst[all_src == u]
all_et = all_etype_array[all_src == u]
for etype in set(et):
v_etype = set(v[et == etype])
all_v_etype = set(all_v[all_et == etype])
assert (len(v_etype) == fanouts[etype]) or (v_etype == all_v_etype)
all_src, all_dst = h_g.edges()
h_g = h_g.formats(format_)
if (direction, format_) in [('in', 'csr'), ('out', 'csc')]:
h_g = h_g.formats(['csc', 'csr', 'coo'])
for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors(
csc_g, seeds, dgl.ETYPE, 5, edge_dir='out', replace=True)
check_num2(subg.edges()[0], True)
h_g, seeds, dgl.ETYPE, fanouts, replace=replace, edge_dir=direction)
check_num(h_g, all_src, all_dst, subg, replace, fanouts, direction)
@pytest.mark.parametrize('dtype', ['int32', 'int64'])
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
......@@ -920,7 +889,9 @@ def test_sample_neighbors_exclude_edges_homoG(dtype):
if __name__ == '__main__':
test_sample_neighbors_etype_homogeneous()
from itertools import product
for args in product(['coo', 'csr', 'csc'], ['in', 'out'], [False, True]):
test_sample_neighbors_etype_homogeneous(*args)
test_random_walk()
test_pack_traces()
test_pinsage_sampling()
......
......@@ -264,11 +264,11 @@ void _TestCSRPerEtypeSampling(bool has_data) {
NDArray::FromVector(std::vector<int32_t>({3, 1, 3, 3, 2, 3, 0})) :
NDArray::FromVector(std::vector<int32_t>({3, 3, 3, 0, 3, 1, 2}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true);
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false);
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
......@@ -316,7 +316,7 @@ void _TestCSRPerEtypeSampling(bool has_data) {
NDArray::FromVector(
std::vector<FloatType>({.0, .5, .0, .5, .0, .5, .5}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true);
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
......@@ -339,11 +339,11 @@ void _TestCSRPerEtypeSamplingSorted(bool has_data, bool etype_sorted) {
NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) :
NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true, etype_sorted);
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true, etype_sorted);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false, etype_sorted);
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false, etype_sorted);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
......@@ -391,7 +391,7 @@ void _TestCSRPerEtypeSamplingSorted(bool has_data, bool etype_sorted) {
NDArray::FromVector(
std::vector<FloatType>({.0, .5, .0, .5, .0, .5, .5}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true, etype_sorted);
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true, etype_sorted);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
......@@ -440,12 +440,12 @@ void _TestCSRPerEtypeSamplingUniform(bool has_data) {
NDArray::FromVector(std::vector<int32_t>({3, 1, 3, 3, 2, 3, 0})) :
NDArray::FromVector(std::vector<int32_t>({3, 3, 3, 0, 3, 1, 2}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true);
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false);
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
......@@ -497,12 +497,12 @@ void _TestCSRPerEtypeSamplingUniformSorted(bool has_data, bool etype_sorted) {
NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) :
NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true, etype_sorted);
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true, etype_sorted);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false, etype_sorted);
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false, etype_sorted);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
......@@ -674,11 +674,11 @@ void _TestCOOerEtypeSampling(bool has_data) {
NDArray::FromVector(std::vector<int32_t>({3, 1, 3, 3, 2, 3, 0})) :
NDArray::FromVector(std::vector<int32_t>({3, 3, 3, 0, 3, 1, 2}));
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true);
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false);
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
......@@ -726,7 +726,7 @@ void _TestCOOerEtypeSampling(bool has_data) {
NDArray::FromVector(
std::vector<FloatType>({.0, .5, .0, .5, .0, .5, .5}));
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true);
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
......@@ -749,11 +749,11 @@ void _TestCOOerEtypeSamplingSorted(bool has_data, bool etype_sorted) {
NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) :
NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2}));
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true, etype_sorted);
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true, etype_sorted);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false, etype_sorted);
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false, etype_sorted);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
......@@ -801,7 +801,7 @@ void _TestCOOerEtypeSamplingSorted(bool has_data, bool etype_sorted) {
NDArray::FromVector(
std::vector<FloatType>({.0, .5, .0, .5, .0, .5, .5}));
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true, etype_sorted);
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true, etype_sorted);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
......@@ -850,12 +850,12 @@ void _TestCOOPerEtypeSamplingUniform(bool has_data) {
NDArray::FromVector(std::vector<int32_t>({3, 1, 3, 3, 2, 3, 0})) :
NDArray::FromVector(std::vector<int32_t>({3, 3, 3, 0, 3, 1, 2}));
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true);
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false);
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
......@@ -907,12 +907,12 @@ void _TestCOOPerEtypeSamplingUniformSorted(bool has_data, bool etype_sorted) {
NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) :
NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2}));
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true, etype_sorted);
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true, etype_sorted);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false, etype_sorted);
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false, etype_sorted);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
......@@ -1134,4 +1134,4 @@ TEST(RowwiseTest, TestCSRSamplingBiased) {
_TestCSRSamplingBiased<int32_t, double>(false);
_TestCSRSamplingBiased<int64_t, double>(true);
_TestCSRSamplingBiased<int64_t, double>(false);
}
\ No newline at end of file
}
......@@ -278,7 +278,10 @@ def start_node_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
# Create sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10])
sampler = dgl.dataloading.MultiLayerNeighborSampler([
# test dict for hetero
{etype: 5 for etype in dist_graph.etypes} if len(dist_graph.etypes) > 1 else 5,
10]) # test int for hetero
# We need to test creating DistDataLoader multiple times.
for i in range(2):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment