common.h 2.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
/*!
 *  Copyright (c) 2019 by Contributors
 * \file kernel/common.h
 * \brief Kernel common utilities
 */
#ifndef DGL_KERNEL_COMMON_H_
#define DGL_KERNEL_COMMON_H_

#include <dgl/runtime/ndarray.h>

#include <cstdint>
12
#include "../c_api_common.h"
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

namespace dgl {
namespace kernel {

#ifdef __CUDACC__
#define DGLDEVICE __device__
#define DGLINLINE __forceinline__
#else
#define DGLDEVICE
#define DGLINLINE inline
#endif  // __CUDACC__

// Macro for dispatch device flag to template function calls
#ifdef DGL_USE_CUDA
#define DGL_XPU_SWITCH(val, Method, ...)  \
  if (val == kDLCPU) {                    \
    Method<kDLCPU>(__VA_ARGS__);          \
  } else if (val == kDLGPU) {             \
    Method<kDLGPU>(__VA_ARGS__);          \
  } else {                                \
    LOG(FATAL) << "Unsupported device type: " << val;  \
  }
#else  // DGL_USE_CUDA
#define DGL_XPU_SWITCH(val, Method, ...)  \
  if (val == kDLCPU) {                    \
    Method<kDLCPU>(__VA_ARGS__);          \
  } else {                                \
    LOG(FATAL) << "Unsupported device type: " << val;  \
  }
#endif  // DGL_USE_CUDA

// MSVC does not expand __VA_ARGS__ correctly, and needs this expand hack
#define MSVC_EXPAND(x) x

// Macro for dispatch dtype flag to template argument. Currently only
// support float32.
49
50
51
52
53
54
#define DGL_DTYPE_SWITCH(val, DType, ...)       \
  if (val.code == kDLFloat && val.bits == 32) { \
    typedef float DType;                        \
    { __VA_ARGS__ }                             \
  } else {                                      \
    LOG(FATAL) << "Unsupported dtype: " << val; \
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
  }

// Macro for unrolling with data type arguments.
#define GEN_DTYPE(GEN, ...)  \
  MSVC_EXPAND(GEN(__VA_ARGS__, float))

// Macro for dispatch index nbits to template argument.
#ifdef __CUDACC__
#define DGL_IDX_TYPE_SWITCH(bits, Idx, ...)            \
  if (bits == 32) {                                    \
    typedef int32_t Idx;                               \
    {__VA_ARGS__}                                      \
  } else {                                             \
    LOG(FATAL) << "Unsupported idx bits: " << bits;    \
  }
#else
#define DGL_IDX_TYPE_SWITCH(bits, Idx, ...)            \
  if (bits == 32) {                                    \
    typedef int32_t Idx;                               \
    {__VA_ARGS__}                                      \
  } else if (bits == 64) {                             \
    typedef int64_t Idx;                               \
    {__VA_ARGS__}                                      \
  } else {                                             \
    LOG(FATAL) << "Unsupported idx bits: " << bits;    \
  }
#endif

}  // namespace kernel
}  // namespace dgl

#endif  // DGL_KERNEL_COMMON_H_