Unverified Commit 9ee7ced5 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Performance] Redirect `AllocWorkspace` to PyTorch's allocator if available (#4199)

parent 9ae117d3
......@@ -304,6 +304,7 @@ if(BUILD_TORCH)
${CMAKE_COMMAND} -E env
CMAKE_COMMAND=${CMAKE_CMD}
CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}
USE_CUDA=${USE_CUDA}
BINDIR=${BINDIR}
cmd /e:on /c ${BUILD_SCRIPT} ${TORCH_PYTHON_INTERPS}
DEPENDS ${BUILD_SCRIPT}
......@@ -315,6 +316,7 @@ if(BUILD_TORCH)
${CMAKE_COMMAND} -E env
CMAKE_COMMAND=${CMAKE_CMD}
CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}
USE_CUDA=${USE_CUDA}
BINDIR=${CMAKE_CURRENT_BINARY_DIR}
bash ${BUILD_SCRIPT} ${TORCH_PYTHON_INTERPS}
DEPENDS ${BUILD_SCRIPT}
......
......@@ -69,8 +69,12 @@ class TensorDispatcher {
/*!
* \brief Allocate an empty tensor.
*
* Used in NDArray::Empty().
* \param shape The shape
* \param dtype The data type
* \param ctx The device
* \return An empty NDArray.
*/
inline NDArray Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) const {
auto entry = entrypoints_[Op::kEmpty];
......@@ -78,6 +82,36 @@ class TensorDispatcher {
return NDArray::FromDLPack(result);
}
#ifdef DGL_USE_CUDA
/*!
* \brief Allocate a piece of GPU memory via
* PyTorch's THCCachingAllocator.
* Used in CUDADeviceAPI::AllocWorkspace().
*
* \note THCCachingAllocator specify the device to allocate on
* via cudaGetDevice(). Make sure to call cudaSetDevice()
* before invoking this function.
*
* \param nbytes The size to be allocated.
* \return Pointer to the allocated memory.
*/
inline void* AllocWorkspace(size_t nbytes) {
auto entry = entrypoints_[Op::kRawAlloc];
return FUNCCAST(tensoradapter::RawAlloc, entry)(nbytes);
}
/*!
* \brief Free the GPU memory.
* Used in CUDADeviceAPI::FreeWorkspace().
*
* \param ptr Pointer to the memory to be freed.
*/
inline void FreeWorkspace(void* ptr) {
auto entry = entrypoints_[Op::kRawDelete];
FUNCCAST(tensoradapter::RawDelete, entry)(ptr);
}
#endif // DGL_USE_CUDA
private:
/*! \brief ctor */
TensorDispatcher() = default;
......@@ -91,19 +125,33 @@ class TensorDispatcher {
*/
static constexpr const char *names_[] = {
"TAempty",
#ifdef DGL_USE_CUDA
"RawAlloc",
"RawDelete",
#endif // DGL_USE_CUDA
};
/*! \brief Index of each function to the symbol list */
class Op {
public:
static constexpr int kEmpty = 0;
#ifdef DGL_USE_CUDA
static constexpr int kRawAlloc = 1;
static constexpr int kRawDelete = 2;
#endif // DGL_USE_CUDA
};
/*! \brief Number of functions */
static constexpr int num_entries_ = sizeof(names_) / sizeof(names_[0]);
/*! \brief Entrypoints of each function */
void* entrypoints_[num_entries_] = {nullptr};
void* entrypoints_[num_entries_] = {
nullptr,
#ifdef DGL_USE_CUDA
nullptr,
nullptr,
#endif // DGL_USE_CUDA
};
bool available_ = false;
#if defined(WIN32) || defined(_WIN32)
......
......@@ -4,7 +4,7 @@
* \brief GPU specific API
*/
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/tensordispatch.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <cuda_runtime.h>
......@@ -224,10 +224,20 @@ class CUDADeviceAPI final : public DeviceAPI {
}
void* AllocWorkspace(DGLContext ctx, size_t size, DGLType type_hint) final {
// Redirect to PyTorch's allocator when available.
SetDevice(ctx);
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->AllocWorkspace(size);
else
return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}
void FreeWorkspace(DGLContext ctx, void* data) final {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
td->FreeWorkspace(data);
else
CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}
......
......@@ -18,7 +18,7 @@ namespace tensoradapter {
extern "C" {
/*!
* \brief Allocate an empty tensor
* \brief Allocate an empty tensor.
*
* \param shape The shape
* \param dtype The data type
......@@ -28,6 +28,24 @@ extern "C" {
DLManagedTensor* TAempty(
std::vector<int64_t> shape, DLDataType dtype, DLContext ctx);
#ifdef DGL_USE_CUDA
/*!
* \brief Allocate a piece of GPU memory via
* PyTorch's THCCachingAllocator.
*
* \param nbytes The size to be allocated.
* \return Pointer to the allocated memory.
*/
void* RawAlloc(size_t nbytes);
/*!
* \brief Free the GPU memory.
*
* \param ptr Pointer to the memory to be freed.
*/
void RawDelete(void* ptr);
#endif // DGL_USE_CUDA
}
}; // namespace tensoradapter
......
......@@ -17,6 +17,10 @@ list(GET TORCH_PREFIX_VER 0 TORCH_PREFIX)
list(GET TORCH_PREFIX_VER 1 TORCH_VER)
message(STATUS "Configuring for PyTorch ${TORCH_VER}")
if(USE_CUDA)
add_definitions(-DDGL_USE_CUDA)
endif()
set(Torch_DIR "${TORCH_PREFIX}/Torch")
message(STATUS "Setting directory to ${Torch_DIR}")
find_package(Torch REQUIRED)
......
......@@ -11,7 +11,7 @@ IF x%1x == xx GOTO single
FOR %%X IN (%*) DO (
DEL /S /Q *
"%CMAKE_COMMAND%" -DCMAKE_CONFIGURATION_TYPES=Release -DCUDA_TOOLKIT_ROOT_DIR="%CUDA_TOOLKIT_ROOT_DIR%" -DTORCH_CUDA_ARCH_LIST=%TORCH_CUDA_ARCH_LIST% -DPYTHON_INTERP=%%X .. -G "Visual Studio 16 2019" || EXIT /B 1
"%CMAKE_COMMAND%" -DCMAKE_CONFIGURATION_TYPES=Release -DCUDA_TOOLKIT_ROOT_DIR="%CUDA_TOOLKIT_ROOT_DIR%" -DTORCH_CUDA_ARCH_LIST=%TORCH_CUDA_ARCH_LIST% -DUSE_CUDA=%USE_CUDA% -DPYTHON_INTERP=%%X .. -G "Visual Studio 16 2019" || EXIT /B 1
msbuild tensoradapter_pytorch.sln /m /nr:false || EXIT /B 1
COPY /Y Release\*.dll "%BINDIR%\tensoradapter\pytorch" || EXIT /B 1
)
......@@ -21,7 +21,7 @@ GOTO end
:single
DEL /S /Q *
"%CMAKE_COMMAND%" -DCMAKE_CONFIGURATION_TYPES=Release -DCUDA_TOOLKIT_ROOT_DIR="%CUDA_TOOLKIT_ROOT_DIR%" -DTORCH_CUDA_ARCH_LIST=%TORCH_CUDA_ARCH_LIST% .. -G "Visual Studio 16 2019" || EXIT /B 1
"%CMAKE_COMMAND%" -DCMAKE_CONFIGURATION_TYPES=Release -DCUDA_TOOLKIT_ROOT_DIR="%CUDA_TOOLKIT_ROOT_DIR%" -DTORCH_CUDA_ARCH_LIST=%TORCH_CUDA_ARCH_LIST% -DUSE_CUDA=%USE_CUDA% .. -G "Visual Studio 16 2019" || EXIT /B 1
msbuild tensoradapter_pytorch.sln /m /nr:false || EXIT /B 1
COPY /Y Release\*.dll "%BINDIR%\tensoradapter\pytorch" || EXIT /B 1
......
......@@ -13,7 +13,7 @@ else
CPSOURCE=*.so
fi
CMAKE_FLAGS="-DCUDA_TOOLKIT_ROOT_DIR=$CUDA_TOOLKIT_ROOT_DIR -DTORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST"
CMAKE_FLAGS="-DCUDA_TOOLKIT_ROOT_DIR=$CUDA_TOOLKIT_ROOT_DIR -DTORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST -DUSE_CUDA=$USE_CUDA"
if [ $# -eq 0 ]; then
$CMAKE_COMMAND $CMAKE_FLAGS ..
......
......@@ -7,6 +7,9 @@
#include <tensoradapter_exports.h>
#include <torch/torch.h>
#include <ATen/DLConvertor.h>
#ifdef DGL_USE_CUDA
#include <c10/cuda/CUDACachingAllocator.h>
#endif // DGL_USE_CUDA
#include <vector>
#include <iostream>
......@@ -47,6 +50,16 @@ TA_EXPORTS DLManagedTensor* TAempty(
return at::toDLPack(tensor);
}
#ifdef DGL_USE_CUDA
TA_EXPORTS void* RawAlloc(size_t nbytes) {
return c10::cuda::CUDACachingAllocator::raw_alloc(nbytes);
}
TA_EXPORTS void RawDelete(void* ptr) {
c10::cuda::CUDACachingAllocator::raw_delete(ptr);
}
#endif // DGL_USE_CUDA
};
}; // namespace tensoradapter
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment