cuda_view.cu 1.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include <torch/all.h>
#include <torch/cuda.h>
#include <cuda_runtime.h>

// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned
// memory, and that UVA (Unified Virtual Addressing) is enabled.
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
  TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");

  // Get raw host pointer from CPU tensor
  void* host_ptr = cpu_tensor.data_ptr();

  // Get a device pointer corresponding to the pinned host memory
  void* device_ptr = nullptr;
  cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
  TORCH_CHECK(err == cudaSuccess,
              "cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));

  // We'll use the same sizes, strides, and dtype as the CPU tensor.
  // TODO: check if layout is respected.
  auto sizes = cpu_tensor.sizes();
  auto strides = cpu_tensor.strides();
  auto options = cpu_tensor.options().device(torch::kCUDA);

25
26
  // use default no-op deleter, since the memory is owned by the original CPU
  // tensor
27
  torch::Tensor cuda_tensor =
28
      torch::from_blob(device_ptr, sizes, strides, options);
29
30
31
32
33
34

  TORCH_CHECK(cuda_tensor.device().is_cuda(),
              "Resulting tensor is not on CUDA device");

  return cuda_tensor;
}