Unverified Commit 2b766740 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Make TensorAdapter Stream Aware (#4472)

* Allocate tensors in DGL's current stream

* make tensoradaptor stream-aware

* replace TAemtpy with cpu allocator

* fix typo

* try fix cpu allocation

* clean header

* redirect AllocDataSpace as well

* resolve comments
parent 468c0ca4
/*! /*!
* Copyright (c) 2020 by Contributors * Copyright (c) 2020-2022 by Contributors
* \file array/tensordispatch.h * \file array/tensordispatch.h
* \brief This file defines the dispatcher of tensor operators to framework-specific * \brief This file defines the dispatcher of tensor operators to framework-specific
* implementations. * implementations.
...@@ -8,8 +8,6 @@ ...@@ -8,8 +8,6 @@
* one separately-built shared library per supported backend. * one separately-built shared library per supported backend.
* *
* Those shared libraries contain wrappers of the framework-specific operators. * Those shared libraries contain wrappers of the framework-specific operators.
* The wrappers have almost the same signatures as functions in aten namespace,
* except that they accept and return DLManagedTensors instead of NDArrays.
* The wrappers are defined with extern "C", meaning that the C++ compiler will * The wrappers are defined with extern "C", meaning that the C++ compiler will
* not do name mangling for those functions so that DGL can conveniently locate * not do name mangling for those functions so that DGL can conveniently locate
* them using dlsym(3) (or GetProcAddress in Windows). * them using dlsym(3) (or GetProcAddress in Windows).
...@@ -23,21 +21,21 @@ ...@@ -23,21 +21,21 @@
* *
* A tensor operator in TensorDispatcher first checks whether the corresponding symbol * A tensor operator in TensorDispatcher first checks whether the corresponding symbol
* address is found in the mapping. If so, it calls the function located at the * address is found in the mapping. If so, it calls the function located at the
* symbol address instead, translating NDArrays to DLManagedTensors using * symbol address instead, allocate/free pieces of memory on CPU/GPU.
* NDArray::ToDLPack(), and translates the DLManagedTensors in the return values * If not, it falls back to DeviceAPI::AllocWorkspace/FreeWorkspace.
* back to NDArrays using NDArray::FromDLPack(). If not, it falls back to the
* implementation in dgl::aten namespace.
*/ */
#ifndef DGL_RUNTIME_TENSORDISPATCH_H_ #ifndef DGL_RUNTIME_TENSORDISPATCH_H_
#define DGL_RUNTIME_TENSORDISPATCH_H_ #define DGL_RUNTIME_TENSORDISPATCH_H_
#include <dlpack/dlpack.h> #include <stddef.h>
#include <tensoradapter.h> #include <tensoradapter.h>
#if defined(WIN32) || defined(_WIN32) #if defined(WIN32) || defined(_WIN32)
#include <windows.h> #include <windows.h>
#endif // WIN32 #endif // WIN32
#include <vector> #ifdef DGL_USE_CUDA
#include <cuda_runtime.h>
#endif // DGL_USE_CUDA
#include "ndarray.h" #include "ndarray.h"
/*! \brief Casts a pointer \c entry to a function pointer with signature of \c func */ /*! \brief Casts a pointer \c entry to a function pointer with signature of \c func */
...@@ -68,47 +66,57 @@ class TensorDispatcher { ...@@ -68,47 +66,57 @@ class TensorDispatcher {
bool Load(const char *path_cstr); bool Load(const char *path_cstr);
/*! /*!
* \brief Allocate an empty tensor. * \brief Allocate a piece of CPU memory via
* Used in NDArray::Empty(). * PyTorch's CPUAllocator.
* Used in CPUDeviceAPI::AllocWorkspace().
*
* \param nbytes The size to be allocated.
* \return Pointer to the allocated memory.
*/
inline void* CPUAllocWorkspace(size_t nbytes) {
auto entry = entrypoints_[Op::kCPURawAlloc];
return FUNCCAST(tensoradapter::CPURawAlloc, entry)(nbytes);
}
* \param shape The shape /*!
* \param dtype The data type * \brief Free the CPU memory.
* \param ctx The device * Used in CPUDeviceAPI::FreeWorkspace().
* \return An empty NDArray. *
* \param ptr Pointer to the memory to be freed.
*/ */
inline NDArray Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) const { inline void CPUFreeWorkspace(void* ptr) {
auto entry = entrypoints_[Op::kEmpty]; auto entry = entrypoints_[Op::kCPURawDelete];
auto result = FUNCCAST(tensoradapter::TAempty, entry)(shape, dtype, ctx); FUNCCAST(tensoradapter::CPURawDelete, entry)(ptr);
return NDArray::FromDLPack(result);
} }
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
/*! /*!
* \brief Allocate a piece of GPU memory via * \brief Allocate a piece of GPU memory via
* PyTorch's THCCachingAllocator. * PyTorch's THCCachingAllocator.
* Used in CUDADeviceAPI::AllocWorkspace(). * Used in CUDADeviceAPI::AllocWorkspace().
* *
* \note THCCachingAllocator specify the device to allocate on * \note THCCachingAllocator specify the device to allocate on
* via cudaGetDevice(). Make sure to call cudaSetDevice() * via cudaGetDevice(). Make sure to call cudaSetDevice()
* before invoking this function. * before invoking this function.
* *
* \param nbytes The size to be allocated. * \param nbytes The size to be allocated.
* \return Pointer to the allocated memory. * \param stream The stream to be allocated on.
*/ * \return Pointer to the allocated memory.
inline void* AllocWorkspace(size_t nbytes) { */
auto entry = entrypoints_[Op::kRawAlloc]; inline void* CUDAAllocWorkspace(size_t nbytes, cudaStream_t stream) {
return FUNCCAST(tensoradapter::RawAlloc, entry)(nbytes); auto entry = entrypoints_[Op::kCUDARawAlloc];
return FUNCCAST(tensoradapter::CUDARawAlloc, entry)(nbytes, stream);
} }
/*! /*!
* \brief Free the GPU memory. * \brief Free the GPU memory.
* Used in CUDADeviceAPI::FreeWorkspace(). * Used in CUDADeviceAPI::FreeWorkspace().
* *
* \param ptr Pointer to the memory to be freed. * \param ptr Pointer to the memory to be freed.
*/ */
inline void FreeWorkspace(void* ptr) { inline void CUDAFreeWorkspace(void* ptr) {
auto entry = entrypoints_[Op::kRawDelete]; auto entry = entrypoints_[Op::kCUDARawDelete];
FUNCCAST(tensoradapter::RawDelete, entry)(ptr); FUNCCAST(tensoradapter::CUDARawDelete, entry)(ptr);
} }
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
...@@ -124,20 +132,22 @@ class TensorDispatcher { ...@@ -124,20 +132,22 @@ class TensorDispatcher {
* Must match the functions in tensoradapter/include/tensoradapter.h. * Must match the functions in tensoradapter/include/tensoradapter.h.
*/ */
static constexpr const char *names_[] = { static constexpr const char *names_[] = {
"TAempty", "CPURawAlloc",
"CPURawDelete",
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
"RawAlloc", "CUDARawAlloc",
"RawDelete", "CUDARawDelete",
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
}; };
/*! \brief Index of each function to the symbol list */ /*! \brief Index of each function to the symbol list */
class Op { class Op {
public: public:
static constexpr int kEmpty = 0; static constexpr int kCPURawAlloc = 0;
static constexpr int kCPURawDelete = 1;
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
static constexpr int kRawAlloc = 1; static constexpr int kCUDARawAlloc = 2;
static constexpr int kRawDelete = 2; static constexpr int kCUDARawDelete = 3;
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
}; };
...@@ -147,6 +157,7 @@ class TensorDispatcher { ...@@ -147,6 +157,7 @@ class TensorDispatcher {
/*! \brief Entrypoints of each function */ /*! \brief Entrypoints of each function */
void* entrypoints_[num_entries_] = { void* entrypoints_[num_entries_] = {
nullptr, nullptr,
nullptr,
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
nullptr, nullptr,
nullptr, nullptr,
......
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016-2022 by Contributors
* \file cpu_device_api.cc * \file cpu_device_api.cc
*/ */
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h> #include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <dgl/runtime/tensordispatch.h>
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include "workspace_pool.h" #include "workspace_pool.h"
...@@ -24,6 +25,10 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -24,6 +25,10 @@ class CPUDeviceAPI final : public DeviceAPI {
size_t nbytes, size_t nbytes,
size_t alignment, size_t alignment,
DGLType type_hint) final { DGLType type_hint) final {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CPUAllocWorkspace(nbytes);
void* ptr; void* ptr;
#if _MSC_VER || defined(__MINGW32__) #if _MSC_VER || defined(__MINGW32__)
ptr = _aligned_malloc(nbytes, alignment); ptr = _aligned_malloc(nbytes, alignment);
...@@ -39,6 +44,10 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -39,6 +44,10 @@ class CPUDeviceAPI final : public DeviceAPI {
} }
void FreeDataSpace(DGLContext ctx, void* ptr) final { void FreeDataSpace(DGLContext ctx, void* ptr) final {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CPUFreeWorkspace(ptr);
#if _MSC_VER || defined(__MINGW32__) #if _MSC_VER || defined(__MINGW32__)
_aligned_free(ptr); _aligned_free(ptr);
#else #else
...@@ -83,11 +92,18 @@ struct CPUWorkspacePool : public WorkspacePool { ...@@ -83,11 +92,18 @@ struct CPUWorkspacePool : public WorkspacePool {
void* CPUDeviceAPI::AllocWorkspace(DGLContext ctx, void* CPUDeviceAPI::AllocWorkspace(DGLContext ctx,
size_t size, size_t size,
DGLType type_hint) { DGLType type_hint) {
return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get() TensorDispatcher* td = TensorDispatcher::Global();
->AllocWorkspace(ctx, size); if (td->IsAvailable())
return td->CPUAllocWorkspace(size);
return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->AllocWorkspace(ctx, size);
} }
void CPUDeviceAPI::FreeWorkspace(DGLContext ctx, void* data) { void CPUDeviceAPI::FreeWorkspace(DGLContext ctx, void* data) {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CPUFreeWorkspace(data);
dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data); dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data);
} }
......
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017-2022 by Contributors
* \file cuda_device_api.cc * \file cuda_device_api.cc
* \brief GPU specific API * \brief GPU specific API
*/ */
...@@ -107,7 +107,12 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -107,7 +107,12 @@ class CUDADeviceAPI final : public DeviceAPI {
size_t nbytes, size_t nbytes,
size_t alignment, size_t alignment,
DGLType type_hint) final { DGLType type_hint) final {
CUDA_CALL(cudaSetDevice(ctx.device_id)); SetDevice(ctx);
// Redirect to PyTorch's allocator when available.
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CUDAAllocWorkspace(nbytes, CUDAThreadEntry::ThreadLocal()->stream);
CHECK_EQ(256 % alignment, 0U) CHECK_EQ(256 % alignment, 0U)
<< "CUDA space is aligned at 256 bytes"; << "CUDA space is aligned at 256 bytes";
void *ret; void *ret;
...@@ -116,7 +121,11 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -116,7 +121,11 @@ class CUDADeviceAPI final : public DeviceAPI {
} }
void FreeDataSpace(DGLContext ctx, void* ptr) final { void FreeDataSpace(DGLContext ctx, void* ptr) final {
CUDA_CALL(cudaSetDevice(ctx.device_id)); SetDevice(ctx);
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CUDAFreeWorkspace(ptr);
CUDA_CALL(cudaFree(ptr)); CUDA_CALL(cudaFree(ptr));
} }
...@@ -246,21 +255,22 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -246,21 +255,22 @@ class CUDADeviceAPI final : public DeviceAPI {
} }
void* AllocWorkspace(DGLContext ctx, size_t size, DGLType type_hint) final { void* AllocWorkspace(DGLContext ctx, size_t size, DGLType type_hint) final {
// Redirect to PyTorch's allocator when available.
SetDevice(ctx); SetDevice(ctx);
// Redirect to PyTorch's allocator when available.
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable())
return td->AllocWorkspace(size); return td->CUDAAllocWorkspace(size, CUDAThreadEntry::ThreadLocal()->stream);
else
return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
} }
void FreeWorkspace(DGLContext ctx, void* data) final { void FreeWorkspace(DGLContext ctx, void* data) final {
SetDevice(ctx);
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable())
td->FreeWorkspace(data); return td->CUDAFreeWorkspace(data);
else
CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data); CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
} }
static const std::shared_ptr<CUDADeviceAPI>& Global() { static const std::shared_ptr<CUDADeviceAPI>& Global() {
......
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017-2022 by Contributors
* \file ndarray.cc * \file ndarray.cc
* \brief NDArray container infratructure. * \brief NDArray container infratructure.
*/ */
...@@ -214,10 +214,6 @@ NDArray NDArray::EmptyShared(const std::string &name, ...@@ -214,10 +214,6 @@ NDArray NDArray::EmptyShared(const std::string &name,
NDArray NDArray::Empty(std::vector<int64_t> shape, NDArray NDArray::Empty(std::vector<int64_t> shape,
DLDataType dtype, DLDataType dtype,
DLContext ctx) { DLContext ctx) {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->Empty(shape, dtype, ctx);
NDArray ret = Internal::Create(shape, dtype, ctx); NDArray ret = Internal::Create(shape, dtype, ctx);
// setup memory content // setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor); size_t size = GetDataSize(ret.data_->dl_tensor);
......
/*! /*!
* Copyright (c) 2020 by Contributors * Copyright (c) 2020-2022 by Contributors
* \file tensoradapter.h * \file tensoradapter.h
* \brief Header file for functions exposed by the adapter library. * \brief Header file for functions exposed by the adapter library.
* *
...@@ -10,23 +10,29 @@ ...@@ -10,23 +10,29 @@
#ifndef TENSORADAPTER_H_ #ifndef TENSORADAPTER_H_
#define TENSORADAPTER_H_ #define TENSORADAPTER_H_
#include <dlpack/dlpack.h> #ifdef DGL_USE_CUDA
#include <vector> #include <cuda_runtime.h>
#endif // DGL_USE_CUDA
namespace tensoradapter { namespace tensoradapter {
extern "C" { extern "C" {
/*! /*!
* \brief Allocate an empty tensor. * \brief Allocate a piece of CPU memory via
* PyTorch's CPUAllocator
* *
* \param shape The shape * \param nbytes The size to be allocated.
* \param dtype The data type * \return Pointer to the allocated memory.
* \param ctx The device */
* \return The allocated tensor void* CPURawAlloc(size_t nbytes);
/*!
* \brief Free the CPU memory.
*
* \param ptr Pointer to the memory to be freed.
*/ */
DLManagedTensor* TAempty( void CPURawDelete(void* ptr);
std::vector<int64_t> shape, DLDataType dtype, DLContext ctx);
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
/*! /*!
...@@ -34,16 +40,17 @@ DLManagedTensor* TAempty( ...@@ -34,16 +40,17 @@ DLManagedTensor* TAempty(
* PyTorch's THCCachingAllocator. * PyTorch's THCCachingAllocator.
* *
* \param nbytes The size to be allocated. * \param nbytes The size to be allocated.
* \param stream The stream to be allocated on.
* \return Pointer to the allocated memory. * \return Pointer to the allocated memory.
*/ */
void* RawAlloc(size_t nbytes); void* CUDARawAlloc(size_t nbytes, cudaStream_t stream);
/*! /*!
* \brief Free the GPU memory. * \brief Free the GPU memory.
* *
* \param ptr Pointer to the memory to be freed. * \param ptr Pointer to the memory to be freed.
*/ */
void RawDelete(void* ptr); void CUDARawDelete(void* ptr);
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
} }
......
/*! /*!
* Copyright (c) 2020 by Contributors * Copyright (c) 2020-2022 by Contributors
* \file torch/torch.cpp * \file torch/torch.cpp
* \brief Implementation of PyTorch adapter library. * \brief Implementation of PyTorch adapter library.
*/ */
#include <tensoradapter_exports.h> #include <tensoradapter_exports.h>
#include <torch/torch.h> #include <c10/core/CPUAllocator.h>
#include <ATen/DLConvertor.h>
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
#include <c10/cuda/CUDACachingAllocator.h> #include <c10/cuda/CUDACachingAllocator.h>
#include <cuda_runtime.h>
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
#include <vector>
#include <iostream>
#if DLPACK_VERSION > 040
// Compatibility across DLPack - note that this assumes that the ABI stays the same.
#define kDLGPU kDLCUDA
#define DLContext DLDevice
#endif
namespace tensoradapter { namespace tensoradapter {
static at::Device get_device(DLContext ctx) {
switch (ctx.device_type) {
case kDLCPU:
return at::Device(torch::kCPU);
break;
case kDLGPU:
return at::Device(torch::kCUDA, ctx.device_id);
break;
default:
// fallback to CPU
return at::Device(torch::kCPU);
break;
}
}
extern "C" { extern "C" {
TA_EXPORTS DLManagedTensor* TAempty( TA_EXPORTS void* CPURawAlloc(size_t nbytes) {
std::vector<int64_t> shape, return c10::GetCPUAllocator()->raw_allocate(nbytes);
DLDataType dtype, }
DLContext ctx) {
auto options = torch::TensorOptions() TA_EXPORTS void CPURawDelete(void* ptr) {
.layout(torch::kStrided) c10::GetCPUAllocator()->raw_deallocate(ptr);
.device(get_device(ctx))
.dtype(at::toScalarType(dtype));
torch::Tensor tensor = torch::empty(shape, options);
return at::toDLPack(tensor);
} }
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
TA_EXPORTS void* RawAlloc(size_t nbytes) { TA_EXPORTS void* CUDARawAlloc(size_t nbytes, cudaStream_t stream) {
return c10::cuda::CUDACachingAllocator::raw_alloc(nbytes); return c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(
nbytes, stream);
} }
TA_EXPORTS void RawDelete(void* ptr) { TA_EXPORTS void CUDARawDelete(void* ptr) {
c10::cuda::CUDACachingAllocator::raw_delete(ptr); c10::cuda::CUDACachingAllocator::raw_delete(ptr);
} }
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment