Unverified Commit e6eefd1a authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Bugfix] Fix the uninitialized pin_memory flag for empty graph (#5609)

parent 8ecbfa57
......@@ -64,9 +64,11 @@ struct COOMatrix {
data(darr),
row_sorted(rsorted),
col_sorted(csorted) {
is_pinned = (aten::IsNullArray(row) || row.IsPinned()) &&
(aten::IsNullArray(col) || col.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned());
if (!IsEmpty()) {
is_pinned = (aten::IsNullArray(row) || row.IsPinned()) &&
(aten::IsNullArray(col) || col.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned());
}
CheckValidity();
}
......@@ -127,6 +129,11 @@ struct COOMatrix {
CHECK_NO_OVERFLOW(row->dtype, num_cols);
}
inline bool IsEmpty() const {
return aten::IsNullArray(row) && aten::IsNullArray(col) &&
aten::IsNullArray(data);
}
/** @brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DGLContext& ctx) const {
if (ctx == row->ctx) return *this;
......@@ -138,16 +145,21 @@ struct COOMatrix {
/** @brief Return a copy of this matrix in pinned (page-locked) memory. */
inline COOMatrix PinMemory() {
if (is_pinned) return *this;
auto new_coo = COOMatrix(
num_rows, num_cols, row.PinMemory(), col.PinMemory(),
aten::IsNullArray(data) ? data : data.PinMemory(), row_sorted,
col_sorted);
CHECK(new_coo.is_pinned)
<< "An internal DGL error has occured while trying to pin a COO "
"matrix. Please file a bug at 'https://github.com/dmlc/dgl/issues' "
"with the above stacktrace.";
return new_coo;
if (!IsEmpty()) {
if (is_pinned) return *this;
auto new_coo = COOMatrix(
num_rows, num_cols, row.PinMemory(), col.PinMemory(),
aten::IsNullArray(data) ? data : data.PinMemory(), row_sorted,
col_sorted);
CHECK(new_coo.is_pinned)
<< "An internal DGL error has occured while trying to pin a COO "
"matrix. Please file a bug at "
"'https://github.com/dmlc/dgl/issues' "
"with the above stacktrace.";
return new_coo;
}
is_pinned = true;
return *this;
}
/**
......@@ -159,13 +171,17 @@ struct COOMatrix {
* The context check is deferred to pinning the NDArray.
*/
inline void PinMemory_() {
if (is_pinned) return;
row.PinMemory_();
col.PinMemory_();
if (!aten::IsNullArray(data)) {
data.PinMemory_();
if (!IsEmpty()) {
if (is_pinned) return;
row.PinMemory_();
col.PinMemory_();
if (!aten::IsNullArray(data)) {
data.PinMemory_();
}
is_pinned = true;
}
is_pinned = true;
return;
}
/**
......@@ -176,13 +192,17 @@ struct COOMatrix {
* The context check is deferred to unpinning the NDArray.
*/
inline void UnpinMemory_() {
if (!is_pinned) return;
row.UnpinMemory_();
col.UnpinMemory_();
if (!aten::IsNullArray(data)) {
data.UnpinMemory_();
if (!IsEmpty()) {
if (!is_pinned) return;
row.UnpinMemory_();
col.UnpinMemory_();
if (!aten::IsNullArray(data)) {
data.UnpinMemory_();
}
is_pinned = false;
}
is_pinned = false;
return;
}
/**
......
......@@ -60,9 +60,11 @@ struct CSRMatrix {
indices(iarr),
data(darr),
sorted(sorted_flag) {
is_pinned = (aten::IsNullArray(indptr) || indptr.IsPinned()) &&
(aten::IsNullArray(indices) || indices.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned());
if (!IsEmpty()) {
is_pinned = (aten::IsNullArray(indptr) || indptr.IsPinned()) &&
(aten::IsNullArray(indices) || indices.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned());
}
CheckValidity();
}
......@@ -121,6 +123,11 @@ struct CSRMatrix {
CHECK_EQ(indptr->shape[0], num_rows + 1);
}
inline bool IsEmpty() const {
return aten::IsNullArray(indptr) && aten::IsNullArray(indices) &&
aten::IsNullArray(data);
}
/** @brief Return a copy of this matrix on the give device context. */
inline CSRMatrix CopyTo(const DGLContext& ctx) const {
if (ctx == indptr->ctx) return *this;
......@@ -131,15 +138,20 @@ struct CSRMatrix {
/** @brief Return a copy of this matrix in pinned (page-locked) memory. */
inline CSRMatrix PinMemory() {
if (is_pinned) return *this;
auto new_csr = CSRMatrix(
num_rows, num_cols, indptr.PinMemory(), indices.PinMemory(),
aten::IsNullArray(data) ? data : data.PinMemory(), sorted);
CHECK(new_csr.is_pinned)
<< "An internal DGL error has occured while trying to pin a CSR "
"matrix. Please file a bug at 'https://github.com/dmlc/dgl/issues' "
"with the above stacktrace.";
return new_csr;
if (!IsEmpty()) {
if (is_pinned) return *this;
auto new_csr = CSRMatrix(
num_rows, num_cols, indptr.PinMemory(), indices.PinMemory(),
aten::IsNullArray(data) ? data : data.PinMemory(), sorted);
CHECK(new_csr.is_pinned)
<< "An internal DGL error has occured while trying to pin a CSR "
"matrix. Please file a bug at "
"'https://github.com/dmlc/dgl/issues' "
"with the above stacktrace.";
return new_csr;
}
is_pinned = true;
return *this;
}
/**
......@@ -151,13 +163,17 @@ struct CSRMatrix {
* The context check is deferred to pinning the NDArray.
*/
inline void PinMemory_() {
if (is_pinned) return;
indptr.PinMemory_();
indices.PinMemory_();
if (!aten::IsNullArray(data)) {
data.PinMemory_();
if (!IsEmpty()) {
if (is_pinned) return;
indptr.PinMemory_();
indices.PinMemory_();
if (!aten::IsNullArray(data)) {
data.PinMemory_();
}
is_pinned = true;
}
is_pinned = true;
return;
}
/**
......@@ -168,13 +184,17 @@ struct CSRMatrix {
* The context check is deferred to unpinning the NDArray.
*/
inline void UnpinMemory_() {
if (!is_pinned) return;
indptr.UnpinMemory_();
indices.UnpinMemory_();
if (!aten::IsNullArray(data)) {
data.UnpinMemory_();
if (!IsEmpty()) {
if (!is_pinned) return;
indptr.UnpinMemory_();
indices.UnpinMemory_();
if (!aten::IsNullArray(data)) {
data.UnpinMemory_();
}
is_pinned = false;
}
is_pinned = false;
return;
}
/**
......
......@@ -68,6 +68,7 @@ def test_pin_memory(idtype):
# Test pinning an empty homograph
g2 = dgl.graph(([], []))
assert not g2.is_pinned()
g2._graph = g2._graph.pin_memory()
assert g2.is_pinned()
......
......@@ -1259,6 +1259,7 @@ def test_pin_memory_(idtype):
# test pin empty homograph
g2 = dgl.graph(([], []))
assert not g2.is_pinned()
g2.pin_memory_()
assert g2.is_pinned()
g2.unpin_memory_()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment