utils.h 1.6 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
#pragma once

#include <torch/extension.h>

#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
rusty1s's avatar
rusty1s committed
7
8
9
10
11

#define AT_DISPATCH_DEGREE_TYPES(degree, ...)                                  \
  [&] {                                                                        \
    switch (degree) {                                                          \
    case 1: {                                                                  \
rusty1s's avatar
rusty1s committed
12
      static constexpr int64_t DEGREE = 1;                                     \
rusty1s's avatar
rusty1s committed
13
14
15
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case 2: {                                                                  \
rusty1s's avatar
rusty1s committed
16
      static constexpr int64_t DEGREE = 2;                                     \
rusty1s's avatar
rusty1s committed
17
18
19
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case 3: {                                                                  \
rusty1s's avatar
rusty1s committed
20
      static constexpr int64_t DEGREE = 3;                                     \
rusty1s's avatar
rusty1s committed
21
22
23
24
25
26
      return __VA_ARGS__();                                                    \
    }                                                                          \
    default:                                                                   \
      AT_ERROR("Basis degree not implemented");                                \
    }                                                                          \
  }()