Unverified Commit e6eefd1a authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Bugfix] Fix the uninitialized pin_memory flag for empty graph (#5609)

parent 8ecbfa57
...@@ -64,9 +64,11 @@ struct COOMatrix { ...@@ -64,9 +64,11 @@ struct COOMatrix {
data(darr), data(darr),
row_sorted(rsorted), row_sorted(rsorted),
col_sorted(csorted) { col_sorted(csorted) {
is_pinned = (aten::IsNullArray(row) || row.IsPinned()) && if (!IsEmpty()) {
(aten::IsNullArray(col) || col.IsPinned()) && is_pinned = (aten::IsNullArray(row) || row.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned()); (aten::IsNullArray(col) || col.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned());
}
CheckValidity(); CheckValidity();
} }
...@@ -127,6 +129,11 @@ struct COOMatrix { ...@@ -127,6 +129,11 @@ struct COOMatrix {
CHECK_NO_OVERFLOW(row->dtype, num_cols); CHECK_NO_OVERFLOW(row->dtype, num_cols);
} }
inline bool IsEmpty() const {
return aten::IsNullArray(row) && aten::IsNullArray(col) &&
aten::IsNullArray(data);
}
/** @brief Return a copy of this matrix on the give device context. */ /** @brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DGLContext& ctx) const { inline COOMatrix CopyTo(const DGLContext& ctx) const {
if (ctx == row->ctx) return *this; if (ctx == row->ctx) return *this;
...@@ -138,16 +145,21 @@ struct COOMatrix { ...@@ -138,16 +145,21 @@ struct COOMatrix {
/** @brief Return a copy of this matrix in pinned (page-locked) memory. */ /** @brief Return a copy of this matrix in pinned (page-locked) memory. */
inline COOMatrix PinMemory() { inline COOMatrix PinMemory() {
if (is_pinned) return *this; if (!IsEmpty()) {
auto new_coo = COOMatrix( if (is_pinned) return *this;
num_rows, num_cols, row.PinMemory(), col.PinMemory(), auto new_coo = COOMatrix(
aten::IsNullArray(data) ? data : data.PinMemory(), row_sorted, num_rows, num_cols, row.PinMemory(), col.PinMemory(),
col_sorted); aten::IsNullArray(data) ? data : data.PinMemory(), row_sorted,
CHECK(new_coo.is_pinned) col_sorted);
<< "An internal DGL error has occured while trying to pin a COO " CHECK(new_coo.is_pinned)
"matrix. Please file a bug at 'https://github.com/dmlc/dgl/issues' " << "An internal DGL error has occured while trying to pin a COO "
"with the above stacktrace."; "matrix. Please file a bug at "
return new_coo; "'https://github.com/dmlc/dgl/issues' "
"with the above stacktrace.";
return new_coo;
}
is_pinned = true;
return *this;
} }
/** /**
...@@ -159,13 +171,17 @@ struct COOMatrix { ...@@ -159,13 +171,17 @@ struct COOMatrix {
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
inline void PinMemory_() { inline void PinMemory_() {
if (is_pinned) return; if (!IsEmpty()) {
row.PinMemory_(); if (is_pinned) return;
col.PinMemory_(); row.PinMemory_();
if (!aten::IsNullArray(data)) { col.PinMemory_();
data.PinMemory_(); if (!aten::IsNullArray(data)) {
data.PinMemory_();
}
is_pinned = true;
} }
is_pinned = true; is_pinned = true;
return;
} }
/** /**
...@@ -176,13 +192,17 @@ struct COOMatrix { ...@@ -176,13 +192,17 @@ struct COOMatrix {
* The context check is deferred to unpinning the NDArray. * The context check is deferred to unpinning the NDArray.
*/ */
inline void UnpinMemory_() { inline void UnpinMemory_() {
if (!is_pinned) return; if (!IsEmpty()) {
row.UnpinMemory_(); if (!is_pinned) return;
col.UnpinMemory_(); row.UnpinMemory_();
if (!aten::IsNullArray(data)) { col.UnpinMemory_();
data.UnpinMemory_(); if (!aten::IsNullArray(data)) {
data.UnpinMemory_();
}
is_pinned = false;
} }
is_pinned = false; is_pinned = false;
return;
} }
/** /**
......
...@@ -60,9 +60,11 @@ struct CSRMatrix { ...@@ -60,9 +60,11 @@ struct CSRMatrix {
indices(iarr), indices(iarr),
data(darr), data(darr),
sorted(sorted_flag) { sorted(sorted_flag) {
is_pinned = (aten::IsNullArray(indptr) || indptr.IsPinned()) && if (!IsEmpty()) {
(aten::IsNullArray(indices) || indices.IsPinned()) && is_pinned = (aten::IsNullArray(indptr) || indptr.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned()); (aten::IsNullArray(indices) || indices.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned());
}
CheckValidity(); CheckValidity();
} }
...@@ -121,6 +123,11 @@ struct CSRMatrix { ...@@ -121,6 +123,11 @@ struct CSRMatrix {
CHECK_EQ(indptr->shape[0], num_rows + 1); CHECK_EQ(indptr->shape[0], num_rows + 1);
} }
inline bool IsEmpty() const {
return aten::IsNullArray(indptr) && aten::IsNullArray(indices) &&
aten::IsNullArray(data);
}
/** @brief Return a copy of this matrix on the give device context. */ /** @brief Return a copy of this matrix on the give device context. */
inline CSRMatrix CopyTo(const DGLContext& ctx) const { inline CSRMatrix CopyTo(const DGLContext& ctx) const {
if (ctx == indptr->ctx) return *this; if (ctx == indptr->ctx) return *this;
...@@ -131,15 +138,20 @@ struct CSRMatrix { ...@@ -131,15 +138,20 @@ struct CSRMatrix {
/** @brief Return a copy of this matrix in pinned (page-locked) memory. */ /** @brief Return a copy of this matrix in pinned (page-locked) memory. */
inline CSRMatrix PinMemory() { inline CSRMatrix PinMemory() {
if (is_pinned) return *this; if (!IsEmpty()) {
auto new_csr = CSRMatrix( if (is_pinned) return *this;
num_rows, num_cols, indptr.PinMemory(), indices.PinMemory(), auto new_csr = CSRMatrix(
aten::IsNullArray(data) ? data : data.PinMemory(), sorted); num_rows, num_cols, indptr.PinMemory(), indices.PinMemory(),
CHECK(new_csr.is_pinned) aten::IsNullArray(data) ? data : data.PinMemory(), sorted);
<< "An internal DGL error has occured while trying to pin a CSR " CHECK(new_csr.is_pinned)
"matrix. Please file a bug at 'https://github.com/dmlc/dgl/issues' " << "An internal DGL error has occured while trying to pin a CSR "
"with the above stacktrace."; "matrix. Please file a bug at "
return new_csr; "'https://github.com/dmlc/dgl/issues' "
"with the above stacktrace.";
return new_csr;
}
is_pinned = true;
return *this;
} }
/** /**
...@@ -151,13 +163,17 @@ struct CSRMatrix { ...@@ -151,13 +163,17 @@ struct CSRMatrix {
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
inline void PinMemory_() { inline void PinMemory_() {
if (is_pinned) return; if (!IsEmpty()) {
indptr.PinMemory_(); if (is_pinned) return;
indices.PinMemory_(); indptr.PinMemory_();
if (!aten::IsNullArray(data)) { indices.PinMemory_();
data.PinMemory_(); if (!aten::IsNullArray(data)) {
data.PinMemory_();
}
is_pinned = true;
} }
is_pinned = true; is_pinned = true;
return;
} }
/** /**
...@@ -168,13 +184,17 @@ struct CSRMatrix { ...@@ -168,13 +184,17 @@ struct CSRMatrix {
* The context check is deferred to unpinning the NDArray. * The context check is deferred to unpinning the NDArray.
*/ */
inline void UnpinMemory_() { inline void UnpinMemory_() {
if (!is_pinned) return; if (!IsEmpty()) {
indptr.UnpinMemory_(); if (!is_pinned) return;
indices.UnpinMemory_(); indptr.UnpinMemory_();
if (!aten::IsNullArray(data)) { indices.UnpinMemory_();
data.UnpinMemory_(); if (!aten::IsNullArray(data)) {
data.UnpinMemory_();
}
is_pinned = false;
} }
is_pinned = false; is_pinned = false;
return;
} }
/** /**
......
...@@ -68,6 +68,7 @@ def test_pin_memory(idtype): ...@@ -68,6 +68,7 @@ def test_pin_memory(idtype):
# Test pinning an empty homograph # Test pinning an empty homograph
g2 = dgl.graph(([], [])) g2 = dgl.graph(([], []))
assert not g2.is_pinned()
g2._graph = g2._graph.pin_memory() g2._graph = g2._graph.pin_memory()
assert g2.is_pinned() assert g2.is_pinned()
......
...@@ -1259,6 +1259,7 @@ def test_pin_memory_(idtype): ...@@ -1259,6 +1259,7 @@ def test_pin_memory_(idtype):
# test pin empty homograph # test pin empty homograph
g2 = dgl.graph(([], [])) g2 = dgl.graph(([], []))
assert not g2.is_pinned()
g2.pin_memory_() g2.pin_memory_()
assert g2.is_pinned() assert g2.is_pinned()
g2.unpin_memory_() g2.unpin_memory_()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment