Unverified Commit 077e002f authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Bugfix][Rework] Automatically unpin tensors pinned by DGL (rework #3997) (#4135)



* Explicitly unpin tensoradapter allocated arrays

* Undo unrelated change

* Add unit test

* update unit test

* add pinned_by_dgl flag to NDArray::Container

* use dgl.ndarray for holding the pinning status

* update multi-gpu uva inference

* reinterpret cast NDArray::Container* to DLTensor* in MoveAsDLTensor

* update unpin column and examples

* add unit test for unpin column
Co-authored-by: default avatarDominique LaSalle <dlasalle@nvidia.com>
Co-authored-by: default avatarnv-dlasalle <63612878+nv-dlasalle@users.noreply.github.com>
parent 1ad65879
......@@ -6,7 +6,6 @@ import torch.distributed.optim
import torchmetrics.functional as MF
import dgl
import dgl.nn as dglnn
from dgl.utils import pin_memory_inplace, unpin_memory_inplace
from dgl.multiprocessing import shared_tensor
import time
import numpy as np
......@@ -64,18 +63,16 @@ class SAGE(nn.Module):
"""
g.ndata['h'] = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
dataloader = dgl.dataloading.DataLoader(
for l, layer in enumerate(self.layers):
dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.num_nodes(), device=device), sampler, device=device,
batch_size=batch_size, shuffle=False, drop_last=False,
num_workers=0, use_ddp=True, use_uva=True)
for l, layer in enumerate(self.layers):
# in order to prevent running out of GPU memory, we allocate a
# shared output tensor 'y' in host memory, pin it to allow UVA
# access from each GPU during forward propagation.
# shared output tensor 'y' in host memory
y = shared_tensor(
(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes))
pin_memory_inplace(y)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader) \
if dist.get_rank() == 0 else dataloader:
......@@ -84,8 +81,6 @@ class SAGE(nn.Module):
y[output_nodes] = h.to(y.device)
# make sure all GPUs are done writing to 'y'
dist.barrier()
if l > 0:
unpin_memory_inplace(g.ndata['h'])
if l + 1 < len(self.layers):
# assign the output features of this layer as the new input
# features for the next layer
......
......@@ -172,8 +172,8 @@ class NDArray {
*/
inline NDArray Clone(const DGLStreamHandle &stream = nullptr) const;
/*!
* \brief In-place method to pin the current array by calling PinData
* on the underlying DLTensor.
* \brief In-place method to pin the current array by calling PinContainer
* on the underlying NDArray:Container.
* \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned;
* IsPinned: directly return;
......@@ -181,8 +181,8 @@ class NDArray {
*/
inline void PinMemory_();
/*!
* \brief In-place method to unpin the current array by calling UnpinData
* on the underlying DLTensor.
* \brief In-place method to unpin the current array by calling UnpinContainer
* on the underlying NDArray:Container.
* \note This is an in-place method. Behavior depends on the current context,
* IsPinned: will be unpinned;
* others: directly return.
......@@ -294,32 +294,32 @@ class NDArray {
DLTensor* from, DLTensor* to, DGLStreamHandle stream = nullptr);
/*!
* \brief Function to pin the data of a DLTensor.
* \param tensor The array to be pinned.
* \brief Function to pin the DLTensor of a Container.
* \param ptr The container to be pinned.
* \note Data of the given array will be pinned inplace.
* Behavior depends on the current context,
* kDLCPU: will be pinned;
* IsPinned: directly return;
* kDLGPU: invalid, will throw an error.
*/
DGL_DLL static void PinData(DLTensor* tensor);
DGL_DLL static void PinContainer(Container* ptr);
/*!
* \brief Function to unpin the data of a DLTensor.
* \param tensor The array to be unpinned.
* \brief Function to unpin the DLTensor of a Container.
* \param ptr The container to be unpinned.
* \note Data of the given array will be unpinned inplace.
* Behavior depends on the current context,
* IsPinned: will be unpinned;
* others: directly return.
*/
DGL_DLL static void UnpinData(DLTensor* tensor);
DGL_DLL static void UnpinContainer(Container* ptr);
/*!
* \brief Function check if the data of a DLTensor is pinned.
* \param tensor The array to be checked.
* \brief Function check if the DLTensor of a Container is pinned.
* \param ptr The container to be checked.
* \return true if pinned.
*/
DGL_DLL static bool IsDataPinned(DLTensor* tensor);
DGL_DLL static bool IsContainerPinned(Container* ptr);
// internal namespace
struct Internal;
......@@ -361,8 +361,6 @@ struct NDArray::Container {
* The head ptr of this struct can be viewed as DLTensor*.
*/
DLTensor dl_tensor;
std::shared_ptr<SharedMemory> mem;
/*!
* \brief addtional context, reserved for recycling
* \note We can attach additional content here
......@@ -386,6 +384,8 @@ struct NDArray::Container {
dl_tensor.strides = nullptr;
dl_tensor.byte_offset = 0;
}
/*! \brief pointer to shared memory */
std::shared_ptr<SharedMemory> mem;
/*! \brief developer function, increases reference counter */
void IncRef() {
ref_counter_.fetch_add(1, std::memory_order_relaxed);
......@@ -415,6 +415,8 @@ struct NDArray::Container {
std::vector<int64_t> stride_;
/*! \brief The internal array object */
std::atomic<int> ref_counter_{0};
bool pinned_by_dgl_{false};
};
// implementations of inline functions
......@@ -482,17 +484,17 @@ inline NDArray NDArray::Clone(const DGLStreamHandle &stream) const {
inline void NDArray::PinMemory_() {
CHECK(data_ != nullptr);
PinData(&(data_->dl_tensor));
PinContainer(data_);
}
inline void NDArray::UnpinMemory_() {
CHECK(data_ != nullptr);
UnpinData(&(data_->dl_tensor));
UnpinContainer(data_);
}
inline bool NDArray::IsPinned() const {
CHECK(data_ != nullptr);
return IsDataPinned(&(data_->dl_tensor));
return IsContainerPinned(data_);
}
inline int NDArray::use_count() const {
......
......@@ -8,7 +8,7 @@ from . import backend as F
from .base import DGLError, dgl_warning
from .init import zero_initializer
from .storages import TensorStorage
from .utils import gather_pinned_tensor_rows, pin_memory_inplace, unpin_memory_inplace
from .utils import gather_pinned_tensor_rows, pin_memory_inplace
class _LazyIndex(object):
def __init__(self, index):
......@@ -189,6 +189,7 @@ class Column(TensorStorage):
self.device = device
self.deferred_dtype = deferred_dtype
self.pinned_by_dgl = False
self._data_nd = None
def __len__(self):
"""The number of features (number of rows) in this column."""
......@@ -243,6 +244,7 @@ class Column(TensorStorage):
"""Update the column data."""
self.index = None
self.storage = val
self._data_nd = None # should unpin data if it was pinned.
self.pinned_by_dgl = False
def to(self, device, **kwargs): # pylint: disable=invalid-name
......@@ -444,7 +446,7 @@ class Column(TensorStorage):
Does nothing if the storage is already pinned.
"""
if not self.pinned_by_dgl and not F.is_pinned(self.data):
pin_memory_inplace(self.data)
self._data_nd = pin_memory_inplace(self.data)
self.pinned_by_dgl = True
def unpin_memory_(self):
......@@ -454,7 +456,8 @@ class Column(TensorStorage):
it is actually in page-locked memory.
"""
if self.pinned_by_dgl:
unpin_memory_inplace(self.data)
self._data_nd.unpin_memory_()
self._data_nd = None
self.pinned_by_dgl = False
class Frame(MutableMapping):
......
......@@ -5545,11 +5545,10 @@ class DGLHeteroGraph(object):
>>> g.in_degrees()
tensor([0, 1, 1])
"""
if self._graph.is_pinned():
return self
if F.device_type(self.device) != 'cpu':
raise DGLError("The graph structure must be on CPU to be pinned.")
self._graph.pin_memory_()
if not self._graph.is_pinned():
if F.device_type(self.device) != 'cpu':
raise DGLError("The graph structure must be on CPU to be pinned.")
self._graph.pin_memory_()
for frame in itertools.chain(self._node_frames, self._edge_frames):
for col in frame._columns.values():
col.pin_memory_()
......@@ -5567,9 +5566,8 @@ class DGLHeteroGraph(object):
DGLGraph
The unpinned graph.
"""
if not self._graph.is_pinned():
return self
self._graph.unpin_memory_()
if self._graph.is_pinned():
self._graph.unpin_memory_()
for frame in itertools.chain(self._node_frames, self._edge_frames):
for col in frame._columns.values():
col.unpin_memory_()
......
......@@ -5,25 +5,32 @@ from .. import backend as F
from .._ffi.function import _init_api
def pin_memory_inplace(tensor):
"""Register the tensor into pinned memory in-place (i.e. without copying)."""
"""Register the tensor into pinned memory in-place (i.e. without copying).
Users are required to save the returned dgl.ndarray object to avoid being unpinned.
Parameters
----------
tensor : Tensor
The tensor to be pinned.
Returns
-------
dgl.ndarray
The dgl.ndarray object that holds the pinning status and shares the same
underlying data with the tensor.
"""
if F.backend_name in ['mxnet', 'tensorflow']:
raise DGLError("The {} backend does not support pinning " \
"tensors in-place.".format(F.backend_name))
# needs to be writable to allow in-place modification
try:
F.zerocopy_to_dgl_ndarray_for_write(tensor).pin_memory_()
nd_array = F.zerocopy_to_dgl_ndarray_for_write(tensor)
nd_array.pin_memory_()
return nd_array
except Exception as e:
raise DGLError("Failed to pin memory in-place due to: {}".format(e))
def unpin_memory_inplace(tensor):
"""Unregister the tensor from pinned memory in-place (i.e. without copying)."""
# needs to be writable to allow in-place modification
try:
F.zerocopy_to_dgl_ndarray_for_write(tensor).unpin_memory_()
except Exception as e:
raise DGLError("Failed to unpin memory in-place due to: {}".format(e))
def gather_pinned_tensor_rows(tensor, rows):
"""Directly gather rows from a CPU tensor given an indices array on CUDA devices,
and returns the result on the same CUDA device without copying.
......
......@@ -64,9 +64,7 @@ struct NDArray::Internal {
ptr->mem = nullptr;
} else if (ptr->dl_tensor.data != nullptr) {
// if the array is still pinned before freeing, unpin it.
if (IsDataPinned(&(ptr->dl_tensor))) {
UnpinData(&(ptr->dl_tensor));
}
UnpinContainer(ptr);
dgl::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)->FreeDataSpace(
ptr->dl_tensor.ctx, ptr->dl_tensor.data);
}
......@@ -78,6 +76,9 @@ struct NDArray::Internal {
// This enables us to create NDArray from memory allocated by other
// frameworks that are DLPack compatible
static void DLPackDeleter(NDArray::Container* ptr) {
// if the array is pinned by dgl, unpin it before freeing
if (ptr->pinned_by_dgl_)
UnpinContainer(ptr);
DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr->manager_ctx);
if (tensor->deleter != nullptr) {
(*tensor->deleter)(tensor);
......@@ -115,8 +116,8 @@ struct NDArray::Internal {
}
// Implementation of API function
static DLTensor* MoveAsDLTensor(NDArray arr) {
DLTensor* tensor = const_cast<DLTensor*>(arr.operator->());
CHECK(reinterpret_cast<DLTensor*>(arr.data_) == tensor);
DLTensor* tensor = reinterpret_cast<DLTensor*>(arr.data_);
CHECK(tensor == const_cast<DLTensor*>(arr.operator->()));
arr.data_ = nullptr;
return tensor;
}
......@@ -256,16 +257,26 @@ void NDArray::CopyFromTo(DLTensor* from,
from_size, from->ctx, to->ctx, from->dtype, stream);
}
void NDArray::PinData(DLTensor* tensor) {
if (IsDataPinned(tensor)) return;
void NDArray::PinContainer(NDArray::Container* ptr) {
if (IsContainerPinned(ptr)) return;
auto* tensor = &(ptr->dl_tensor);
CHECK_EQ(tensor->ctx.device_type, kDLCPU)
<< "Only NDArray on CPU can be pinned";
DeviceAPI::Get(kDLGPU)->PinData(tensor->data, GetDataSize(*tensor));
ptr->pinned_by_dgl_ = true;
}
void NDArray::UnpinData(DLTensor* tensor) {
if (!IsDataPinned(tensor)) return;
DeviceAPI::Get(kDLGPU)->UnpinData(tensor->data);
void NDArray::UnpinContainer(NDArray::Container* ptr) {
auto container_is_pinned = IsContainerPinned(ptr);
// The tensor may be pinned outside of DGL via a different CUDA API,
// so we cannot unpin it with cudaHostUnregister.
CHECK(ptr->pinned_by_dgl_ || !container_is_pinned)
<< "Cannot unpin a tensor that is pinned outside of DGL.";
// 1. not pinned, do nothing
if (!container_is_pinned) return;
// 2. pinned by DGL, unpin it
DeviceAPI::Get(kDLGPU)->UnpinData(ptr->dl_tensor.data);
ptr->pinned_by_dgl_ = false;
}
template<typename T>
......@@ -327,7 +338,10 @@ std::shared_ptr<SharedMemory> NDArray::GetSharedMem() const {
return this->data_->mem;
}
bool NDArray::IsDataPinned(DLTensor* tensor) {
bool NDArray::IsContainerPinned(NDArray::Container* ptr) {
if (ptr->pinned_by_dgl_)
return true;
auto* tensor = &(ptr->dl_tensor);
// Can only be pinned if on CPU...
if (tensor->ctx.device_type != kDLCPU)
return false;
......@@ -533,13 +547,15 @@ int DGLArrayCopyToBytes(DGLArrayHandle handle,
int DGLArrayPinData(DGLArrayHandle handle,
DLContext ctx) {
API_BEGIN();
NDArray::PinData(handle);
auto* nd_container = reinterpret_cast<NDArray::Container*>(handle);
NDArray::PinContainer(nd_container);
API_END();
}
int DGLArrayUnpinData(DGLArrayHandle handle,
DLContext ctx) {
API_BEGIN();
NDArray::UnpinData(handle);
auto* nd_container = reinterpret_cast<NDArray::Container*>(handle);
NDArray::UnpinContainer(nd_container);
API_END();
}
......@@ -9,10 +9,20 @@ def test_pin_unpin():
assert not F.is_pinned(t)
if F.backend_name == 'pytorch':
dgl.utils.pin_memory_inplace(t)
nd = dgl.utils.pin_memory_inplace(t)
assert F.is_pinned(t)
dgl.utils.unpin_memory_inplace(t)
nd.unpin_memory_()
assert not F.is_pinned(t)
del nd
# tensor will be unpinned immediately if the returned ndarray is not saved
dgl.utils.pin_memory_inplace(t)
assert not F.is_pinned(t)
t_pin = t.pin_memory()
# cannot unpin a tensor that is pinned outside of DGL
with pytest.raises(dgl.DGLError):
F.to_dgl_nd(t_pin).unpin_memory_()
else:
with pytest.raises(dgl.DGLError):
# tensorflow and mxnet should throw an erro
......
......@@ -3,7 +3,7 @@ import dgl
import pytest
import torch
@pytest.mark.skipif(F._default_context_str == 'cpu', reason="Need gpu for this test")
@pytest.mark.skipif(F._default_context_str == 'cpu', reason="Need gpu for this test.")
def test_pin_noncontiguous():
t = torch.empty([10, 100]).transpose(0, 1)
......@@ -13,7 +13,7 @@ def test_pin_noncontiguous():
with pytest.raises(dgl.DGLError):
dgl.utils.pin_memory_inplace(t)
@pytest.mark.skipif(F._default_context_str == 'cpu', reason="Need gpu for this test")
@pytest.mark.skipif(F._default_context_str == 'cpu', reason="Need gpu for this test.")
def test_pin_view():
t = torch.empty([100, 10])
v = t[10:20]
......@@ -24,7 +24,41 @@ def test_pin_view():
with pytest.raises(dgl.DGLError):
dgl.utils.pin_memory_inplace(v)
@pytest.mark.skipif(F._default_context_str == 'cpu', reason='Need gpu for this test.')
def test_unpin_automatically():
# run a sufficient number of iterations such that the memory pool should be
# re-used
for j in range(10):
t = torch.ones(10000, 10)
assert not F.is_pinned(t)
nd = dgl.utils.pin_memory_inplace(t)
assert F.is_pinned(t)
del nd
# dgl.ndarray will unpin its data upon destruction
assert not F.is_pinned(t)
del t
@pytest.mark.skipif(F._default_context_str == 'cpu', reason='Need gpu for this test.')
def test_pin_unpin_column():
g = dgl.graph(([1, 2, 3, 4], [0, 0, 0, 0]))
g.ndata['x'] = torch.randn(g.num_nodes())
g.pin_memory_()
assert g.is_pinned()
assert g.ndata['x'].is_pinned()
for col in g._node_frames[0].values():
assert col.pinned_by_dgl
assert col._data_nd is not None
g.ndata['x'] = torch.randn(g.num_nodes()) # unpin the old ndata['x']
assert g.is_pinned()
for col in g._node_frames[0].values():
assert not col.pinned_by_dgl
assert col._data_nd is None
assert not g.ndata['x'].is_pinned()
if __name__ == "__main__":
test_pin_noncontiguous()
test_pin_view()
test_unpin_automatically()
test_pin_unpin_column()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment