Unverified Commit 1b3f14b0 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Misc.] Avoid calling `IsPinned` in the coo/csr constructor from every sampling process (#6568)

parent 2d8d6fbb
...@@ -64,11 +64,6 @@ struct COOMatrix { ...@@ -64,11 +64,6 @@ struct COOMatrix {
data(darr), data(darr),
row_sorted(rsorted), row_sorted(rsorted),
col_sorted(csorted) { col_sorted(csorted) {
if (!IsEmpty()) {
is_pinned = (aten::IsNullArray(row) || row.IsPinned()) &&
(aten::IsNullArray(col) || col.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned());
}
CheckValidity(); CheckValidity();
} }
...@@ -134,6 +129,15 @@ struct COOMatrix { ...@@ -134,6 +129,15 @@ struct COOMatrix {
aten::IsNullArray(data); aten::IsNullArray(data);
} }
// Check and update the internal flag is_pinned.
// This function will initialize a cuda context.
inline bool CheckIfPinnedInCUDA() {
is_pinned = (aten::IsNullArray(row) || row.IsPinned()) &&
(aten::IsNullArray(col) || col.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned());
return is_pinned;
}
/** @brief Return a copy of this matrix on the give device context. */ /** @brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DGLContext& ctx) const { inline COOMatrix CopyTo(const DGLContext& ctx) const {
if (ctx == row->ctx) return *this; if (ctx == row->ctx) return *this;
...@@ -151,7 +155,7 @@ struct COOMatrix { ...@@ -151,7 +155,7 @@ struct COOMatrix {
num_rows, num_cols, row.PinMemory(), col.PinMemory(), num_rows, num_cols, row.PinMemory(), col.PinMemory(),
aten::IsNullArray(data) ? data : data.PinMemory(), row_sorted, aten::IsNullArray(data) ? data : data.PinMemory(), row_sorted,
col_sorted); col_sorted);
CHECK(new_coo.is_pinned) CHECK(new_coo.CheckIfPinnedInCUDA())
<< "An internal DGL error has occured while trying to pin a COO " << "An internal DGL error has occured while trying to pin a COO "
"matrix. Please file a bug at " "matrix. Please file a bug at "
"'https://github.com/dmlc/dgl/issues' " "'https://github.com/dmlc/dgl/issues' "
......
...@@ -60,11 +60,6 @@ struct CSRMatrix { ...@@ -60,11 +60,6 @@ struct CSRMatrix {
indices(iarr), indices(iarr),
data(darr), data(darr),
sorted(sorted_flag) { sorted(sorted_flag) {
if (!IsEmpty()) {
is_pinned = (aten::IsNullArray(indptr) || indptr.IsPinned()) &&
(aten::IsNullArray(indices) || indices.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned());
}
CheckValidity(); CheckValidity();
} }
...@@ -128,6 +123,15 @@ struct CSRMatrix { ...@@ -128,6 +123,15 @@ struct CSRMatrix {
aten::IsNullArray(data); aten::IsNullArray(data);
} }
// Check and update the internal flag is_pinned.
// This function will initialize a cuda context.
inline bool CheckIfPinnedInCUDA() {
is_pinned = (aten::IsNullArray(indptr) || indptr.IsPinned()) &&
(aten::IsNullArray(indices) || indices.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned());
return is_pinned;
}
/** @brief Return a copy of this matrix on the give device context. */ /** @brief Return a copy of this matrix on the give device context. */
inline CSRMatrix CopyTo(const DGLContext& ctx) const { inline CSRMatrix CopyTo(const DGLContext& ctx) const {
if (ctx == indptr->ctx) return *this; if (ctx == indptr->ctx) return *this;
...@@ -143,7 +147,7 @@ struct CSRMatrix { ...@@ -143,7 +147,7 @@ struct CSRMatrix {
auto new_csr = CSRMatrix( auto new_csr = CSRMatrix(
num_rows, num_cols, indptr.PinMemory(), indices.PinMemory(), num_rows, num_cols, indptr.PinMemory(), indices.PinMemory(),
aten::IsNullArray(data) ? data : data.PinMemory(), sorted); aten::IsNullArray(data) ? data : data.PinMemory(), sorted);
CHECK(new_csr.is_pinned) CHECK(new_csr.CheckIfPinnedInCUDA())
<< "An internal DGL error has occured while trying to pin a CSR " << "An internal DGL error has occured while trying to pin a CSR "
"matrix. Please file a bug at " "matrix. Please file a bug at "
"'https://github.com/dmlc/dgl/issues' " "'https://github.com/dmlc/dgl/issues' "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment