utils.cuh 1.67 KB
Newer Older
quyuanhao123's avatar
quyuanhao123 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#pragma once

#include <torch/extension.h>

#define CHECK_CUDA(x)                                                          \
  AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")

#define AT_DISPATCH_DEGREE_TYPES(degree, ...)                                  \
  [&] {                                                                        \
    switch (degree) {                                                          \
    case 1: {                                                                  \
      const int64_t DEGREE = 1;                                                \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case 2: {                                                                  \
      const int64_t DEGREE = 2;                                                \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case 3: {                                                                  \
      const int64_t DEGREE = 3;                                                \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    default:                                                                   \
      AT_ERROR("Basis degree not implemented");                                \
    }                                                                          \
  }()