"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8adc6003ba4dbf5b61bb4f1ce571e9e55e145a99"
Unverified Commit 2b766740 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Make TensorAdapter Stream Aware (#4472)

* Allocate tensors in DGL's current stream

* make tensoradaptor stream-aware

* replace TAemtpy with cpu allocator

* fix typo

* try fix cpu allocation

* clean header

* redirect AllocDataSpace as well

* resolve comments
parent 468c0ca4
/*!
* Copyright (c) 2020 by Contributors
* Copyright (c) 2020-2022 by Contributors
* \file array/tensordispatch.h
* \brief This file defines the dispatcher of tensor operators to framework-specific
* implementations.
......@@ -8,8 +8,6 @@
* one separately-built shared library per supported backend.
*
* Those shared libraries contain wrappers of the framework-specific operators.
* The wrappers have almost the same signatures as functions in aten namespace,
* except that they accept and return DLManagedTensors instead of NDArrays.
* The wrappers are defined with extern "C", meaning that the C++ compiler will
* not do name mangling for those functions so that DGL can conveniently locate
* them using dlsym(3) (or GetProcAddress in Windows).
......@@ -23,21 +21,21 @@
*
* A tensor operator in TensorDispatcher first checks whether the corresponding symbol
* address is found in the mapping. If so, it calls the function located at the
* symbol address instead, translating NDArrays to DLManagedTensors using
* NDArray::ToDLPack(), and translates the DLManagedTensors in the return values
* back to NDArrays using NDArray::FromDLPack(). If not, it falls back to the
* implementation in dgl::aten namespace.
* symbol address instead, allocate/free pieces of memory on CPU/GPU.
* If not, it falls back to DeviceAPI::AllocWorkspace/FreeWorkspace.
*/
#ifndef DGL_RUNTIME_TENSORDISPATCH_H_
#define DGL_RUNTIME_TENSORDISPATCH_H_
#include <dlpack/dlpack.h>
#include <stddef.h>
#include <tensoradapter.h>
#if defined(WIN32) || defined(_WIN32)
#include <windows.h>
#endif // WIN32
#include <vector>
#ifdef DGL_USE_CUDA
#include <cuda_runtime.h>
#endif // DGL_USE_CUDA
#include "ndarray.h"
/*! \brief Casts a pointer \c entry to a function pointer with signature of \c func */
......@@ -68,47 +66,57 @@ class TensorDispatcher {
bool Load(const char *path_cstr);
/*!
* \brief Allocate an empty tensor.
* Used in NDArray::Empty().
* \brief Allocate a piece of CPU memory via
* PyTorch's CPUAllocator.
* Used in CPUDeviceAPI::AllocWorkspace().
*
* \param nbytes The size to be allocated.
* \return Pointer to the allocated memory.
*/
inline void* CPUAllocWorkspace(size_t nbytes) {
auto entry = entrypoints_[Op::kCPURawAlloc];
return FUNCCAST(tensoradapter::CPURawAlloc, entry)(nbytes);
}
* \param shape The shape
* \param dtype The data type
* \param ctx The device
* \return An empty NDArray.
/*!
* \brief Free the CPU memory.
* Used in CPUDeviceAPI::FreeWorkspace().
*
* \param ptr Pointer to the memory to be freed.
*/
inline NDArray Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) const {
auto entry = entrypoints_[Op::kEmpty];
auto result = FUNCCAST(tensoradapter::TAempty, entry)(shape, dtype, ctx);
return NDArray::FromDLPack(result);
inline void CPUFreeWorkspace(void* ptr) {
auto entry = entrypoints_[Op::kCPURawDelete];
FUNCCAST(tensoradapter::CPURawDelete, entry)(ptr);
}
#ifdef DGL_USE_CUDA
/*!
* \brief Allocate a piece of GPU memory via
* PyTorch's THCCachingAllocator.
* Used in CUDADeviceAPI::AllocWorkspace().
*
* \note THCCachingAllocator specify the device to allocate on
* via cudaGetDevice(). Make sure to call cudaSetDevice()
* before invoking this function.
*
* \param nbytes The size to be allocated.
* \return Pointer to the allocated memory.
*/
inline void* AllocWorkspace(size_t nbytes) {
auto entry = entrypoints_[Op::kRawAlloc];
return FUNCCAST(tensoradapter::RawAlloc, entry)(nbytes);
* \brief Allocate a piece of GPU memory via
* PyTorch's THCCachingAllocator.
* Used in CUDADeviceAPI::AllocWorkspace().
*
* \note THCCachingAllocator specify the device to allocate on
* via cudaGetDevice(). Make sure to call cudaSetDevice()
* before invoking this function.
*
* \param nbytes The size to be allocated.
* \param stream The stream to be allocated on.
* \return Pointer to the allocated memory.
*/
inline void* CUDAAllocWorkspace(size_t nbytes, cudaStream_t stream) {
auto entry = entrypoints_[Op::kCUDARawAlloc];
return FUNCCAST(tensoradapter::CUDARawAlloc, entry)(nbytes, stream);
}
/*!
* \brief Free the GPU memory.
* Used in CUDADeviceAPI::FreeWorkspace().
*
* \param ptr Pointer to the memory to be freed.
*/
inline void FreeWorkspace(void* ptr) {
auto entry = entrypoints_[Op::kRawDelete];
FUNCCAST(tensoradapter::RawDelete, entry)(ptr);
* \brief Free the GPU memory.
* Used in CUDADeviceAPI::FreeWorkspace().
*
* \param ptr Pointer to the memory to be freed.
*/
inline void CUDAFreeWorkspace(void* ptr) {
auto entry = entrypoints_[Op::kCUDARawDelete];
FUNCCAST(tensoradapter::CUDARawDelete, entry)(ptr);
}
#endif // DGL_USE_CUDA
......@@ -124,20 +132,22 @@ class TensorDispatcher {
* Must match the functions in tensoradapter/include/tensoradapter.h.
*/
static constexpr const char *names_[] = {
"TAempty",
"CPURawAlloc",
"CPURawDelete",
#ifdef DGL_USE_CUDA
"RawAlloc",
"RawDelete",
"CUDARawAlloc",
"CUDARawDelete",
#endif // DGL_USE_CUDA
};
/*! \brief Index of each function to the symbol list */
class Op {
public:
static constexpr int kEmpty = 0;
static constexpr int kCPURawAlloc = 0;
static constexpr int kCPURawDelete = 1;
#ifdef DGL_USE_CUDA
static constexpr int kRawAlloc = 1;
static constexpr int kRawDelete = 2;
static constexpr int kCUDARawAlloc = 2;
static constexpr int kCUDARawDelete = 3;
#endif // DGL_USE_CUDA
};
......@@ -147,6 +157,7 @@ class TensorDispatcher {
/*! \brief Entrypoints of each function */
void* entrypoints_[num_entries_] = {
nullptr,
nullptr,
#ifdef DGL_USE_CUDA
nullptr,
nullptr,
......
/*!
* Copyright (c) 2016 by Contributors
* Copyright (c) 2016-2022 by Contributors
* \file cpu_device_api.cc
*/
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/tensordispatch.h>
#include <cstdlib>
#include <cstring>
#include "workspace_pool.h"
......@@ -24,6 +25,10 @@ class CPUDeviceAPI final : public DeviceAPI {
size_t nbytes,
size_t alignment,
DGLType type_hint) final {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CPUAllocWorkspace(nbytes);
void* ptr;
#if _MSC_VER || defined(__MINGW32__)
ptr = _aligned_malloc(nbytes, alignment);
......@@ -39,6 +44,10 @@ class CPUDeviceAPI final : public DeviceAPI {
}
void FreeDataSpace(DGLContext ctx, void* ptr) final {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CPUFreeWorkspace(ptr);
#if _MSC_VER || defined(__MINGW32__)
_aligned_free(ptr);
#else
......@@ -83,11 +92,18 @@ struct CPUWorkspacePool : public WorkspacePool {
void* CPUDeviceAPI::AllocWorkspace(DGLContext ctx,
size_t size,
DGLType type_hint) {
return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()
->AllocWorkspace(ctx, size);
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CPUAllocWorkspace(size);
return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->AllocWorkspace(ctx, size);
}
void CPUDeviceAPI::FreeWorkspace(DGLContext ctx, void* data) {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CPUFreeWorkspace(data);
dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data);
}
......
/*!
* Copyright (c) 2017 by Contributors
* Copyright (c) 2017-2022 by Contributors
* \file cuda_device_api.cc
* \brief GPU specific API
*/
......@@ -107,7 +107,12 @@ class CUDADeviceAPI final : public DeviceAPI {
size_t nbytes,
size_t alignment,
DGLType type_hint) final {
CUDA_CALL(cudaSetDevice(ctx.device_id));
SetDevice(ctx);
// Redirect to PyTorch's allocator when available.
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CUDAAllocWorkspace(nbytes, CUDAThreadEntry::ThreadLocal()->stream);
CHECK_EQ(256 % alignment, 0U)
<< "CUDA space is aligned at 256 bytes";
void *ret;
......@@ -116,7 +121,11 @@ class CUDADeviceAPI final : public DeviceAPI {
}
void FreeDataSpace(DGLContext ctx, void* ptr) final {
CUDA_CALL(cudaSetDevice(ctx.device_id));
SetDevice(ctx);
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->CUDAFreeWorkspace(ptr);
CUDA_CALL(cudaFree(ptr));
}
......@@ -246,21 +255,22 @@ class CUDADeviceAPI final : public DeviceAPI {
}
void* AllocWorkspace(DGLContext ctx, size_t size, DGLType type_hint) final {
// Redirect to PyTorch's allocator when available.
SetDevice(ctx);
// Redirect to PyTorch's allocator when available.
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->AllocWorkspace(size);
else
return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
return td->CUDAAllocWorkspace(size, CUDAThreadEntry::ThreadLocal()->stream);
return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}
void FreeWorkspace(DGLContext ctx, void* data) final {
SetDevice(ctx);
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
td->FreeWorkspace(data);
else
CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
return td->CUDAFreeWorkspace(data);
CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}
static const std::shared_ptr<CUDADeviceAPI>& Global() {
......
/*!
* Copyright (c) 2017 by Contributors
* Copyright (c) 2017-2022 by Contributors
* \file ndarray.cc
* \brief NDArray container infratructure.
*/
......@@ -214,10 +214,6 @@ NDArray NDArray::EmptyShared(const std::string &name,
NDArray NDArray::Empty(std::vector<int64_t> shape,
DLDataType dtype,
DLContext ctx) {
TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable())
return td->Empty(shape, dtype, ctx);
NDArray ret = Internal::Create(shape, dtype, ctx);
// setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor);
......
/*!
* Copyright (c) 2020 by Contributors
* Copyright (c) 2020-2022 by Contributors
* \file tensoradapter.h
* \brief Header file for functions exposed by the adapter library.
*
......@@ -10,23 +10,29 @@
#ifndef TENSORADAPTER_H_
#define TENSORADAPTER_H_
#include <dlpack/dlpack.h>
#include <vector>
#ifdef DGL_USE_CUDA
#include <cuda_runtime.h>
#endif // DGL_USE_CUDA
namespace tensoradapter {
extern "C" {
/*!
* \brief Allocate an empty tensor.
* \brief Allocate a piece of CPU memory via
* PyTorch's CPUAllocator
*
* \param shape The shape
* \param dtype The data type
* \param ctx The device
* \return The allocated tensor
* \param nbytes The size to be allocated.
* \return Pointer to the allocated memory.
*/
void* CPURawAlloc(size_t nbytes);
/*!
* \brief Free the CPU memory.
*
* \param ptr Pointer to the memory to be freed.
*/
DLManagedTensor* TAempty(
std::vector<int64_t> shape, DLDataType dtype, DLContext ctx);
void CPURawDelete(void* ptr);
#ifdef DGL_USE_CUDA
/*!
......@@ -34,16 +40,17 @@ DLManagedTensor* TAempty(
* PyTorch's THCCachingAllocator.
*
* \param nbytes The size to be allocated.
* \param stream The stream to be allocated on.
* \return Pointer to the allocated memory.
*/
void* RawAlloc(size_t nbytes);
void* CUDARawAlloc(size_t nbytes, cudaStream_t stream);
/*!
* \brief Free the GPU memory.
*
* \param ptr Pointer to the memory to be freed.
*/
void RawDelete(void* ptr);
void CUDARawDelete(void* ptr);
#endif // DGL_USE_CUDA
}
......
/*!
* Copyright (c) 2020 by Contributors
* Copyright (c) 2020-2022 by Contributors
* \file torch/torch.cpp
* \brief Implementation of PyTorch adapter library.
*/
#include <tensoradapter_exports.h>
#include <torch/torch.h>
#include <ATen/DLConvertor.h>
#include <c10/core/CPUAllocator.h>
#ifdef DGL_USE_CUDA
#include <c10/cuda/CUDACachingAllocator.h>
#include <cuda_runtime.h>
#endif // DGL_USE_CUDA
#include <vector>
#include <iostream>
#if DLPACK_VERSION > 040
// Compatibility across DLPack - note that this assumes that the ABI stays the same.
#define kDLGPU kDLCUDA
#define DLContext DLDevice
#endif
namespace tensoradapter {
static at::Device get_device(DLContext ctx) {
switch (ctx.device_type) {
case kDLCPU:
return at::Device(torch::kCPU);
break;
case kDLGPU:
return at::Device(torch::kCUDA, ctx.device_id);
break;
default:
// fallback to CPU
return at::Device(torch::kCPU);
break;
}
}
extern "C" {
TA_EXPORTS DLManagedTensor* TAempty(
std::vector<int64_t> shape,
DLDataType dtype,
DLContext ctx) {
auto options = torch::TensorOptions()
.layout(torch::kStrided)
.device(get_device(ctx))
.dtype(at::toScalarType(dtype));
torch::Tensor tensor = torch::empty(shape, options);
return at::toDLPack(tensor);
TA_EXPORTS void* CPURawAlloc(size_t nbytes) {
return c10::GetCPUAllocator()->raw_allocate(nbytes);
}
TA_EXPORTS void CPURawDelete(void* ptr) {
c10::GetCPUAllocator()->raw_deallocate(ptr);
}
#ifdef DGL_USE_CUDA
TA_EXPORTS void* RawAlloc(size_t nbytes) {
return c10::cuda::CUDACachingAllocator::raw_alloc(nbytes);
TA_EXPORTS void* CUDARawAlloc(size_t nbytes, cudaStream_t stream) {
return c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(
nbytes, stream);
}
TA_EXPORTS void RawDelete(void* ptr) {
TA_EXPORTS void CUDARawDelete(void* ptr) {
c10::cuda::CUDACachingAllocator::raw_delete(ptr);
}
#endif // DGL_USE_CUDA
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment