"tests/python/vscode:/vscode.git/clone" did not exist on "109aed560f9382fc476fb558b1d9f75478d49457"
Unverified Commit e6eefd1a authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Bugfix] Fix the uninitialized pin_memory flag for empty graph (#5609)

parent 8ecbfa57
...@@ -64,9 +64,11 @@ struct COOMatrix { ...@@ -64,9 +64,11 @@ struct COOMatrix {
data(darr), data(darr),
row_sorted(rsorted), row_sorted(rsorted),
col_sorted(csorted) { col_sorted(csorted) {
if (!IsEmpty()) {
is_pinned = (aten::IsNullArray(row) || row.IsPinned()) && is_pinned = (aten::IsNullArray(row) || row.IsPinned()) &&
(aten::IsNullArray(col) || col.IsPinned()) && (aten::IsNullArray(col) || col.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned()); (aten::IsNullArray(data) || data.IsPinned());
}
CheckValidity(); CheckValidity();
} }
...@@ -127,6 +129,11 @@ struct COOMatrix { ...@@ -127,6 +129,11 @@ struct COOMatrix {
CHECK_NO_OVERFLOW(row->dtype, num_cols); CHECK_NO_OVERFLOW(row->dtype, num_cols);
} }
inline bool IsEmpty() const {
return aten::IsNullArray(row) && aten::IsNullArray(col) &&
aten::IsNullArray(data);
}
/** @brief Return a copy of this matrix on the give device context. */ /** @brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DGLContext& ctx) const { inline COOMatrix CopyTo(const DGLContext& ctx) const {
if (ctx == row->ctx) return *this; if (ctx == row->ctx) return *this;
...@@ -138,6 +145,7 @@ struct COOMatrix { ...@@ -138,6 +145,7 @@ struct COOMatrix {
/** @brief Return a copy of this matrix in pinned (page-locked) memory. */ /** @brief Return a copy of this matrix in pinned (page-locked) memory. */
inline COOMatrix PinMemory() { inline COOMatrix PinMemory() {
if (!IsEmpty()) {
if (is_pinned) return *this; if (is_pinned) return *this;
auto new_coo = COOMatrix( auto new_coo = COOMatrix(
num_rows, num_cols, row.PinMemory(), col.PinMemory(), num_rows, num_cols, row.PinMemory(), col.PinMemory(),
...@@ -145,10 +153,14 @@ struct COOMatrix { ...@@ -145,10 +153,14 @@ struct COOMatrix {
col_sorted); col_sorted);
CHECK(new_coo.is_pinned) CHECK(new_coo.is_pinned)
<< "An internal DGL error has occured while trying to pin a COO " << "An internal DGL error has occured while trying to pin a COO "
"matrix. Please file a bug at 'https://github.com/dmlc/dgl/issues' " "matrix. Please file a bug at "
"'https://github.com/dmlc/dgl/issues' "
"with the above stacktrace."; "with the above stacktrace.";
return new_coo; return new_coo;
} }
is_pinned = true;
return *this;
}
/** /**
* @brief Pin the row, col and data (if not Null) of the matrix. * @brief Pin the row, col and data (if not Null) of the matrix.
...@@ -159,6 +171,7 @@ struct COOMatrix { ...@@ -159,6 +171,7 @@ struct COOMatrix {
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
inline void PinMemory_() { inline void PinMemory_() {
if (!IsEmpty()) {
if (is_pinned) return; if (is_pinned) return;
row.PinMemory_(); row.PinMemory_();
col.PinMemory_(); col.PinMemory_();
...@@ -167,6 +180,9 @@ struct COOMatrix { ...@@ -167,6 +180,9 @@ struct COOMatrix {
} }
is_pinned = true; is_pinned = true;
} }
is_pinned = true;
return;
}
/** /**
* @brief Unpin the row, col and data (if not Null) of the matrix. * @brief Unpin the row, col and data (if not Null) of the matrix.
...@@ -176,6 +192,7 @@ struct COOMatrix { ...@@ -176,6 +192,7 @@ struct COOMatrix {
* The context check is deferred to unpinning the NDArray. * The context check is deferred to unpinning the NDArray.
*/ */
inline void UnpinMemory_() { inline void UnpinMemory_() {
if (!IsEmpty()) {
if (!is_pinned) return; if (!is_pinned) return;
row.UnpinMemory_(); row.UnpinMemory_();
col.UnpinMemory_(); col.UnpinMemory_();
...@@ -184,6 +201,9 @@ struct COOMatrix { ...@@ -184,6 +201,9 @@ struct COOMatrix {
} }
is_pinned = false; is_pinned = false;
} }
is_pinned = false;
return;
}
/** /**
* @brief Record stream for the row, col and data (if not Null) of the matrix. * @brief Record stream for the row, col and data (if not Null) of the matrix.
......
...@@ -60,9 +60,11 @@ struct CSRMatrix { ...@@ -60,9 +60,11 @@ struct CSRMatrix {
indices(iarr), indices(iarr),
data(darr), data(darr),
sorted(sorted_flag) { sorted(sorted_flag) {
if (!IsEmpty()) {
is_pinned = (aten::IsNullArray(indptr) || indptr.IsPinned()) && is_pinned = (aten::IsNullArray(indptr) || indptr.IsPinned()) &&
(aten::IsNullArray(indices) || indices.IsPinned()) && (aten::IsNullArray(indices) || indices.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned()); (aten::IsNullArray(data) || data.IsPinned());
}
CheckValidity(); CheckValidity();
} }
...@@ -121,6 +123,11 @@ struct CSRMatrix { ...@@ -121,6 +123,11 @@ struct CSRMatrix {
CHECK_EQ(indptr->shape[0], num_rows + 1); CHECK_EQ(indptr->shape[0], num_rows + 1);
} }
inline bool IsEmpty() const {
return aten::IsNullArray(indptr) && aten::IsNullArray(indices) &&
aten::IsNullArray(data);
}
/** @brief Return a copy of this matrix on the give device context. */ /** @brief Return a copy of this matrix on the give device context. */
inline CSRMatrix CopyTo(const DGLContext& ctx) const { inline CSRMatrix CopyTo(const DGLContext& ctx) const {
if (ctx == indptr->ctx) return *this; if (ctx == indptr->ctx) return *this;
...@@ -131,16 +138,21 @@ struct CSRMatrix { ...@@ -131,16 +138,21 @@ struct CSRMatrix {
/** @brief Return a copy of this matrix in pinned (page-locked) memory. */ /** @brief Return a copy of this matrix in pinned (page-locked) memory. */
inline CSRMatrix PinMemory() { inline CSRMatrix PinMemory() {
if (!IsEmpty()) {
if (is_pinned) return *this; if (is_pinned) return *this;
auto new_csr = CSRMatrix( auto new_csr = CSRMatrix(
num_rows, num_cols, indptr.PinMemory(), indices.PinMemory(), num_rows, num_cols, indptr.PinMemory(), indices.PinMemory(),
aten::IsNullArray(data) ? data : data.PinMemory(), sorted); aten::IsNullArray(data) ? data : data.PinMemory(), sorted);
CHECK(new_csr.is_pinned) CHECK(new_csr.is_pinned)
<< "An internal DGL error has occured while trying to pin a CSR " << "An internal DGL error has occured while trying to pin a CSR "
"matrix. Please file a bug at 'https://github.com/dmlc/dgl/issues' " "matrix. Please file a bug at "
"'https://github.com/dmlc/dgl/issues' "
"with the above stacktrace."; "with the above stacktrace.";
return new_csr; return new_csr;
} }
is_pinned = true;
return *this;
}
/** /**
* @brief Pin the indptr, indices and data (if not Null) of the matrix. * @brief Pin the indptr, indices and data (if not Null) of the matrix.
...@@ -151,6 +163,7 @@ struct CSRMatrix { ...@@ -151,6 +163,7 @@ struct CSRMatrix {
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
inline void PinMemory_() { inline void PinMemory_() {
if (!IsEmpty()) {
if (is_pinned) return; if (is_pinned) return;
indptr.PinMemory_(); indptr.PinMemory_();
indices.PinMemory_(); indices.PinMemory_();
...@@ -159,6 +172,9 @@ struct CSRMatrix { ...@@ -159,6 +172,9 @@ struct CSRMatrix {
} }
is_pinned = true; is_pinned = true;
} }
is_pinned = true;
return;
}
/** /**
* @brief Unpin the indptr, indices and data (if not Null) of the matrix. * @brief Unpin the indptr, indices and data (if not Null) of the matrix.
...@@ -168,6 +184,7 @@ struct CSRMatrix { ...@@ -168,6 +184,7 @@ struct CSRMatrix {
* The context check is deferred to unpinning the NDArray. * The context check is deferred to unpinning the NDArray.
*/ */
inline void UnpinMemory_() { inline void UnpinMemory_() {
if (!IsEmpty()) {
if (!is_pinned) return; if (!is_pinned) return;
indptr.UnpinMemory_(); indptr.UnpinMemory_();
indices.UnpinMemory_(); indices.UnpinMemory_();
...@@ -176,6 +193,9 @@ struct CSRMatrix { ...@@ -176,6 +193,9 @@ struct CSRMatrix {
} }
is_pinned = false; is_pinned = false;
} }
is_pinned = false;
return;
}
/** /**
* @brief Record stream for the indptr, indices and data (if not Null) of the * @brief Record stream for the indptr, indices and data (if not Null) of the
......
...@@ -68,6 +68,7 @@ def test_pin_memory(idtype): ...@@ -68,6 +68,7 @@ def test_pin_memory(idtype):
# Test pinning an empty homograph # Test pinning an empty homograph
g2 = dgl.graph(([], [])) g2 = dgl.graph(([], []))
assert not g2.is_pinned()
g2._graph = g2._graph.pin_memory() g2._graph = g2._graph.pin_memory()
assert g2.is_pinned() assert g2.is_pinned()
......
...@@ -1259,6 +1259,7 @@ def test_pin_memory_(idtype): ...@@ -1259,6 +1259,7 @@ def test_pin_memory_(idtype):
# test pin empty homograph # test pin empty homograph
g2 = dgl.graph(([], [])) g2 = dgl.graph(([], []))
assert not g2.is_pinned()
g2.pin_memory_() g2.pin_memory_()
assert g2.is_pinned() assert g2.is_pinned()
g2.unpin_memory_() g2.unpin_memory_()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment