Unverified Commit 72781efb authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Sampling] Enable sampling with edge masks on homogeneous graph (#4748)

* sample neighbors with masks

* oops

* refactor again

* remove

* remove debug code

* rename macro

* address comments

* address comment

* address comments

* rename a lot of stuff

* oops
parent 0aa8310a
......@@ -396,8 +396,9 @@ COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArr
* \param mat Input coo matrix.
* \param rows Rows to sample from.
* \param num_samples Number of samples
* \param prob Unnormalized probability array. Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* \param prob_or_mask Unnormalized probability array or mask array.
* Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* \param replace True if sample with replacement
* \return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array.
......@@ -406,7 +407,7 @@ COOMatrix COORowWiseSampling(
COOMatrix mat,
IdArray rows,
int64_t num_samples,
FloatArray prob = FloatArray(),
NDArray prob_or_mask = NDArray(),
bool replace = true);
/*!
......
......@@ -423,8 +423,9 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
* \param mat Input CSR matrix.
* \param rows Rows to sample from.
* \param num_samples Number of samples
* \param prob Unnormalized probability array. Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* \param prob_or_mask Unnormalized probability array or mask array.
* Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* \param replace True if sample with replacement
* \return A COOMatrix storing the picked row, col and data indices.
*/
......@@ -432,7 +433,7 @@ COOMatrix CSRRowWiseSampling(
CSRMatrix mat,
IdArray rows,
int64_t num_samples,
FloatArray prob = FloatArray(),
NDArray prob_or_mask = NDArray(),
bool replace = true);
/*!
......
......@@ -171,6 +171,33 @@
} \
} while (0)
/*
* Dispatch according to data type (int8, uint8, float32 or float64):
*
* ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(array->dtype, DType, {
* // Now DType is the type corresponding to data type in array.
* // For instance, one can do this for a CPU array:
* DType *data = static_cast<DType *>(array->data);
* });
*/
#define ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(val, DType, val_name, ...) do { \
if ((val).code == kDGLInt && (val).bits == 8) { \
typedef int8_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDGLUInt && (val).bits == 8) { \
typedef uint8_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDGLFloat && (val).bits == 32) { \
typedef float DType; \
{__VA_ARGS__} \
} else if ((val).code == kDGLFloat && (val).bits == 64) { \
typedef double DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << (val_name) << " can only be int8, uint8, float32 or float64"; \
} \
} while (0)
/*
* Dispatch data type only based on bit-width (8-bit, 16-bit, 32-bit, 64-bit):
*
......
......@@ -42,6 +42,7 @@ struct DGLDataTypeTraits {
static constexpr DGLDataType dtype{code, bits, 1}; \
}
GEN_DGLDATATYPETRAITS_FOR(int8_t, kDGLInt, 8);
GEN_DGLDATATYPETRAITS_FOR(uint8_t, kDGLUInt, 8);
GEN_DGLDATATYPETRAITS_FOR(int16_t, kDGLInt, 16);
GEN_DGLDATATYPETRAITS_FOR(int32_t, kDGLInt, 32);
GEN_DGLDATATYPETRAITS_FOR(int64_t, kDGLInt, 64);
......
......@@ -24,13 +24,6 @@ int64_t divup(int64_t x, int64_t y) {
namespace dgl {
namespace runtime {
namespace {
size_t compute_num_threads(size_t begin, size_t end, size_t grain_size) {
if (omp_in_parallel() || end - begin <= grain_size || end - begin == 1)
return 1;
return std::min(static_cast<int64_t>(omp_get_max_threads()), divup(end - begin, grain_size));
}
struct DefaultGrainSizeT {
size_t grain_size;
......@@ -50,6 +43,17 @@ struct DefaultGrainSizeT {
};
} // namespace
inline size_t compute_num_threads(size_t begin, size_t end, size_t grain_size) {
#ifdef _OPENMP
if (omp_in_parallel() || end - begin <= grain_size || end - begin == 1)
return 1;
return std::min(static_cast<int64_t>(omp_get_max_threads()), divup(end - begin, grain_size));
#else
return 1;
#endif
}
static DefaultGrainSizeT default_grain_size;
/*!
......
......@@ -545,19 +545,22 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
}
COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace) {
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace) {
COOMatrix ret;
if (IsNullArray(prob)) {
if (IsNullArray(prob_or_mask)) {
ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, "CSRRowWiseSamplingUniform", {
ret = impl::CSRRowWiseSamplingUniform<XPU, IdType>(mat, rows, num_samples, replace);
});
} else {
// prob is pinned and rows on GPU is valid
CHECK_VALID_CONTEXT(prob, rows);
// prob_or_mask is pinned and rows on GPU is valid
CHECK_VALID_CONTEXT(prob_or_mask, rows);
ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, "CSRRowWiseSampling", {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
CHECK(!(prob_or_mask->dtype.bits == 8 && XPU == kDGLCUDA)) <<
"GPU sampling with masks is currently not supported yet.";
ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
prob_or_mask->dtype, FloatType, "prob_or_maskability or mask", {
ret = impl::CSRRowWiseSampling<XPU, IdType, FloatType>(
mat, rows, num_samples, prob, replace);
mat, rows, num_samples, prob_or_mask, replace);
});
});
}
......@@ -804,15 +807,16 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries) {
}
COOMatrix COORowWiseSampling(
COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace) {
COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace) {
COOMatrix ret;
ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWiseSampling", {
if (IsNullArray(prob)) {
if (IsNullArray(prob_or_mask)) {
ret = impl::COORowWiseSamplingUniform<XPU, IdType>(mat, rows, num_samples, replace);
} else {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
ret = impl::COORowWiseSampling<XPU, IdType, FloatType>(
mat, rows, num_samples, prob, replace);
ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
prob_or_mask->dtype, DType, "prob_or_maskability or mask", {
ret = impl::COORowWiseSampling<XPU, IdType, DType>(
mat, rows, num_samples, prob_or_mask, replace);
});
}
});
......
......@@ -160,9 +160,9 @@ template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
// FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename FloatType>
template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace);
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace);
// FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename FloatType>
......@@ -269,9 +269,9 @@ template <DGLDeviceType XPU, typename IdType>
COOMatrix COORemove(COOMatrix coo, IdArray entries);
// FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename FloatType>
template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix COORowWiseSampling(
COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace);
COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace);
// FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename FloatType>
......
......@@ -34,15 +34,38 @@ namespace impl {
// \param rowid The row to pick from.
// \param off Starting offset of this row.
// \param len NNZ of the row.
// \param num_picks Number of picks on the row.
// \param col Pointer of the column indices.
// \param data Pointer of the data indices.
// \param out_idx Picked indices in [off, off + len).
template <typename IdxType>
using PickFn = std::function<void(
IdxType rowid, IdxType off, IdxType len,
IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
const IdxType* col, const IdxType* data,
IdxType* out_idx)>;
// User-defined function for determining the number of elements to pick from one row.
//
// The column indices of the given row are stored in
// [col + off, col + off + len)
//
// Similarly, the data indices are stored in
// [data + off, data + off + len)
// Data index pointer could be NULL, which means data[i] == i
//
// *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe.
//
// \param rowid The row to pick from.
// \param off Starting offset of this row.
// \param len NNZ of the row.
// \param col Pointer of the column indices.
// \param data Pointer of the data indices.
template <typename IdxType>
using NumPicksFn = std::function<IdxType(
IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data)>;
// User-defined function for picking elements from a range within a row.
//
// The column indices of each element is in
......@@ -72,7 +95,8 @@ using RangePickFn = std::function<void(
// OpenMP parallelization on rows because each row performs computation independently.
template <typename IdxType>
COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
int64_t num_picks, bool replace, PickFn<IdxType> pick_fn) {
int64_t num_picks, bool replace, PickFn<IdxType> pick_fn,
NumPicksFn<IdxType> num_picks_fn) {
using namespace aten;
const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data);
const IdxType* indices = static_cast<IdxType*>(mat.indices->data);
......@@ -80,6 +104,7 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
const IdxType* rows_data = static_cast<IdxType*>(rows->data);
const int64_t num_rows = rows->shape[0];
const auto& ctx = mat.indptr->ctx;
const auto& idtype = mat.indptr->dtype;
// To leverage OMP parallelization, we create two arrays to store
// picked src and dst indices. Each array is of length num_rows * num_picks.
......@@ -96,22 +121,16 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
//
// [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism is more
// significant. (minjie)
IdArray picked_row = NDArray::Empty({num_rows * num_picks},
DGLDataType{kDGLInt, 8*sizeof(IdxType), 1},
ctx);
IdArray picked_col = NDArray::Empty({num_rows * num_picks},
DGLDataType{kDGLInt, 8*sizeof(IdxType), 1},
ctx);
IdArray picked_idx = NDArray::Empty({num_rows * num_picks},
DGLDataType{kDGLInt, 8*sizeof(IdxType), 1},
ctx);
IdxType* picked_rdata = static_cast<IdxType*>(picked_row->data);
IdxType* picked_cdata = static_cast<IdxType*>(picked_col->data);
IdxType* picked_idata = static_cast<IdxType*>(picked_idx->data);
const int num_threads = omp_get_max_threads();
std::vector<int64_t> global_prefix(num_threads+1, 0);
// Do not use omp_get_max_threads() since that doesn't work for compiling without OpenMP.
const int num_threads = runtime::compute_num_threads(0, num_rows, 1);
std::vector<int64_t> global_prefix(num_threads + 1, 0);
// TODO(BarclayII) Using OMP parallel directly instead of using runtime::parallel_for
// does not handle exceptions well (directly aborts when an exception pops up).
// It runs faster though because there is less scheduling. Need to handle
// exceptions better.
IdArray picked_row, picked_col, picked_idx;
#pragma omp parallel num_threads(num_threads)
{
const int thread_id = omp_get_thread_num();
......@@ -131,13 +150,8 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
// build prefix-sum
const int64_t local_i = i-start_i;
const IdxType rid = rows_data[i];
IdxType len;
if (replace) {
len = indptr[rid+1] == indptr[rid] ? 0 : num_picks;
} else {
len = std::min(
static_cast<IdxType>(num_picks), indptr[rid + 1] - indptr[rid]);
}
IdxType len = num_picks_fn(
rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data);
local_prefix[local_i + 1] = local_prefix[local_i] + len;
}
global_prefix[thread_id + 1] = local_prefix[num_local];
......@@ -146,11 +160,18 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
#pragma omp master
{
for (int t = 0; t < num_threads; ++t) {
global_prefix[t+1] += global_prefix[t];
global_prefix[t + 1] += global_prefix[t];
}
picked_row = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
picked_col = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
}
#pragma omp barrier
IdxType* picked_rdata = picked_row.Ptr<IdxType>();
IdxType* picked_cdata = picked_col.Ptr<IdxType>();
IdxType* picked_idata = picked_idx.Ptr<IdxType>();
const IdxType thread_offset = global_prefix[thread_id];
for (int64_t i = start_i; i < end_i; ++i) {
......@@ -163,35 +184,26 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
const int64_t local_i = i - start_i;
const int64_t row_offset = thread_offset + local_prefix[local_i];
if (len <= num_picks && !replace) {
// nnz <= num_picks and w/o replacement, take all nnz
for (int64_t j = 0; j < len; ++j) {
picked_rdata[row_offset + j] = rid;
picked_cdata[row_offset + j] = indices[off + j];
picked_idata[row_offset + j] = data? data[off + j] : off + j;
}
} else {
pick_fn(rid, off, len,
indices, data,
picked_idata + row_offset);
for (int64_t j = 0; j < num_picks; ++j) {
const IdxType picked = picked_idata[row_offset + j];
picked_rdata[row_offset + j] = rid;
picked_cdata[row_offset + j] = indices[picked];
picked_idata[row_offset + j] = data? data[picked] : picked;
}
const int64_t num_picks = thread_offset + local_prefix[local_i + 1] - row_offset;
pick_fn(rid, off, len, num_picks, indices, data, picked_idata + row_offset);
for (int64_t j = 0; j < num_picks; ++j) {
const IdxType picked = picked_idata[row_offset + j];
picked_rdata[row_offset + j] = rid;
picked_cdata[row_offset + j] = indices[picked];
picked_idata[row_offset + j] = data ? data[picked] : picked;
}
}
}
const int64_t new_len = global_prefix.back();
picked_row = picked_row.CreateView({new_len}, picked_row->dtype);
picked_col = picked_col.CreateView({new_len}, picked_col->dtype);
picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype);
return COOMatrix(mat.num_rows, mat.num_cols,
picked_row, picked_col, picked_idx);
return COOMatrix(
mat.num_rows,
mat.num_cols,
picked_row.CreateView({new_len}, picked_row->dtype),
picked_col.CreateView({new_len}, picked_row->dtype),
picked_idx.CreateView({new_len}, picked_row->dtype));
}
// Template for picking non-zero values row-wise. The implementation utilizes
......@@ -347,11 +359,13 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes,
// row-wise pick on the CSR matrix and rectifies the returned results.
template <typename IdxType>
COOMatrix COORowWisePick(COOMatrix mat, IdArray rows,
int64_t num_picks, bool replace, PickFn<IdxType> pick_fn) {
int64_t num_picks, bool replace, PickFn<IdxType> pick_fn,
NumPicksFn<IdxType> num_picks_fn) {
using namespace aten;
const auto& csr = COOToCSR(COOSliceRows(mat, rows));
const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
const auto& picked = CSRRowWisePick<IdxType>(csr, new_rows, num_picks, replace, pick_fn);
const auto& picked = CSRRowWisePick<IdxType>(
csr, new_rows, num_picks, replace, pick_fn, num_picks_fn);
return COOMatrix(mat.num_rows, mat.num_cols,
IndexSelect(rows, picked.row), // map the row index to the correct one
picked.col,
......
......@@ -27,17 +27,41 @@ inline FloatArray DoubleSlice(FloatArray array, const IdxType* idx_data,
return ret;
}
template <typename IdxType, typename FloatType>
inline PickFn<IdxType> GetSamplingPickFn(
int64_t num_samples, FloatArray prob, bool replace) {
PickFn<IdxType> pick_fn = [prob, num_samples, replace]
template <typename IdxType, typename DType>
inline NumPicksFn<IdxType> GetSamplingNumPicksFn(
int64_t num_samples, NDArray prob_or_mask, bool replace) {
NumPicksFn<IdxType> num_picks_fn = [prob_or_mask, num_samples, replace]
(IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data) {
const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
const DType* prob_or_mask_data = prob_or_mask.Ptr<DType>();
IdxType nnz = 0;
for (IdxType i = off; i < off + len; ++i) {
if (prob_or_mask_data[i] > 0) {
++nnz;
}
}
if (replace) {
return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);
} else {
return std::min(static_cast<IdxType>(max_num_picks), nnz);
}
};
return num_picks_fn;
}
template <typename IdxType, typename DType>
inline PickFn<IdxType> GetSamplingPickFn(
int64_t num_samples, NDArray prob_or_mask, bool replace) {
PickFn<IdxType> pick_fn = [prob_or_mask, num_samples, replace]
(IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
const IdxType* col, const IdxType* data,
IdxType* out_idx) {
FloatArray prob_selected = DoubleSlice<IdxType, FloatType>(prob, data, off, len);
RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>(
num_samples, prob_selected, out_idx, replace);
for (int64_t j = 0; j < num_samples; ++j) {
NDArray prob_or_mask_selected = DoubleSlice<IdxType, DType>(prob_or_mask, data, off, len);
RandomEngine::ThreadLocal()->Choice<IdxType, DType>(
num_picks, prob_or_mask_selected, out_idx, replace);
for (int64_t j = 0; j < num_picks; ++j) {
out_idx[j] += off;
}
};
......@@ -67,16 +91,32 @@ inline RangePickFn<IdxType> GetSamplingRangePickFn(
return pick_fn;
}
template <typename IdxType>
inline NumPicksFn<IdxType> GetSamplingUniformNumPicksFn(
int64_t num_samples, bool replace) {
NumPicksFn<IdxType> num_picks_fn = [num_samples, replace]
(IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data) {
const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
if (replace) {
return static_cast<IdxType>(len == 0 ? 0 : max_num_picks);
} else {
return std::min(static_cast<IdxType>(max_num_picks), len);
}
};
return num_picks_fn;
}
template <typename IdxType>
inline PickFn<IdxType> GetSamplingUniformPickFn(
int64_t num_samples, bool replace) {
PickFn<IdxType> pick_fn = [num_samples, replace]
(IdxType rowid, IdxType off, IdxType len,
(IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
const IdxType* col, const IdxType* data,
IdxType* out_idx) {
RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
num_samples, len, out_idx, replace);
for (int64_t j = 0; j < num_samples; ++j) {
num_picks, len, out_idx, replace);
for (int64_t j = 0; j < num_picks; ++j) {
out_idx[j] += off;
}
};
......@@ -96,17 +136,43 @@ inline RangePickFn<IdxType> GetSamplingUniformRangePickFn(
return pick_fn;
}
template <typename IdxType, typename FloatType>
inline NumPicksFn<IdxType> GetSamplingBiasedNumPicksFn(
int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
NumPicksFn<IdxType> num_picks_fn = [num_samples, split, bias, replace]
(IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data) {
const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
const int64_t num_tags = split->shape[1] - 1;
const IdxType* tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];
const FloatType* bias_data = bias.Ptr<FloatType>();
IdxType nnz = 0;
for (int64_t j = 0; j < num_tags; ++j) {
if (bias_data[j] > 0) {
nnz += tag_offset[j + 1] - tag_offset[j];
}
}
if (replace) {
return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);
} else {
return std::min(static_cast<IdxType>(max_num_picks), nnz);
}
};
return num_picks_fn;
}
template <typename IdxType, typename FloatType>
inline PickFn<IdxType> GetSamplingBiasedPickFn(
int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
PickFn<IdxType> pick_fn = [num_samples, split, bias, replace]
(IdxType rowid, IdxType off, IdxType len,
(IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
const IdxType* col, const IdxType* data,
IdxType* out_idx) {
const IdxType *tag_offset = static_cast<IdxType *>(split->data) + rowid * split->shape[1];
const IdxType *tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];
RandomEngine::ThreadLocal()->BiasedChoice<IdxType, FloatType>(
num_samples, tag_offset, bias, out_idx, replace);
for (int64_t j = 0; j < num_samples; ++j) {
num_picks, tag_offset, bias, out_idx, replace);
for (int64_t j = 0; j < num_picks; ++j) {
out_idx[j] += off;
}
};
......@@ -117,22 +183,35 @@ inline PickFn<IdxType> GetSamplingBiasedPickFn(
/////////////////////////////// CSR ///////////////////////////////
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples,
FloatArray prob, bool replace) {
CHECK(prob.defined());
auto pick_fn = GetSamplingPickFn<IdxType, FloatType>(num_samples, prob, replace);
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
NDArray prob_or_mask, bool replace) {
// If num_samples is -1, select all neighbors without replacement.
replace = (replace && num_samples != -1);
CHECK(prob_or_mask.defined());
auto num_picks_fn = GetSamplingNumPicksFn<IdxType, DType>(
num_samples, prob_or_mask, replace);
auto pick_fn = GetSamplingPickFn<IdxType, DType>(
num_samples, prob_or_mask, replace);
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
}
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, int8_t>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, int8_t>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, uint8_t>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, uint8_t>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes,
......@@ -155,8 +234,11 @@ template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
template <DGLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows,
int64_t num_samples, bool replace) {
// If num_samples is -1, select all neighbors without replacement.
replace = (replace && num_samples != -1);
auto num_picks_fn = GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
}
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int32_t>(
......@@ -186,9 +268,13 @@ COOMatrix CSRRowWiseSamplingBiased(
FloatArray bias,
bool replace
) {
// If num_samples is -1, select all neighbors without replacement.
replace = (replace && num_samples != -1);
auto num_picks_fn = GetSamplingBiasedNumPicksFn<IdxType, FloatType>(
num_samples, tag_offset, bias, replace);
auto pick_fn = GetSamplingBiasedPickFn<IdxType, FloatType>(
num_samples, tag_offset, bias, replace);
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
}
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, float>(
......@@ -206,22 +292,35 @@ template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, double>(
/////////////////////////////// COO ///////////////////////////////
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples,
FloatArray prob, bool replace) {
CHECK(prob.defined());
auto pick_fn = GetSamplingPickFn<IdxType, FloatType>(num_samples, prob, replace);
return COORowWisePick(mat, rows, num_samples, replace, pick_fn);
NDArray prob_or_mask, bool replace) {
// If num_samples is -1, select all neighbors without replacement.
replace = (replace && num_samples != -1);
CHECK(prob_or_mask.defined());
auto num_picks_fn = GetSamplingNumPicksFn<IdxType, DType>(
num_samples, prob_or_mask, replace);
auto pick_fn = GetSamplingPickFn<IdxType, DType>(
num_samples, prob_or_mask, replace);
return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
}
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, float>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, float>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, double>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, double>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, int8_t>(
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, int8_t>(
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, uint8_t>(
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, uint8_t>(
COOMatrix, IdArray, int64_t, NDArray, bool);
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes,
......@@ -244,8 +343,11 @@ template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
template <DGLDeviceType XPU, typename IdxType>
COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows,
int64_t num_samples, bool replace) {
// If num_samples is -1, select all neighbors without replacement.
replace = (replace && num_samples != -1);
auto num_picks_fn = GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
return COORowWisePick(mat, rows, num_samples, replace, pick_fn);
return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
}
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int32_t>(
......
......@@ -12,11 +12,22 @@ namespace aten {
namespace impl {
namespace {
template <typename IdxType>
inline NumPicksFn<IdxType> GetTopkNumPicksFn(int64_t k) {
NumPicksFn<IdxType> num_picks_fn = [k]
(IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data) {
const int64_t max_num_picks = (k == -1) ? len : k;
return std::min(static_cast<IdxType>(max_num_picks), len);
};
return num_picks_fn;
}
template <typename IdxType, typename DType>
inline PickFn<IdxType> GetTopkPickFn(int64_t k, NDArray weight, bool ascending) {
inline PickFn<IdxType> GetTopkPickFn(NDArray weight, bool ascending) {
const DType* wdata = static_cast<DType*>(weight->data);
PickFn<IdxType> pick_fn = [k, ascending, wdata]
(IdxType rowid, IdxType off, IdxType len,
PickFn<IdxType> pick_fn = [ascending, wdata]
(IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
const IdxType* col, const IdxType* data,
IdxType* out_idx) {
std::function<bool(IdxType, IdxType)> compare_fn;
......@@ -45,7 +56,7 @@ inline PickFn<IdxType> GetTopkPickFn(int64_t k, NDArray weight, bool ascending)
std::vector<IdxType> idx(len);
std::iota(idx.begin(), idx.end(), off);
std::sort(idx.begin(), idx.end(), compare_fn);
for (int64_t j = 0; j < k; ++j) {
for (int64_t j = 0; j < num_picks; ++j) {
out_idx[j] = idx[j];
}
};
......@@ -58,8 +69,9 @@ inline PickFn<IdxType> GetTopkPickFn(int64_t k, NDArray weight, bool ascending)
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix CSRRowWiseTopk(
CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
auto pick_fn = GetTopkPickFn<IdxType, DType>(k, weight, ascending);
return CSRRowWisePick(mat, rows, k, false, pick_fn);
auto num_picks_fn = GetTopkNumPicksFn<IdxType>(k);
auto pick_fn = GetTopkPickFn<IdxType, DType>(weight, ascending);
return CSRRowWisePick(mat, rows, k, false, pick_fn, num_picks_fn);
}
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, int32_t>(
......@@ -82,8 +94,9 @@ template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, double>(
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix COORowWiseTopk(
COOMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
auto pick_fn = GetTopkPickFn<IdxType, DType>(k, weight, ascending);
return COORowWisePick(mat, rows, k, false, pick_fn);
auto num_picks_fn = GetTopkNumPicksFn<IdxType>(k);
auto pick_fn = GetTopkPickFn<IdxType, DType>(weight, ascending);
return COORowWisePick(mat, rows, k, false, pick_fn, num_picks_fn);
}
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, int32_t>(
......
......@@ -235,16 +235,14 @@ __global__ void _CSRRowWiseSampleUniformReplaceKernel(
out_row += 1;
}
}
} // namespace
} // namespace
///////////////////////////// CSR sampling //////////////////////////
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
IdArray rows,
const int64_t num_picks,
const bool replace) {
COOMatrix _CSRRowWiseSamplingUniform(
CSRMatrix mat, IdArray rows, const int64_t num_picks, const bool replace) {
const auto& ctx = rows->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
......@@ -369,6 +367,19 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
picked_col, picked_idx);
}
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform(
CSRMatrix mat, IdArray rows, const int64_t num_picks, const bool replace) {
if (num_picks == -1) {
// Basically this is UnitGraph::InEdges().
COOMatrix coo = CSRToCOO(CSRSliceRows(mat, rows), false);
IdArray sliced_rows = IndexSelect(rows, coo.row);
return COOMatrix(mat.num_rows, mat.num_cols, sliced_rows, coo.col, coo.data);
} else {
return _CSRRowWiseSamplingUniform<XPU, IdType>(mat, rows, num_picks, replace);
}
}
template COOMatrix CSRRowWiseSamplingUniform<kDGLCUDA, int32_t>(
CSRMatrix, IdArray, int64_t, bool);
template COOMatrix CSRRowWiseSamplingUniform<kDGLCUDA, int64_t>(
......
......@@ -12,6 +12,7 @@
#include <numeric>
#include "./dgl_cub.cuh"
#include "./utils.h"
#include "../../array/cuda/atomic.cuh"
#include "../../runtime/cuda/cuda_common.h"
......@@ -392,6 +393,74 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
}
}
template <typename IdType, typename DType, typename BoolType>
__global__ void _GenerateFlagsKernel(
int64_t n, const IdType* idx, const DType* values, DType criteria, BoolType* output) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
while (tx < n) {
output[tx] = (values[idx ? idx[tx] : tx] != criteria);
tx += stride_x;
}
}
template <DGLDeviceType XPU, typename IdType, typename DType, typename MaskGen>
COOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) {
using namespace dgl::cuda;
const auto idtype = coo.row->dtype;
const auto ctx = coo.row->ctx;
const int64_t nnz = coo.row->shape[0];
const IdType* row = coo.row.Ptr<IdType>();
const IdType* col = coo.col.Ptr<IdType>();
const IdArray& eid = COOHasData(coo) ? coo.data : Range(
0, nnz, sizeof(IdType) * 8, ctx);
const IdType* data = coo.data.Ptr<IdType>();
IdArray new_row = IdArray::Empty({nnz}, idtype, ctx);
IdArray new_col = IdArray::Empty({nnz}, idtype, ctx);
IdArray new_eid = IdArray::Empty({nnz}, idtype, ctx);
IdType* new_row_data = new_row.Ptr<IdType>();
IdType* new_col_data = new_col.Ptr<IdType>();
IdType* new_eid_data = new_eid.Ptr<IdType>();
auto stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(ctx);
int8_t* flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, nnz));
int nt = dgl::cuda::FindNumThreads(nnz);
int64_t nb = (nnz + nt - 1) / nt;
maskgen(nb, nt, stream, nnz, data, flags);
int64_t* rst = static_cast<int64_t*>(device->AllocWorkspace(ctx, sizeof(int64_t)));
MaskSelect(device, ctx, row, flags, new_row_data, nnz, rst, stream);
MaskSelect(device, ctx, col, flags, new_col_data, nnz, rst, stream);
MaskSelect(device, ctx, data, flags, new_eid_data, nnz, rst, stream);
int64_t new_len = GetCUDAScalar(device, ctx, rst);
device->FreeWorkspace(ctx, flags);
device->FreeWorkspace(ctx, rst);
return COOMatrix(
coo.num_rows,
coo.num_cols,
new_row.CreateView({new_len}, idtype, 0),
new_col.CreateView({new_len}, idtype, 0),
new_eid.CreateView({new_len}, idtype, 0));
}
template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix _COORemoveIf(const COOMatrix& coo, const NDArray& values, DType criteria) {
const DType* val = values.Ptr<DType>();
auto maskgen = [val, criteria] (
int nb, int nt, cudaStream_t stream, int64_t nnz, const IdType* data,
int8_t* flags) {
CUDA_KERNEL_CALL((_GenerateFlagsKernel<IdType, DType, int8_t>),
nb, nt, 0, stream,
nnz, data, val, criteria, flags);
};
return COOGeneralRemoveIf<XPU, IdType, DType, decltype(maskgen)>(coo, maskgen);
}
} // namespace
/////////////////////////////// CSR ///////////////////////////////
......@@ -417,11 +486,12 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
* @author pengqirong (OPPO), dlasalle and Xin from Nvidia.
*/
template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWiseSampling(CSRMatrix mat,
IdArray rows,
int64_t num_picks,
FloatArray prob,
bool replace) {
COOMatrix _CSRRowWiseSampling(
const CSRMatrix& mat,
const IdArray& rows,
int64_t num_picks,
const FloatArray& prob,
bool replace) {
const auto& ctx = rows->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
......@@ -647,8 +717,24 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat,
picked_col = picked_col.CreateView({new_len}, picked_col->dtype);
picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype);
return COOMatrix(mat.num_rows, mat.num_cols, picked_row,
picked_col, picked_idx);
return COOMatrix(mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx);
}
template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_picks, FloatArray prob, bool replace) {
COOMatrix result;
if (num_picks == -1) {
// Basically this is UnitGraph::InEdges().
COOMatrix coo = CSRToCOO(CSRSliceRows(mat, rows), false);
IdArray sliced_rows = IndexSelect(rows, coo.row);
result = COOMatrix(mat.num_rows, mat.num_cols, sliced_rows, coo.col, coo.data);
} else {
result = _CSRRowWiseSampling<XPU, IdType, DType>(mat, rows, num_picks, prob, replace);
}
// NOTE(BarclayII): I'm removing the entries with zero probability after sampling.
// Is there a better way?
return _COORemoveIf<XPU, IdType, DType>(result, prob, static_cast<DType>(0));
}
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, float>(
......@@ -659,6 +745,16 @@ template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
// These are not being called, but we instantiate them anyway to prevent missing
// symbols in Debug build
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, int8_t>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, int8_t>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, uint8_t>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, uint8_t>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
} // namespace impl
} // namespace aten
......
......@@ -11,6 +11,7 @@
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/ndarray.h>
#include "../../runtime/cuda/cuda_common.h"
#include "dgl_cub.cuh"
namespace dgl {
namespace cuda {
......@@ -252,6 +253,20 @@ __device__ IdType _BinarySearch(const IdType *A, int64_t n, IdType x) {
return n; // not found
}
template <typename DType, typename BoolType>
void MaskSelect(
runtime::DeviceAPI* device, const DGLContext& ctx,
const DType* input, const BoolType* mask, DType* output, int64_t n,
int64_t* rst, cudaStream_t stream) {
size_t workspace_size = 0;
CUDA_CALL(cub::DeviceSelect::Flagged(
nullptr, workspace_size, input, mask, output, rst, n, stream));
void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUDA_CALL(cub::DeviceSelect::Flagged(
workspace, workspace_size, input, mask, output, rst, n, stream));
device->FreeWorkspace(ctx, workspace);
}
} // namespace cuda
} // namespace dgl
......
......@@ -67,7 +67,7 @@ HeteroSubgraph SampleNeighbors(
const std::vector<IdArray>& nodes,
const std::vector<int64_t>& fanouts,
EdgeDir dir,
const std::vector<FloatArray>& prob,
const std::vector<NDArray>& prob_or_mask,
const std::vector<IdArray>& exclude_edges,
bool replace) {
......@@ -76,8 +76,8 @@ HeteroSubgraph SampleNeighbors(
<< "Number of node ID tensors must match the number of node types.";
CHECK_EQ(fanouts.size(), hg->NumEdgeTypes())
<< "Number of fanout values must match the number of edge types.";
CHECK_EQ(prob.size(), hg->NumEdgeTypes())
<< "Number of probability tensors must match the number of edge types.";
CHECK_EQ(prob_or_mask.size(), hg->NumEdgeTypes())
<< "Number of prob_or_maskability tensors must match the number of edge types.";
DGLContext ctx = aten::GetContextOf(nodes);
......@@ -89,6 +89,7 @@ HeteroSubgraph SampleNeighbors(
const dgl_type_t dst_vtype = pair.second;
const IdArray nodes_ntype = nodes[(dir == EdgeDir::kOut)? src_vtype : dst_vtype];
const int64_t num_nodes = nodes_ntype->shape[0];
if (num_nodes == 0 || fanouts[etype] == 0) {
// Nothing to sample for this etype, create a placeholder relation graph
subrels[etype] = UnitGraph::Empty(
......@@ -97,47 +98,37 @@ HeteroSubgraph SampleNeighbors(
hg->NumVertices(dst_vtype),
hg->DataType(), ctx);
induced_edges[etype] = aten::NullArray(hg->DataType(), ctx);
} else if (fanouts[etype] == -1) {
const auto &earr = (dir == EdgeDir::kOut) ?
hg->OutEdges(etype, nodes_ntype) :
hg->InEdges(etype, nodes_ntype);
subrels[etype] = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(),
hg->NumVertices(src_vtype),
hg->NumVertices(dst_vtype),
earr.src,
earr.dst);
induced_edges[etype] = earr.id;
} else {
COOMatrix sampled_coo;
// sample from one relation graph
auto req_fmt = (dir == EdgeDir::kOut)? CSR_CODE : CSC_CODE;
auto avail_fmt = hg->SelectFormat(etype, req_fmt);
COOMatrix sampled_coo;
switch (avail_fmt) {
case SparseFormat::kCOO:
if (dir == EdgeDir::kIn) {
sampled_coo = aten::COOTranspose(aten::COORowWiseSampling(
aten::COOTranspose(hg->GetCOOMatrix(etype)),
nodes_ntype, fanouts[etype], prob[etype], replace));
nodes_ntype, fanouts[etype], prob_or_mask[etype], replace));
} else {
sampled_coo = aten::COORowWiseSampling(
hg->GetCOOMatrix(etype), nodes_ntype, fanouts[etype], prob[etype], replace);
hg->GetCOOMatrix(etype), nodes_ntype, fanouts[etype], prob_or_mask[etype], replace);
}
break;
case SparseFormat::kCSR:
CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix.";
sampled_coo = aten::CSRRowWiseSampling(
hg->GetCSRMatrix(etype), nodes_ntype, fanouts[etype], prob[etype], replace);
hg->GetCSRMatrix(etype), nodes_ntype, fanouts[etype], prob_or_mask[etype], replace);
break;
case SparseFormat::kCSC:
CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
sampled_coo = aten::CSRRowWiseSampling(
hg->GetCSCMatrix(etype), nodes_ntype, fanouts[etype], prob[etype], replace);
hg->GetCSCMatrix(etype), nodes_ntype, fanouts[etype], prob_or_mask[etype], replace);
sampled_coo = aten::COOTranspose(sampled_coo);
break;
default:
LOG(FATAL) << "Unsupported sparse format.";
}
subrels[etype] = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows, sampled_coo.num_cols,
sampled_coo.row, sampled_coo.col);
......@@ -279,17 +270,6 @@ HeteroSubgraph SampleNeighborsTopk(
hg->NumVertices(dst_vtype),
hg->DataType(), hg->Context());
induced_edges[etype] = aten::NullArray();
} else if (k[etype] == -1) {
const auto &earr = (dir == EdgeDir::kOut) ?
hg->OutEdges(etype, nodes_ntype) :
hg->InEdges(etype, nodes_ntype);
subrels[etype] = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(),
hg->NumVertices(src_vtype),
hg->NumVertices(dst_vtype),
earr.src,
earr.dst);
induced_edges[etype] = earr.id;
} else {
// sample from one relation graph
auto req_fmt = (dir == EdgeDir::kOut)? CSR_CODE : CSC_CODE;
......@@ -368,17 +348,6 @@ HeteroSubgraph SampleNeighborsBiased(
hg->NumVertices(dst_vtype),
hg->DataType(), hg->Context());
induced_edges = aten::NullArray();
} else if (fanout == -1) {
const auto &earr = (dir == EdgeDir::kOut) ?
hg->OutEdges(etype, nodes_ntype) :
hg->InEdges(etype, nodes_ntype);
subrel = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(),
hg->NumVertices(src_vtype),
hg->NumVertices(dst_vtype),
earr.src,
earr.dst);
induced_edges = earr.id;
} else {
// sample from one relation graph
const auto req_fmt = (dir == EdgeDir::kOut)? CSR_CODE : CSC_CODE;
......@@ -444,7 +413,7 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors")
IdArray fanouts_array = args[2];
const auto& fanouts = fanouts_array.ToVector<int64_t>();
const std::string dir_str = args[3];
const auto& prob = ListValueToVector<FloatArray>(args[4]);
const auto& prob_or_mask = ListValueToVector<NDArray>(args[4]);
const auto& exclude_edges = ListValueToVector<IdArray>(args[5]);
const bool replace = args[6];
......@@ -454,7 +423,7 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors")
std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
*subg = sampling::SampleNeighbors(
hg.sptr(), nodes, fanouts, dir, prob, exclude_edges, replace);
hg.sptr(), nodes, fanouts, dir, prob_or_mask, exclude_edges, replace);
*rv = HeteroSubgraphRef(subg);
});
......
......@@ -55,10 +55,22 @@ template void RandomEngine::Choice<int32_t, double>(int32_t num,
template void RandomEngine::Choice<int64_t, double>(int64_t num,
FloatArray prob,
int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, int8_t>(int32_t num, FloatArray prob,
int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, int8_t>(int64_t num, FloatArray prob,
int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, uint8_t>(int32_t num,
FloatArray prob,
int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, uint8_t>(int64_t num,
FloatArray prob,
int64_t* out, bool replace);
template <typename IdxType>
void RandomEngine::UniformChoice(IdxType num, IdxType population, IdxType* out,
bool replace) {
CHECK_GE(num, 0) << "The numbers to sample should be non-negative.";
CHECK_GE(population, 0) << "The population size should be non-negative.";
if (!replace)
CHECK_LE(num, population)
<< "Cannot take more sample than population when 'replace=false'";
......
......@@ -32,6 +32,11 @@ class BaseSampler {
}
};
// (BarclayII 2022.9.20) Changing the internal data type of probabilities to double since
// we are using non-uniform sampling to sample on boolean masks, where False represents
// probability 0. DType could be uint8 in this case, which will give incorrect arithmetic
// results due to overflowing and/or integer division.
/*
* AliasSampler is used to sample elements from a given discrete categorical distribution.
* Algorithm: Alias Method(https://en.wikipedia.org/wiki/Alias_method)
......@@ -47,9 +52,9 @@ class AliasSampler: public BaseSampler<Idx> {
private:
RandomEngine *re;
Idx N;
DType accum, taken; // accumulated likelihood
double accum, taken; // accumulated likelihood
std::vector<Idx> K; // alias table
std::vector<DType> U; // probability table
std::vector<double> U; // probability table
FloatArray _prob; // category distribution
std::vector<bool> used; // indicate availability, activated when replace=false;
std::vector<Idx> id_mapping; // index mapping, activated when replace=false;
......@@ -63,7 +68,7 @@ class AliasSampler: public BaseSampler<Idx> {
void Reconstruct(FloatArray prob) { // Reconstruct alias table
const int64_t prob_size = prob->shape[0];
const DType *prob_data = static_cast<DType *>(prob->data);
const DType *prob_data = prob.Ptr<DType>();
N = 0;
accum = 0.;
taken = 0.;
......@@ -79,11 +84,11 @@ class AliasSampler: public BaseSampler<Idx> {
if (N == 0) LOG(FATAL) << "Cannot take more sample than population when 'replace=false'";
K.resize(N);
U.resize(N);
DType avg = accum / static_cast<DType>(N);
double avg = accum / static_cast<double>(N);
std::fill(U.begin(), U.end(), avg); // initialize U
std::queue<std::pair<Idx, DType> > under, over;
std::queue<std::pair<Idx, double> > under, over;
for (Idx i = 0; i < N; ++i) {
DType p = prob_data[Map(i)];
double p = prob_data[Map(i)];
if (p > avg)
over.push(std::make_pair(i, p));
else
......@@ -93,7 +98,7 @@ class AliasSampler: public BaseSampler<Idx> {
while (!under.empty() && !over.empty()) {
auto u_pair = under.front(), o_pair = over.front();
Idx i_u = u_pair.first, i_o = o_pair.first;
DType p_u = u_pair.second, p_o = o_pair.second;
double p_u = u_pair.second, p_o = o_pair.second;
K[i_u] = i_o;
U[i_u] = p_u;
if (p_o + p_u > 2 * avg)
......@@ -121,21 +126,24 @@ class AliasSampler: public BaseSampler<Idx> {
~AliasSampler() {}
Idx Draw() {
DType avg = accum / N;
if (!replace) {
const DType *_prob_data = static_cast<DType *>(_prob->data);
const DType *_prob_data = _prob.Ptr<DType>();
if (2 * taken >= accum)
Reconstruct(_prob);
if (accum <= 0)
return -1;
// accum changes after Reconstruct(), so avg should be computed after that.
double avg = accum / N;
while (true) {
DType dice = re->Uniform<DType>(0, N);
double dice = re->Uniform<double>(0, N);
Idx i = static_cast<Idx>(dice), rst;
DType p = (dice - i) * avg;
if (p <= U[Map(i)]) {
double p = (dice - i) * avg;
if (p <= U[i]) {
rst = Map(i);
} else {
rst = Map(K[i]);
}
DType cap = _prob_data[rst];
double cap = _prob_data[rst];
if (!used[rst]) {
used[rst] = true;
taken += cap;
......@@ -143,10 +151,13 @@ class AliasSampler: public BaseSampler<Idx> {
}
}
}
DType dice = re->Uniform<DType>(0, N);
if (accum <= 0)
return -1;
double avg = accum / N;
double dice = re->Uniform<double>(0, N);
Idx i = static_cast<Idx>(dice);
DType p = (dice - i) * avg;
if (p <= U[Map(i)])
double p = (dice - i) * avg;
if (p <= U[i])
return Map(i);
else
return Map(K[i]);
......@@ -169,9 +180,9 @@ class CDFSampler: public BaseSampler<Idx> {
private:
RandomEngine *re;
Idx N;
DType accum, taken;
double accum, taken;
FloatArray _prob; // categorical distribution
std::vector<DType> cdf; // cumulative distribution function
std::vector<double> cdf; // cumulative distribution function
std::vector<bool> used; // indicate availability, activated when replace=false;
std::vector<Idx> id_mapping; // indicate index mapping, activated when replace=false;
......@@ -184,7 +195,7 @@ class CDFSampler: public BaseSampler<Idx> {
void Reconstruct(FloatArray prob) { // Reconstruct CDF
int64_t prob_size = prob->shape[0];
const DType *prob_data = static_cast<DType *>(prob->data);
const DType *prob_data = prob.Ptr<DType>();
N = 0;
accum = 0.;
taken = 0.;
......@@ -219,15 +230,17 @@ class CDFSampler: public BaseSampler<Idx> {
~CDFSampler() {}
Idx Draw() {
DType eps = std::numeric_limits<DType>::min();
double eps = std::numeric_limits<double>::min();
if (!replace) {
const DType *_prob_data = static_cast<DType *>(_prob->data);
const DType *_prob_data = _prob.Ptr<DType>();
if (2 * taken >= accum)
Reconstruct(_prob);
if (accum <= 0)
return -1;
while (true) {
DType p = std::max(re->Uniform<DType>(0., accum), eps);
double p = std::max(re->Uniform<double>(0., accum), eps);
Idx rst = Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
DType cap = _prob_data[rst];
double cap = static_cast<double>(_prob_data[rst]);
if (!used[rst]) {
used[rst] = true;
taken += cap;
......@@ -235,7 +248,9 @@ class CDFSampler: public BaseSampler<Idx> {
}
}
}
DType p = std::max(re->Uniform<DType>(0., accum), eps);
if (accum <= 0)
return -1;
double p = std::max(re->Uniform<double>(0., accum), eps);
return Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
}
};
......@@ -255,7 +270,7 @@ template <
class TreeSampler: public BaseSampler<Idx> {
private:
RandomEngine *re;
std::vector<DType> weight; // accumulated likelihood of subtrees.
std::vector<double> weight; // accumulated likelihood of subtrees.
int64_t N;
int64_t num_leafs;
const DType *decrease;
......@@ -263,7 +278,7 @@ class TreeSampler: public BaseSampler<Idx> {
public:
void ResetState(FloatArray prob) {
int64_t prob_size = prob->shape[0];
const DType *prob_data = static_cast<DType *>(prob->data);
const DType *prob_data = prob.Ptr<DType>();
std::fill(weight.begin(), weight.end(), 0);
for (int64_t i = 0; i < prob_size; ++i)
weight[num_leafs + i] = prob_data[i];
......@@ -293,13 +308,15 @@ class TreeSampler: public BaseSampler<Idx> {
*
*/
Idx Draw() {
if (weight[1] <= 0)
return -1;
int64_t cur = 1;
DType p = re->Uniform<DType>(0, weight[cur]);
DType accum = 0.;
double p = re->Uniform<double>(0, weight[cur]);
double accum = 0.;
while (cur < num_leafs) {
DType w_l = weight[cur * 2], w_r = weight[cur * 2 + 1];
DType pivot = accum + w_l;
// w_r > 0 can depress some numerical problems.
double w_l = weight[cur * 2], w_r = weight[cur * 2 + 1];
double pivot = accum + w_l;
// w_r > 0 can suppress some numerical problems.
Idx shift = static_cast<Idx>(p > pivot && w_r > 0);
cur = cur * 2 + shift;
if (shift == 1)
......@@ -309,7 +326,8 @@ class TreeSampler: public BaseSampler<Idx> {
if (!replace) {
while (cur >= 1) {
if (cur >= num_leafs)
weight[cur] = this->decrease ? weight[cur] - this->decrease[rst] : 0.;
weight[cur] = this->decrease ?
weight[cur] - static_cast<double>(this->decrease[rst]) : 0.;
else
weight[cur] = weight[cur * 2] + weight[cur * 2 + 1];
cur /= 2;
......
......@@ -16,6 +16,7 @@
namespace dgl {
constexpr DGLDataType DGLDataTypeTraits<int8_t>::dtype;
constexpr DGLDataType DGLDataTypeTraits<uint8_t>::dtype;
constexpr DGLDataType DGLDataTypeTraits<int16_t>::dtype;
constexpr DGLDataType DGLDataTypeTraits<int32_t>::dtype;
constexpr DGLDataType DGLDataTypeTraits<int64_t>::dtype;
......
......@@ -272,6 +272,7 @@ def _gen_neighbor_sampling_test_graph(hypersparse, reverse):
}, {'user': card if card is not None else 4})
g = g.to(F.ctx())
g.edata['prob'] = F.tensor([.5, .5, 0., .5, .5, 0., 1.], dtype=F.float32)
g.edata['mask'] = F.tensor([True, True, False, True, True, False, True])
hg = dgl.heterograph({
('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2],
[1, 2, 3, 0, 2, 3, 0]),
......@@ -286,6 +287,7 @@ def _gen_neighbor_sampling_test_graph(hypersparse, reverse):
}, {'user': card if card is not None else 4})
g = g.to(F.ctx())
g.edata['prob'] = F.tensor([.5, .5, 0., .5, .5, 0., 1.], dtype=F.float32)
g.edata['mask'] = F.tensor([True, True, False, True, True, False, True])
hg = dgl.heterograph({
('user', 'follow', 'user'): ([1, 2, 3, 0, 2, 3, 0],
[0, 0, 0, 1, 1, 1, 2]),
......@@ -295,7 +297,9 @@ def _gen_neighbor_sampling_test_graph(hypersparse, reverse):
}, num_nodes_dict)
hg = hg.to(F.ctx())
hg.edges['follow'].data['prob'] = F.tensor([.5, .5, 0., .5, .5, 0., 1.], dtype=F.float32)
hg.edges['follow'].data['mask'] = F.tensor([True, True, False, True, True, False, True])
hg.edges['play'].data['prob'] = F.tensor([.8, .5, .5, .5], dtype=F.float32)
# Leave out the mask of play and liked-by since all of them are True anyway.
hg.edges['liked-by'].data['prob'] = F.tensor([.3, .5, .2, .5, .1, .1], dtype=F.float32)
return g, hg
......@@ -344,7 +348,13 @@ def _test_sample_neighbors(hypersparse, prob):
subg = dgl.sampling.sample_neighbors(g, [0, 1], -1, prob=p, replace=replace)
assert subg.number_of_nodes() == g.number_of_nodes()
u, v = subg.edges()
u_ans, v_ans = subg.in_edges([0, 1])
u_ans, v_ans, e_ans = g.in_edges([0, 1], form='all')
if p is not None:
emask = F.gather_row(g.edata[p], e_ans)
if p == 'prob':
emask = (emask != 0)
u_ans = F.boolean_mask(u_ans, emask)
v_ans = F.boolean_mask(v_ans, emask)
uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
assert uv == uv_ans
......@@ -371,7 +381,13 @@ def _test_sample_neighbors(hypersparse, prob):
subg = dgl.sampling.sample_neighbors(g, [0, 2], -1, prob=p, replace=replace)
assert subg.number_of_nodes() == g.number_of_nodes()
u, v = subg.edges()
u_ans, v_ans = subg.in_edges([0, 2])
u_ans, v_ans, e_ans = g.in_edges([0, 2], form='all')
if p is not None:
emask = F.gather_row(g.edata[p], e_ans)
if p == 'prob':
emask = (emask != 0)
u_ans = F.boolean_mask(u_ans, emask)
v_ans = F.boolean_mask(v_ans, emask)
uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
assert uv == uv_ans
......@@ -398,7 +414,7 @@ def _test_sample_neighbors(hypersparse, prob):
subg = dgl.sampling.sample_neighbors(hg, {'user': [0, 1], 'game': 0}, -1, prob=p, replace=replace)
assert len(subg.ntypes) == 3
assert len(subg.etypes) == 4
assert subg['follow'].number_of_edges() == 6
assert subg['follow'].number_of_edges() == 6 if p is None else 4
assert subg['play'].number_of_edges() == 1
assert subg['liked-by'].number_of_edges() == 4
assert subg['flips'].number_of_edges() == 0
......@@ -436,7 +452,13 @@ def _test_sample_neighbors_outedge(hypersparse):
subg = dgl.sampling.sample_neighbors(g, [0, 1], -1, prob=p, replace=replace, edge_dir='out')
assert subg.number_of_nodes() == g.number_of_nodes()
u, v = subg.edges()
u_ans, v_ans = subg.out_edges([0, 1])
u_ans, v_ans, e_ans = g.out_edges([0, 1], form='all')
if p is not None:
emask = F.gather_row(g.edata[p], e_ans)
if p == 'prob':
emask = (emask != 0)
u_ans = F.boolean_mask(u_ans, emask)
v_ans = F.boolean_mask(v_ans, emask)
uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
assert uv == uv_ans
......@@ -465,7 +487,13 @@ def _test_sample_neighbors_outedge(hypersparse):
subg = dgl.sampling.sample_neighbors(g, [0, 2], -1, prob=p, replace=replace, edge_dir='out')
assert subg.number_of_nodes() == g.number_of_nodes()
u, v = subg.edges()
u_ans, v_ans = subg.out_edges([0, 2])
u_ans, v_ans, e_ans = g.out_edges([0, 2], form='all')
if p is not None:
emask = F.gather_row(g.edata[p], e_ans)
if p == 'prob':
emask = (emask != 0)
u_ans = F.boolean_mask(u_ans, emask)
v_ans = F.boolean_mask(v_ans, emask)
uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
assert uv == uv_ans
......@@ -494,7 +522,7 @@ def _test_sample_neighbors_outedge(hypersparse):
subg = dgl.sampling.sample_neighbors(hg, {'user': [0, 1], 'game': 0}, -1, prob=p, replace=replace, edge_dir='out')
assert len(subg.ntypes) == 3
assert len(subg.etypes) == 4
assert subg['follow'].number_of_edges() == 6
assert subg['follow'].number_of_edges() == 6 if p is None else 4
assert subg['play'].number_of_edges() == 1
assert subg['liked-by'].number_of_edges() == 4
assert subg['flips'].number_of_edges() == 0
......@@ -651,6 +679,11 @@ def test_sample_neighbors_outedge():
_test_sample_neighbors_outedge(False)
#_test_sample_neighbors_outedge(True)
@unittest.skipIf(F.backend_name == 'mxnet', reason='MXNet has problem converting bool arrays')
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors with mask not implemented")
def test_sample_neighbors_mask():
_test_sample_neighbors(False, 'mask')
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors_topk():
_test_sample_neighbors_topk(False)
......@@ -1026,12 +1059,15 @@ def test_global_uniform_negative_sampling(dtype):
if __name__ == '__main__':
from itertools import product
test_sample_neighbors_noprob()
test_sample_neighbors_prob()
test_sample_neighbors_mask()
for args in product(['coo', 'csr', 'csc'], ['in', 'out'], [False, True]):
test_sample_neighbors_etype_homogeneous(*args)
test_non_uniform_random_walk()
test_non_uniform_random_walk(False)
test_uniform_random_walk(False)
test_pack_traces()
test_pinsage_sampling()
test_pinsage_sampling(False)
test_sample_neighbors_outedge()
test_sample_neighbors_topk()
test_sample_neighbors_topk_outedge()
......
......@@ -48,7 +48,7 @@ def check_sort(spm, tag_arr=None, tag_pos=None):
# tag value is equal to `tag_pos_ptr`
return False
if tag_arr[dst[j]] > tag_arr[dst[j + 1]]:
# The tag should be in descending order after sorting
# The tag should be in ascending order after sorting
return False
if tag_pos is not None and tag_arr[dst[j]] < tag_arr[dst[j + 1]]:
if j + 1 != int(tag_pos_row[tag_pos_ptr + 1]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment