Unverified Commit 3685000a authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Bugfix] Fix pinning empty tensors and graphs (#4393)

parent 49c81795
...@@ -197,10 +197,15 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -197,10 +197,15 @@ class CUDADeviceAPI final : public DeviceAPI {
* not just the one that performed the allocation * not just the one that performed the allocation
*/ */
void PinData(void* ptr, size_t nbytes) { void PinData(void* ptr, size_t nbytes) {
// prevent users from pinning empty tensors or graphs
if (ptr == nullptr || nbytes == 0)
return;
CUDA_CALL(cudaHostRegister(ptr, nbytes, cudaHostRegisterDefault)); CUDA_CALL(cudaHostRegister(ptr, nbytes, cudaHostRegisterDefault));
} }
void UnpinData(void* ptr) { void UnpinData(void* ptr) {
if (ptr == nullptr)
return;
CUDA_CALL(cudaHostUnregister(ptr)); CUDA_CALL(cudaHostUnregister(ptr));
} }
......
...@@ -1008,7 +1008,6 @@ def test_pin_memory_(idtype): ...@@ -1008,7 +1008,6 @@ def test_pin_memory_(idtype):
g = g.to(F.cpu()) g = g.to(F.cpu())
assert not g.is_pinned() assert not g.is_pinned()
if F.is_cuda_available():
# unpin an unpinned CPU graph, directly return # unpin an unpinned CPU graph, directly return
g.unpin_memory_() g.unpin_memory_()
assert not g.is_pinned() assert not g.is_pinned()
...@@ -1018,6 +1017,9 @@ def test_pin_memory_(idtype): ...@@ -1018,6 +1017,9 @@ def test_pin_memory_(idtype):
g.pin_memory_() g.pin_memory_()
assert g.is_pinned() assert g.is_pinned()
assert g.device == F.cpu() assert g.device == F.cpu()
assert g.nodes['user'].data['h'].is_pinned()
assert g.nodes['game'].data['i'].is_pinned()
assert g.edges['plays'].data['e'].is_pinned()
assert F.context(g.nodes['user'].data['h']) == F.cpu() assert F.context(g.nodes['user'].data['h']) == F.cpu()
assert F.context(g.nodes['game'].data['i']) == F.cpu() assert F.context(g.nodes['game'].data['i']) == F.cpu()
assert F.context(g.edges['plays'].data['e']) == F.cpu() assert F.context(g.edges['plays'].data['e']) == F.cpu()
...@@ -1055,6 +1057,22 @@ def test_pin_memory_(idtype): ...@@ -1055,6 +1057,22 @@ def test_pin_memory_(idtype):
with pytest.raises(DGLError): with pytest.raises(DGLError):
g1.pin_memory_() g1.pin_memory_()
# test pin empty homograph
g2 = dgl.graph(([], []))
g2.pin_memory_()
assert g2.is_pinned()
g2.unpin_memory_()
assert not g2.is_pinned()
# test pin heterograph with 0 edge of one relation type
g3 = dgl.heterograph({
('a','b','c'): ([0, 1], [1, 2]),
('c','d','c'): ([], [])}).astype(idtype)
g3.pin_memory_()
assert g3.is_pinned()
g3.unpin_memory_()
assert not g3.is_pinned()
@parametrize_idtype @parametrize_idtype
def test_convert_bound(idtype): def test_convert_bound(idtype):
def _test_bipartite_bound(data, card): def _test_bipartite_bound(data, card):
......
...@@ -25,7 +25,7 @@ def test_pin_unpin(): ...@@ -25,7 +25,7 @@ def test_pin_unpin():
F.to_dgl_nd(t_pin).unpin_memory_() F.to_dgl_nd(t_pin).unpin_memory_()
else: else:
with pytest.raises(dgl.DGLError): with pytest.raises(dgl.DGLError):
# tensorflow and mxnet should throw an erro # tensorflow and mxnet should throw an error
dgl.utils.pin_memory_inplace(t) dgl.utils.pin_memory_inplace(t)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -57,6 +57,17 @@ def test_pin_unpin_column(): ...@@ -57,6 +57,17 @@ def test_pin_unpin_column():
assert col._data_nd is None assert col._data_nd is None
assert not g.ndata['x'].is_pinned() assert not g.ndata['x'].is_pinned()
@pytest.mark.skipif(F._default_context_str == 'cpu', reason='Need gpu for this test.')
def test_pin_empty():
t = torch.tensor([])
assert not t.is_pinned()
# Empty tensors will not be pinned or unpinned. It's a no-op.
# This is also the default behavior in PyTorch.
# We just check that it won't raise an error.
nd = dgl.utils.pin_memory_inplace(t)
assert not t.is_pinned()
if __name__ == "__main__": if __name__ == "__main__":
test_pin_noncontiguous() test_pin_noncontiguous()
test_pin_view() test_pin_view()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment