Unverified Commit f4fe518f authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Feature] Add a HINT for the per edge type sampler of heterogeneous DistGraph...


[Feature] Add a HINT for the per edge type sampler of heterogeneous DistGraph that highlighting the etypes are sorted already. (#3260)

* pass cpp test

* distgraph use sorted edge flag.

* lint

* triger

* update test
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-2-66.ec2.internal>
parent 8e525dad
......@@ -396,6 +396,7 @@ COOMatrix COORowWiseSampling(
* \param prob Unnormalized probability array. Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* \param replace True if sample with replacement
* \param etype_sorted True if the edge types are already sorted
* \return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array.
*/
......@@ -405,7 +406,8 @@ COOMatrix COORowWisePerEtypeSampling(
IdArray etypes,
int64_t num_samples,
FloatArray prob = FloatArray(),
bool replace = true);
bool replace = true,
bool etype_sorted = false);
/*!
* \brief Select K non-zero entries with the largest weights along each given row.
......
......@@ -423,6 +423,7 @@ COOMatrix CSRRowWiseSampling(
* \param prob Unnormalized probability array. Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* \param replace True if sample with replacement
* \param etype_sorted True if the edge types are already sorted
* \return A COOMatrix storing the picked row, col and data indices.
*/
COOMatrix CSRRowWisePerEtypeSampling(
......@@ -431,7 +432,8 @@ COOMatrix CSRRowWisePerEtypeSampling(
IdArray etypes,
int64_t num_samples,
FloatArray prob = FloatArray(),
bool replace = true);
bool replace = true,
bool etype_sorted = false);
/*!
* \brief Select K non-zero entries with the largest weights along each given row.
......
......@@ -81,8 +81,12 @@ def _sample_etype_neighbors(local_g, partition_book, seed_nodes, etype_field,
local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
local_ids = F.astype(local_ids, local_g.idtype)
# local_ids = self.seed_nodes
# DistGraph's edges are sorted by default according to
# graph partition mechanism.
sampled_graph = local_sample_etype_neighbors(
local_g, local_ids, etype_field, fan_out, edge_dir, prob, replace, _dist_training=True)
local_g, local_ids, etype_field, fan_out, edge_dir, prob, replace,
etype_sorted=True, _dist_training=True)
global_nid_mapping = local_g.ndata[NID]
src, dst = sampled_graph.edges()
global_src, global_dst = F.gather_row(global_nid_mapping, src), \
......
......@@ -14,7 +14,8 @@ __all__ = [
'select_topk']
def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=None,
replace=False, copy_ndata=True, copy_edata=True, _dist_training=False):
replace=False, copy_ndata=True, copy_edata=True, etype_sorted=False,
_dist_training=False):
"""Sample neighboring edges of the given nodes and return the induced subgraph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
......@@ -74,6 +75,10 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
_dist_training : bool, optional
Internal argument. Do not use.
(Default: False)
etype_sorted: bool, optional
A hint telling whether the etypes are already sorted.
(Default: False)
Returns
......@@ -115,7 +120,7 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
prob_array = F.to_dgl_nd(F.tensor(prob, dtype=F.float32))
subgidx = _CAPI_DGLSampleNeighborsEType(g._graph, nodes, etypes, fanout,
edge_dir, prob_array, replace)
edge_dir, prob_array, replace, etype_sorted)
induced_edges = subgidx.induced_edges
ret = DGLHeteroGraph(subgidx.graph, g.ntypes, g.etypes)
......
......@@ -570,16 +570,16 @@ COOMatrix CSRRowWiseSampling(
COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, FloatArray prob, bool replace) {
int64_t num_samples, FloatArray prob, bool replace, bool etype_sorted) {
COOMatrix ret;
ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWisePerEtypeSampling", {
if (IsNullArray(prob)) {
ret = impl::CSRRowWisePerEtypeSamplingUniform<XPU, IdType>(
mat, rows, etypes, num_samples, replace);
mat, rows, etypes, num_samples, replace, etype_sorted);
} else {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
ret = impl::CSRRowWisePerEtypeSampling<XPU, IdType, FloatType>(
mat, rows, etypes, num_samples, prob, replace);
mat, rows, etypes, num_samples, prob, replace, etype_sorted);
});
}
});
......@@ -807,16 +807,16 @@ COOMatrix COORowWiseSampling(
COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, FloatArray prob, bool replace) {
int64_t num_samples, FloatArray prob, bool replace, bool etype_sorted) {
COOMatrix ret;
ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWisePerEtypeSampling", {
if (IsNullArray(prob)) {
ret = impl::COORowWisePerEtypeSamplingUniform<XPU, IdType>(
mat, rows, etypes, num_samples, replace);
mat, rows, etypes, num_samples, replace, etype_sorted);
} else {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
ret = impl::COORowWisePerEtypeSampling<XPU, IdType, FloatType>(
mat, rows, etypes, num_samples, prob, replace);
mat, rows, etypes, num_samples, prob, replace, etype_sorted);
});
}
});
......
......@@ -168,7 +168,7 @@ COOMatrix CSRRowWiseSampling(
template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, FloatArray prob, bool replace);
int64_t num_samples, FloatArray prob, bool replace, bool etype_sorted);
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform(
......@@ -176,7 +176,8 @@ COOMatrix CSRRowWiseSamplingUniform(
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRRowWisePerEtypeSamplingUniform(
CSRMatrix mat, IdArray rows, IdArray etypes, int64_t num_samples, bool replace);
CSRMatrix mat, IdArray rows, IdArray etypes, int64_t num_samples,
bool replace, bool etype_sorted);
// FloatType is the type of weight data.
template <DLDeviceType XPU, typename IdType, typename DType>
......@@ -264,7 +265,7 @@ COOMatrix COORowWiseSampling(
template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, FloatArray prob, bool replace);
int64_t num_samples, FloatArray prob, bool replace, bool etype_sorted);
template <DLDeviceType XPU, typename IdType>
COOMatrix COORowWiseSamplingUniform(
......@@ -272,7 +273,8 @@ COOMatrix COORowWiseSamplingUniform(
template <DLDeviceType XPU, typename IdType>
COOMatrix COORowWisePerEtypeSamplingUniform(
COOMatrix mat, IdArray rows, IdArray etypes, int64_t num_samples, bool replace);
COOMatrix mat, IdArray rows, IdArray etypes, int64_t num_samples,
bool replace, bool etype_sorted);
// FloatType is the type of weight data.
template <DLDeviceType XPU, typename IdType, typename FloatType>
......
......@@ -195,7 +195,8 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
// OpenMP parallelization on rows because each row performs computation independently.
template <typename IdxType>
COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes,
int64_t num_picks, bool replace, RangePickFn<IdxType> pick_fn) {
int64_t num_picks, bool replace, bool etype_sorted,
RangePickFn<IdxType> pick_fn) {
using namespace aten;
const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data);
const IdxType* indices = static_cast<IdxType*>(mat.indices->data);
......@@ -250,6 +251,7 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes,
for (int64_t j = 0; j < len; ++j) {
et[j] = data ? etype_data[data[off+j]] : etype_data[off+j];
}
if (!etype_sorted) // the edge type is sorted, not need to sort it
std::sort(et_idx.begin(), et_idx.end(),
[&et](IdxType i1, IdxType i2) {return et[i1] < et[i2];});
......@@ -339,12 +341,13 @@ COOMatrix COORowWisePick(COOMatrix mat, IdArray rows,
// row-wise pick on the CSR matrix and rectifies the returned results.
template <typename IdxType>
COOMatrix COORowWisePerEtypePick(COOMatrix mat, IdArray rows, IdArray etypes,
int64_t num_picks, bool replace, RangePickFn<IdxType> pick_fn) {
int64_t num_picks, bool replace, bool etype_sorted,
RangePickFn<IdxType> pick_fn) {
using namespace aten;
const auto& csr = COOToCSR(COOSliceRows(mat, rows));
const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
const auto& picked = CSRRowWisePerEtypePick<IdxType>(
csr, new_rows, etypes, num_picks, replace, pick_fn);
csr, new_rows, etypes, num_picks, replace, etype_sorted, pick_fn);
return COOMatrix(mat.num_rows, mat.num_cols,
IndexSelect(rows, picked.row), // map the row index to the correct one
picked.col,
......
......@@ -136,20 +136,21 @@ template COOMatrix CSRRowWiseSampling<kDLCPU, int64_t, double>(
template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, FloatArray prob, bool replace) {
int64_t num_samples, FloatArray prob,
bool replace, bool etype_sorted) {
CHECK(prob.defined());
auto pick_fn = GetSamplingRangePickFn<IdxType, FloatType>(num_samples, prob, replace);
return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, pick_fn);
return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
}
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int32_t, float>(
CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool);
CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int64_t, float>(
CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool);
CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int32_t, double>(
CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool);
CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int64_t, double>(
CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool);
CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
template <DLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows,
......@@ -165,15 +166,16 @@ template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int64_t>(
template <DLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, bool replace) {
int64_t num_samples,
bool replace, bool etype_sorted) {
auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, pick_fn);
return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
}
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDLCPU, int32_t>(
CSRMatrix, IdArray, IdArray, int64_t, bool);
CSRMatrix, IdArray, IdArray, int64_t, bool, bool);
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDLCPU, int64_t>(
CSRMatrix, IdArray, IdArray, int64_t, bool);
CSRMatrix, IdArray, IdArray, int64_t, bool, bool);
template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased(
......@@ -223,20 +225,21 @@ template COOMatrix COORowWiseSampling<kDLCPU, int64_t, double>(
template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, FloatArray prob, bool replace) {
int64_t num_samples, FloatArray prob,
bool replace, bool etype_sorted) {
CHECK(prob.defined());
auto pick_fn = GetSamplingRangePickFn<IdxType, FloatType>(num_samples, prob, replace);
return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, pick_fn);
return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
}
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int32_t, float>(
COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool);
COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int64_t, float>(
COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool);
COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int32_t, double>(
COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool);
COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int64_t, double>(
COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool);
COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool);
template <DLDeviceType XPU, typename IdxType>
COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows,
......@@ -252,15 +255,15 @@ template COOMatrix COORowWiseSamplingUniform<kDLCPU, int64_t>(
template <DLDeviceType XPU, typename IdxType>
COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray etypes,
int64_t num_samples, bool replace) {
int64_t num_samples, bool replace, bool etype_sorted) {
auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, pick_fn);
return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
}
template COOMatrix COORowWisePerEtypeSamplingUniform<kDLCPU, int32_t>(
COOMatrix, IdArray, IdArray, int64_t, bool);
COOMatrix, IdArray, IdArray, int64_t, bool, bool);
template COOMatrix COORowWisePerEtypeSamplingUniform<kDLCPU, int64_t>(
COOMatrix, IdArray, IdArray, int64_t, bool);
COOMatrix, IdArray, IdArray, int64_t, bool, bool);
} // namespace impl
} // namespace aten
......
......@@ -111,7 +111,8 @@ HeteroSubgraph SampleNeighborsEType(
const int64_t fanout,
EdgeDir dir,
const IdArray prob,
bool replace) {
bool replace,
bool etype_sorted) {
CHECK_EQ(1, hg->NumVertexTypes())
<< "SampleNeighborsEType only work with homogeneous graph";
......@@ -155,18 +156,18 @@ HeteroSubgraph SampleNeighborsEType(
nodes, etypes, fanout, prob, replace));
} else {
sampled_coo = aten::COORowWisePerEtypeSampling(
hg->GetCOOMatrix(etype), nodes, etypes, fanout, prob, replace);
hg->GetCOOMatrix(etype), nodes, etypes, fanout, prob, replace, etype_sorted);
}
break;
case SparseFormat::kCSR:
CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix.";
sampled_coo = aten::CSRRowWisePerEtypeSampling(
hg->GetCSRMatrix(etype), nodes, etypes, fanout, prob, replace);
hg->GetCSRMatrix(etype), nodes, etypes, fanout, prob, replace, etype_sorted);
break;
case SparseFormat::kCSC:
CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
sampled_coo = aten::CSRRowWisePerEtypeSampling(
hg->GetCSCMatrix(etype), nodes, etypes, fanout, prob, replace);
hg->GetCSCMatrix(etype), nodes, etypes, fanout, prob, replace, etype_sorted);
sampled_coo = aten::COOTranspose(sampled_coo);
break;
default:
......@@ -360,6 +361,7 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType")
const std::string dir_str = args[4];
IdArray prob = args[5];
const bool replace = args[6];
const bool etype_sorted = args[7];
CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\".";
......@@ -367,7 +369,7 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType")
std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
*subg = sampling::SampleNeighborsEType(
hg.sptr(), nodes, etypes, fanout, dir, prob, replace);
hg.sptr(), nodes, etypes, fanout, dir, prob, replace, etype_sorted);
*rv = HeteroSubgraphRef(subg);
});
......
......@@ -254,7 +254,6 @@ TEST(RowwiseTest, TestCSRSamplingUniform) {
_TestCSRSamplingUniform<int64_t, double>(false);
}
template <typename Idx, typename FloatType>
void _TestCSRPerEtypeSampling(bool has_data) {
auto mat = CSREtypes<Idx>(has_data);
......@@ -262,8 +261,8 @@ void _TestCSRPerEtypeSampling(bool has_data) {
std::vector<FloatType>({.5, .5, .5, .5, .5, .5, .5}));
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
IdArray etypes = has_data ?
NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) :
NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2}));
NDArray::FromVector(std::vector<int32_t>({3, 1, 3, 3, 2, 3, 0})) :
NDArray::FromVector(std::vector<int32_t>({3, 3, 3, 0, 3, 1, 2}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
......@@ -330,6 +329,81 @@ void _TestCSRPerEtypeSampling(bool has_data) {
}
}
template <typename Idx, typename FloatType>
void _TestCSRPerEtypeSamplingSorted(bool has_data, bool etype_sorted) {
auto mat = CSREtypes<Idx>(has_data);
FloatArray prob = NDArray::FromVector(
std::vector<FloatType>({.5, .5, .5, .5, .5, .5, .5}));
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
IdArray etypes = has_data ?
NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) :
NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true, etype_sorted);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false, etype_sorted);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
int counts = 0;
counts += eset.count(std::make_tuple(0, 0, 2));
counts += eset.count(std::make_tuple(0, 1, 3));
counts += eset.count(std::make_tuple(0, 2, 5));
ASSERT_EQ(counts, 2);
counts = 0;
counts += eset.count(std::make_tuple(0, 3, 6));
ASSERT_EQ(counts, 1);
counts = 0;
counts += eset.count(std::make_tuple(1, 1, 0));
ASSERT_EQ(counts, 0);
counts = 0;
counts += eset.count(std::make_tuple(3, 2, 1));
ASSERT_EQ(counts, 1);
counts = 0;
counts += eset.count(std::make_tuple(3, 3, 4));
ASSERT_EQ(counts, 1);
} else {
int counts = 0;
counts += eset.count(std::make_tuple(0, 0, 0));
counts += eset.count(std::make_tuple(0, 1, 1));
counts += eset.count(std::make_tuple(0, 2, 2));
ASSERT_EQ(counts, 2);
counts = 0;
counts += eset.count(std::make_tuple(0, 3, 3));
ASSERT_EQ(counts, 1);
counts = 0;
counts += eset.count(std::make_tuple(1, 1, 4));
ASSERT_EQ(counts, 0);
counts = 0;
counts += eset.count(std::make_tuple(3, 2, 5));
ASSERT_EQ(counts, 1);
counts = 0;
counts += eset.count(std::make_tuple(3, 3, 6));
ASSERT_EQ(counts, 1);
}
}
prob = has_data ?
NDArray::FromVector(
std::vector<FloatType>({.0, .5, .0, .5, .5, .0, .5})) :
NDArray::FromVector(
std::vector<FloatType>({.0, .5, .0, .5, .0, .5, .5}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true, etype_sorted);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 2)));
ASSERT_FALSE(eset.count(std::make_tuple(0, 2, 5)));
} else {
ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0)));
ASSERT_FALSE(eset.count(std::make_tuple(0, 2, 2)));
}
}
}
TEST(RowwiseTest, TestCSRPerEtypeSampling) {
_TestCSRPerEtypeSampling<int32_t, float>(true);
_TestCSRPerEtypeSampling<int64_t, float>(true);
......@@ -339,6 +413,22 @@ TEST(RowwiseTest, TestCSRPerEtypeSampling) {
_TestCSRPerEtypeSampling<int64_t, float>(false);
_TestCSRPerEtypeSampling<int32_t, double>(false);
_TestCSRPerEtypeSampling<int64_t, double>(false);
_TestCSRPerEtypeSamplingSorted<int32_t, float>(true, true);
_TestCSRPerEtypeSamplingSorted<int64_t, float>(true, true);
_TestCSRPerEtypeSamplingSorted<int32_t, double>(true, true);
_TestCSRPerEtypeSamplingSorted<int64_t, double>(true, true);
_TestCSRPerEtypeSamplingSorted<int32_t, float>(false, true);
_TestCSRPerEtypeSamplingSorted<int64_t, float>(false, true);
_TestCSRPerEtypeSamplingSorted<int32_t, double>(false, true);
_TestCSRPerEtypeSamplingSorted<int64_t, double>(false, true);
_TestCSRPerEtypeSamplingSorted<int32_t, float>(true, false);
_TestCSRPerEtypeSamplingSorted<int64_t, float>(true, false);
_TestCSRPerEtypeSamplingSorted<int32_t, double>(true, false);
_TestCSRPerEtypeSamplingSorted<int64_t, double>(true, false);
_TestCSRPerEtypeSamplingSorted<int32_t, float>(false, false);
_TestCSRPerEtypeSamplingSorted<int64_t, float>(false, false);
_TestCSRPerEtypeSamplingSorted<int32_t, double>(false, false);
_TestCSRPerEtypeSamplingSorted<int64_t, double>(false, false);
}
template <typename Idx, typename FloatType>
......@@ -347,8 +437,8 @@ void _TestCSRPerEtypeSamplingUniform(bool has_data) {
FloatArray prob = aten::NullArray();
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
IdArray etypes = has_data ?
NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) :
NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2}));
NDArray::FromVector(std::vector<int32_t>({3, 1, 3, 3, 2, 3, 0})) :
NDArray::FromVector(std::vector<int32_t>({3, 3, 3, 0, 3, 1, 2}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
......@@ -398,6 +488,63 @@ void _TestCSRPerEtypeSamplingUniform(bool has_data) {
}
}
template <typename Idx, typename FloatType>
void _TestCSRPerEtypeSamplingUniformSorted(bool has_data, bool etype_sorted) {
auto mat = CSREtypes<Idx>(has_data);
FloatArray prob = aten::NullArray();
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
IdArray etypes = has_data ?
NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) :
NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true, etype_sorted);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false, etype_sorted);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
int counts = 0;
counts += eset.count(std::make_tuple(0, 0, 2));
counts += eset.count(std::make_tuple(0, 1, 3));
counts += eset.count(std::make_tuple(0, 2, 5));
ASSERT_EQ(counts, 2);
counts = 0;
counts += eset.count(std::make_tuple(0, 3, 6));
ASSERT_EQ(counts, 1);
counts = 0;
counts += eset.count(std::make_tuple(1, 1, 0));
ASSERT_EQ(counts, 0);
counts = 0;
counts += eset.count(std::make_tuple(3, 2, 1));
ASSERT_EQ(counts, 1);
counts = 0;
counts += eset.count(std::make_tuple(3, 3, 4));
ASSERT_EQ(counts, 1);
} else {
int counts = 0;
counts += eset.count(std::make_tuple(0, 0, 0));
counts += eset.count(std::make_tuple(0, 1, 1));
counts += eset.count(std::make_tuple(0, 2, 2));
ASSERT_EQ(counts, 2);
counts = 0;
counts += eset.count(std::make_tuple(0, 3, 3));
ASSERT_EQ(counts, 1);
counts = 0;
counts += eset.count(std::make_tuple(1, 1, 4));
ASSERT_EQ(counts, 0);
counts = 0;
counts += eset.count(std::make_tuple(3, 2, 5));
ASSERT_EQ(counts, 1);
counts = 0;
counts += eset.count(std::make_tuple(3, 3, 6));
ASSERT_EQ(counts, 1);
}
}
}
TEST(RowwiseTest, TestCSRPerEtypeSamplingUniform) {
_TestCSRPerEtypeSamplingUniform<int32_t, float>(true);
_TestCSRPerEtypeSamplingUniform<int64_t, float>(true);
......@@ -407,6 +554,22 @@ TEST(RowwiseTest, TestCSRPerEtypeSamplingUniform) {
_TestCSRPerEtypeSamplingUniform<int64_t, float>(false);
_TestCSRPerEtypeSamplingUniform<int32_t, double>(false);
_TestCSRPerEtypeSamplingUniform<int64_t, double>(false);
_TestCSRPerEtypeSamplingUniformSorted<int32_t, float>(true, true);
_TestCSRPerEtypeSamplingUniformSorted<int64_t, float>(true, true);
_TestCSRPerEtypeSamplingUniformSorted<int32_t, double>(true, true);
_TestCSRPerEtypeSamplingUniformSorted<int64_t, double>(true, true);
_TestCSRPerEtypeSamplingUniformSorted<int32_t, float>(false, true);
_TestCSRPerEtypeSamplingUniformSorted<int64_t, float>(false, true);
_TestCSRPerEtypeSamplingUniformSorted<int32_t, double>(false, true);
_TestCSRPerEtypeSamplingUniformSorted<int64_t, double>(false, true);
_TestCSRPerEtypeSamplingUniformSorted<int32_t, float>(true, false);
_TestCSRPerEtypeSamplingUniformSorted<int64_t, float>(true, false);
_TestCSRPerEtypeSamplingUniformSorted<int32_t, double>(true, false);
_TestCSRPerEtypeSamplingUniformSorted<int64_t, double>(true, false);
_TestCSRPerEtypeSamplingUniformSorted<int32_t, float>(false, false);
_TestCSRPerEtypeSamplingUniformSorted<int64_t, float>(false, false);
_TestCSRPerEtypeSamplingUniformSorted<int32_t, double>(false, false);
_TestCSRPerEtypeSamplingUniformSorted<int64_t, double>(false, false);
}
......@@ -508,8 +671,8 @@ void _TestCOOerEtypeSampling(bool has_data) {
std::vector<FloatType>({.5, .5, .5, .5, .5, .5, .5}));
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
IdArray etypes = has_data ?
NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) :
NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2}));
NDArray::FromVector(std::vector<int32_t>({3, 1, 3, 3, 2, 3, 0})) :
NDArray::FromVector(std::vector<int32_t>({3, 3, 3, 0, 3, 1, 2}));
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
......@@ -576,6 +739,81 @@ void _TestCOOerEtypeSampling(bool has_data) {
}
}
template <typename Idx, typename FloatType>
void _TestCOOerEtypeSamplingSorted(bool has_data, bool etype_sorted) {
auto mat = COOEtypes<Idx>(has_data);
FloatArray prob = NDArray::FromVector(
std::vector<FloatType>({.5, .5, .5, .5, .5, .5, .5}));
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
IdArray etypes = has_data ?
NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) :
NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2}));
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true, etype_sorted);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false, etype_sorted);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
int counts = 0;
counts += eset.count(std::make_tuple(0, 0, 2));
counts += eset.count(std::make_tuple(0, 1, 3));
counts += eset.count(std::make_tuple(0, 2, 5));
ASSERT_EQ(counts, 2);
counts = 0;
counts += eset.count(std::make_tuple(0, 3, 6));
ASSERT_EQ(counts, 1);
counts = 0;
counts += eset.count(std::make_tuple(1, 1, 0));
ASSERT_EQ(counts, 0);
counts = 0;
counts += eset.count(std::make_tuple(3, 2, 1));
ASSERT_EQ(counts, 1);
counts = 0;
counts += eset.count(std::make_tuple(3, 3, 4));
ASSERT_EQ(counts, 1);
} else {
int counts = 0;
counts += eset.count(std::make_tuple(0, 0, 0));
counts += eset.count(std::make_tuple(0, 1, 1));
counts += eset.count(std::make_tuple(0, 2, 2));
ASSERT_EQ(counts, 2);
counts = 0;
counts += eset.count(std::make_tuple(0, 3, 3));
ASSERT_EQ(counts, 1);
counts = 0;
counts += eset.count(std::make_tuple(1, 1, 4));
ASSERT_EQ(counts, 0);
counts = 0;
counts += eset.count(std::make_tuple(3, 2, 5));
ASSERT_EQ(counts, 1);
counts = 0;
counts += eset.count(std::make_tuple(3, 3, 6));
ASSERT_EQ(counts, 1);
}
}
prob = has_data ?
NDArray::FromVector(
std::vector<FloatType>({.0, .5, .0, .5, .5, .0, .5})) :
NDArray::FromVector(
std::vector<FloatType>({.0, .5, .0, .5, .0, .5, .5}));
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true, etype_sorted);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 2)));
ASSERT_FALSE(eset.count(std::make_tuple(0, 2, 5)));
} else {
ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0)));
ASSERT_FALSE(eset.count(std::make_tuple(0, 2, 2)));
}
}
}
TEST(RowwiseTest, TestCOOerEtypeSampling) {
_TestCOOerEtypeSampling<int32_t, float>(true);
_TestCOOerEtypeSampling<int64_t, float>(true);
......@@ -585,6 +823,22 @@ TEST(RowwiseTest, TestCOOerEtypeSampling) {
_TestCOOerEtypeSampling<int64_t, float>(false);
_TestCOOerEtypeSampling<int32_t, double>(false);
_TestCOOerEtypeSampling<int64_t, double>(false);
_TestCOOerEtypeSamplingSorted<int32_t, float>(true, true);
_TestCOOerEtypeSamplingSorted<int64_t, float>(true, true);
_TestCOOerEtypeSamplingSorted<int32_t, double>(true, true);
_TestCOOerEtypeSamplingSorted<int64_t, double>(true, true);
_TestCOOerEtypeSamplingSorted<int32_t, float>(false, true);
_TestCOOerEtypeSamplingSorted<int64_t, float>(false, true);
_TestCOOerEtypeSamplingSorted<int32_t, double>(false, true);
_TestCOOerEtypeSamplingSorted<int64_t, double>(false, true);
_TestCOOerEtypeSamplingSorted<int32_t, float>(true, false);
_TestCOOerEtypeSamplingSorted<int64_t, float>(true, false);
_TestCOOerEtypeSamplingSorted<int32_t, double>(true, false);
_TestCOOerEtypeSamplingSorted<int64_t, double>(true, false);
_TestCOOerEtypeSamplingSorted<int32_t, float>(false, false);
_TestCOOerEtypeSamplingSorted<int64_t, float>(false, false);
_TestCOOerEtypeSamplingSorted<int32_t, double>(false, false);
_TestCOOerEtypeSamplingSorted<int64_t, double>(false, false);
}
template <typename Idx, typename FloatType>
......@@ -593,8 +847,8 @@ void _TestCOOPerEtypeSamplingUniform(bool has_data) {
FloatArray prob = aten::NullArray();
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
IdArray etypes = has_data ?
NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) :
NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2}));
NDArray::FromVector(std::vector<int32_t>({3, 1, 3, 3, 2, 3, 0})) :
NDArray::FromVector(std::vector<int32_t>({3, 3, 3, 0, 3, 1, 2}));
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
......@@ -644,6 +898,63 @@ void _TestCOOPerEtypeSamplingUniform(bool has_data) {
}
}
template <typename Idx, typename FloatType>
void _TestCOOPerEtypeSamplingUniformSorted(bool has_data, bool etype_sorted) {
auto mat = COOEtypes<Idx>(has_data);
FloatArray prob = aten::NullArray();
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
IdArray etypes = has_data ?
NDArray::FromVector(std::vector<int32_t>({0, 1, 0, 0, 2, 0, 3})) :
NDArray::FromVector(std::vector<int32_t>({0, 0, 0, 3, 0, 1, 2}));
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, true, etype_sorted);
CheckSampledPerEtypeReplaceResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, 2, prob, false, etype_sorted);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
int counts = 0;
counts += eset.count(std::make_tuple(0, 0, 2));
counts += eset.count(std::make_tuple(0, 1, 3));
counts += eset.count(std::make_tuple(0, 2, 5));
ASSERT_EQ(counts, 2);
counts = 0;
counts += eset.count(std::make_tuple(0, 3, 6));
ASSERT_EQ(counts, 1);
counts = 0;
counts += eset.count(std::make_tuple(1, 1, 0));
ASSERT_EQ(counts, 0);
counts = 0;
counts += eset.count(std::make_tuple(3, 2, 1));
ASSERT_EQ(counts, 1);
counts = 0;
counts += eset.count(std::make_tuple(3, 3, 4));
ASSERT_EQ(counts, 1);
} else {
int counts = 0;
counts += eset.count(std::make_tuple(0, 0, 0));
counts += eset.count(std::make_tuple(0, 1, 1));
counts += eset.count(std::make_tuple(0, 2, 2));
ASSERT_EQ(counts, 2);
counts = 0;
counts += eset.count(std::make_tuple(0, 3, 3));
ASSERT_EQ(counts, 1);
counts = 0;
counts += eset.count(std::make_tuple(1, 1, 4));
ASSERT_EQ(counts, 0);
counts = 0;
counts += eset.count(std::make_tuple(3, 2, 5));
ASSERT_EQ(counts, 1);
counts = 0;
counts += eset.count(std::make_tuple(3, 3, 6));
ASSERT_EQ(counts, 1);
}
}
}
TEST(RowwiseTest, TestCOOPerEtypeSamplingUniform) {
_TestCOOPerEtypeSamplingUniform<int32_t, float>(true);
_TestCOOPerEtypeSamplingUniform<int64_t, float>(true);
......@@ -653,6 +964,22 @@ TEST(RowwiseTest, TestCOOPerEtypeSamplingUniform) {
_TestCOOPerEtypeSamplingUniform<int64_t, float>(false);
_TestCOOPerEtypeSamplingUniform<int32_t, double>(false);
_TestCOOPerEtypeSamplingUniform<int64_t, double>(false);
_TestCOOPerEtypeSamplingUniformSorted<int32_t, float>(true, true);
_TestCOOPerEtypeSamplingUniformSorted<int64_t, float>(true, true);
_TestCOOPerEtypeSamplingUniformSorted<int32_t, double>(true, true);
_TestCOOPerEtypeSamplingUniformSorted<int64_t, double>(true, true);
_TestCOOPerEtypeSamplingUniformSorted<int32_t, float>(false, true);
_TestCOOPerEtypeSamplingUniformSorted<int64_t, float>(false, true);
_TestCOOPerEtypeSamplingUniformSorted<int32_t, double>(false, true);
_TestCOOPerEtypeSamplingUniformSorted<int64_t, double>(false, true);
_TestCOOPerEtypeSamplingUniformSorted<int32_t, float>(true, false);
_TestCOOPerEtypeSamplingUniformSorted<int64_t, float>(true, false);
_TestCOOPerEtypeSamplingUniformSorted<int32_t, double>(true, false);
_TestCOOPerEtypeSamplingUniformSorted<int64_t, double>(true, false);
_TestCOOPerEtypeSamplingUniformSorted<int32_t, float>(false, false);
_TestCOOPerEtypeSamplingUniformSorted<int64_t, float>(false, false);
_TestCOOPerEtypeSamplingUniformSorted<int32_t, double>(false, false);
_TestCOOPerEtypeSamplingUniformSorted<int64_t, double>(false, false);
}
template <typename Idx, typename FloatType>
......
......@@ -335,6 +335,17 @@ def start_hetero_etype_sample_client(rank, tmpdir, disable_shared_mem, fanout=3,
assert 'feat' in dist_graph.nodes['n1'].data
assert 'feat' not in dist_graph.nodes['n2'].data
assert 'feat' not in dist_graph.nodes['n3'].data
if dist_graph.local_partition is not None:
# Check whether etypes are sorted in dist_graph
local_g = dist_graph.local_partition
local_nids = np.arange(local_g.num_nodes())
for lnid in local_nids:
leids = local_g.in_edges(lnid, form='eid')
letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
_, idices = np.unique(letids, return_index=True)
assert np.all(idices[:-1] <= idices[1:])
if gpb is None:
gpb = dist_graph.get_partition_book()
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment