Unverified Commit 3685000a authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Bugfix] Fix pinning empty tensors and graphs (#4393)

parent 49c81795
...@@ -197,10 +197,15 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -197,10 +197,15 @@ class CUDADeviceAPI final : public DeviceAPI {
* not just the one that performed the allocation * not just the one that performed the allocation
*/ */
void PinData(void* ptr, size_t nbytes) { void PinData(void* ptr, size_t nbytes) {
// prevent users from pinning empty tensors or graphs
if (ptr == nullptr || nbytes == 0)
return;
CUDA_CALL(cudaHostRegister(ptr, nbytes, cudaHostRegisterDefault)); CUDA_CALL(cudaHostRegister(ptr, nbytes, cudaHostRegisterDefault));
} }
void UnpinData(void* ptr) { void UnpinData(void* ptr) {
if (ptr == nullptr)
return;
CUDA_CALL(cudaHostUnregister(ptr)); CUDA_CALL(cudaHostUnregister(ptr));
} }
......
...@@ -1008,52 +1008,70 @@ def test_pin_memory_(idtype): ...@@ -1008,52 +1008,70 @@ def test_pin_memory_(idtype):
g = g.to(F.cpu()) g = g.to(F.cpu())
assert not g.is_pinned() assert not g.is_pinned()
if F.is_cuda_available(): # unpin an unpinned CPU graph, directly return
# unpin an unpinned CPU graph, directly return g.unpin_memory_()
g.unpin_memory_() assert not g.is_pinned()
assert not g.is_pinned() assert g.device == F.cpu()
assert g.device == F.cpu()
# pin a CPU graph # pin a CPU graph
g.pin_memory_() g.pin_memory_()
assert g.is_pinned() assert g.is_pinned()
assert g.device == F.cpu() assert g.device == F.cpu()
assert F.context(g.nodes['user'].data['h']) == F.cpu() assert g.nodes['user'].data['h'].is_pinned()
assert F.context(g.nodes['game'].data['i']) == F.cpu() assert g.nodes['game'].data['i'].is_pinned()
assert F.context(g.edges['plays'].data['e']) == F.cpu() assert g.edges['plays'].data['e'].is_pinned()
for ntype in g.ntypes: assert F.context(g.nodes['user'].data['h']) == F.cpu()
assert F.context(g.batch_num_nodes(ntype)) == F.cpu() assert F.context(g.nodes['game'].data['i']) == F.cpu()
for etype in g.canonical_etypes: assert F.context(g.edges['plays'].data['e']) == F.cpu()
assert F.context(g.batch_num_edges(etype)) == F.cpu() for ntype in g.ntypes:
assert F.context(g.batch_num_nodes(ntype)) == F.cpu()
for etype in g.canonical_etypes:
assert F.context(g.batch_num_edges(etype)) == F.cpu()
# it's fine to clone with new formats, but new graphs are not pinned # it's fine to clone with new formats, but new graphs are not pinned
# >>> g.formats() # >>> g.formats()
# {'created': ['coo'], 'not created': ['csr', 'csc']} # {'created': ['coo'], 'not created': ['csr', 'csc']}
assert not g.formats('csc').is_pinned() assert not g.formats('csc').is_pinned()
assert not g.formats('csr').is_pinned() assert not g.formats('csr').is_pinned()
# 'coo' formats is already created and thus not cloned # 'coo' formats is already created and thus not cloned
assert g.formats('coo').is_pinned() assert g.formats('coo').is_pinned()
# pin a pinned graph, directly return # pin a pinned graph, directly return
g.pin_memory_() g.pin_memory_()
assert g.is_pinned() assert g.is_pinned()
assert g.device == F.cpu() assert g.device == F.cpu()
# unpin a pinned graph # unpin a pinned graph
g.unpin_memory_() g.unpin_memory_()
assert not g.is_pinned() assert not g.is_pinned()
assert g.device == F.cpu() assert g.device == F.cpu()
g1 = g.to(F.cuda()) g1 = g.to(F.cuda())
# unpin an unpinned GPU graph, directly return # unpin an unpinned GPU graph, directly return
g1.unpin_memory_() g1.unpin_memory_()
assert not g1.is_pinned() assert not g1.is_pinned()
assert g1.device == F.cuda() assert g1.device == F.cuda()
# error pinning a GPU graph # error pinning a GPU graph
with pytest.raises(DGLError): with pytest.raises(DGLError):
g1.pin_memory_() g1.pin_memory_()
# test pin empty homograph
g2 = dgl.graph(([], []))
g2.pin_memory_()
assert g2.is_pinned()
g2.unpin_memory_()
assert not g2.is_pinned()
# test pin heterograph with 0 edge of one relation type
g3 = dgl.heterograph({
('a','b','c'): ([0, 1], [1, 2]),
('c','d','c'): ([], [])}).astype(idtype)
g3.pin_memory_()
assert g3.is_pinned()
g3.unpin_memory_()
assert not g3.is_pinned()
@parametrize_idtype @parametrize_idtype
def test_convert_bound(idtype): def test_convert_bound(idtype):
......
...@@ -25,7 +25,7 @@ def test_pin_unpin(): ...@@ -25,7 +25,7 @@ def test_pin_unpin():
F.to_dgl_nd(t_pin).unpin_memory_() F.to_dgl_nd(t_pin).unpin_memory_()
else: else:
with pytest.raises(dgl.DGLError): with pytest.raises(dgl.DGLError):
# tensorflow and mxnet should throw an erro # tensorflow and mxnet should throw an error
dgl.utils.pin_memory_inplace(t) dgl.utils.pin_memory_inplace(t)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -57,6 +57,17 @@ def test_pin_unpin_column(): ...@@ -57,6 +57,17 @@ def test_pin_unpin_column():
assert col._data_nd is None assert col._data_nd is None
assert not g.ndata['x'].is_pinned() assert not g.ndata['x'].is_pinned()
@pytest.mark.skipif(F._default_context_str == 'cpu', reason='Need gpu for this test.')
def test_pin_empty():
t = torch.tensor([])
assert not t.is_pinned()
# Empty tensors will not be pinned or unpinned. It's a no-op.
# This is also the default behavior in PyTorch.
# We just check that it won't raise an error.
nd = dgl.utils.pin_memory_inplace(t)
assert not t.is_pinned()
if __name__ == "__main__": if __name__ == "__main__":
test_pin_noncontiguous() test_pin_noncontiguous()
test_pin_view() test_pin_view()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment