utils.h 1.6 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
#pragma once

#include <torch/extension.h>

#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

#define AT_DISPATCH_DEGREE_TYPES(degree, ...)                                  \
  [&] {                                                                        \
    switch (degree) {                                                          \
    case 1: {                                                                  \
      const int64_t DEGREE = 1;                                                \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case 2: {                                                                  \
      const int64_t DEGREE = 2;                                                \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    case 3: {                                                                  \
      const int64_t DEGREE = 3;                                                \
      return __VA_ARGS__();                                                    \
    }                                                                          \
    default:                                                                   \
      AT_ERROR("Basis degree not implemented");                                \
    }                                                                          \
  }()