"...pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "72b3e078af32b9b237d024526f8f25fffe088c03"
Unverified Commit 5e64481b authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[Graphbolt][CUDA] Simplify allocator class by discarding tensors (#6654)

parent 2968c9b2
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#ifndef GRAPHBOLT_CUDA_COMMON_H_ #ifndef GRAPHBOLT_CUDA_COMMON_H_
#define GRAPHBOLT_CUDA_COMMON_H_ #define GRAPHBOLT_CUDA_COMMON_H_
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAException.h> #include <c10/cuda/CUDAException.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <torch/script.h> #include <torch/script.h>
...@@ -34,29 +35,22 @@ namespace cuda { ...@@ -34,29 +35,22 @@ namespace cuda {
* int_array.get() gives the raw pointer. * int_array.get() gives the raw pointer.
*/ */
class CUDAWorkspaceAllocator { struct CUDAWorkspaceAllocator {
using TensorPtrMapType = std::unordered_map<void*, torch::Tensor>;
std::shared_ptr<TensorPtrMapType> ptr_map_;
public:
// Required by thrust to satisfy allocator requirements. // Required by thrust to satisfy allocator requirements.
using value_type = char; using value_type = char;
explicit CUDAWorkspaceAllocator() explicit CUDAWorkspaceAllocator() { at::globalContext().lazyInitCUDA(); }
: ptr_map_(std::make_shared<TensorPtrMapType>()) {}
CUDAWorkspaceAllocator& operator=(const CUDAWorkspaceAllocator&) = default; CUDAWorkspaceAllocator& operator=(const CUDAWorkspaceAllocator&) = default;
void operator()(void* ptr) const { ptr_map_->erase(ptr); } void operator()(void* ptr) const {
c10::cuda::CUDACachingAllocator::raw_delete(ptr);
}
// Required by thrust to satisfy allocator requirements. // Required by thrust to satisfy allocator requirements.
value_type* allocate(std::ptrdiff_t size) const { value_type* allocate(std::ptrdiff_t size) const {
auto tensor = torch::empty( return reinterpret_cast<value_type*>(
size, torch::TensorOptions() c10::cuda::CUDACachingAllocator::raw_alloc(size));
.dtype(torch::kByte)
.device(c10::DeviceType::CUDA));
ptr_map_->operator[](tensor.data_ptr()) = tensor;
return reinterpret_cast<value_type*>(tensor.data_ptr());
} }
// Required by thrust to satisfy allocator requirements. // Required by thrust to satisfy allocator requirements.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment