pytorch_cpp_helper.hpp 755 Bytes
Newer Older
1
2
#ifndef PYTORCH_CPP_HELPER
#define PYTORCH_CPP_HELPER
3
#include <torch/types.h>
4
5
6
7
8
9
10

#include <vector>

using namespace at;

#define CHECK_CUDA(x) \
  TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
11
12
#define CHECK_MLU(x) \
  TORCH_CHECK(x.device().type() == at::kMLU, #x " must be a MLU tensor")
13
14
15
16
17
18
19
#define CHECK_CPU(x) \
  TORCH_CHECK(!x.device().is_cuda(), #x " must be a CPU tensor")
#define CHECK_CONTIGUOUS(x) \
  TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CUDA_INPUT(x) \
  CHECK_CUDA(x);            \
  CHECK_CONTIGUOUS(x)
20
21
22
#define CHECK_MLU_INPUT(x) \
  CHECK_MLU(x);            \
  CHECK_CONTIGUOUS(x)
23
24
25
26
27
#define CHECK_CPU_INPUT(x) \
  CHECK_CPU(x);            \
  CHECK_CONTIGUOUS(x)

#endif  // PYTORCH_CPP_HELPER