Unverified Commit 1b3f14b0 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Misc.] Avoid calling `IsPinned` in the coo/csr constructor from every sampling process (#6568)

parent 2d8d6fbb
......@@ -64,11 +64,6 @@ struct COOMatrix {
data(darr),
row_sorted(rsorted),
col_sorted(csorted) {
if (!IsEmpty()) {
is_pinned = (aten::IsNullArray(row) || row.IsPinned()) &&
(aten::IsNullArray(col) || col.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned());
}
CheckValidity();
}
......@@ -134,6 +129,15 @@ struct COOMatrix {
aten::IsNullArray(data);
}
// Check and update the internal flag is_pinned.
// This function will initialize a cuda context.
inline bool CheckIfPinnedInCUDA() {
is_pinned = (aten::IsNullArray(row) || row.IsPinned()) &&
(aten::IsNullArray(col) || col.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned());
return is_pinned;
}
/** @brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DGLContext& ctx) const {
if (ctx == row->ctx) return *this;
......@@ -151,7 +155,7 @@ struct COOMatrix {
num_rows, num_cols, row.PinMemory(), col.PinMemory(),
aten::IsNullArray(data) ? data : data.PinMemory(), row_sorted,
col_sorted);
CHECK(new_coo.is_pinned)
CHECK(new_coo.CheckIfPinnedInCUDA())
<< "An internal DGL error has occured while trying to pin a COO "
"matrix. Please file a bug at "
"'https://github.com/dmlc/dgl/issues' "
......
......@@ -60,11 +60,6 @@ struct CSRMatrix {
indices(iarr),
data(darr),
sorted(sorted_flag) {
if (!IsEmpty()) {
is_pinned = (aten::IsNullArray(indptr) || indptr.IsPinned()) &&
(aten::IsNullArray(indices) || indices.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned());
}
CheckValidity();
}
......@@ -128,6 +123,15 @@ struct CSRMatrix {
aten::IsNullArray(data);
}
// Check and update the internal flag is_pinned.
// This function will initialize a cuda context.
inline bool CheckIfPinnedInCUDA() {
is_pinned = (aten::IsNullArray(indptr) || indptr.IsPinned()) &&
(aten::IsNullArray(indices) || indices.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned());
return is_pinned;
}
/** @brief Return a copy of this matrix on the give device context. */
inline CSRMatrix CopyTo(const DGLContext& ctx) const {
if (ctx == indptr->ctx) return *this;
......@@ -143,7 +147,7 @@ struct CSRMatrix {
auto new_csr = CSRMatrix(
num_rows, num_cols, indptr.PinMemory(), indices.PinMemory(),
aten::IsNullArray(data) ? data : data.PinMemory(), sorted);
CHECK(new_csr.is_pinned)
CHECK(new_csr.CheckIfPinnedInCUDA())
<< "An internal DGL error has occured while trying to pin a CSR "
"matrix. Please file a bug at "
"'https://github.com/dmlc/dgl/issues' "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment