common.h 5.63 KB
Newer Older
1
2
/**
 *  Copyright (c) 2017-2023 by Contributors
3
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
4
5
6
7
8
9
 * @file cuda/common.h
 * @brief Common utilities for CUDA
 */
#ifndef GRAPHBOLT_CUDA_COMMON_H_
#define GRAPHBOLT_CUDA_COMMON_H_

10
#include <c10/cuda/CUDACachingAllocator.h>
11
12
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime.h>
13
14
15
16
#include <torch/script.h>

#include <memory>
#include <unordered_map>
17
18
19
20

namespace graphbolt {
namespace cuda {

21
22
23
24
/**
 * @brief This class is designed to allocate workspace storage
 * and to get a nonblocking thrust execution policy
 * that uses torch's CUDA memory pool and the current cuda stream:
25
 *
26
27
28
 * cuda::CUDAWorkspaceAllocator allocator;
 * const auto stream = torch::cuda::getDefaultCUDAStream();
 * const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);
29
 *
30
 * Now, one can pass exec_policy to thrust functions
31
 *
32
33
 * To get an integer array of size 1000 whose lifetime is managed by unique_ptr,
 * use:
34
 *
35
 * auto int_array = allocator.AllocateStorage<int>(1000);
36
 *
37
38
 * int_array.get() gives the raw pointer.
 */
39
struct CUDAWorkspaceAllocator {
40
41
42
  // Required by thrust to satisfy allocator requirements.
  using value_type = char;

43
  explicit CUDAWorkspaceAllocator() { at::globalContext().lazyInitCUDA(); }
44
45
46

  CUDAWorkspaceAllocator& operator=(const CUDAWorkspaceAllocator&) = default;

47
48
49
  void operator()(void* ptr) const {
    c10::cuda::CUDACachingAllocator::raw_delete(ptr);
  }
50
51
52

  // Required by thrust to satisfy allocator requirements.
  value_type* allocate(std::ptrdiff_t size) const {
53
54
    return reinterpret_cast<value_type*>(
        c10::cuda::CUDACachingAllocator::raw_alloc(size));
55
56
57
58
59
60
61
62
63
64
65
66
67
  }

  // Required by thrust to satisfy allocator requirements.
  void deallocate(value_type* ptr, std::size_t) const { operator()(ptr); }

  template <typename T>
  std::unique_ptr<T, CUDAWorkspaceAllocator> AllocateStorage(
      std::size_t size) const {
    return std::unique_ptr<T, CUDAWorkspaceAllocator>(
        reinterpret_cast<T*>(allocate(sizeof(T) * size)), *this);
  }
};

68
inline auto GetAllocator() { return CUDAWorkspaceAllocator{}; }
69
70

inline auto GetCurrentStream() { return c10::cuda::getCurrentCUDAStream(); }
71

72
73
74
75
76
77
78
79
80
81
template <typename T>
inline bool is_zero(T size) {
  return size == 0;
}

template <>
inline bool is_zero<dim3>(dim3 size) {
  return size.x == 0 || size.y == 0 || size.z == 0;
}

82
83
#define CUDA_CALL(func) C10_CUDA_CHECK((func))

84
85
86
87
88
89
90
91
92
#define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, stream, ...)    \
  {                                                                   \
    if (!graphbolt::cuda::is_zero((nblks)) &&                         \
        !graphbolt::cuda::is_zero((nthrs))) {                         \
      (kernel)<<<(nblks), (nthrs), (shmem), (stream)>>>(__VA_ARGS__); \
      C10_CUDA_KERNEL_LAUNCH_CHECK();                                 \
    }                                                                 \
  }

93
94
95
96
97
98
99
100
101
102
// This includes all integer, float and boolean types.
#define GRAPHBOLT_DISPATCH_CASE_ALL_TYPES(...)            \
  AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__)                 \
  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)     \
  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
  AT_DISPATCH_CASE(at::ScalarType::Bool, __VA_ARGS__)

#define GRAPHBOLT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
  AT_DISPATCH_SWITCH(TYPE, NAME, GRAPHBOLT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__))

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#define GRAPHBOLT_DISPATCH_ELEMENT_SIZES(element_size, name, ...)             \
  [&] {                                                                       \
    switch (element_size) {                                                   \
      case 1: {                                                               \
        using element_size_t = uint8_t;                                       \
        return __VA_ARGS__();                                                 \
      }                                                                       \
      case 2: {                                                               \
        using element_size_t = uint16_t;                                      \
        return __VA_ARGS__();                                                 \
      }                                                                       \
      case 4: {                                                               \
        using element_size_t = uint32_t;                                      \
        return __VA_ARGS__();                                                 \
      }                                                                       \
      case 8: {                                                               \
        using element_size_t = uint64_t;                                      \
        return __VA_ARGS__();                                                 \
      }                                                                       \
      case 16: {                                                              \
        using element_size_t = float4;                                        \
        return __VA_ARGS__();                                                 \
      }                                                                       \
      default:                                                                \
        TORCH_CHECK(false, name, " with the element_size is not supported!"); \
        using element_size_t = uint8_t;                                       \
        return __VA_ARGS__();                                                 \
    }                                                                         \
  }()

133
134
135
}  // namespace cuda
}  // namespace graphbolt
#endif  // GRAPHBOLT_CUDA_COMMON_H_