"tests/python/common/test_heterograph-specialization.py" did not exist on "b9631912dbe8e10198e909f15d13370afca789ce"
torch.cpp 1.69 KB
Newer Older
1
/*!
2
 *  Copyright (c) 2020-2022 by Contributors
3
4
5
6
 * \file torch/torch.cpp
 * \brief Implementation of PyTorch adapter library.
 */

7
#include <tensoradapter_exports.h>
8
#include <c10/core/CPUAllocator.h>
9
#ifdef DGL_USE_CUDA
10
11
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAStream.h>
12
#include <c10/cuda/CUDACachingAllocator.h>
13
#include <c10/cuda/CUDAStream.h>
14
#include <cuda_runtime.h>
15
#endif  // DGL_USE_CUDA
16

17
18
19
20
namespace tensoradapter {

extern "C" {

21
22
23
24
25
26
TA_EXPORTS void* CPURawAlloc(size_t nbytes) {
  return c10::GetCPUAllocator()->raw_allocate(nbytes);
}

TA_EXPORTS void CPURawDelete(void* ptr) {
  c10::GetCPUAllocator()->raw_deallocate(ptr);
27
28
}

29
#ifdef DGL_USE_CUDA
30
TA_EXPORTS void* CUDARawAlloc(size_t nbytes, cudaStream_t stream) {
31
  at::globalContext().lazyInitCUDA();
32
33
  return c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(
    nbytes, stream);
34
35
}

36
TA_EXPORTS void CUDARawDelete(void* ptr) {
37
38
  c10::cuda::CUDACachingAllocator::raw_delete(ptr);
}
39
40
41
42

TA_EXPORTS cudaStream_t CUDACurrentStream() {
  return at::cuda::getCurrentCUDAStream();
}
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

TA_EXPORTS void RecordStream(void* ptr, cudaStream_t stream, int device_id) {
  c10::DataPtr data_ptr{
    ptr, ptr, &c10::cuda::CUDACachingAllocator::raw_delete,
    c10::Device(c10::DeviceType::CUDA, device_id)};
  c10::cuda::CUDACachingAllocator::recordStream(
    data_ptr,
    // getStreamFromExternal doesn't exist before PyTorch 1.10, just copy it here
    c10::cuda::CUDAStream(
      c10::cuda::CUDAStream::UNCHECKED,
      c10::Stream(
          c10::Stream::UNSAFE,
          c10::Device(c10::DeviceType::CUDA, device_id),
          reinterpret_cast<int64_t>(stream)))
  );
  data_ptr.release_context();
}
60
61
#endif  // DGL_USE_CUDA

62
63
64
};

};  // namespace tensoradapter