Unverified Commit c577dc9f authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

add sanity check (#4050)

parent 7ec165c2
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <vector> #include <vector>
#include <unordered_set> #include <unordered_set>
#include <numeric> #include <numeric>
#include <atomic>
#include "array_utils.h" #include "array_utils.h"
namespace dgl { namespace dgl {
...@@ -380,6 +381,10 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { ...@@ -380,6 +381,10 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
std::vector<IdType> sums; std::vector<IdType> sums;
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
bool err = false;
std::stringstream err_msg_stream;
// Perform two-round parallel prefix sum using OpenMP // Perform two-round parallel prefix sum using OpenMP
#pragma omp parallel #pragma omp parallel
{ {
...@@ -398,9 +403,17 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { ...@@ -398,9 +403,17 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
#pragma omp for schedule(static) nowait #pragma omp for schedule(static) nowait
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
int64_t rid = rows_data[i]; int64_t rid = rows_data[i];
if (rid >= csr.num_rows) {
if (!err_flag.test_and_set()) {
err_msg_stream << "expect row ID " << rid << " to be less than number of rows "
<< csr.num_rows;
err = true;
}
} else {
sum += indptr_data[rid + 1] - indptr_data[rid]; sum += indptr_data[rid + 1] - indptr_data[rid];
ret_indptr_data[i + 1] = sum; ret_indptr_data[i + 1] = sum;
} }
}
sums[tid + 1] = sum; sums[tid + 1] = sum;
#pragma omp barrier #pragma omp barrier
...@@ -417,6 +430,10 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { ...@@ -417,6 +430,10 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
for (int64_t i = 0; i < len; ++i) for (int64_t i = 0; i < len; ++i)
ret_indptr_data[i + 1] += offset; ret_indptr_data[i + 1] += offset;
} }
if (err) {
LOG(FATAL) << err_msg_stream.str();
return ret;
}
// After the prefix sum, the last element of ret_indptr_data holds the // After the prefix sum, the last element of ret_indptr_data holds the
// sum of all elements // sum of all elements
......
...@@ -427,6 +427,22 @@ def test_empty_query(idtype): ...@@ -427,6 +427,22 @@ def test_empty_query(idtype):
assert F.shape(g.in_degrees([])) == (0,) assert F.shape(g.in_degrees([])) == (0,)
assert F.shape(g.out_degrees([])) == (0,) assert F.shape(g.out_degrees([])) == (0,)
g = dgl.graph(([], []), idtype=idtype, device=F.ctx())
error_thrown = True
try:
g.in_degrees([0])
fail = False
except:
pass
assert error_thrown
error_thrown = True
try:
g.out_degrees([0])
fail = False
except:
pass
assert error_thrown
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU does not have COO impl.") @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU does not have COO impl.")
def _test_hypersparse(): def _test_hypersparse():
N1 = 1 << 50 # should crash if allocated a CSR N1 = 1 << 50 # should crash if allocated a CSR
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment